pettinato
pettinato

Reputation: 1542

What data type does VectorAssembler require for an input?

The Core problem is this here

from pyspark.ml.feature import VectorAssembler
df = spark.createDataFrame([([1, 2, 3], 0, 3)], ["a", "b", "c"])
vecAssembler = VectorAssembler(outputCol="features", inputCols=["a", "b", "c"])
vecAssembler.transform(df).show()

with error IllegalArgumentException: Data type array<bigint> of column a is not supported.

I know this is a bit of a toy problem, but I'm trying to integrate this into a longer pipeline with steps

If I can determine the proper input datatype for the VectorAssembler I should be able to string everything together properly. I think the input type is a Vector, but I can't figure out how to build one.

Upvotes: 0

Views: 1534

Answers (1)

mck
mck

Reputation: 42352

According to the docs,

VectorAssembler accepts the following input column types: all numeric types, boolean type, and vector type.

So you need to convert your array column to a vector column first (method from this question).

from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf
list_to_vector_udf = udf(lambda l: Vectors.dense(l), VectorUDT())
df_with_vectors = df.withColumn('a', list_to_vector_udf('a'))

Then you can use vector assembler:

vecAssembler = VectorAssembler(outputCol="features", inputCols=["a", "b", "c"])

vecAssembler.transform(df_with_vectors).show(truncate=False)
+-------------+---+---+---------------------+
|a            |b  |c  |features             |
+-------------+---+---+---------------------+
|[1.0,2.0,3.0]|0  |3  |[1.0,2.0,3.0,0.0,3.0]|
+-------------+---+---+---------------------+

Upvotes: 2

Related Questions