xmatadorx
xmatadorx

Reputation: 35

Reversing Group By in PySpark

I am not sure about the correctness of the question itself. The solutions I've found for SQL do not work at Hive SQL or recursion is prohibited. Thus, I'd like to solve the problem in Pyspark and need a solution or at least ideas, how to tackle the problem.

I have an original table which looks like this:

+--------+----------+
|customer|nr_tickets|
+--------+----------+
|       A|         3|
|       B|         1|
|       C|         2|
+--------+----------+

This is how I want the table:

+--------+
|customer|
+--------+
|       A|
|       A|
|       A|
|       B|
|       C|
|       C|
+--------+

Do you have any suggestions?

Thank you very much in advance!

Upvotes: 2

Views: 231

Answers (3)

xmatadorx
xmatadorx

Reputation: 35

in the meanwhile I have also found a solution by myself:

for i in range(1, max_nr_of_tickets):
    table = table.filter(F.col('nr_tickets') >= 1).union(test)
    table = table.withColumn('nr_tickets', F.col('nr_tickets') - 1)

Explanation: The DFs "table" and "test" are the same at the beginning. So "max_nr_of_tickets" is just the highest "nr_tickets". It works. I am only struggling with the format of the max number:

max_nr_of_tickets = df.select(F.max('nr_tickets')).collect()

I cannot use the result in the for loop's range as it is a list. So I manually enter the highest number. Any ideas how I could get the max_nr_of_tickets into the right format so the loops range will accept it?

Thanks

Upvotes: 0

murtihash
murtihash

Reputation: 8410

For Spark2.4+, use array_repeat with explode.

from pyspark.sql import functions as F

df.selectExpr("""explode(array_repeat(customer,cast(nr_tickets as int))) as customer""").show()

#+--------+
#|customer|
#+--------+
#|       A|
#|       A|
#|       A|
#|       B|
#|       C|
#|       C|
#+--------+

Upvotes: 1

A.B
A.B

Reputation: 20445

You can make a new dataframe by iterating over rows(groups).

1st make list of Rows havingcustomer (Row(customer=a["customer"])) repeated nr_tickets times for that customer using range(int(a["nr_tickets"]))

df_list + [Row(customer=a["customer"]) for T in range(int(a["nr_tickets"]))]

you can store and append these in a list and later make a dataframe with it.

 df= spark.createDataFrame(df_list)

Overall,

from pyspark.sql import Row

df_list = []
for a in df.select(["customer","nr_tickets"]).collect():
  df_list = df_list + [Row(customer=a["customer"]) for T in range(int(a["nr_tickets"]))]
df= spark.createDataFrame(df_list)
df.show()

you can also do it with list comprehension as

from pyspark.sql import Row
from functools import reduce #python 3

df_list = [
[Row(customer=a["customer"])]*int(a["nr_tickets"]) 
for a in df.select(["customer","nr_tickets"]).collect() 
 ]

df= spark.createDataFrame(reduce(lambda x,y: x+y,df_list))
df.show()

Produces

+--------+
|customer|
+--------+
|       A|
|       A|
|       A|
|       B|
|       C|
|       C|
+--------+

Upvotes: 0

Related Questions