Reputation: 1958
I am trying to find all of the distinct values in each column in a dataframe and show in one table.
Example data:
|-----------|-----------|-----------|
| COL_1 | COL_2 | COL_3 |
|-----------|-----------|-----------|
| A | C | D |
| A | C | D |
| A | C | E |
| B | C | E |
| B | C | F |
| B | C | F |
|-----------|-----------|-----------|
Example output:
|-----------|-----------|-----------|
| COL_1 | COL_2 | COL_3 |
|-----------|-----------|-----------|
| A | C | D |
| B | | E |
| | | F |
|-----------|-----------|-----------|
Is this even possible? I have been able to do it in separate tables, but it would be much better all in one table.
Any ideas?
Upvotes: 1
Views: 7225
Reputation: 43544
The simplest thing here would be to use pyspark.sql.functions.collect_set
on all of the columns:
import pyspark.sql.functions as f
df.select(*[f.collect_set(c).alias(c) for c in df.columns]).show()
#+------+-----+---------+
#| COL_1|COL_2| COL_3|
#+------+-----+---------+
#|[B, A]| [C]|[F, E, D]|
#+------+-----+---------+
Obviously, this returns the data as one row.
If instead you want the output as you wrote in your question (one row per unique value for each column), it's doable but requires quite a bit of pyspark gymnastics (and any solution likely will be much less efficient).
Nevertheless, I present you some options:
Option 1: Explode and Join
You can use pyspark.sql.functions.posexplode
to explode the elements in the set of values for each column along with the index in the array. Do this for each column separately and then outer join the resulting list of DataFrames together using functools.reduce
:
from functools import reduce
unique_row = df.select(*[f.collect_set(c).alias(c) for c in df.columns])
final_df = reduce(
lambda a, b: a.join(b, how="outer", on="pos"),
(unique_row.select(f.posexplode(c).alias("pos", c)) for c in unique_row.columns)
).drop("pos")
final_df.show()
#+-----+-----+-----+
#|COL_1|COL_2|COL_3|
#+-----+-----+-----+
#| A| null| E|
#| null| null| D|
#| B| C| F|
#+-----+-----+-----+
Option 2: Select by position
First compute the size of the maximum array and store this in a new column max_length
. Then select elements from each array if a value exists at that index.
Once again we use pyspark.sql.functions.posexplode
but this time it's just to create a column to represent the index in each array to extract.
Finally we use this trick that allows you to use a column value as a parameter.
final_df= df.select(*[f.collect_set(c).alias(c) for c in df.columns])\
.withColumn("max_length", f.greatest(*[f.size(c) for c in df.columns]))\
.select("*", f.expr("posexplode(split(repeat(',', max_length-1), ','))"))\
.select(
*[
f.expr(
"case when size({c}) > pos then {c}[pos] else null end AS {c}".format(c=c))
for c in df.columns
]
)
final_df.show()
#+-----+-----+-----+
#|COL_1|COL_2|COL_3|
#+-----+-----+-----+
#| B| C| F|
#| A| null| E|
#| null| null| D|
#+-----+-----+-----+
Upvotes: 7