Reputation: 11
I have one data frame (D1) as follows:
col1 | col2 | col3 | col4
22 | null | 23 | 56
12 | 54 | 22 | 36
48 | null | null | 45
null | 32 | 13 | 6
23 | null | 43 | 8
67 | 54 | 56 | null
null | 32 | 32 | 6
3 | 54 | 64 | 8
67 | 4 | 23 | null
The other data frame (D2):
col_name | value
col 1 | 15
col 2 | 26
col 3 | 38
col 4 | 41
I want to replace the null values in each column of D1 with the values from D2 corresponding to each columns.
So the expected output would be:
col1 | col2 | col3 | col4
22 | 26 | 23 | 56
12 | 54 | 22 | 36
48 | 26 | 38 | 45
15 | 32 | 13 | 6
23 | 26 | 43 | 8
67 | 54 | 56 | 41
15 | 32 | 32 | 6
3 | 54 | 64 | 8
67 | 4 | 23 | 41
I would like to know how to achieve this in PySpark data frames. Cheers!
Upvotes: 1
Views: 1080
Reputation: 13998
IIUC, you can create a column_name:value
mapping and then just do fillna()
on each column:
mapping = { row.col_name.replace(' ',''):row.value for row in D2.collect() }
#{u'col1': 15.0, u'col2': 26.0, u'col3': 38.0, u'col4': 41.0}
# fillna on col1 for testing
D1.fillna(mapping['col1'], subset=['col1']).show()
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
|22.0| NaN|23.0|56.0|
|12.0|54.0|22.0|36.0|
|48.0| NaN| NaN|45.0|
|15.0|32.0|13.0| 6.0|
|23.0| NaN|43.0| 8.0|
|67.0|54.0|56.0| NaN|
|15.0|32.0|32.0| 6.0|
| 3.0|54.0|64.0| 8.0|
|67.0| 4.0|23.0| NaN|
+----+----+----+----+
# use a reduce function to handle all columns
df_new = reduce(lambda d,c: d.fillna(mapping[c], subset=[c]), D1.columns, D1)
Or use list comprehension
from pyspark.sql.functions import isnan, when, col
df_new = D1.select([ when(isnan(c), mapping[c]).otherwise(col(c)).alias(c) for c in D1.columns ])
Note: for StringType columns, replace the above isnan()
with isnull()
Upvotes: 1
Reputation: 7607
This is one approach, but since we are using crossJoin
, it may not be the most efficient, but since the D2
is relatively small, it should be ok. Other way could be udf
.
# Creating the DataFrame
values = [(22,None,23,56),(12,54,22,36),(48,None,None,45),
(None,32,13,6),(23,None,43,8),(67,54,56,None),
(None,32,32,6),(3,54,64,8),(67,4,23,None)]
D1 = sqlContext.createDataFrame(values,['col1','col2','col3','col4'])
D1.show()
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
| 22|null| 23| 56|
| 12| 54| 22| 36|
| 48|null|null| 45|
|null| 32| 13| 6|
| 23|null| 43| 8|
| 67| 54| 56|null|
|null| 32| 32| 6|
| 3| 54| 64| 8|
| 67| 4| 23|null|
+----+----+----+----+
We need the list of columns to iterate upon, so the code below gives that.
list_columns = D1.columns
print(list_columns)
['col1', 'col2', 'col3', 'col4']
Creating the second DataFrame.
D2 = sqlContext.createDataFrame([('col1',15),('col2',26),('col3',38),('col4',41)],['col_name','value'])
D2.show()
+--------+-----+
|col_name|value|
+--------+-----+
| col1| 15|
| col2| 26|
| col3| 38|
| col4| 41|
+--------+-----+
Let's pivot
the DataFrame D2, so that we can append it along all columns.
#Pivoting and then renaming the column
D2_new = D2.groupBy().pivot('col_name').sum('value')
D2_new = D2_new.select(*[col(c).alias(c+'_x') for c in D2_new.columns])
D2_new.show()
+------+------+------+------+
|col1_x|col2_x|col3_x|col4_x|
+------+------+------+------+
| 15| 26| 38| 41|
+------+------+------+------+
Finally using crossJoin, we append them -
# Appending the columns
D1 = D1.crossJoin(D2_new)
D1.show()
+----+----+----+----+------+------+------+------+
|col1|col2|col3|col4|col1_x|col2_x|col3_x|col4_x|
+----+----+----+----+------+------+------+------+
| 22|null| 23| 56| 15| 26| 38| 41|
| 12| 54| 22| 36| 15| 26| 38| 41|
| 48|null|null| 45| 15| 26| 38| 41|
|null| 32| 13| 6| 15| 26| 38| 41|
| 23|null| 43| 8| 15| 26| 38| 41|
| 67| 54| 56|null| 15| 26| 38| 41|
|null| 32| 32| 6| 15| 26| 38| 41|
| 3| 54| 64| 8| 15| 26| 38| 41|
| 67| 4| 23|null| 15| 26| 38| 41|
+----+----+----+----+------+------+------+------+
Once this main DataFrame
is obtained, we can just use simple when-otherwise
construct to do the replacement by running a loop over the list of columns.
# Finally doing the replacement.
for c in list_columns:
D1 = D1.withColumn(c,when(col(c).isNull(),col(c+'_x')).otherwise(col(c))).drop(col(c+'_x'))
D1.show()
+----+----+----+----+
|col1|col2|col3|col4|
+----+----+----+----+
| 22| 26| 23| 56|
| 12| 54| 22| 36|
| 48| 26| 38| 45|
| 15| 32| 13| 6|
| 23| 26| 43| 8|
| 67| 54| 56| 41|
| 15| 32| 32| 6|
| 3| 54| 64| 8|
| 67| 4| 23| 41|
+----+----+----+----+
Upvotes: 1