Ilya V. Schurov
Ilya V. Schurov

Reputation: 8047

Running subgraphs and feeding intermediate variables

I have an autoencoder model in Tensorflow that roughly can be written as (this is unrealistically simplified example):

x = tf.placeholder(tf.float32, input_shape, name='x')

# encoder part:
W = tf.Variable(tf.random_uniform(shape, -1, 1))
z = relu(tf.nn.conv2d(x, W, strides=[1, 2, 2, 1], padding='SAME'))

# decoder part:
y = relu(tf.nn.conv2d_transpose(z, W, shape_tr, 
         strides=[1, 2, 2, 1], padding='SAME'), b))
cost = tf.reduce_sum(tf.square(y - x))

So I have input placeholder x, intermediate representation z, weights matrix W and output y.

Then I train my model like this:

optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(optimizer, feed_dict={x: some_train_data})

Now given some test data I can check the output of the model:

recon = sess.run(y, feed_dict={x: some_test_data})

I can also get the intermediate representation for that data

latent = sess.run(z, feed_dict={x: some_test_data})

What I want is to be able to change my intermediate representation (z) and obtain the decoded results y. Something like this:

recon = sess.run(y, feed_dict={z: some_fake_z})

Of course, it doesn't work as z is not a placeholder, I have an error like You must feed a value for placeholder tensor 'x'. If I provide x, the results will not depend on z at all (which again is what we can expect).

So my question is: how can I run a subgraph that calculates y as a function of z and feed it with my own values of z?

Upvotes: 1

Views: 363

Answers (1)

Siyuan Ren
Siyuan Ren

Reputation: 7844

Create another subgraph with the same variable.

fake_z = tf.placeholder(z.dtype, z.name)
fake_y = relu(tf.nn.conv2d_transpose(fake_z, W, shape_tr, 
     strides=[1, 2, 2, 1], padding='SAME'), b))

Now you can sess.run(fake_y, {fake_z: my_values}).

If you use tf.layers, you should also use variable_scope to ensure layer weights are the same.

This is one of the limitations of static graph libraries. You have to plan ahead everything you need to calculate.

Upvotes: 2

Related Questions