Reputation: 45
I have a spark dataframe:
df = spark.createDataFrame([(10, "Hyundai"), (20, "alpha") ,(70,'Audio'), (1000,'benz'), (50,'Suzuki'),(60,'Lambo'),(30,'Bmw')],["Cars", "Brand"])
Now I want to find outliers, for that I used IQR and got upper and lower values like below and found the outlier:
lower, upper = -55.0 145.0
outliers= df.filter((df['Cars'] > upper) | (df['Cars'] < lower))
Cars Brand
1000 benz
Now I want to find the mean excluding the outliers, to find that I have used function and when but I am getting a error like this
"TypeError: 'Column' object is not callable"
from pyspark.sql import functions as fun
mean = df.select(fun.when((df['Cars'] > upper) | (df['Cars'] < lower), fun.mean(df['Cars'].alias('mean')).collect()[0]['mean']))
print(mean)
Is my code wrong or is there any better way to do it?
Upvotes: 1
Views: 2591
Reputation: 42352
I think you don't need to use when
. You can just do a filter and aggregate the mean:
import pyspark.sql.functions as F
mean = df.filter((df['Cars'] <= upper) & (df['Cars'] >= lower)).agg(F.mean('cars').alias('mean'))
mean.show()
+----+
|mean|
+----+
|40.0|
+----+
If you want to use when
, you can use conditional aggregation:
mean = df.agg(F.mean(F.when((df['Cars'] <= upper) & (df['Cars'] >= lower), df['Cars'])).alias('mean'))
mean.show()
+----+
|mean|
+----+
|40.0|
+----+
To collect to a variable, you can use collect:
mean_collected = mean.collect()[0][0]
Upvotes: 2