Tianjin Gu
Tianjin Gu

Reputation: 784

How to save a tensor in checkpoint in Tensorflow?

I want to use tf.train.Saver() to make checkpoint of a tensor, here is my code snippet:

import tensorflow as tf

with tf.Graph().as_default():
    var = tf.Variable(tf.zeros([10]), name="biases")
    temp = tf.add(var, 0.1)
    init_op = tf.global_variables_initializer()

    saver = tf.train.Saver({'w':temp})

    with tf.Session() as sess:
        sess.run(init_op)
        print(sess.run(temp))

but got an error as follows:

Traceback (most recent call last):
  File "./test_counter.py", line 61, in <module>
    saver = tf.train.Saver({'w':temp})
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1043, in __init__
    self.build()
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 1073, in build
    restore_sequentially=self._restore_sequentially)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 649, in build
    saveables = self._ValidateAndSliceInputs(names_to_saveables)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/saver.py", line 578, in _ValidateAndSliceInputs
     variable)
   TypeError: names_to_saveables must be a dict mapping string names to Tensors/Variables. Not a variable: Tensor("TransformFeatureToIndex:0", shape=(100,), dtype=string)

One way I think about is storing the Tensor in client by sess.run(temp) and save, but is there a more significant way?

Upvotes: 3

Views: 3068

Answers (1)

phipsgabler
phipsgabler

Reputation: 20950

temp is not a tf.Variable, but an operation. It "has" no value or state, it is just a node in the graph. If you want to save the result of adding to var explicitely, you can assign temp to another variable by tf.assign and save this other variable. The easier way would probably be to save var (or the whole session), and after restoring just evaluate temp again.

Upvotes: 5

Related Questions