Gili
Gili

Reputation: 90013

How to invoke model from TensorFlow Java?

The following python code passes ["hello", "world"] into the universal sentence encoder and returns an array of floats denoting their encoded representation.

import tensorflow as tf
import tensorflow_hub as hub

module = hub.KerasLayer("https://tfhub.dev/google/universal-sentence-encoder/4")
model = tf.keras.Sequential(module)
print("model: ", model(["hello", "world"]))

This code works but I'd now like to do the same thing using the Java API. I've successfully loaded the module, but I am unable to pass inputs into the model and extract the output. Here is what I've got so far:

import org.tensorflow.Graph;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.Tensors;
import org.tensorflow.framework.ConfigProto;
import org.tensorflow.framework.GPUOptions;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.util.SaverDef;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.List;

public final class NaiveBayesClassifier
{
    public static void main(String[] args)
    {
        new NaiveBayesClassifier().run();
    }

    protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
    {
        return SavedModelBundle.load(source.toAbsolutePath().normalize().toString(), tags);
    }

    public void run()
    {
        try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
        {
            Graph graph = module.graph();
            try (Session session = new Session(graph, ConfigProto.newBuilder().
                setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
                setAllowSoftPlacement(true).
                build().toByteArray()))
            {
                Tensor<String> input = Tensors.create(new byte[][]
                    {
                        "hello".getBytes(StandardCharsets.UTF_8),
                        "world".getBytes(StandardCharsets.UTF_8)
                    });
                List<Tensor<?>> result = session.runner().feed("serving_default_inputs", input).
                    addTarget("???").run();
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
    }
}

I used https://stackoverflow.com/a/51952478/14731 to scan the model for possible input/output nodes. I believe the input node is "serving_default_inputs" but I can't figure out the output node. More importantly, I don't have to specify any of these values when invoking the code in python through Keras so is there a way to do the same using the Java API?

UPDATE: Thanks to roywei I can now that confirm the input node is serving_default_input and output node is StatefulPartitionedCall_1 but when I plug these names into the aforementioned code I get:

2020-05-22 22:13:52.266287: W tensorflow/core/framework/op_kernel.cc:1651] OP_REQUIRES failed at lookup_table_op.cc:809 : Failed precondition: Table not initialized.
Exception in thread "main" java.lang.IllegalStateException: [_Derived_]{{function_node __inference_pruned_6741}} {{function_node __inference_pruned_6741}} Error while reading resource variable EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25 from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/EncoderDNN/DNN/ResidualHidden_0/dense/kernel/part_25/class tensorflow::Var does not exist.
     [[{{node EncoderDNN/DNN/ResidualHidden_0/dense/kernel/ConcatPartitions/concat/ReadVariableOp_25}}]]
     [[StatefulPartitionedCall_1/StatefulPartitionedCall]]
    at [email protected]/org.tensorflow.Session.run(Native Method)
    at [email protected]/org.tensorflow.Session.access$100(Session.java:48)
    at [email protected]/org.tensorflow.Session$Runner.runHelper(Session.java:326)
    at [email protected]/org.tensorflow.Session$Runner.run(Session.java:276)

Meaning, I still cannot invoke the model. What am I missing?

Upvotes: 0

Views: 2682

Answers (4)

user3232470
user3232470

Reputation: 1

I need to do the same, but seems still lots of missing pieces RE DJL usage. E.g., what to do after this?:

ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);

I finally found an example in the DJL source code. The key take-away is to not use NDList for the input/output at all:

Criteria<String[], float[][]> criteria =
        Criteria.builder()
                .optApplication(Application.NLP.TEXT_EMBEDDING)
                .setTypes(String[].class, float[][].class)
                .optModelUrls(modelUrl)
                .build();
try (ZooModel<String[], float[][]> model = ModelZoo.loadModel(criteria);
        Predictor<String[], float[][]> predictor = model.newPredictor()) {
    return predictor.predict(inputs.toArray(new String[0]));
}

See https://github.com/awslabs/djl/blob/master/examples/src/main/java/ai/djl/examples/inference/UniversalSentenceEncoder.java for the complete example.

Upvotes: 0

Gili
Gili

Reputation: 90013

I figured it out after roywei pointed me in the right direction.

  • I needed to use SavedModuleBundle.session() instead of constructing my own instance. This is because the loader initializes the graph variables.
  • Instead of passing a ConfigProto to the Session constructor, I passed it into the SavedModelBundle loader instead.
  • I needed to use fetch() instead of addTarget() to retrieve the output tensor.

Here is the working code:

public final class NaiveBayesClassifier
{
    public static void main(String[] args)
    {
        new NaiveBayesClassifier().run();
    }

    public void run()
    {
        try (SavedModelBundle module = loadModule(Paths.get("universal-sentence-encoder"), "serve"))
        {
            try (Tensor<String> input = Tensors.create(new byte[][]
                {
                    "hello".getBytes(StandardCharsets.UTF_8),
                    "world".getBytes(StandardCharsets.UTF_8)
                }))
            {
                MetaGraphDef metadata = MetaGraphDef.parseFrom(module.metaGraphDef());
                Map<String, Shape> nameToInput = getInputToShape(metadata);
                String firstInput = nameToInput.keySet().iterator().next();

                Map<String, Shape> nameToOutput = getOutputToShape(metadata);
                String firstOutput = nameToOutput.keySet().iterator().next();

                System.out.println("input: " + firstInput);
                System.out.println("output: " + firstOutput);
                System.out.println();

                List<Tensor<?>> result = module.session().runner().feed(firstInput, input).
                    fetch(firstOutput).run();
                for (Tensor<?> tensor : result)
                {
                    {
                        float[][] array = new float[tensor.numDimensions()][tensor.numElements() /
                            tensor.numDimensions()];
                        tensor.copyTo(array);
                        System.out.println(Arrays.deepToString(array));
                    }
                }
            }
        }
        catch (IOException e)
        {
            e.printStackTrace();
        }
    }

    /**
     * Loads a graph from a file.
     *
     * @param source the directory containing  to load from
     * @param tags   the model variant(s) to load
     * @return the graph
     * @throws NullPointerException if any of the arguments are null
     * @throws IOException          if an error occurs while reading the file
     */
    protected SavedModelBundle loadModule(Path source, String... tags) throws IOException
    {
        // https://stackoverflow.com/a/43526228/14731
        try
        {
            return SavedModelBundle.loader(source.toAbsolutePath().normalize().toString()).
                withTags(tags).
                withConfigProto(ConfigProto.newBuilder().
                    setGpuOptions(GPUOptions.newBuilder().setAllowGrowth(true)).
                    setAllowSoftPlacement(true).
                    build().toByteArray()).
                load();
        }
        catch (TensorFlowException e)
        {
            throw new IOException(e);
        }
    }

    /**
     * @param metadata the graph metadata
     * @return the first signature, or null
     */
    private SignatureDef getFirstSignature(MetaGraphDef metadata)
    {
        Map<String, SignatureDef> nameToSignature = metadata.getSignatureDefMap();
        if (nameToSignature.isEmpty())
            return null;
        return nameToSignature.get(nameToSignature.keySet().iterator().next());
    }

    /**
     * @param metadata the graph metadata
     * @return the output signature
     */
    private SignatureDef getServingSignature(MetaGraphDef metadata)
    {
        return metadata.getSignatureDefOrDefault("serving_default", getFirstSignature(metadata));
    }

    /**
     * @param metadata the graph metadata
     * @return a map from an output name to its shape
     */
    protected Map<String, Shape> getOutputToShape(MetaGraphDef metadata)
    {
        Map<String, Shape> result = new HashMap<>();
        SignatureDef servingDefault = getServingSignature(metadata);
        for (Map.Entry<String, TensorInfo> entry : servingDefault.getOutputsMap().entrySet())
        {
            TensorShapeProto shapeProto = entry.getValue().getTensorShape();
            List<Dim> dimensions = shapeProto.getDimList();
            long firstDimension = dimensions.get(0).getSize();
            long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
            Shape shape = Shape.make(firstDimension, remainingDimensions);
            result.put(entry.getValue().getName(), shape);
        }
        return result;
    }

    /**
     * @param metadata the graph metadata
     * @return a map from an input name to its shape
     */
    protected Map<String, Shape> getInputToShape(MetaGraphDef metadata)
    {
        Map<String, Shape> result = new HashMap<>();
        SignatureDef servingDefault = getServingSignature(metadata);
        for (Map.Entry<String, TensorInfo> entry : servingDefault.getInputsMap().entrySet())
        {
            TensorShapeProto shapeProto = entry.getValue().getTensorShape();
            List<Dim> dimensions = shapeProto.getDimList();
            long firstDimension = dimensions.get(0).getSize();
            long[] remainingDimensions = dimensions.stream().skip(1).mapToLong(Dim::getSize).toArray();
            Shape shape = Shape.make(firstDimension, remainingDimensions);
            result.put(entry.getValue().getName(), shape);
        }
        return result;
    }
}

Upvotes: 4

roywei
roywei

Reputation: 11

There are two ways to get the names:

1) Using Java:

You can read the input and output names from the org.tensorflow.proto.framework.MetaGraphDef stored in saved model bundle.

Here is an example on how to extract the information:

https://github.com/awslabs/djl/blob/master/tensorflow/tensorflow-engine/src/main/java/ai/djl/tensorflow/engine/TfSymbolBlock.java#L149

2) Using python:

load the saved model in tensorflow python and print the names

loaded = tf.saved_model.load("path/to/model/")
print(list(loaded.signatures.keys()))
infer = loaded.signatures["serving_default"]
print(infer.structured_outputs)

I recommend to take a look at Deep Java Library, it automatically handle the input, output names. It supports TensorFlow 2.1.0 and allows you to load Keras models as well as TF Hub Saved Model. Take a look at the documentation here and here

Feel free to open an issue if you have problem loading your model.

Upvotes: 1

Frank Liu
Frank Liu

Reputation: 336

You can load TF model with Deep Java Library

System.setProperty("ai.djl.repository.zoo.location", "https://storage.googleapis.com/tfhub-modules/google/universal-sentence-encoder/1.tar.gz?artifact_id=encoder");

Criteria.Builder<NDList, NDList> builder =
        Criteria.builder()
                .setTypes(NDList.class, NDList.class)
                .optArtifactId("ai.djl.localmodelzoo:encoder")
                .build();
ZooModel<NDList, NDList> model = ModelZoo.loadModel(criteria);

See https://github.com/awslabs/djl/blob/master/docs/load_model.md#load-model-from-a-url for detail

Upvotes: 0

Related Questions