Pysdm
Pysdm

Reputation: 45

How to get the mean in pyspark?

I have a spark dataframe:

df = spark.createDataFrame([(10, "Hyundai"), (20, "alpha") ,(70,'Audio'), (1000,'benz'), (50,'Suzuki'),(60,'Lambo'),(30,'Bmw')],["Cars", "Brand"])

Now I want to find outliers, for that I used IQR and got upper and lower values like below and found the outlier:

lower, upper = -55.0 145.0
outliers= df.filter((df['Cars'] > upper) | (df['Cars'] < lower))
Cars    Brand
1000    benz

Now I want to find the mean excluding the outliers, to find that I have used function and when but I am getting a error like this

"TypeError: 'Column' object is not callable"

from pyspark.sql import functions as fun
mean = df.select(fun.when((df['Cars'] > upper) | (df['Cars'] < lower), fun.mean(df['Cars'].alias('mean')).collect()[0]['mean']))
print(mean)

Is my code wrong or is there any better way to do it?

Upvotes: 1

Views: 2591

Answers (1)

mck
mck

Reputation: 42352

I think you don't need to use when. You can just do a filter and aggregate the mean:

import pyspark.sql.functions as F

mean = df.filter((df['Cars'] <= upper) & (df['Cars'] >= lower)).agg(F.mean('cars').alias('mean'))

mean.show()
+----+
|mean|
+----+
|40.0|
+----+

If you want to use when, you can use conditional aggregation:

mean = df.agg(F.mean(F.when((df['Cars'] <= upper) & (df['Cars'] >= lower), df['Cars'])).alias('mean'))

mean.show()
+----+
|mean|
+----+
|40.0|
+----+

To collect to a variable, you can use collect:

mean_collected = mean.collect()[0][0]

Upvotes: 2

Related Questions