Reputation: 35
I am not sure about the correctness of the question itself. The solutions I've found for SQL do not work at Hive SQL or recursion is prohibited. Thus, I'd like to solve the problem in Pyspark and need a solution or at least ideas, how to tackle the problem.
I have an original table which looks like this:
+--------+----------+
|customer|nr_tickets|
+--------+----------+
| A| 3|
| B| 1|
| C| 2|
+--------+----------+
This is how I want the table:
+--------+
|customer|
+--------+
| A|
| A|
| A|
| B|
| C|
| C|
+--------+
Do you have any suggestions?
Thank you very much in advance!
Upvotes: 2
Views: 231
Reputation: 35
in the meanwhile I have also found a solution by myself:
for i in range(1, max_nr_of_tickets):
table = table.filter(F.col('nr_tickets') >= 1).union(test)
table = table.withColumn('nr_tickets', F.col('nr_tickets') - 1)
Explanation: The DFs "table" and "test" are the same at the beginning. So "max_nr_of_tickets" is just the highest "nr_tickets". It works. I am only struggling with the format of the max number:
max_nr_of_tickets = df.select(F.max('nr_tickets')).collect()
I cannot use the result in the for loop's range as it is a list. So I manually enter the highest number. Any ideas how I could get the max_nr_of_tickets into the right format so the loops range will accept it?
Thanks
Upvotes: 0
Reputation: 8410
For Spark2.4+
, use array_repeat
with explode
.
from pyspark.sql import functions as F
df.selectExpr("""explode(array_repeat(customer,cast(nr_tickets as int))) as customer""").show()
#+--------+
#|customer|
#+--------+
#| A|
#| A|
#| A|
#| B|
#| C|
#| C|
#+--------+
Upvotes: 1
Reputation: 20445
You can make a new dataframe by iterating over rows(groups).
1st make list of Rows havingcustomer
(Row(customer=a["customer"])
) repeated nr_tickets
times for that customer using range(int(a["nr_tickets"]))
df_list + [Row(customer=a["customer"]) for T in range(int(a["nr_tickets"]))]
you can store and append these in a list and later make a dataframe with it.
df= spark.createDataFrame(df_list)
Overall,
from pyspark.sql import Row
df_list = []
for a in df.select(["customer","nr_tickets"]).collect():
df_list = df_list + [Row(customer=a["customer"]) for T in range(int(a["nr_tickets"]))]
df= spark.createDataFrame(df_list)
df.show()
you can also do it with list comprehension as
from pyspark.sql import Row
from functools import reduce #python 3
df_list = [
[Row(customer=a["customer"])]*int(a["nr_tickets"])
for a in df.select(["customer","nr_tickets"]).collect()
]
df= spark.createDataFrame(reduce(lambda x,y: x+y,df_list))
df.show()
Produces
+--------+
|customer|
+--------+
| A|
| A|
| A|
| B|
| C|
| C|
+--------+
Upvotes: 0