LucyB
LucyB

Reputation: 31

Groupby and UDF/UDAF in PySpark while maintaining DataFrame structure

I am new to PySpark and struggling with a simple dataframe manipulation. I have a dataframe similar to:

product    period     rating   product_Desc1   product_Desc2 ..... more columns 
a            1         60          foo              xx
a            2         70          foo              xx
a            3         59          foo              xx
b            1         50          bar              yy
b            2         55          bar              yy
c            1         90          foo bar          xy
c            2         100         foo bar          xy

I would like to groupBy product, add columns to calculate arithmetic, geometric and harmonic means of ratings while also maintaining the rest of the columns in the dataframe, which are all consistent across each product.

I have tried to do so with a combination of built in functions and UDF. For example:

a_means = df.groupBy("product").agg(mean("rating").alias("a_mean")
g_means = df.groupBy("product").agg(udf_gmean("rating").alias("g_mean")

where:

def g_mean(x):
  gm = reduce(mul,x)**(1/len(x))
  return gm

udf_gmean = udf(g_mean, FloatType())

I would then join the a_means and g_means output with the original dataframe on product and drop duplicates. However, this method returns an error, for g_means, stating that "rating" is not involved in the groupBy nor is it a user defined aggregation function....

I have also tried using SciPy's gmean module but the error message I get states that the ufunc 'log' is not suitable for the input types, despite all of the rating column being integer type as far as I can see.

There are similar questions on the site but nothing that I can find that seems to fix this issue I have. I would really appreciate the help as it's driving me mad!

Thanks in advance and I should be able to provide any further info quickly today if I haven't provided enough.

It's worth noting that, for efficiency, I am unable to simply convert to Pandas and transform as I would with a Pandas dataframe...and I am using Spark 2.2 and unable to update!

Upvotes: 2

Views: 3682

Answers (2)

LucyB
LucyB

Reputation: 31

A slightly easier way than above using gapply:

from spark_sklearn.group_apply import gapply
from scipy.stats.mstats import gmean
import pandas as pd

def g_mean(_, vals):
    gm = gmean(vals["rating"])
    return pd.DataFrame(data=[gm])

geoSchema = StructType().add("geo_mean", FloatType())

gMeans = gapply(df.groupby("product"), g_mean, geoSchema)

This returns a dataframe which can then be sorted and joined onto the original using:

df_withGeo = df.join(gMeans, ["product"])

And repeat the process for any aggregation type function columns to be added to the original DataFrame...

Upvotes: 1

sramalingam24
sramalingam24

Reputation: 1337

How about something like this

from pyspark.sql.functions import avg
df1 = df.select("product","rating").rdd.map(lambda x: (x[0],(1.0,x[1]*1.0))).reduceByKey(lambda x,y: (x[0]+y[0], x[1]*y[1])).toDF(['product', 'g_mean'])
gdf = df1.select(df1['product'],pow(df1['g_mean._2'],1.0/df1['g_mean._1']).alias("rating_g_mean"))
display(gdf)

+-------+-----------------+
|product|    rating_g_mean|
+-------+-----------------+
|      a|62.81071936240795|
|      b|52.44044240850758|
|      c|94.86832980505137|
+-------+-----------------+


df1 = df.withColumn("h_mean", 1.0/df["rating"])
hdf = df1.groupBy("product").agg(avg(df1["rating"]).alias("rating_mean"), (1.0/avg(df1["h_mean"])).alias("rating_h_mean"))
sdf = hdf.join(gdf, ['product'])
display(sdf)

+-------+-----------+-----------------+-----------------+
|product|rating_mean|    rating_h_mean|    rating_g_mean|
+-------+-----------+-----------------+-----------------+
|      a|       63.0|62.62847514743051|62.81071936240795|
|      b|       52.5|52.38095238095239|52.44044240850758|
|      c|       95.0|94.73684210526315|94.86832980505137|
+-------+-----------+-----------------+-----------------+


fdf = df.join(sdf, ['product'])
display(fdf.sort("product"))


+-------+------+------+-------------+-------------+-----------+-----------------+-----------------+
|product|period|rating|product_Desc1|product_Desc2|rating_mean|    rating_h_mean|    rating_g_mean|
+-------+------+------+-------------+-------------+-----------+-----------------+-----------------+
|      a|     3|    59|          foo|           xx|       63.0|62.62847514743051|62.81071936240795|
|      a|     2|    70|          foo|           xx|       63.0|62.62847514743051|62.81071936240795|
|      a|     1|    60|          foo|           xx|       63.0|62.62847514743051|62.81071936240795|
|      b|     2|    55|          bar|           yy|       52.5|52.38095238095239|52.44044240850758|
|      b|     1|    50|          bar|           yy|       52.5|52.38095238095239|52.44044240850758|
|      c|     2|   100|      foo bar|           xy|       95.0|94.73684210526315|94.86832980505137|
|      c|     1|    90|      foo bar|           xy|       95.0|94.73684210526315|94.86832980505137|
+-------+------+------+-------------+-------------+-----------+-----------------+-----------------+

Upvotes: 1

Related Questions