Reputation: 2409
I'm doing some experimentation with TensorFlow, and have run into a snag. I'm trying to use TF to evalute a change in a model, then either retain or revert the model, based on the resultant change in loss function. I've got the hard part (conditional control) figured out, but I'm stuck on something that should be fairly straightforward: I can't seem to store a tf.trainable_variables
for an iteration, then restore it if needed.
Let's say a build an Op:
...
store_trainable_vars = []
for v in tf.trainable_variables():
store_trainable_vars.append(v)
...
Then later, I want to restore tf.trainable_variables
to the value it had when this Op was last run. I'd want to do something like:
def reject_move():
revert_state = []
for (v, s) in zip(tf.trainable_variables(), store_trainable_vars):
revert_state.append(tf.assign(v, s, name="revert_state"))
return(revert_state)
Obviously, this will re-evaluate store_trainable_vars
, which in turn links to the present value of tf.trainable_variables()
, obviating the revert_state
Op. I need some way to store and retrieve the value of Tensors without calling back to the present value of those Tensors. Something like
...
store_trainable_vars = []
for v in tf.trainable_variables():
store_trainable_vars.append(v.value_right_now())
...
where v.value_right_now()
returns a constant that won't change until overwritten.
I know I could use Saver, but that solution writes to the disk, which is not acceptable for this application as it will run inside a training loop.
I'm probably missing something obvious - any guidance would be appreciated.
Upvotes: 8
Views: 347
Reputation: 2409
It wasn't my original intent to answer this question myself, but I've come up with a method that works fairly well. So, I thought I'd share it. The key insight came from this very clever answer. The approach is to reuse the assignment nodes created for inital variable assignment. A complete class implementing that approach is given below.
import tensorflow as tf
class TensorFlowState(object):
def __init__(self):
# Get the graph.
graph = tf.get_default_graph()
# Extract the global varibles from the graph.
self.gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
# Exract the Assign operations for later use.
self.assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign")
for v in self.gvars]
# Extract the initial value ops from each Assign op for later use.
self.init_values = [op.inputs[1] for op in self.assign_ops]
def start(self, sess):
self.sess = sess
def store(self):
# Record the current state of the TF global varaibles
self.state = self.sess.run(self.gvars)
def restore(self):
# Create a dictionary of the iniailizers and stored state of globals.
feed_dict = {init_value: val
for init_value, val in zip(self.init_values, self.state)}
# Use the initializer ops for each variable to load the stored values.
return(self.sess.run(self.assign_ops, feed_dict=feed_dict))
To use, simply instantiate the class, call the start
method to pass a tf.Session
, and call the store
and restore
methods as needed inside your imperative training loop. I've used this implementation to build an optimizer, which runs about as fast as the gradient descent optimizers included with TensorFlow.
Upvotes: 1
Reputation: 53758
To restore a graph state manually you need to use tf.tuple
or tf.group
operation, that will modify the flow for a bulk change:
This creates a tuple of tensors with the same values as the tensors argument, except that the value of each tensor is only returned after the values of all tensors have been computed.
[Update] Here's how I would do it:
import numpy as np
import tensorflow as tf
x = tf.placeholder(shape=[None, 5], dtype=tf.float32, name='x')
W = tf.Variable(np.zeros([5, 5]), dtype=tf.float32, name='W')
b = tf.Variable(np.zeros([5]), dtype=tf.float32, name='b')
y = tf.add(tf.matmul(x, W), b)
with tf.Session() as session:
batch = np.ones([2, 5])
session.run(tf.global_variables_initializer())
print session.run(y, feed_dict={x: batch}) # prints [2, 5] zeros
# store the current value
store = {v.name: v.eval(session) for v in tf.trainable_variables()}
print store # prints [5, 5] and [5] zeros
# update
new = {'W:0': np.ones([5, 5]), 'b:0': np.ones([5])}
session.run(tf.tuple([tf.assign(var, new[var.name]) for var in tf.trainable_variables()]))
print session.run(y, feed_dict={x: batch}) # prints [2, 5] sixes
# restore
session.run(tf.tuple([tf.assign(var, store[var.name]) for var in tf.trainable_variables()]))
print session.run(y, feed_dict={x: batch}) # prints [2, 5] zeros again
But I really think you should reconsider your decision about Saver
, because it was designed to be used inside a training loop as well. Internally, Saver
does all the tricky work for you (in particular, it's restore op calls tf.group
and tf.control_dependencies
if needed), which may otherwise become the source of pretty nasty bugs. Besides, the disk is (almost) always bigger than your GPU and main memory, so if you can afford to store the model in memory, you should be able to store on disk as well.
Here are some parameters that help to control the proliferation of checkpoint files on disk:
max_to_keep
indicates the maximum number of recent checkpoint files to
keep. As new files are created, older files are deleted. If None or 0, all checkpoint files are kept. Defaults to 5 (that is, the 5 most recent
checkpoint files are kept).keep_checkpoint_every_n_hours
: In addition to keeping the most recent
max_to_keep
checkpoint files, you might want to keep one checkpoint file
for every N hours of training. This can be useful if you want to later
analyze how a model progressed during a long training session. For
example, passing keep_checkpoint_every_n_hours=2
ensures that you keep one checkpoint file for every 2 hours of training. The default value of 10,000 hours effectively disables the feature.[Update] As clarified in the comments, the main concern is disk latency, that may slow down the training if accessed too often. If you're using Linux, it caches frequently used disk pages, Windows does it as well. But if you want to be absolutely sure, consider using tmpfs
.
Upvotes: 5