Reputation: 15
I am trying to figure out how to translate my Pandas-utilising function to PySpark.
I have a Pandas DataFrame like this:
+---+----+
|num| val|
+---+----+
| 1| 0.0|
| 2| 0.0|
| 3|48.6|
| 4|49.0|
| 5|48.7|
| 6|49.1|
| 7|74.5|
| 8|48.7|
| 9| 0.0|
| 10|49.0|
| 11| 0.0|
| 12| 0.0|
+---+----+
The code in the snippet below is fairly simple. It goes forwards till finds a non-zero value. If there is none of them it goes backwards for the same purpose
def next_non_zero(data,i,column):
for j in range(i+1,len(data[column])):
res = data[column].iloc[j]
if res !=0:
return res
for j in range(i-1,0,-1):
res = data[column].iloc[j]
if res !=0:
return res
def fix_zero(data, column):
for i, row in data.iterrows():
if (row[column] == 0):
data.at[i,column] = next_non_zero(data,i,column)
So as a result I expect to see
+---+----+
|num| val|
+---+----+
| 1|48.6|
| 2|48.6|
| 3|48.6|
| 4|49.0|
| 5|48.7|
| 6|49.1|
| 7|74.5|
| 8|48.7|
| 9|49.0|
| 10|49.0|
| 11|49.0|
| 12|49.0|
+---+----+
So I do understand that in PySpark I have to create a new column with the desired result and replace an existing column using withColumn() for example. However, I do not understand how to properly iterate through a DataFrame.
I am trying to use functions over Window:
my_window = Window.partitionBy().orderBy('num')
df = df.withColumn('new_val', F.when(df.val==0,F.lead(df.val).over(my_window)).
otherwise(F.lag(df.val).over(my_window))
Obviously, it does not provide me with the desired result as it iterates only once. So I tried to write some udf recursion like
def fix_zero(param):
return F.when(F.lead(param).over(my_window)!=0,F.lead(param).over(my_window)).
otherwise(fix_zero(F.lead(param).over(my_window)))
spark_udf = udf(fix_zero, DoubleType())
df = df.withColumn('new_val', F.when(df.val!=0, df.val).otherwise(fix_zero('val')))
I got
RecursionError: maximum recursion depth exceeded in comparison
I suspect that this is because I pass into recursion not a row but a result of lead() Anyway, I am totally stuck on this hurdle at this moment and would deeply appreciate any advice
Upvotes: 1
Views: 161
Reputation: 4284
There is a way with Window to go through all preceeding (or all following rows) until you reach a non-null value.
So my first step was to replace all 0 values by null
Recreating your dataframe:
values = [
(1, 0.0),
(2,0.0),
(3,48.6),
(4,49.0),
(5,48.7),
(6,49.1),
(7, 74.5),
(8,48.7),
(9,0.0),
(10,49.0),
(11,0.0),
(12,0.0)
]
df = spark.createDataFrame(values, ['num','val'])
Replacing 0s with null
from pyspark.sql.functions import when, lit, col
df= df.withColumn('val_null', when(col('val') != 0.0,col('val')))
Then define the windows, which combined with first and null, will allow us to get last non null value before row and first non null value after row
from pyspark.sql import Window
from pyspark.sql.functions import last,first,coalesce
windowForward = Window.rowsBetween(Window.unboundedPreceding, Window.currentRow)
ffilled_column = last(df['val_null'], ignorenulls=True).over(windowForward)
windowBackward = Window.rowsBetween(Window.currentRow,Window.unboundedFollowing)
bfilled_column = first(df['val_null'], ignorenulls=True).over(windowBackward)
# creating new columns in df
df =df.withColumn('ffill',ffilled_column).withColumn('bfill',bfilled_column)
# replace null with bfill if bfill is not null otherwise fill with ffill
df =df.withColumn('val_full',coalesce('bfill','ffill'))
Using this technique we arrive at your expected output in column 'val_full'
+---+----+--------+-----+-----+--------+
|num| val|val_null|ffill|bfill|val_full|
+---+----+--------+-----+-----+--------+
| 1| 0.0| null| null| 48.6| 48.6|
| 2| 0.0| null| null| 48.6| 48.6|
| 3|48.6| 48.6| 48.6| 48.6| 48.6|
| 4|49.0| 49.0| 49.0| 49.0| 49.0|
| 5|48.7| 48.7| 48.7| 48.7| 48.7|
| 6|49.1| 49.1| 49.1| 49.1| 49.1|
| 7|74.5| 74.5| 74.5| 74.5| 74.5|
| 8|48.7| 48.7| 48.7| 48.7| 48.7|
| 9| 0.0| null| 48.7| 49.0| 49.0|
| 10|49.0| 49.0| 49.0| 49.0| 49.0|
| 11| 0.0| null| 49.0| null| 49.0|
| 12| 0.0| null| 49.0| null| 49.0|
+---+----+--------+-----+-----+--------+
Upvotes: 2