canonball
canonball

Reputation: 515

How do you obtain the output and input values of a model in tensorflow?

I am working on GANs and decided to implement my algorithm using HyperGAN. Its a wrapper on DCGANs using TensorFlow. HyperGAN saves the output using TF's checkpoint method.

Later, I tried to run the load the model using:

import tensorflow as tf
sess=tf.Session()    
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
sess.run(tf.global_variables_initializer())

However, since its a GAN, it needs a input Latent Vector and outputs an image. This is done using

out_image = sess.run(last_node, feed_dict(input_node: value))

But since I loaded the model, I do not know what the name of the last node is and what the name of the input node placeholders are. How do I obtain the names that was used to create the graph in the first place? I tried to visualize using TensorBoard but the graph was large and hence, it got stuck.

Upvotes: 1

Views: 1855

Answers (1)

kwotsin
kwotsin

Reputation: 2923

You should try to print the list of tensors within the graph:

with tf.Graph().as_default() as graph:
....

count = 0
for op in graph.get_operations():
    print op.values()
    count+=1
    if count == 50:
        assert False

in order to see the first 50 nodes of the graph, and you will see something like this:

(<tf.Tensor 'import/Placeholder_only:0' shape=<unknown> dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_max:0' shape=() dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_min:0' shape=() dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_const:0' shape=(512,) dtype=quint8>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53:0' shape=(512,) dtype=float32>,)

I put the count in there because usually the terminal prints so many tensors out that the initial input node name disappears in the terminal.

Finally, simply comment out the lines for counting to use:

#count = 0
for op in graph.get_operations():
    print op.values()
    #count+=1
    #if count == 50:
    #    assert False

to get the last few nodes printed out (i.e. your output node).

Upvotes: 1

Related Questions