Reputation: 5430
I want to use H2O Autoencoder (Anomaly Detection) for Inference / Prediction in a Java class.
I built the autoencoder example "ECG Hearbeats" from H2O DeepLearningBooklet with R and saved it. I can succesfully import the generated Java class and its related h2o-genmodel.jar into my Java project.
Unfortunately, I cannot find an example or documentation how to use it there.
Here is my first try with some code and some guesses from my experience with other H2O models used for inference in Java code:
private static String modelClassName = "machinelearning.DeepLearning_model_R_1509973865970_1";
public static void main(String[] args) throws Exception {
hex.genmodel.GenModel rawModel;
rawModel = (hex.genmodel.GenModel) Class.forName(modelClassName).newInstance();
EasyPredictModelWrapper model = new EasyPredictModelWrapper(rawModel);
RowData row = new RowData();
// row.put(key, value); // TODO Add new line of input data, e.g.:
// 2.10,2.13,2.19,2.28,2.44,2.62,2.80,3.04,3.36,3.69,3.97,4.24,4.53,4.80,5.02,5.21,5.40,5.57,5.71,5.79,5.86,5.92,5.98,6.02,6.06,6.08,6.14,6.18,6.22,6.27,6.32,6.35,6.38,6.45,6.49,6.53,6.57,6.64,6.70,6.73,6.78,6.83,6.88,6.92,6.94,6.98,7.01,7.03,7.05,7.06,7.07,7.08,7.06,7.04,7.03,6.99,6.94,6.88,6.83,6.77,6.69,6.60,6.53,6.45,6.36,6.27,6.19,6.11,6.03,5.94,5.88,5.81,5.75,5.68,5.62,5.61,5.54,5.49,5.45,5.42,5.38,5.34,5.31,5.30,5.29,5.26,5.23,5.23,5.22,5.20,5.19,5.18,5.19,5.17,5.15,5.14,5.17,5.16,5.15,5.15,5.15,5.14,5.14,5.14,5.15,5.14,5.14,5.13,5.15,5.15,5.15,5.14,5.16,5.15,5.15,5.14,5.14,5.15,5.15,5.14,5.13,5.14,5.14,5.11,5.12,5.12,5.12,5.09,5.09,5.09,5.10,5.08,5.08,5.08,5.08,5.06,5.05,5.06,5.07,5.05,5.03,5.03,5.04,5.03,5.01,5.01,5.02,5.01,5.01,5.00,5.00,5.02,5.01,4.98,5.00,5.00,5.00,4.99,5.00,5.01,5.02,5.01,5.03,5.03,5.02,5.02,5.04,5.04,5.04,5.02,5.02,5.01,4.99,4.98,4.96,4.96,4.96,4.94,4.93,4.93,4.93,4.93,4.93,5.02,5.27,5.80,5.94,5.58,5.39,5.32,5.25,5.21,5.13,4.97,4.71,4.39,4.05,3.69,3.32,3.05,2.99,2.74,2.61,2.47,2.35,2.26,2.20,2.15,2.10,2.08
AutoEncoderModelPrediction p = model.predictAutoEncoder(row);
System.out.println(p.reconstructedRowData);
System.out.println(p.reconstructed[0]);
// TODO How to do get the MSE from object 'p'?
This code actually compiles and runs. However, I do not really understand how to
I assume the answer is simple, but without documentation not easy to find :-)
Thanks for help.
Upvotes: 1
Views: 619
Reputation: 5778
(code example for main.java at the end)
you have it correctly and it is instantiated in this line rawModel = (hex.genmodel.GenModel) Class.forName(modelClassName).newInstance();
the key
is the column header and value
is the actual value, if the H2Oframe doesn’t have column headers then H2O will automatically assign them C1
, C2
, etc. You can manually write this or use a loop using System.out.println(java.util.Arrays.toString(rawModel.getNames()));
(see code snippet for example of this)
there is not a method for this currently, but you can get the original values and reconstructed values and then calculate the MSE from that (see code snippet below, the last few lines calculate the MSE using the original
and reconstructed
arrays)
When I created my model I called it anomaly_model
(see the code directly below, model_id
is one of the parameters) and you will see that used in the last code snippet below, so if you use a different name you will need to update that part.
anomaly_model <- h2o.deeplearning(x = names(train_ecg), training_frame = train_ecg, activation = "Tanh",
autoencoder = TRUE,hidden = c(50,20,50),sparse = TRUE,l1 = 1e-4,epochs = 100, model_id = 'anomaly_model')
Here is example code for how to create the main.java
file, pass in column names for your keys, and calculate MSE with built in method results.
(Note: I generated random values for the row.put(key, values)
you can put whatever you want there instead)
import java.io.*;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.prediction.*;
public class main {
private static String modelClassName = "anomaly_model";
public static void main(String[] args) throws Exception {
hex.genmodel.GenModel rawModel;
rawModel = (hex.genmodel.GenModel) Class.forName(modelClassName).newInstance();
EasyPredictModelWrapper model = new EasyPredictModelWrapper(rawModel);
java.util.Random rng = new java.util.Random();
RowData row = new RowData();
for (String colName : rawModel.getNames()) {
row.put(colName,rng.nextDouble());
}
AutoEncoderModelPrediction p = model.predictAutoEncoder(row);
System.out.println("original: " + java.util.Arrays.toString(p.original));
System.out.println("reconstructedrowData: " + p.reconstructedRowData);
System.out.println("reconstructed: " + java.util.Arrays.toString(p.reconstructed));
double sum = 0;
for (int i = 0; i<p.original.length; i++) {
sum += (p.original[i] - p.reconstructed[i])*(p.original[i] - p.reconstructed[i]);
}
double mse = sum/p.original.length;
System.out.println("MSE: " + mse);
}
}
Hope this helps!
Upvotes: 3