user3494047
user3494047

Reputation: 1693

What is the input format of org.apache.spark.ml.classification.LogisticRegression fit()?

In this example of training a LogisticRegression model they use as an RDD[LabeledPoint] as input to the fit() method but they write "// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes // into SchemaRDDs, where it uses the case class metadata to infer the schema."

Where is this conversion happening? When I try this code:

val sqlContext = new SQLContext(sc)
import sqlContext._
val model = lr.fit(training);

,where training is of type RDD[LabeledPoint], it gives a compilation error stating that fit expects a data frame. When I convert the RDD to a data frame I get this exception:

An exception occured while executing the Java class. null: InvocationTargetException: requirement failed: Column features must be of type org.apache.spark.mllib.linalg.VectorUDT@f71b0bce but was actually StructType(StructField(label,DoubleType,false), StructField(features,org.apache.spark.mllib.linalg.VectorUDT@f71b0bce,true))

But this is confusing to me. Why would it expect a Vector? it also needs labels. So I am wondering what is the correct format?

The reason I am using ML LogisticRegression and not Mllib LogisticRegressionWithLBFGS is because I want an elasticNet implementation.

Upvotes: 1

Views: 857

Answers (1)

Umberto Griffo
Umberto Griffo

Reputation: 931

The Exception says that the DataFrame expects the follow structure:

StructType(StructField(label,DoubleType,false), 
StructField(features,org.apache.spark.mllib.linalg.VectorUDT@f71b0bce,true))

So prepare training data from a list of (label, features) tuples like this:

val training = sqlContext.createDataFrame(Seq(
  (1.0, Vectors.dense(0.0, 1.1, 0.1)),
  (0.0, Vectors.dense(2.0, 1.0, -1.0)),
  (0.0, Vectors.dense(2.0, 1.3, 1.0)),
  (1.0, Vectors.dense(0.0, 1.2, -0.5))
)).toDF("label", "features")

Upvotes: 4

Related Questions