Reputation: 863
I have a pyspark dataframe which has 4 columns.
+-----+-----+-----+-----+
|col1 |col2 |col3 |col4 |
+-----+-----+-----+-----+
|10 | 5.0 | 5.0 | 5.0 |
|20 | 5.0 | 5.0 | 5.0 |
|null | 5.0 | 5.0 | 5.0 |
|30 | 5.0 | 5.0 | 6.0 |
|40 | 5.0 | 5.0 | 7.0 |
|null | 5.0 | 5.0 | 8.0 |
|50 | 5.0 | 6.0 | 9.0 |
|60 | 5.0 | 7.0 | 10.0|
|null | 5.0 | 8.0 | 11.0|
|70 | 6.0 | 9.0 | 12.0|
|80 | 7.0 | 10.0| 13.0|
|null | 8.0 | 11.0| 14.0|
+-----+-----+-----+-----+
Some values in the col1 are missing and I want to set those missing values based on the following approach:
try to set it based on the average of values of col1 of the records that have the same col2,col3,col4 values
if there is no such record, set it based on the average of values of col1 of the records that have the same col2,col3 values
if there is still no such record, set it based on the average of values of col1 of the records that have the same col2 values
If none of the above could be found, set it to the average of all other non-missing values in col1
For example, given the dataframe above, only the first two rows have the same col2, col3, col4 values as row 3. So the null value in col1 for row 3 should be replaced by the average of col1 values in row 1 and 2. For null value in col1 in row 6, it will be the average of col1 values in row 4 and 5, because only those rows have the same col2 and col3 values and not the same col4 values as row 6. And the list goes on...
+-----+-----+-----+-----+
|col1 |col2 |col3 |col4 |
+-----+-----+-----+-----+
|10 | 5.0 | 5.0 | 5.0 |
|20 | 5.0 | 5.0 | 5.0 |
|15 | 5.0 | 5.0 | 5.0 |
|30 | 5.0 | 5.0 | 6.0 |
|40 | 5.0 | 5.0 | 7.0 |
|25 | 5.0 | 5.0 | 8.0 |
|50 | 5.0 | 6.0 | 9.0 |
|60 | 5.0 | 7.0 | 10.0|
|35 | 5.0 | 8.0 | 11.0|
|70 | 6.0 | 9.0 | 12.0|
|80 | 7.0 | 10.0| 13.0|
|45 | 8.0 | 11.0| 14.0|
+-----+-----+-----+-----+
What's the best way to do this?
Upvotes: 0
Views: 316
Reputation: 15258
I do not find exactly the same values than you do but, based on what you said, the code would be something like this :
from pyspark.sql import functions as F
df_2_3_4 = df.groupBy("col2", "col3", "col4").agg(
F.avg("col1").alias("avg_col1_by_2_3_4")
)
df_2_3 = df.groupBy("col2", "col3").agg(F.avg("col1").alias("avg_col1_by_2_3"))
df_2 = df.groupBy("col2").agg(F.avg("col1").alias("avg_col1_by_2"))
avg_value = df.groupBy().agg(F.avg("col1").alias("avg_col1")).first().avg_col1
df_out = (
df.join(df_2_3_4, how="left", on=["col2", "col3", "col4"])
.join(df_2_3, how="left", on=["col2", "col3"])
.join(df_2, how="left", on=["col2"])
)
df_out.select(
F.coalesce(
F.col("col1"),
F.col("avg_col1_by_2_3_4"),
F.col("avg_col1_by_2_3"),
F.col("avg_col1_by_2"),
F.lit(avg_value),
).alias("col1"),
"col2",
"col3",
"col4",
).show()
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
|10.0| 5.0| 5.0| 5.0|
|15.0| 5.0| 5.0| 5.0|
|20.0| 5.0| 5.0| 5.0|
|30.0| 5.0| 5.0| 6.0|
|40.0| 5.0| 5.0| 7.0|
|25.0| 5.0| 5.0| 8.0|
|50.0| 5.0| 6.0| 9.0|
|60.0| 5.0| 7.0|10.0|
|35.0| 5.0| 8.0|11.0|
|70.0| 6.0| 9.0|12.0|
|80.0| 7.0|10.0|13.0|
|45.0| 8.0|11.0|14.0|
+----+----+----+----+
Upvotes: 1