LEJ
LEJ

Reputation: 1958

Pyspark - Select the distinct values from each column

I am trying to find all of the distinct values in each column in a dataframe and show in one table.

Example data:

|-----------|-----------|-----------|
|   COL_1   |   COL_2   |   COL_3   | 
|-----------|-----------|-----------|
|     A     |     C     |     D     |
|     A     |     C     |     D     |
|     A     |     C     |     E     |
|     B     |     C     |     E     |
|     B     |     C     |     F     |
|     B     |     C     |     F     |
|-----------|-----------|-----------|

Example output:

|-----------|-----------|-----------|
|   COL_1   |   COL_2   |   COL_3   | 
|-----------|-----------|-----------|
|     A     |     C     |     D     |
|     B     |           |     E     |
|           |           |     F     |
|-----------|-----------|-----------|

Is this even possible? I have been able to do it in separate tables, but it would be much better all in one table.

Any ideas?

Upvotes: 1

Views: 7225

Answers (1)

pault
pault

Reputation: 43544

The simplest thing here would be to use pyspark.sql.functions.collect_set on all of the columns:

import pyspark.sql.functions as f
df.select(*[f.collect_set(c).alias(c) for c in df.columns]).show()
#+------+-----+---------+
#| COL_1|COL_2|    COL_3|
#+------+-----+---------+
#|[B, A]|  [C]|[F, E, D]|
#+------+-----+---------+

Obviously, this returns the data as one row.

If instead you want the output as you wrote in your question (one row per unique value for each column), it's doable but requires quite a bit of pyspark gymnastics (and any solution likely will be much less efficient).

Nevertheless, I present you some options:

Option 1: Explode and Join

You can use pyspark.sql.functions.posexplode to explode the elements in the set of values for each column along with the index in the array. Do this for each column separately and then outer join the resulting list of DataFrames together using functools.reduce:

from functools import reduce 

unique_row = df.select(*[f.collect_set(c).alias(c) for c in df.columns])

final_df = reduce(
    lambda a, b: a.join(b, how="outer", on="pos"),
    (unique_row.select(f.posexplode(c).alias("pos", c)) for c in unique_row.columns)
).drop("pos")

final_df.show()
#+-----+-----+-----+
#|COL_1|COL_2|COL_3|
#+-----+-----+-----+
#|    A| null|    E|
#| null| null|    D|
#|    B|    C|    F|
#+-----+-----+-----+

Option 2: Select by position

First compute the size of the maximum array and store this in a new column max_length. Then select elements from each array if a value exists at that index.

Once again we use pyspark.sql.functions.posexplode but this time it's just to create a column to represent the index in each array to extract.

Finally we use this trick that allows you to use a column value as a parameter.

final_df= df.select(*[f.collect_set(c).alias(c) for c in df.columns])\
    .withColumn("max_length", f.greatest(*[f.size(c) for c in df.columns]))\
    .select("*", f.expr("posexplode(split(repeat(',', max_length-1), ','))"))\
    .select(
        *[
            f.expr(
                "case when size({c}) > pos then {c}[pos] else null end AS {c}".format(c=c))
            for c in df.columns
        ]
    )

final_df.show()
#+-----+-----+-----+
#|COL_1|COL_2|COL_3|
#+-----+-----+-----+
#|    B|    C|    F|
#|    A| null|    E|
#| null| null|    D|
#+-----+-----+-----+

Upvotes: 7

Related Questions