Reputation: 13475
I want to paste an existing tensorflow graph into a new graph.
Suppose I create a graph computing y = tanh(x @ w)
import tensorflow as tf
import numpy as np
def some_function(x):
w = tf.Variable(initial_value=np.random.randn(4, 5), dtype=tf.float32)
return tf.tanh(x @ w)
x = tf.placeholder(shape=(None, 4), dtype = tf.float32)
y = some_function(x)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
val_x = np.random.randn(3, 4)
val_y, = sess.run([y], feed_dict={x: val_x})
Great. Now suppose I've lost the code that generated that graph, but I still have access to variables (x
, y
). Now I want to take this graph (using the current value of w), and copy it twice into a new graph (the two paths should share the same w
), so that I now compute d = tf.reduce_sum((tanh(x1 @ w)-tanh(x2 @ w))**2)
by adding the line:
# Starting with access to tensors: x, y
<SOMETHING HERE>
d = tf.reduce_sum((y1-y2)**2)
val_x1 = np.random.randn(3, 4)
val_x2 = np.random.randn(3, 4)
val_d = sess.run([d], feed_dict = {x1: val_x1, x2: val_x2})
What do I fill in for <SOMETHING HERE>
to make this work? (Obviously, without recreating the first graph)
Upvotes: 1
Views: 994
Reputation: 59681
There is the Graph Editor module to help with this sort of operations. Its main disadvantage is that you cannot have a running session while you modify the graph. However, you can checkpoint the session, modify the graph and the restore it back if you need so.
The problem with what you want is that you basically need to replicate a subgraph except you do no want to replicate variables. So you can simply exclude variable types (mainly Variable
, VariableV2
and maybe VarHandleOp
, although I threw in a few more I found in TensorFlow code). You can do that with a function like this:
import tensorflow as tf
# Receives the outputs to recalculate and the input replacements
def replicate_subgraph(outputs, mappings):
# Types of operation that should not be replicated
# Taken from tensorflow/python/training/device_setter.py
NON_REPLICABLE = {'Variable', 'VariableV2', 'AutoReloadVariable',
'MutableHashTable', 'MutableHashTableV2',
'MutableHashTableOfTensors', 'MutableHashTableOfTensorsV2',
'MutableDenseHashTable', 'MutableDenseHashTableV2',
'VarHandleOp', 'BoostedTreesEnsembleResourceHandleOp'}
# Find subgraph ops
ops = tf.contrib.graph_editor.get_backward_walk_ops(outputs, stop_at_ts=mappings.keys())
# Exclude non-replicable operations
ops_replicate = [op for op in ops if op.type not in NON_REPLICABLE]
# Make subgraph viewitems
sgv = tf.contrib.graph_editor.make_view(*ops_replicate)
# Make the copy
_, info = tf.contrib.graph_editor.copy_with_input_replacements(sgv, mappings)
# Return new outputs
return info.transformed(outputs)
For an example similar to yours (I edited it a bit so it is easy to see that the output is correct because the second value is ten times the first one).
import tensorflow as tf
def some_function(x):
w = tf.Variable(initial_value=tf.random_normal((5,)), dtype=tf.float32)
return 2 * (x * w)
x1 = tf.placeholder(shape=(), dtype=tf.float32, name='X1')
x2 = tf.placeholder(shape=(), dtype=tf.float32, name='X2')
y1 = some_function(x1)
y2, = replicate_subgraph([y1], {x1: x2})
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(*sess.run([y1, y2], feed_dict={x1: 1, x2: 10}), sep='\n')
Output:
[ 2.3356955 2.277849 0.58513653 2.0919807 -0.15102367]
[23.356955 22.77849 5.851365 20.919807 -1.5102367]
EDIT:
Here is another solution using tf.make_template
. This requires you to actually have the code for the function, but it is a cleaner and "more official" way of supporting subgraph reuse.
import tensorflow as tf
def some_function(x):
w = tf.get_variable('W', (5,), initializer=tf.random_normal_initializer())
# Or if the variable is only local and not trainable
# w = tf.Variable(initial_value=tf.random_normal(5,), dtype=tf.float32, trainable=False)
return 2 * (x * w)
x1 = tf.placeholder(shape=(), dtype=tf.float32, name='X1')
x2 = tf.placeholder(shape=(), dtype=tf.float32, name='X2')
some_function_tpl = tf.make_template('some_function', some_function)
y1 = some_function_tpl(x1)
y2 = some_function_tpl(x2)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(*sess.run([y1, y2], feed_dict={x1: 1, x2: 10}), sep='\n')
Upvotes: 1