Reputation: 3845
I have a pyspark dataframe where I have grouped data to list with collect_list.
from pyspark.sql.functions import udf, collect_list
from itertools import combinations, chain
#Create Dataframe
df = spark.createDataFrame( [(1,'a'), (1,'b'), (2,'c')] , ["id", "colA"])
df.show()
>>>
+---+----+
| id|colA|
+---+----+
| 1| a|
| 1| b|
| 2| c|
+---+----+
#Group by and collect to list
df = df.groupBy(df.id).agg(collect_list("colA").alias("colAlist"))
df.show()
>>>
+---+--------+
| id|colAList|
+---+--------+
| 1| [a, b]|
| 2| [c]|
+---+--------+
Then I use a function to find all combinations of the list elements to a new list
allsubsets = lambda l: list(chain(*[combinations(l , n) for n in range(1,len(l)+1)]))
df = df.withColumn('colAsubsets',udf(allsubsets)(df['colAList']))
so I would excpect something like
+---+--------------------+
| id| colAsubsets |
+---+--------------------+
| 1| [[a], [b], [a,b]] |
| 2| [[b]] |
+---+--------------------+
but I get:
df.show()
>>>
+---+--------+-----------------------------------------------------------------------------------------+
|id |colAList|colAsubsets |
+---+--------+-----------------------------------------------------------------------------------------+
|1 |[a, b] |[[Ljava.lang.Object;@75e2d657, [Ljava.lang.Object;@7f662637, [Ljava.lang.Object;@b572639]|
|2 |[c] |[[Ljava.lang.Object;@26f67148] |
+---+--------+-----------------------------------------------------------------------------------------+
Any ideas what to do? And then maybe how to flatten the list to different rows?
Upvotes: 3
Views: 1692
Reputation: 41957
All you need to do is to extract the elements from objects created by chain
and combinations
in a flattened way
so changing
allsubsets = lambda l: list(chain(*[combinations(l , n) for n in range(1,len(l)+1)]))
to the following
allsubsets = lambda l: [[z for z in y] for y in chain(*[combinations(l , n) for n in range(1,len(l)+1)])]
should give you
+---+---------+------------------+
|id |colA_list|colAsubsets |
+---+---------+------------------+
|1 |[a, b] |[[a], [b], [a, b]]|
|2 |[c] |[[c]] |
+---+---------+------------------+
I hope the answer is helpful
Upvotes: 3
Reputation: 2477
Improving on @RameshMaharjan answer, in order to flatten the list to different rows:
You have to use explode on an array. You must before specify the type of your udf so it doesn't return a StringType.
from pyspark.sql.functions import explode
from pyspark.sql.types import ArrayType, StringType
allsubsets = lambda l: [[z for z in y] for y in chain(*[combinations(l , n) for n in range(1,len(l)+1)])]
df = df.withColumn('colAsubsets', udf(allsubsets, ArrayType(ArrayType(StringType())))(df['colAList']))
df = df.select('id', explode('colAsubsets'))
Result :
+---+------+
| id| col|
+---+------+
| 1| [a]|
| 1| [b]|
| 1|[a, b]|
| 2| [c]|
+---+------+
Upvotes: 3