right_here
right_here

Reputation: 97

How to use a list of aggregate expressions with groupby in pyspark?

I have a dataframe with 2 columns like this:

+----+---+
|ptyp|sID|
+----+---+
|  CO|111|
|  CO|222|
|  CO|222|
|  CO|222|
|  CO|111|
|  CD|111|
|  CD|222|
|  CD|222|
|  CD|333|
|  CD|333|
|  CD|333|
|  AG|111|
|  AG|111|
|  AG|111|
|  AG|222|
+----+---+

Given an input n, for each pytp, I want to add columns which display the top n sIDs (in terms of number of times they appear for that pytp). I also want to print the number of times each sID occurs in a column sIDval for each group.

For example, if n = 2, I want the output to be like this:

+----+-------+-----------+-------+-----------+
|ptyp|topSID1|topSID1_val|topSID2|topSID2_val|
+----+-------+-----------+-------+-----------+
|  AG|    111|          3|    222|          1|
|  CO|    222|          3|    111|          2|
|  CD|    333|          3|    222|          2|
+----+-------+-----------+-------+-----------+

I am using UDFs to calculate this:

@F.udf
def mode(x, top_rank):
    c = Counter(x).most_common(top_rank)
    sz = len(c)
    return c[min(top_rank-1, sz-1)][0]

@F.udf
def modeval(x, top_rank):
    c = Counter(x).most_common(top_rank)
    sz = len(c)
    return c[min(top_rank-1, sz-1)][1]

And I am storing the aggregate expressions required for each new column in a list newcols:

newcols = []
n = 3

for r in range(1, num_ranks+1):
    newcols.append([mode(F.collect_list('sID'), F.lit(r)).alias('topSID' + str(r))])
    newcols.append([modeval(F.collect_list('sID'), F.lit(r)).alias('topSID' + str(r) +'_val')])

Now if I know that n=3, I can do it in this way:

df.groupBy('ptyp').agg(*newcols[0], *newcols[1], *newcols[2], \
                       *newcols[3], *newcols[4], *newcols[5])

Is there a way I can use generalize this for any value of n? I tried

df.groupBy('ptyp').agg([*e for e in new_cols])

and

df.groupBy('ptyp').agg((*e for e in new_cols))

and many more variations, but all of them give errors.

Now I have resorted to aggregating one at a time and doing a join but that is very expensive.

Is there a way to do this in the manner I've tried above?

Upvotes: 1

Views: 1884

Answers (1)

cronoik
cronoik

Reputation: 19345

A list comprehension is the correct way to go, but you can't extract the sublists with * as they don't have a target. When you call:

print(*newcols[0])

You get an output like:

Column<b'mode(collect_list(sID, 0, 0), 1) AS `topSID1`'>

newcols is a list of list and you can flatten this list of list by using another list comprehension.

print([item for sublist in newcols for item in sublist][0])

Which returns the same output:

Column<b'mode(collect_list(sID, 0, 0), 1) AS `topSID1`'>

Therefore you can do the following to get the expected behavior:

df.groupBy('ptyp').agg(*[item for sublist in newcols for item in sublist])

Upvotes: 3

Related Questions