Reputation: 515
I am working on GANs and decided to implement my algorithm using HyperGAN. Its a wrapper on DCGANs using TensorFlow. HyperGAN saves the output using TF
's checkpoint method.
Later, I tried to run the load the model using:
import tensorflow as tf
sess=tf.Session()
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))
sess.run(tf.global_variables_initializer())
However, since its a GAN, it needs a input Latent Vector and outputs an image. This is done using
out_image = sess.run(last_node, feed_dict(input_node: value))
But since I loaded the model, I do not know what the name of the last node is and what the name of the input node placeholders are. How do I obtain the names that was used to create the graph in the first place? I tried to visualize using TensorBoard
but the graph was large and hence, it got stuck.
Upvotes: 1
Views: 1855
Reputation: 2923
You should try to print the list of tensors within the graph:
with tf.Graph().as_default() as graph:
....
count = 0
for op in graph.get_operations():
print op.values()
count+=1
if count == 50:
assert False
in order to see the first 50 nodes of the graph, and you will see something like this:
(<tf.Tensor 'import/Placeholder_only:0' shape=<unknown> dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_max:0' shape=() dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_min:0' shape=() dtype=float32>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53_quantized_const:0' shape=(512,) dtype=quint8>,)
(<tf.Tensor 'import/MobileNet/conv_ds_8/dw_batch_norm/batchnorm/sub/_53__cf__53:0' shape=(512,) dtype=float32>,)
I put the count in there because usually the terminal prints so many tensors out that the initial input node name disappears in the terminal.
Finally, simply comment out the lines for counting to use:
#count = 0
for op in graph.get_operations():
print op.values()
#count+=1
#if count == 50:
# assert False
to get the last few nodes printed out (i.e. your output node).
Upvotes: 1