Reputation: 727
I have a PySpark dataframe-
df1 = spark.createDataFrame([
("u1", 1),
("u1", 2),
("u2", 1),
("u2", 1),
("u2", 1),
("u3", 3),
],
['user_id', 'var1'])
print(df1.printSchema())
df1.show(truncate=False)
Output-
root
|-- user_id: string (nullable = true)
|-- var1: long (nullable = true)
None
+-------+----+
|user_id|var1|
+-------+----+
|u1 |1 |
|u1 |2 |
|u2 |1 |
|u2 |1 |
|u2 |1 |
|u3 |3 |
+-------+----+
Now I want to group all the unique users and show the number of unique var for them in a new column. The desired output would look like-
+-------+---------------+
|user_id|num_unique_var1|
+-------+---------------+
|u1 |2 |
|u2 |1 |
|u3 |1 |
+-------+---------------+
I can use collect_set and make a udf to find the set's length. But I think there must be a better way to do it. How do I achieve this in one line of code?
Upvotes: 0
Views: 302
Reputation: 42332
countDistinct
is surely the best way to do it, but for the sake of completeness, what you said in your question is also possible without using an UDF. You can use size
to get the length of the collect_set
:
df1.groupBy('user_id').agg(F.size(F.collect_set('var1')).alias('num'))
this is helpful if you want to use it in a window function, because countDistinct
is not supported in a window function.
e.g.
df1.withColumn('num', F.countDistinct('var1').over(Window.partitionBy('user_id')))
would fail, but
df1.withColumn('num', F.size(F.collect_set('var1')).over(Window.partitionBy('user_id')))
would work.
Upvotes: 2
Reputation: 727
df1.groupBy('user_id').agg(F.countDistinct('var1').alias('num')).show()
countDistinct is exactly what I needed.
Output-
+-------+---+
|user_id|num|
+-------+---+
| u3| 1|
| u2| 1|
| u1| 2|
+-------+---+
Upvotes: 4