Reputation: 700
i am trying to load a pretrained model (using python) into a java project.
The problem is
Exception in thread "Thread-9" java.lang.IllegalStateException: Tensor is not a scalar
at org.tensorflow.Tensor.scalarFloat(Native Method)
at org.tensorflow.Tensor.floatValue(Tensor.java:279)
Code
float[] arr=context.csvintarr(context.getPlayer(playerId));
float[][] martix={arr};
try (Graph g=model.graph()){
try(Session s=model.session()){
Tensor y=s.runner().feed("input/input", Tensor.create(martix))
.fetch("out/predict").run().get(0);
logger.info("a {}",y.floatValue());
}
}
The python code to train and save the model
with tf.Session() as sess:
with tf.name_scope('input'):
x=tf.placeholder(tf.float32,[None,bucketlen],name="input")
......
with tf.name_scope('out'):
y=tf.tanh(tf.matmul(h,hW)+hb,name="predict")
builder=tf.saved_model.builder.SavedModelBuilder(export_dir)
builder.add_meta_graph_and_variables(sess,['foo-tag'])
......after the train process
builder.save()
It seems that I have successfully loaded the model and the graph,because
try (Graph g=model.graph()){
try(Session s=model.session()){
Operation operation=g.operation("input/input");
logger.info(operation.name());
}
}
print out the name successfully.
Upvotes: 0
Views: 977
Reputation: 6751
The error message indicates that the output tensor isn't a float-valued scalar, so it's probably a higher dimension tensor (a vector, a matrix).
You can learn the shape of the tensor using System.out.println(y.toString())
or specifically using y.shape()
. In your Python code, that would correspond to y.shape
.
For non-scalars, use y.copyTo
to get an array of floats (for a vector), or array of array of floats (for a matrix) etc.
For example, something like:
System.out.println(y);
// If the above printed something like:
// "FLOAT tensor with shape [1]"
// then you can get the values using:
float[] vector = y.copyTo(new float[1]);
// If the shape was something like [2, 3]
// then you can get the values using:
float[][] matrix = y.copyTo(new float[2][3]);
See the Tensor
javadoc for more information on floatValue()
vs copyTo
vs writeTo
.
Hope that helps.
Upvotes: 1