dunstantom
dunstantom

Reputation: 571

Passing data to Tensorflow model in Java

I'm trying to use a Tensorflow model that I trained in python to score data in Scala (using TF Java API). For the model, I've used thisregression example, with the only change being that I dropped asText=True from export_savedmodel.

My snippet of Scala:

  val b = SavedModelBundle.load("/tensorflow/tf-estimator-tutorials/trained_models/reg-model-01/export/1531933435/", "serve")
  val s = b.session()

  // output = predictor_fn({'csv_rows': ["0.5,1,ax01,bx02", "-0.5,-1,ax02,bx02"]})
  val input = "0.5,1,ax01,bx02"

  val inputTensor = Tensor.create(input.getBytes("UTF-8"))

  val result = s.runner()
    .feed("csv_rows", inputTensor)
    .fetch("dnn/logits/BiasAdd")
    .run()
    .get(0)

When I run, I get the following error:

Exception in thread "main" java.lang.IllegalArgumentException: Input to reshape is a tensor with 2 values, but the requested shape has 4
 [[Node: dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _output_shapes=[[?,2]], _device="/job:localhost/replica:0/task:0/device:CPU:0"](dnn/input_from_feature_columns/input_layer/alpha_indicator/Sum, dnn/input_from_feature_columns/input_layer/alpha_indicator/Reshape/shape)]]
at org.tensorflow.Session.run(Native Method)
at org.tensorflow.Session.access$100(Session.java:48)
at org.tensorflow.Session$Runner.runHelper(Session.java:298)
at org.tensorflow.Session$Runner.run(Session.java:248)

I figure that there's a problem with how I've prepared my input Tensor, but I'm stuck on how to best debug this.

Upvotes: 0

Views: 762

Answers (1)

ash
ash

Reputation: 6751

The error message suggests that the shape of the input tensor in some operation isn't what is expected.

Looking at the Python notebook you linked to (particularly section 8a and 8c), it seems that the input tensor is expected to be a "batch" of string tensors, not a single string tensor.

You can observe this by comparing the shapes of the tensors in your Scala and Python program (inputTensor.shape() in scala vs. the shape of csv_rows provided to predict_fn in the Python notebook).

From that, it seems what you want is for inputTensor to be a vector of strings, not a single scalar string. To do that, you'd want to do something like:

val input = Array("0.5,1,ax01,bx02")
val inputTensor = Tensor.create(input.map(x => x.getBytes("UTF-8"))

Hope that helps

Upvotes: 1

Related Questions