Reputation: 77
I have retrained inception model for my own data set. Tho model is built in python and i now have the saved graph as .pb file and label file as .txt. Now i need to predict using this model for an image through java. Can anyone please help me
Upvotes: 2
Views: 2066
Reputation: 11551
The code I used that worked read a protobuf
file, ending with .pb
.
try (SavedModelBundle b = SavedModelBundle.load("/tmp/model", "serve")) {
Session sess = b.session();
...
float[][]matrix = sess.runner()
.feed("x", input)
.feed("keep_prob", keep_prob)
.fetch("y_conv")
.run()
.get(0)
.copyTo(new float[1][10]);
...
}
The python code I used to save it was:
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs = {'x': tf.saved_model.utils.build_tensor_info(x)},
outputs = {'y_conv': tf.saved_model.utils.build_tensor_info(y_conv)},
)
builder = tf.saved_model.builder.SavedModelBuilder("/tmp/model" )
builder.add_meta_graph_and_variables(sess,
[tf.saved_model.tag_constants.SERVING],
signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}
)
builder.save()
Upvotes: 0
Reputation: 59701
The TensorFlow team is developing a Java interface, but it is not stable yet. You can find the existing code here: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java and follow updates on its development here https://github.com/tensorflow/tensorflow/issues/5. You can take a look at GraphTest.java, SessionTest.java and TensorTest.java to see how it is currently used (although, as explained, this may change in the future). Basically, you need to load the binary saved graph into a Graph
object, create a Session
with it and run it with the appropriate values (as Tensor
s) to receive a List<Tensor>
with the output. Put together from the examples in the source:
import java.nio.file.Files;
import java.nio.file.Paths;
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
try (Graph graph = new Graph()) {
graph.importGraphDef(Files.readAllBytes(Paths.get("saved_model.pb"));
try (Session sess = new Session(graph)) {
try (Tensor x = Tensor.create(1.0f);
Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
System.out.println(y.floatValue());
}
}
}
Upvotes: 3