vinayak A
vinayak A

Reputation: 77

Running a tensorflow model written in python for training and prediction from java

I have retrained inception model for my own data set. Tho model is built in python and i now have the saved graph as .pb file and label file as .txt. Now i need to predict using this model for an image through java. Can anyone please help me

Upvotes: 2

Views: 2066

Answers (2)

K.Nicholas
K.Nicholas

Reputation: 11551

The code I used that worked read a protobuf file, ending with .pb.

try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) {
    Session sess = b.session();
    ...
    float[][]matrix = sess.runner()
        .feed("x", input)
        .feed("keep_prob", keep_prob)
        .fetch("y_conv")
        .run()
        .get(0)
        .copyTo(new float[1][10]);
    ...
}

The python code I used to save it was:

  signature = tf.saved_model.signature_def_utils.build_signature_def(
    inputs = {'x': tf.saved_model.utils.build_tensor_info(x)},
    outputs = {'y_conv': tf.saved_model.utils.build_tensor_info(y_conv)},
  )
  builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model" )
  builder.add_meta_graph_and_variables(sess, 
       [tf.saved_model.tag_constants.SERVING],
       signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}
   )
  builder.save()

Upvotes: 0

javidcf
javidcf

Reputation: 59701

The TensorFlow team is developing a Java interface, but it is not stable yet. You can find the existing code here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java and follow updates on its development here https://github.com/tensorflow/tensorflow/issues/5. You can take a look at GraphTest.java, SessionTest.java and TensorTest.java to see how it is currently used (although, as explained, this may change in the future). Basically, you need to load the binary saved graph into a Graph object, create a Session with it and run it with the appropriate values (as Tensors) to receive a List<Tensor> with the output. Put together from the examples in the source:

import java.nio.file.Files;
import java.nio.file.Paths;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;

try (Graph graph = new Graph()) {
    graph.importGraphDef(Files.readAllBytes(Paths.get("saved_model.pb"));
    try (Session sess = new Session(graph)) {
        try (Tensor x = Tensor.create(1.0f);
             Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
            System.out.println(y.floatValue());
        }
    }
}

Upvotes: 3

Related Questions