Alex
Alex

Reputation: 603

PySpark: How to write a Spark dataframe having a column with type SparseVector into CSV file?

I have a spark dataframe which has one column with type spark.mllib.linalg.SparseVector:

1) how can I write it into a csv file?

2) how can I print all the vectors?

Upvotes: 5

Views: 7287

Answers (2)

cindyxiaoxiaoli
cindyxiaoxiaoli

Reputation: 828

To write the dataframe to a csv file you can use the standard df.write.csv(output_path).

However, if you just use the above you are likely to get the java.lang.UnsupportedOperationException: CSV data source does not support struct<type:tinyint,size:int,indices:array<int>,values:array<double>> data type error for the column with the SparseVector type.

There are two ways to print the SparseVector and avoid that error: the sparse format or the dense format.

If you want to print in the dence format, you can define udf like this:

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from pyspark.sql.functions import col

dense_format_udf = udf(lambda x: ','.join([str(elem) for elem in x], StringType())

df = df.withColumn('column_name', dense_format_udf(col('column_name')))

df.write.option("delimiter", "\t").csv(output_path)

The column outputs to something like this in the dense format: 1.0,0.0,5.0,0.0

If you want to print in the sparse format, you can utilize the OOB __str__ function of the SparseVector class, or be creative and define your own output format. Here I am going to use the OOB function.

from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
from pyspark.sql.functions import col

sparse_format_udf = udf(lambda x: str(x), StringType())

df = df.withColumn('column_name', sparse_format_udf(col('column_name')))

df.write.option("delimiter", "\t").csv(output_path)

The column prints to something like this in the sparse format (4,[0,2],[1.0,5.0])

Note I have tried this approach before: df = df.withColumn("column_name", col("column_name").cast("string")) but the column just prints to something like this [0,5,org.apache.spark.sql.catalyst.expressions.UnsafeArrayData@6988050,org.apache.spark.sql.catalyst.expressions.UnsafeArrayData@ec4ae6ab] which is not desirable.

Upvotes: 4

Kristian
Kristian

Reputation: 21830

  1. https://github.com/databricks/spark-csv
  2. df2 = df1.map(lambda row: row.yourVectorCol)

    OR df1.map(lambda row: row[1])

    where you either have a named column or just refer to the column by its position in the row.

    Then, to print it, you can df2.collect()

Without more information, this may be helpful to you, or not helpful enough to you. Please elaborate a bit.

Upvotes: 2

Related Questions