curtank
curtank

Reputation: 700

tensorflow java api err: java.lang.IllegalStateException: Tensor is not a scalar

i am trying to load a pretrained model (using python) into a java project.

The problem is

 Exception in thread "Thread-9" java.lang.IllegalStateException: Tensor is not a scalar
    at org.tensorflow.Tensor.scalarFloat(Native Method)
    at org.tensorflow.Tensor.floatValue(Tensor.java:279)

Code

    float[] arr=context.csvintarr(context.getPlayer(playerId));
    float[][] martix={arr};
    try (Graph g=model.graph()){
        try(Session s=model.session()){

            Tensor y=s.runner().feed("input/input", Tensor.create(martix))
            .fetch("out/predict").run().get(0);
            logger.info("a {}",y.floatValue());
        }
    }

The python code to train and save the model

with tf.Session() as sess:
    with tf.name_scope('input'):
        x=tf.placeholder(tf.float32,[None,bucketlen],name="input")
......
    with tf.name_scope('out'):
        y=tf.tanh(tf.matmul(h,hW)+hb,name="predict")
    builder=tf.saved_model.builder.SavedModelBuilder(export_dir)
    builder.add_meta_graph_and_variables(sess,['foo-tag'])
......after the train process

    builder.save()

It seems that I have successfully loaded the model and the graph,because

  try (Graph g=model.graph()){
        try(Session s=model.session()){
            Operation operation=g.operation("input/input");
            logger.info(operation.name());

        }
    }

print out the name successfully.

Upvotes: 0

Views: 977

Answers (1)

ash
ash

Reputation: 6751

The error message indicates that the output tensor isn't a float-valued scalar, so it's probably a higher dimension tensor (a vector, a matrix).

You can learn the shape of the tensor using System.out.println(y.toString()) or specifically using y.shape(). In your Python code, that would correspond to y.shape.

For non-scalars, use y.copyTo to get an array of floats (for a vector), or array of array of floats (for a matrix) etc.

For example, something like:

System.out.println(y);
// If the above printed something like:
// "FLOAT tensor with shape [1]"
// then you can get the values using:
float[] vector = y.copyTo(new float[1]);

// If the shape was something like [2, 3]
// then you can get the values using:
float[][] matrix = y.copyTo(new float[2][3]);

See the Tensor javadoc for more information on floatValue() vs copyTo vs writeTo.

Hope that helps.

Upvotes: 1

Related Questions