Reputation: 97
I have a dataframe with 2 columns like this:
+----+---+
|ptyp|sID|
+----+---+
| CO|111|
| CO|222|
| CO|222|
| CO|222|
| CO|111|
| CD|111|
| CD|222|
| CD|222|
| CD|333|
| CD|333|
| CD|333|
| AG|111|
| AG|111|
| AG|111|
| AG|222|
+----+---+
Given an input n
, for each pytp
, I want to add columns which display the top n
sIDs
(in terms of number of times they appear for that pytp
). I also want to print the number of times each sID
occurs in a column sIDval
for each group.
For example, if n = 2
, I want the output to be like this:
+----+-------+-----------+-------+-----------+
|ptyp|topSID1|topSID1_val|topSID2|topSID2_val|
+----+-------+-----------+-------+-----------+
| AG| 111| 3| 222| 1|
| CO| 222| 3| 111| 2|
| CD| 333| 3| 222| 2|
+----+-------+-----------+-------+-----------+
I am using UDFs
to calculate this:
@F.udf
def mode(x, top_rank):
c = Counter(x).most_common(top_rank)
sz = len(c)
return c[min(top_rank-1, sz-1)][0]
@F.udf
def modeval(x, top_rank):
c = Counter(x).most_common(top_rank)
sz = len(c)
return c[min(top_rank-1, sz-1)][1]
And I am storing the aggregate expressions required for each new column in a list newcols
:
newcols = []
n = 3
for r in range(1, num_ranks+1):
newcols.append([mode(F.collect_list('sID'), F.lit(r)).alias('topSID' + str(r))])
newcols.append([modeval(F.collect_list('sID'), F.lit(r)).alias('topSID' + str(r) +'_val')])
Now if I know that n=3
, I can do it in this way:
df.groupBy('ptyp').agg(*newcols[0], *newcols[1], *newcols[2], \
*newcols[3], *newcols[4], *newcols[5])
Is there a way I can use generalize this for any value of n
?
I tried
df.groupBy('ptyp').agg([*e for e in new_cols])
and
df.groupBy('ptyp').agg((*e for e in new_cols))
and many more variations, but all of them give errors.
Now I have resorted to aggregating one at a time and doing a join but that is very expensive.
Is there a way to do this in the manner I've tried above?
Upvotes: 1
Views: 1884
Reputation: 19345
A list comprehension is the correct way to go, but you can't extract the sublists with *
as they don't have a target. When you call:
print(*newcols[0])
You get an output like:
Column<b'mode(collect_list(sID, 0, 0), 1) AS `topSID1`'>
newcols
is a list of list and you can flatten this list of list by using another list comprehension.
print([item for sublist in newcols for item in sublist][0])
Which returns the same output:
Column<b'mode(collect_list(sID, 0, 0), 1) AS `topSID1`'>
Therefore you can do the following to get the expected behavior:
df.groupBy('ptyp').agg(*[item for sublist in newcols for item in sublist])
Upvotes: 3