Mohamed Zouga
Mohamed Zouga

Reputation: 17

Using Python Tensor of TensorFlow in Java

I have a Tensorflow program running in Python, and for some convenience reasons I want to run the same program on Java, so I have to save my model and load it in my Java application.

My problem is that a don't know how to save a Tensor object, here is my code :

class Main:
def __init__(self, checkpoint):
    ...
    self.g = tf.Graph()
    self.sess = tf.Session()

    self.img_placeholder = tf.placeholder(tf.float32, 
    shape=(1, 679, 1024, 3), name='img_placeholder')

    #self.preds is an instance of Tensor
    self.preds = transform(self.img_placeholder)

    self.saver = tf.train.Saver()
    self.saver.restore(self.sess, checkpoint)

def ffwd(...):

    ...
    _preds = self.sess.run(self.preds, feed_dict=
    {self.img_placeholder: self.X})

    ...

So since I can't create my Tensor (the transform function creates the NN behind the scenes...), I'am obliged to save it and reload it into Java. I have found ways of saving the session but not Tensor instances.

Could someone give me some insights on how to achieve this ?

Upvotes: 0

Views: 373

Answers (1)

ash
ash

Reputation: 6751

Python Tensor objects are symbolic references to a specific output of an operation in the graph.

An operation in a graph can be uniquely identified by its string name. A specific output of that operation is identified by an integer index into the list of outputs of that operation. That index is typically zero since a vast majority of operations produce a single output.

To obtain the name of an Operation and the output index referred to by a Tensor object in Python you could do something like:

print(preds.op.name)
print(preds.value_index)  # Most likely will be 0

And then in Java, you can feed/fetch nodes by name. Let's say preds.op.name returned the string foo, and preds.value_index returned the integer 1, then in Java, you'd do the following:

session.runner().feed("img_placeholder").fetch("foo", 1)

(See javadoc for org.tensorflow.Session.Runner for details).

You may find the slides linked to in https://github.com/tensorflow/models/tree/master/samples/languages/java along with the speaker notes in those slides useful.

Hope that helps.

Upvotes: 2

Related Questions