Reputation: 890
Consider the following spark dataframe:
df.printSchema()
|-- predictor: double (nullable = true)
|-- label: double (nullable = true)
|-- date: string (nullable = true)
df.show(6)
predictor label date
4.23 6.33 20160510
4.77 7.18 20160510
4.09 5.94 20160511
4.23 6.33 20160511
4.77 7.18 20160512
4.09 5.94 20160512
Essentially, my dataframe consists of data with daily frequency. I need to map the column of dates to a column of binary vectors. This is simple to implement using StringIndexer & OneHotEncoder:
val dateIndexer = new StringIndexer()
.setInputCol("date")
.setOutputCol("dateIndex")
.fit(df)
val indexed = dateIndexer.transform(df)
val encoder = new OneHotEncoder()
.setInputCol("dateIndex")
.setOutputCol("date_codeVec")
val encoded = encoder.transform(indexed)
My problem is that OneHotEncoder drops the last category by default. However, I need to drop the category which relates to the first date in my dataframe (20160510 in the above example) because I need to compute a time trend relative to the first date.
How can I achieve this for the above example (note that I have more than 100 dates in my dataframe)?
Upvotes: 1
Views: 599
Reputation: 330283
You can try setting setDropLast
to false
:
val encoder = new OneHotEncoder()
.setInputCol("dateIndex")
.setOutputCol("date_codeVec")
.setDropLast(false)
val encoded = encoder.transform(indexed)
and dropping level choice manually, using VectorSlicer
:
import org.apache.spark.ml.feature.VectorSlicer
val slicer = new VectorSlicer()
.setInputCol("date_codeVec")
.setOutputCol("data_codeVec_selected")
.setNames(dateIndexer.labels.diff(Seq(dateIndexer.labels.min)))
slicer.transform(encoded)
+---------+-----+--------+---------+-------------+---------------------+
|predictor|label| date|dateIndex| date_codeVec|data_codeVec_selected|
+---------+-----+--------+---------+-------------+---------------------+
| 4.23| 6.33|20160510| 0.0|(3,[0],[1.0])| (2,[],[])|
| 4.77| 7.18|20160510| 0.0|(3,[0],[1.0])| (2,[],[])|
| 4.09| 5.94|20160511| 2.0|(3,[2],[1.0])| (2,[1],[1.0])|
| 4.23| 6.33|20160511| 2.0|(3,[2],[1.0])| (2,[1],[1.0])|
| 4.77| 7.18|20160512| 1.0|(3,[1],[1.0])| (2,[0],[1.0])|
| 4.09| 5.94|20160512| 1.0|(3,[1],[1.0])| (2,[0],[1.0])|
+---------+-----+--------+---------+-------------+---------------------+
Upvotes: 1