cerebrou
cerebrou

Reputation: 5540

List of restored variables in TensorFlow

Before I restore values of the variables from the ckpt file, I need to create these variables in the file where I restore the variables. Yet in this new file, there can be some other variables that are not in the ckpt file. Is it possible to print just a list of variables that are restored (tf.all_variables would not work in this case)?

Upvotes: 0

Views: 2744

Answers (3)

user1670642
user1670642

Reputation: 121

use print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors)

Args: file_name: Name of the checkpoint file. tensor_name: Name of the tensor in the checkpoint file to print. all_tensors: Boolean indicating whether to print all tensors.

Upvotes: 1

sagunms
sagunms

Reputation: 8515

You can use inspect_checkpoint.py tool to list the variables in your checkpoint you are restoring.

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# List ALL tensors.
print_tensors_in_checkpoint_file(file_name='./model.ckpt', tensor_name='')

# List contents of a specific tensor.
print_tensors_in_checkpoint_file(file_name='./model.ckpt', tensor_name='conv1_w')

Another method:

from tensorflow.python import pywrap_tensorflow
reader = pywrap_tensorflow.NewCheckpointReader('./model.ckpt')
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
    print("tensor_name: ", key)
    print(reader.get_tensor(key)) # Remove this is you want to print only variable names

List all global variables from current graph:

for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
    print(v)

Hope it helps.

Upvotes: 3

Neal
Neal

Reputation: 942

If you want a list of variables that can be saved, I believe you can use this:

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) + tf.get_collection(tf.GraphKeys.SAVEABLE_OBJECTS)

Upvotes: 1

Related Questions