Reputation: 3481
I have a tensorflow model trained in python by following this article. After training I have generated the frozen graph. Now I need to use this graph and generate recognition on a JAVA based application. For this I was looking at the following example . However I failed to understand is to how to collect my output. I know that I need to provide 3 inputs to the graph.
From the example given on the official tutorial I have read the code that is based on python.
def run_graph(wav_data, labels, input_layer_name, output_layer_name,
num_top_predictions):
"""Runs the audio data through the graph and prints predictions."""
with tf.Session() as sess:
# Feed the audio data as input to the graph.
# predictions will contain a two-dimensional array, where one
# dimension represents the input image count, and the other has
# predictions per class
softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name)
predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data})
# Sort to show labels in order of confidence
top_k = predictions.argsort()[-num_top_predictions:][::-1]
for node_id in top_k:
human_string = labels[node_id]
score = predictions[node_id]
print('%s (score = %.5f)' % (human_string, score))
return 0
Can someone help me to understand the tensorflow java api?
Upvotes: 2
Views: 603
Reputation: 6751
The literal translation of the Python code you listed above would be something like this:
public static float[][] getPredictions(Session sess, byte[] wavData, String inputLayerName, String outputLayerName) {
try (Tensor<String> wavDataTensor = Tensors.create(wavData);
Tensor<Float> predictionsTensor = sess.runner()
.feed(inputLayerName, wavDataTensor)
.fetch(outputLayerName)
.run()
.get(0)
.expect(Float.class)) {
float[][] predictions = new float[(int)predictionsTensor.shape(0)][(int)predictionsTensor.shape(1)];
predictionsTensor.copyTo(predictions);
return predictions;
}
}
The returned predictions
array will have the "confidence" values of each of the predictions, and you'll have to run the logic to compute the "top K" on it similar to how the Python code is using numpy (.argsort()
) to do so on what sess.run()
returned.
From a cursory reading of the tutorial page and code, it seems that predictions
will have 1 row and 12 columns (one for each hotword). I got this from the following Python code:
import tensorflow as tf
graph_def = tf.GraphDef()
with open('/tmp/my_frozen_graph.pb', 'rb') as f:
graph_def.ParseFromString(f.read())
output_layer_name = 'labels_softmax:0'
tf.import_graph_def(graph_def, name='')
print(tf.get_default_graph().get_tensor_by_name(output_layer_name).shape)
Hope that helps.
Upvotes: 1