kanimbla
kanimbla

Reputation: 890

spark OneHotEncoder - how to exclude user-defined category?

Consider the following spark dataframe:

df.printSchema()

     |-- predictor: double (nullable = true)
     |-- label: double (nullable = true)
     |-- date: string (nullable = true)

df.show(6)

    predictor      label              date    
    4.23           6.33               20160510
    4.77           7.18               20160510
    4.09           5.94               20160511
    4.23           6.33               20160511
    4.77           7.18               20160512
    4.09           5.94               20160512

Essentially, my dataframe consists of data with daily frequency. I need to map the column of dates to a column of binary vectors. This is simple to implement using StringIndexer & OneHotEncoder:

val dateIndexer = new StringIndexer()
  .setInputCol("date")
  .setOutputCol("dateIndex")
  .fit(df)
val indexed = dateIndexer.transform(df)

val encoder = new OneHotEncoder()
  .setInputCol("dateIndex")
  .setOutputCol("date_codeVec")

val encoded = encoder.transform(indexed)

My problem is that OneHotEncoder drops the last category by default. However, I need to drop the category which relates to the first date in my dataframe (20160510 in the above example) because I need to compute a time trend relative to the first date.

How can I achieve this for the above example (note that I have more than 100 dates in my dataframe)?

Upvotes: 1

Views: 599

Answers (1)

zero323
zero323

Reputation: 330283

You can try setting setDropLast to false:

val encoder = new OneHotEncoder()
  .setInputCol("dateIndex")
  .setOutputCol("date_codeVec")
  .setDropLast(false)

val encoded = encoder.transform(indexed)

and dropping level choice manually, using VectorSlicer:

import org.apache.spark.ml.feature.VectorSlicer

val slicer = new VectorSlicer()
  .setInputCol("date_codeVec")
  .setOutputCol("data_codeVec_selected")
  .setNames(dateIndexer.labels.diff(Seq(dateIndexer.labels.min)))

slicer.transform(encoded)
+---------+-----+--------+---------+-------------+---------------------+
|predictor|label|    date|dateIndex| date_codeVec|data_codeVec_selected|
+---------+-----+--------+---------+-------------+---------------------+
|     4.23| 6.33|20160510|      0.0|(3,[0],[1.0])|            (2,[],[])|
|     4.77| 7.18|20160510|      0.0|(3,[0],[1.0])|            (2,[],[])|
|     4.09| 5.94|20160511|      2.0|(3,[2],[1.0])|        (2,[1],[1.0])|
|     4.23| 6.33|20160511|      2.0|(3,[2],[1.0])|        (2,[1],[1.0])|
|     4.77| 7.18|20160512|      1.0|(3,[1],[1.0])|        (2,[0],[1.0])|
|     4.09| 5.94|20160512|      1.0|(3,[1],[1.0])|        (2,[0],[1.0])|
+---------+-----+--------+---------+-------------+---------------------+

Upvotes: 1

Related Questions