Reputation: 7997
I'm trying to create a new column on a dataframe based on the values of some columns. It's returning null in all cases. Anyone know what's going wrong with this simple example?
df = pd.DataFrame([[0,1,0],[1,0,0],[1,1,1]],columns = ['Foo','Bar','Baz'])
spark_df = spark.createDataFrame(df)
def get_profile():
if 'Foo'==1:
return 'Foo'
elif 'Bar' == 1:
return 'Bar'
elif 'Baz' ==1 :
return 'Baz'
spark_df = spark_df.withColumn('get_profile', lit(get_profile()))
spark_df.show()
Foo Bar Baz get_profile
0 1 0 None
1 0 0 None
1 1 1 None
I would expect that the get_profile column would be filled out for all rows.
I also tried this:
spark_udf = udf(get_profile,StringType())
spark_df = spark_df.withColumn('get_profile', spark_udf())
print(spark_df.toPandas())
to the same effect.
Upvotes: 2
Views: 10440
Reputation: 43504
The udf
has no knowledge of what the column names are. So it checks each of your conditions in your if
/elif
block and all of them evaluate to False
. Thus the function will return None
.
You'd have to rewrite your udf
to take in the columns you want to check:
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
def get_profile(foo, bar, baz):
if foo == 1:
return 'Foo'
elif bar == 1:
return 'Bar'
elif baz == 1 :
return 'Baz'
spark_udf = udf(get_profile, StringType())
spark_df = spark_df.withColumn('get_profile',spark_udf('Foo', 'Bar', 'Baz'))
spark_df.show()
#+---+---+---+-----------+
#|Foo|Bar|Baz|get_profile|
#+---+---+---+-----------+
#| 0| 1| 0| Bar|
#| 1| 0| 0| Foo|
#| 1| 1| 1| Foo|
#+---+---+---+-----------+
If you have a lot of columns and want to pass them all (in order):
spark_df = spark_df.withColumn('get_profile', spark_udf(*spark_df.columns))
More generally, you can unpack any ordered list of columns:
cols_to_pass_to_udf = ['Foo', 'Bar', 'Baz']
spark_df = spark_df.withColumn('get_profile', spark_udf(*cols_to_pass_to_udf ))
But this particular operation does not require a udf
. I would do it this way:
from pyspark.sql.functions import coalesce, when, col, lit
spark_df.withColumn(
"get_profile",
coalesce(*[when(col(c)==1, lit(c)) for c in spark_df.columns])
).show()
#+---+---+---+-----------+
#|Foo|Bar|Baz|get_profile|
#+---+---+---+-----------+
#| 0| 1| 0| Bar|
#| 1| 0| 0| Foo|
#| 1| 1| 1| Foo|
#+---+---+---+-----------+
This works because pyspark.sql.functions.when()
will return null
by default if the condition evaluates to False
and no otherwise
is specified. Then the list comprehension of pyspark.sql.functions.coalesce
will return the first non-null column.
Note this is equivalent to the udf
ONLY if the order of the columns is the same as the sequence that's evaluated in the get_profile
function. To be more explicit, you should do:
spark_df.withColumn(
"get_profile",
coalesce(*[when(col(c)==1, lit(c)) for c in ['Foo', 'Bar', 'Baz'])
).show()
Upvotes: 6