amol desai
amol desai

Reputation: 257

Group column of pyspark dataframe by taking only unique values from two columns

I want group a column based on unique values from two columns of pyspark dataframe. The output of the dataframe should be such that once some value used for groupby and if it is present in another column then it should not repeat.

    |------------------|-------------------|
    |   fruit          |     fruits        | 
    |------------------|-------------------|
    |    apple         |     banana        |
    |    banana        |     apple         |
    |    apple         |     mango         |
    |    orange        |     guava         |
    |    apple         |    pineapple      |
    |    mango         |    apple          |
    |   banana         |     mango         |
    |   banana         |    pineapple      |
    | -------------------------------------|

I have tried to group by using single column and it needs to be modified or some other logic should be required.

df9=final_main.groupBy('fruit').agg(collect_list('fruits').alias('values'))

I am getting following output from above query;

       |------------------|--------------------------------|
       |   fruit          |     values                     | 
       |------------------|--------------------------------|
       |  apple           | ['banana','mango','pineapple'] |
       |  banana          | ['apple']                      |
       |  orange          | ['guava']                      |
       |  mango           | ['apple']                      |
       |------------------|--------------------------------|

But I want following output;

       |------------------|--------------------------------|
       |   fruit          |     values                     | 
       |------------------|--------------------------------|
       |  apple           | ['banana','mango','pineapple'] |
       |  orange          | ['guava']                      |
       |------------------|--------------------------------|

Upvotes: 0

Views: 1635

Answers (1)

absolutelydevastated
absolutelydevastated

Reputation: 1747

This looks like a connected components problem. There are a couple ways you can go about doing this.

1. GraphFrames

You can use the GraphFrames package. Each row of your dataframe defines an edge, and you can just create a graph using df as edges and a dataframe of all the distinct fruits as vertices. Then call the connectedComponents method. You can then manipulate the output to get what you want.

2. Just Pyspark

The second method is a bit of a hack. Create a "hash" for each row like

hashed_df = df.withColumn('hash', F.sort_array(F.array(F.col('fruit'), F.col('fruits'))))

Drop all non-distinct rows for that column

distinct_df = hashed_df.dropDuplicates(['hash'])

Split up the items again

revert_df = distinct_df.withColumn('fruit', F.col('hash')[0]) \
    .withColumn('fruits', F.col('hash')[1])

Group by the first column

grouped_df = revert_df.groupBy('fruit').agg(F.collect_list('fruits').alias('group'))

You might need to "stringify" your hash usingF.concat_ws if Pyspark complains, but the idea is the same.

Upvotes: 1

Related Questions