RDK
RDK

Reputation: 953

Saving an H2O model directly from Java

I'm trying to create and save a generated model directly from Java. The documentation specifies how to do this in R and Python, but not in Java. A similar question was asked before, but no real answer was provided (beyond linking to H2O doc, which doesn't contain a code example).

It'd be sufficient for my present purpose get some pointers to be able to translate the following reference code to Java. I'm mainly looking for guidance on the relevant JAR(s) to import from the Maven repository.

import h2o
h2o.init()
path = h2o.system_file("prostate.csv")
h2o_df = h2o.import_file(path)
h2o_df['CAPSULE'] = h2o_df['CAPSULE'].asfactor()
model = h2o.glm(y = "CAPSULE",
            x = ["AGE", "RACE", "PSA", "GLEASON"],
            training_frame = h2o_df,
            family = "binomial")
h2o.download_pojo(model)

Upvotes: 1

Views: 507

Answers (2)

user3030851
user3030851

Reputation: 123

Would something like this do the trick?

public void saveModel(URI uri, Keyed<Frame> model)
{
    Persist p = H2O.getPM().getPersistForURI(uri);
    OutputStream os = p.create(uri.toString(), true);
    model.writeAll(new AutoBuffer(os, true)).close();
}

Make sure the URI has a proper form otherwise H2O will break on an npe. As for Maven you should be able to get away with the h2o core.

    <dependency>
        <groupId>ai.h2o</groupId>
        <artifactId>h2o-core</artifactId>
        <version>3.14.0.2</version>
    </dependency>

Upvotes: 1

RDK
RDK

Reputation: 953

I think I've figured out an answer to my question. A self-contained sample code follows. However, I'll still appreciate an answer from the community since I don't know if this is the best/idiomatic way to do it.

package org.name.company;

import hex.glm.GLMModel;
import water.H2O;
import water.Key;
import water.api.StreamWriter;
import water.api.StreamingSchema;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import hex.glm.GLMModel.GLMParameters.Family;
import hex.glm.GLMModel.GLMParameters;
import hex.glm.GLM;
import water.util.JCodeGen;

import java.io.*;
import java.util.Map;

public class Launcher
{
    public static void initCloud(){
        String[] args = new String [] {"-name", "h2o_test_cloud"};
        H2O.main(args);
        H2O.waitForCloudSize(1, 10 * 1000);
    }

    public static void main( String[] args ) throws Exception {
        // Initialize the cloud
        initCloud();

        // Create a Frame object from CSV
        File f = new File("/path/to/data.csv");
        NFSFileVec nfs = NFSFileVec.make(f);
        Key frameKey = Key.make("frameKey");
        Frame fr = water.parser.ParseDataset.parse(frameKey, nfs._key);

        // Create a GLM and output coefficients
        Key modelKey = Key.make("modelKey");
        try {
            GLMParameters params = new GLMParameters();
            params._train = frameKey;
            params._response_column = fr.names()[1];
            params._intercept = true;
            params._lambda = new double[]{0};
            params._family = Family.gaussian;

            GLMModel model = new GLM(params).trainModel().get();
            Map<String, Double> coefs = model.coefficients();
            for(Map.Entry<String, Double> entry : coefs.entrySet()) {
                System.out.format("%s: %f\n", entry.getKey(), entry.getValue());
            }

            String filename = JCodeGen.toJavaId(model._key.toString()) + ".java";
            StreamingSchema ss = new StreamingSchema(model.new JavaModelStreamWriter(false), filename);
            StreamWriter sw = ss.getStreamWriter();
            OutputStream os = new FileOutputStream("/base/path/" + filename);
            sw.writeTo(os);

        } finally {
            if (fr != null) {
                fr.remove();
            }
        }
    }
}

Upvotes: 1

Related Questions