sideny_bishop
sideny_bishop

Reputation: 31

Get the distinct elements of a column grouped by another column on a PySpark Dataframe

I have a pyspark DF of ids and purchases which I'm trying to transform for use with FP growth. Currently i have multiple rows for a given id with each row only relating to a single purchase.

I'd like to transform this dataframe to a form where there are two columns, one for id (with a single row per id ) and the second column containing a list of distinct purchases for that id.

I've tried to use a User Defined Function (UDF) to map the distinct purchases onto the distinct ids but I get a "py4j.Py4JException: Method getstate([]) does not exist". Thanks to @Mithril I see that "You can't use sparkSession object , spark.DataFrame object or other Spark distributed objects in udf and pandas_udf, because they are unpickled."

So I've implemented the TERRIBLE approach below (which will work but is not scalable):

#Lets create some fake transactions
customers  = [1,2,3,1,1]
purschases = ['cake','tea','beer','fruit','cake']

# Lets create a spark DF to capture the transactions
transactions = zip(customers,purschases)
spk_df_1 = spark.createDataFrame(list(transactions) , ["id", "item"])

# Lets have a look at the resulting spark dataframe
spk_df_1.show()

# Lets capture the ids and list of their distinct pruschases in a 
# list of tuples
purschases_lst = []
nums1 = []
import pyspark.sql.functions as f

# for each distinct id lets get the list of their distinct pruschases

for id in spark.sql("SELECT distinct(id) FROM TBLdf ").rdd.map(lambda row : row[0]).collect():
   purschase = df.filter(f.col("id") == id).select("item").distinct().rdd.map(lambda row : row[0]).collect()
   nums1.append((id,purschase))


# Lets see what our list of transaction tuples looks like  
print(nums1)
print("\n")

# lets turn the list of transaction tuples into a pandas dataframe
df_pd = pd.DataFrame(nums1)

# Finally lets turn our pandas dataframe into a pyspark Dataframe
df2 = spark.createDataFrame(df_pd)
df2.show()

Output:

+---+-----+
| id| item|
+---+-----+
|  1| cake|
|  2|  tea|
|  3| beer|
|  1|fruit|
|  1| cake|
+---+-----+

[(1, ['fruit', 'cake']), (3, ['beer']), (2, ['tea'])]


+---+-------------+
|  0|            1|
+---+-------------+
|  1|[fruit, cake]|
|  3|       [beer]|
|  2|        [tea]|
+---+-------------+

If anybody has any suggestions I'd greatly appreciate it.

Upvotes: 3

Views: 2003

Answers (1)

cronoik
cronoik

Reputation: 19365

That is a task for collect_set, which creates a set of items without duplicates:

import pyspark.sql.functions as F

#Lets create some fake transactions
customers  = [1,2,3,1,1]
purschases = ['cake','tea','beer','fruit','cake']

# Lets create a spark DF to capture the transactions
transactions = zip(customers,purschases)
spk_df_1 = spark.createDataFrame(list(transactions) , ["id", "item"])
spk_df_1.show()

spk_df_1.groupby('id').agg(F.collect_set('item')).show()

Output:

+---+-----+
| id| item|
+---+-----+
|  1| cake|
|  2|  tea|
|  3| beer|
|  1|fruit|
|  1| cake|
+---+-----+

+---+-----------------+
| id|collect_set(item)|
+---+-----------------+
|  1|    [fruit, cake]|
|  3|           [beer]|
|  2|            [tea]|
+---+-----------------+

Upvotes: 2

Related Questions