Reputation: 8431
Let's say I have a DataFrame
with a column for users and another column for words they've written:
Row(user='Bob', word='hello')
Row(user='Bob', word='world')
Row(user='Mary', word='Have')
Row(user='Mary', word='a')
Row(user='Mary', word='nice')
Row(user='Mary', word='day')
I would like to aggregate the word
column into a vector:
Row(user='Bob', words=['hello','world'])
Row(user='Mary', words=['Have','a','nice','day'])
It seems I can't use any of Sparks grouping functions because they expect a subsequent aggregation step. My use case is that I want to feed these data into Word2Vec
not use other Spark aggregations.
Upvotes: 29
Views: 34372
Reputation: 51
You have a native aggregate function for that, collect_set (docs here).
Then, you could use:
from pyspark.sql import functions as F
df.groupby("user").agg(F.collect_set("word"))
Upvotes: 2
Reputation: 3619
from pyspark.sql import functions as F
df.groupby("user").agg(F.collect_list("word"))
Upvotes: 22
Reputation: 434
As of the spark 2.3 release we now have Pandas UDF(aka Vectorized UDF). The function below will accomplish the OP's task... A benefit of using this function is the order is guaranteed to be preserved. Order is essential in many cases such as time series analysis.
import pandas as pd
import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession, Row
from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import StructType, StructField, ArrayType
spark = SparkSession.builder.appName('test_collect_array_grouped').getOrCreate()
def collect_array_grouped(df, groupbyCols, aggregateCol, outputCol):
"""
Aggregate function: returns a new :class:`DataFrame` such that for a given column, aggregateCol,
in a DataFrame, df, collect into an array the elements for each grouping defined by the groupbyCols list.
The new DataFrame will have, for each row, the grouping columns and an array of the grouped
values from aggregateCol in the outputCol.
:param groupbyCols: list of columns to group by.
Each element should be a column name (string) or an expression (:class:`Column`).
:param aggregateCol: the column name of the column of values to aggregate into an array
for each grouping.
:param outputCol: the column name of the column to output the aggregeted array to.
"""
groupbyCols = [] if groupbyCols is None else groupbyCols
df = df.select(groupbyCols + [aggregateCol])
schema = df.select(groupbyCols).schema
aggSchema = df.select(aggregateCol).schema
arrayField = StructField(name=outputCol, dataType=ArrayType(aggSchema[0].dataType, False))
schema = schema.add(arrayField)
@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def _get_array(pd_df):
vals = pd_df[groupbyCols].iloc[0].tolist()
vals.append(pd_df[aggregateCol].values)
return pd.DataFrame([vals])
return df.groupby(groupbyCols).apply(_get_array)
rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
Row(user='Bob', word='world'),
Row(user='Mary', word='Have'),
Row(user='Mary', word='a'),
Row(user='Mary', word='nice'),
Row(user='Mary', word='day')])
df = spark.createDataFrame(rdd)
collect_array_grouped(df, ['user'], 'word', 'users_words').show()
+----+--------------------+
|user| users_words|
+----+--------------------+
|Mary|[Have, a, nice, day]|
| Bob| [hello, world]|
+----+--------------------+
Upvotes: 2
Reputation: 8431
Thanks to @titipat for giving the RDD solution. I did realize shortly after my post that there is actually a DataFrame solution using collect_set
(or collect_list
):
from pyspark.sql import Row
from pyspark.sql.functions import collect_set
rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
Row(user='Bob', word='world'),
Row(user='Mary', word='Have'),
Row(user='Mary', word='a'),
Row(user='Mary', word='nice'),
Row(user='Mary', word='day')])
df = spark.createDataFrame(rdd)
group_user = df.groupBy('user').agg(collect_set('word').alias('words'))
print(group_user.collect())
>[Row(user='Mary', words=['Have', 'nice', 'day', 'a']), Row(user='Bob', words=['world', 'hello'])]
Upvotes: 47
Reputation: 5389
Here is a solution using rdd
.
from pyspark.sql import Row
rdd = spark.sparkContext.parallelize([Row(user='Bob', word='hello'),
Row(user='Bob', word='world'),
Row(user='Mary', word='Have'),
Row(user='Mary', word='a'),
Row(user='Mary', word='nice'),
Row(user='Mary', word='day')])
group_user = rdd.groupBy(lambda x: x.user)
group_agg = group_user.map(lambda x: Row(**{'user': x[0], 'word': [t.word for t in x[1]]}))
Output from group_agg.collect()
:
[Row(user='Bob', word=['hello', 'world']),
Row(user='Mary', word=['Have', 'a', 'nice', 'day'])]
Upvotes: 6