Odisseo
Odisseo

Reputation: 777

Create PySpark Dataframe from Features Vector with Label

I have a dataframe which I created with a Pipeline object that looks like this:

df.show()

+--------------------+-----+
|            features|label|
+--------------------+-----+
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
|[-0.0775219322931...|    0|
+--------------------+-----+

I have successfully extracted the features vectors like this:

df_table = df.rdd.map(lambda x: [float(y) for y in x['features']]).toDF(cols)

The problem with the above is that it does not retain the label column. As a workaround, I used a Join successfully to bring that label column back but I find that it's too convoluted.

How would I use a one-liner such as the above to both extract the features vector and make a Spark DF out of it and at the same time append that label column to it as well?

Upvotes: 0

Views: 523

Answers (1)

linog
linog

Reputation: 6226

You have good options here, especially if you have a version of Spark >= 3.0.0

Assuming you don't have such recent version, your problem comes from the fact that you loose your key in your map. You can do:

df_table = df.rdd.map(lambda l: tuple([l['label']] + [float(y) for y in l['features']])).toDF()

You end-up with a wide formatted dataframe. If you want a long formatted vector, you have more options.

If you want long-formatted data

First, with rdd:

df.rdd.flatMapValues(lambda l: l).toDF(['label','feature'])

Or, even better, directly using DataFrame API: (untested solution)

import pyspark.sql.functions as psf
df.select('label', psf.explode(psf.col('label')))

Upvotes: 2

Related Questions