Reputation: 9078
How can I use the 'groupby(key).agg(' with a user defined functions? Specifically I need a list of all unique values per key [not count].
Upvotes: 2
Views: 3018
Reputation: 1369
The collect_set
and collect_list
(for unordered and ordered results respectively) can be used to post-process groupby results. Starting out with a simple spark dataframe
df = sqlContext.createDataFrame(
[('first-neuron', 1, [0.0, 1.0, 2.0]),
('first-neuron', 2, [1.0, 2.0, 3.0, 4.0])],
("neuron_id", "time", "V"))
Let's say the goal is to return the longest length of the V list for each neuron (grouped by name)
from pyspark.sql import functions as F
grouped_df = tile_img_df.groupby('neuron_id').agg(F.collect_list('V'))
We have now grouped the V lists into a list of lists. Since we wanted the longest length we can run
import pyspark.sql.types as sq_types
len_udf = F.udf(lambda v_list: int(np.max([len(v) in v_list])),
returnType = sq_types.IntegerType())
max_len_df = grouped_df.withColumn('max_len',len_udf('collect_list(V)'))
To get the max_len column added with the maximum length of the V list
Upvotes: 2
Reputation: 9078
I found pyspark.sql.functions.collect_set(col)
which does the job I wanted.
Upvotes: 1