mrgloom
mrgloom

Reputation: 21622

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) of loaded .pb file is empy

I'm trying to look at variable list of loaded .pb file, but for some reason it's empty.

Here is the code:

import tensorflow as tf

tf_model_path = './tf_coreml_ssd_resources/ssd_mobilenet_v1_android_export.pb'

with open(tf_model_path, 'rb') as f:
    serialized = f.read()

tf.reset_default_graph()

original_gdef = tf.GraphDef()
original_gdef.ParseFromString(serialized)

# V1
with tf.Graph().as_default() as g:
    print('type(g)', type(g)) # type(g) <class 'tensorflow.python.framework.ops.Graph'>

    tf.import_graph_def(original_gdef, name='')

    model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
    print('type(model_vars)', type(model_vars))
    print('model_vars', model_vars)

# V2
graph = tf.import_graph_def(original_gdef, name='')

print('type(graph)', type(graph)) # why type(graph) <class 'NoneType'> ?

model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print('type(model_vars)', type(model_vars))
print('model_vars', model_vars)

Also why is in V2 case I get type(graph) <class 'NoneType'> ?

Upvotes: 1

Views: 531

Answers (1)

javidcf
javidcf

Reputation: 59701

The GraphDef object serialized to the .pb file does not contain collections information. If you want to store a graph along with its metadata (including the collections), you should save a MetaGraphDef instead (see tf.train.export_meta_graph / tf.train.import_meta_graph).

In your V2 code, graph is None because tf.import_graph_def does not return anything, it just imports the nodes in the given graph definition into the current default graph.

As a side comment, note that graph collections are being deprecated in TensorFlow 2.x.

Upvotes: 1

Related Questions