paolof89
paolof89

Reputation: 1369

Pyspark: Extract Multiclass Classification results as different columns

I'm using the RandomForestClassifier object for a multiclass classification problem. The output dataframe of the prediction presents the 'probability' columns as a vector:

df.select('probability').printSchema()
root
 |-- probability: vector (nullable = true)

Each row is a vector of 4:

df.select('probability').show(3)
+--------------------+
|         probability|
+--------------------+
|[0.02753394443688...|
|[7.95347766409877...|
|[0.02264704615632...|
+--------------------+

I would like to create 4 columns on my df with one Double value each.

A similar question suggests this solution:

from pyspark.sql.functions import udf
from pyspark.sql.types import FloatType

firstelement=udf(lambda v:float(v[0]),FloatType())
df.select(firstelement('probability'))

The solution works but when I try to assign the value to a new column with

df.withColumn('prob_SELF', df.select(firstelement('probability'))['<lambda>(probability)'])

I have the following error:

AnalysisException: 'resolved attribute(s) <lambda>(probability)#26116 missing from prediction#25521

Upvotes: 1

Views: 291

Answers (1)

mellowonpsx
mellowonpsx

Reputation: 56

Short answer

To use an udf with withColumn you should do like this:

firstelement=udf(lambda v:float(v[0]),FloatType())
df.withColumn('prob_SELF', firstelement('probability'))

Long answer

The problem is that when you do df.select(firstelement('probability'))['<lambda>(probability)'] you are creating a new, separate, dataframe.

You can't use .withColumn on columns from different dataframes, to join separate dataframes you must use join.

Here a simple demonstration:

df_a = spark.sql("""
SELECT CAST(1.0 AS FLOAT) AS A
""")

df_b = spark.sql("""
SELECT CAST(1.0 AS FLOAT) AS B
""")

df_a.withColumn('B', df_b['B'])

you get

AnalysisException: u'Resolved attribute(s) B#2465 missing from A#2463 in operator !Project [A#2463, B#2465 AS B#2468].;;\n!Project [A#2463, B#2465 AS B#2468]\n+- Project [cast(1.0 as float) AS A#2463]\n   +- OneRowRelation\n'```

Upvotes: 1

Related Questions