Reputation: 571
I'm trying to use a Tensorflow model that I trained in python to score data in Scala (using TF Java API). For the model, I've used thisregression example, with the only change being that I dropped asText=True
from export_savedmodel
.
My snippet of Scala:
val b = SavedModelBundle.load("/tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1531933435/", "serve")
val s = b.session()
// output = predictor_fn({'csv_rows': ["0.5,1,ax01,bx02", "-0.5,-1,ax02,bx02"]})
val input = "0.5,1,ax01,bx02"
val inputTensor = Tensor.create(input.getBytes("UTF-8"))
val result = s.runner()
.feed("csv_rows", inputTensor)
.fetch("dnn/logits/BiasAdd")
.run()
.get(0)
When I run, I get the following error:
Exception in thread "main" java.lang.IllegalArgumentException: Input to reshape is a tensor with 2 values, but the requested shape has 4
[[Node: dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _output_shapes=[[?,2]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](dnn/input_from_feature_columns/input_layer/alpha_indicator/Sum, dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape/shape)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)
I figure that there's a problem with how I've prepared my input Tensor, but I'm stuck on how to best debug this.
Upvotes: 0
Views: 762
Reputation: 6751
The error message suggests that the shape of the input tensor in some operation isn't what is expected.
Looking at the Python notebook you linked to (particularly section 8a and 8c), it seems that the input tensor is expected to be a "batch" of string tensors, not a single string tensor.
You can observe this by comparing the shapes of the tensors in your Scala and Python program (inputTensor.shape()
in scala vs. the shape of csv_rows
provided to predict_fn
in the Python notebook).
From that, it seems what you want is for inputTensor
to be a vector of strings, not a single scalar string. To do that, you'd want to do something like:
val input = Array("0.5,1,ax01,bx02")
val inputTensor = Tensor.create(input.map(x => x.getBytes("UTF-8"))
Hope that helps
Upvotes: 1