Hanan Shteingart
Hanan Shteingart

Reputation: 9078

Pyspark data frame aggregation with user defined function

How can I use the 'groupby(key).agg(' with a user defined functions? Specifically I need a list of all unique values per key [not count].

Upvotes: 2

Views: 3018

Answers (2)

kmader
kmader

Reputation: 1369

The collect_set and collect_list (for unordered and ordered results respectively) can be used to post-process groupby results. Starting out with a simple spark dataframe

    df = sqlContext.createDataFrame(
    [('first-neuron', 1, [0.0, 1.0, 2.0]), 
    ('first-neuron', 2, [1.0, 2.0, 3.0, 4.0])], 
    ("neuron_id", "time", "V"))

Let's say the goal is to return the longest length of the V list for each neuron (grouped by name)

    from pyspark.sql import functions as F
    grouped_df = tile_img_df.groupby('neuron_id').agg(F.collect_list('V'))

We have now grouped the V lists into a list of lists. Since we wanted the longest length we can run

    import pyspark.sql.types as sq_types
    len_udf = F.udf(lambda v_list: int(np.max([len(v) in v_list])),
                      returnType = sq_types.IntegerType())
    max_len_df = grouped_df.withColumn('max_len',len_udf('collect_list(V)'))

To get the max_len column added with the maximum length of the V list

Upvotes: 2

Hanan Shteingart
Hanan Shteingart

Reputation: 9078

I found pyspark.sql.functions.collect_set(col) which does the job I wanted.

Upvotes: 1

Related Questions