Justin Fletcher
Justin Fletcher

Reputation: 2409

How can I restore Tensors to a past value, without saving the value to disk?

I'm doing some experimentation with TensorFlow, and have run into a snag. I'm trying to use TF to evalute a change in a model, then either retain or revert the model, based on the resultant change in loss function. I've got the hard part (conditional control) figured out, but I'm stuck on something that should be fairly straightforward: I can't seem to store a tf.trainable_variables for an iteration, then restore it if needed.

Let's say a build an Op:

...
store_trainable_vars = []

for v in tf.trainable_variables():

    store_trainable_vars.append(v)
...

Then later, I want to restore tf.trainable_variables to the value it had when this Op was last run. I'd want to do something like:

 def reject_move():

    revert_state = []

    for (v, s) in zip(tf.trainable_variables(), store_trainable_vars):

        revert_state.append(tf.assign(v, s, name="revert_state"))

    return(revert_state)

Obviously, this will re-evaluate store_trainable_vars, which in turn links to the present value of tf.trainable_variables(), obviating the revert_state Op. I need some way to store and retrieve the value of Tensors without calling back to the present value of those Tensors. Something like

...
store_trainable_vars = []

for v in tf.trainable_variables():

    store_trainable_vars.append(v.value_right_now())
...

where v.value_right_now() returns a constant that won't change until overwritten.

I know I could use Saver, but that solution writes to the disk, which is not acceptable for this application as it will run inside a training loop.

I'm probably missing something obvious - any guidance would be appreciated.

Upvotes: 8

Views: 347

Answers (2)

Justin Fletcher
Justin Fletcher

Reputation: 2409

It wasn't my original intent to answer this question myself, but I've come up with a method that works fairly well. So, I thought I'd share it. The key insight came from this very clever answer. The approach is to reuse the assignment nodes created for inital variable assignment. A complete class implementing that approach is given below.

import tensorflow as tf


class TensorFlowState(object):

    def __init__(self):

        # Get the graph.
        graph = tf.get_default_graph()

        # Extract the global varibles from the graph.
        self.gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

        # Exract the Assign operations for later use.
        self.assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign")
                           for v in self.gvars]

        # Extract the initial value ops from each Assign op for later use.
        self.init_values = [op.inputs[1] for op in self.assign_ops]

    def start(self, sess):

        self.sess = sess

    def store(self):

        # Record the current state of the TF global varaibles
        self.state = self.sess.run(self.gvars)

    def restore(self):
    # Create a dictionary of the iniailizers and stored state of globals.
    feed_dict = {init_value: val
                 for init_value, val in zip(self.init_values, self.state)}

    # Use the initializer ops for each variable to load the stored values.
    return(self.sess.run(self.assign_ops, feed_dict=feed_dict))

To use, simply instantiate the class, call the start method to pass a tf.Session, and call the store and restore methods as needed inside your imperative training loop. I've used this implementation to build an optimizer, which runs about as fast as the gradient descent optimizers included with TensorFlow.

Upvotes: 1

Maxim
Maxim

Reputation: 53758

To restore a graph state manually you need to use tf.tuple or tf.group operation, that will modify the flow for a bulk change:

This creates a tuple of tensors with the same values as the tensors argument, except that the value of each tensor is only returned after the values of all tensors have been computed.

[Update] Here's how I would do it:

import numpy as np
import tensorflow as tf

x = tf.placeholder(shape=[None, 5], dtype=tf.float32, name='x')
W = tf.Variable(np.zeros([5, 5]), dtype=tf.float32, name='W')
b = tf.Variable(np.zeros([5]), dtype=tf.float32, name='b')
y = tf.add(tf.matmul(x, W), b)

with tf.Session() as session:
  batch = np.ones([2, 5])
  session.run(tf.global_variables_initializer())
  print session.run(y, feed_dict={x: batch})      # prints [2, 5] zeros

  # store the current value
  store = {v.name: v.eval(session) for v in tf.trainable_variables()}
  print store                                     # prints [5, 5] and [5] zeros

  # update
  new = {'W:0': np.ones([5, 5]), 'b:0': np.ones([5])}
  session.run(tf.tuple([tf.assign(var, new[var.name]) for var in tf.trainable_variables()]))
  print session.run(y, feed_dict={x: batch})      # prints [2, 5] sixes

  # restore
  session.run(tf.tuple([tf.assign(var, store[var.name]) for var in tf.trainable_variables()]))
  print session.run(y, feed_dict={x: batch})      # prints [2, 5] zeros again

But I really think you should reconsider your decision about Saver, because it was designed to be used inside a training loop as well. Internally, Saver does all the tricky work for you (in particular, it's restore op calls tf.group and tf.control_dependencies if needed), which may otherwise become the source of pretty nasty bugs. Besides, the disk is (almost) always bigger than your GPU and main memory, so if you can afford to store the model in memory, you should be able to store on disk as well.

Here are some parameters that help to control the proliferation of checkpoint files on disk:

  • max_to_keep indicates the maximum number of recent checkpoint files to keep. As new files are created, older files are deleted. If None or 0, all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent checkpoint files are kept).
  • keep_checkpoint_every_n_hours: In addition to keeping the most recent max_to_keep checkpoint files, you might want to keep one checkpoint file for every N hours of training. This can be useful if you want to later analyze how a model progressed during a long training session. For example, passing keep_checkpoint_every_n_hours=2 ensures that you keep one checkpoint file for every 2 hours of training. The default value of 10,000 hours effectively disables the feature.

[Update] As clarified in the comments, the main concern is disk latency, that may slow down the training if accessed too often. If you're using Linux, it caches frequently used disk pages, Windows does it as well. But if you want to be absolutely sure, consider using tmpfs.

Upvotes: 5

Related Questions