PatPanda
PatPanda

Reputation: 5060

Java Spark ML - prediction/forecast with Spark ML 3.1+ issue

Small question regarding prediction/forecast using Spark ML 3.1+ please.

I have a dataset, very simple, of timestamps for when an event happened.

The dataset is very simple, here is a small portion, of the very very very big file.

        +----------+-----+
        |      time|label|
        +----------+-----+
        |1621900800|   43|
        |1619568000|   41|
        |1620432000|   41|
        |1623974400|   42|
        |1620604800|   41|
        |1622505600|   42|
       truncated
        |1624665600|   42|
        |1623715200|   41|
        |1623024000|   43|
        |1623888000|   42|
        |1621296000|   42|
        |1620691200|   44|
        |1620345600|   41|
        |1625702400|   44|
        +----------+-----+
        only showing top 20 rows

The dataset is really just a timestamp representing a day, on the left, and on the right, the number of banana sold that day. Example of the first three rows of above sample translated.

        +------    ----+--            ---+
        |          time|            value|
        +-------    ---+-----+
        |May   25, 2021|   banana sold 43|
        |April 28, 2021|   banana sold 41|
        |May    8, 2021|   banana sold 41|

My goal is just to build a prediction model, how many "banana will be sold tomorrow, the day after, etc...

Therefore, I went to try Linear Regression, but it might bot be a good model for this problem:

VectorAssembler       vectorAssembler = new VectorAssembler().setInputCols(new String[]{"time", "label"}).setOutputCol("features");
        Dataset<Row>          vectorData      = vectorAssembler.transform(dataSetBanana);
        LinearRegression      lr              = new LinearRegression(); 
        LinearRegressionModel lrModel         = lr.fit(vectorData);
        System.out.println("Coefficients: " + lrModel.coefficients() + " Intercept: " + lrModel.intercept());
        LinearRegressionTrainingSummary trainingSummary = lrModel.summary();
        System.out.println("numIterations: " + trainingSummary.totalIterations());
        System.out.println("objectiveHistory: " + Vectors.dense(trainingSummary.objectiveHistory()));
        trainingSummary.residuals().show();
        System.out.println("RMSE: " + trainingSummary.rootMeanSquaredError());
        System.out.println("r2: " + trainingSummary.r2());
        System.out.println("the magical prediction: " + lrModel.predict(new DenseVector(new double[]{1.0, 1.0})));

I see all the values printed, very happy.

Coefficients: [-1.5625735463489882E-19,1.0000000000000544] Intercept: 2.5338210784074846E-10
numIterations: 0
objectiveHistory: [0.0]

+--------------------+
|           residuals|
+--------------------+
|-1.11910480882215...|

RMSE: 3.0933584599870493E-13
r2: 1.0
the magical prediction: 1.0000000002534366

It is not giving me anything close to a prediction, I was expecting something like

|Some time in the future|   banana sold some prediction|
| 1626414043 | 38 |

May I ask what would be a model that can result an answer like "model predicts X banana will be sold at time Y in the future"

A small piece of code with result would be great.

Thank you

Upvotes: 0

Views: 570

Answers (2)

David
David

Reputation: 1136

Linear regression can be a good start to get familiar with mllib before you go for more complicated models. First, let's have a look at when you have done so far.

Your VectorAssembler transform your data frame that way:

before:

time label
1621900800 43
1620432000 41

after:

time label features
1621900800 43 [1621900800;43]
1620432000 41 [1620432000;41]

Now, when you are asking LinearRegression to train its model, it will expect your dataset to contain two columns:

  • one column named features and containing a vector with everything that can be used to predict the label.
  • one column named label, what you want to predict

Regression will find a and b which minimizes errors across all record i where:

y_i = a * x_i + b + error_i

In your particular setup, you have passed the label to your vector assembler, which is wrong, that's what you want to predict ! Your model has simply learnt that the label predicts perfectly the label:

y = 0.0 * features[0] + 1.0 * features[1]

So you should correct your VectorAssembler:

val vectorAssembler = new VectorAssembler().setInputCols(new String[]{"time"}).setOutputCol("features");

Now when you are doing your prediction, you had passed this:

lrModel.predict(new DenseVector(new double[]{ 1.0,        1.0})));
                                             timestamp   label

It returned 1.0 as per formula above. Now if you change the VectorAssembler as proposed above, you should call the prediction that way:

lrModel.predict(new DenseVector(new double[]{ timeStampIWantToPredict })));

Side notes:

  • you can pass a dataset to your predictor, it will return a dataset with a new column with the prediction.
  • you should really have a closer look at Mllib Pipeline documentation
  • then you can try to add some new features to your linear regression : seasonality, auto regressive features...

Upvotes: 1

Onur Başt&#252;rk
Onur Başt&#252;rk

Reputation: 735

The model gives you the coefficients of your variables. Then it's easy to calculate the output. If you have only one variable x1 your model will be something like:

y = a*x1 + b

Then the outputs of your model are a and b. Then you can calculate y.

Generally speaking, machine learning libraries also implement other methods that let you calculate the output. It's better to search how to save, load and then evaluate your model with new inputs. Check out https://spark.apache.org/docs/1.6.1/api/java/org/apache/spark/ml/regression/LinearRegressionModel.html

There's a method called predict that you can call on your model by giving the input as a Vector instance. I think that will work!

Another thing is: you are trying to solve a time-series problem with a single-variable linear regression model. I think you should use a better algorithm that is intended to deal with time-series or sequence problems such as Long Short Term Memory (LSTM).

I hope that my answer is useful for you. Keep going ;)

Upvotes: 1

Related Questions