Reputation: 21622
I'm trying to look at variable list of loaded .pb
file, but for some reason it's empty.
Here is the code:
import tensorflow as tf
tf_model_path = './tf_coreml_ssd_resources/ssd_mobilenet_v1_android_export.pb'
with open(tf_model_path, 'rb') as f:
serialized = f.read()
tf.reset_default_graph()
original_gdef = tf.GraphDef()
original_gdef.ParseFromString(serialized)
# V1
with tf.Graph().as_default() as g:
print('type(g)', type(g)) # type(g) <class 'tensorflow.python.framework.ops.Graph'>
tf.import_graph_def(original_gdef, name='')
model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print('type(model_vars)', type(model_vars))
print('model_vars', model_vars)
# V2
graph = tf.import_graph_def(original_gdef, name='')
print('type(graph)', type(graph)) # why type(graph) <class 'NoneType'> ?
model_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
print('type(model_vars)', type(model_vars))
print('model_vars', model_vars)
Also why is in V2 case I get type(graph) <class 'NoneType'>
?
Upvotes: 1
Views: 531
Reputation: 59701
The GraphDef
object serialized to the .pb
file does not contain collections information. If you want to store a graph along with its metadata (including the collections), you should save a MetaGraphDef
instead (see tf.train.export_meta_graph
/ tf.train.import_meta_graph
).
In your V2
code, graph
is None
because tf.import_graph_def
does not return anything, it just imports the nodes in the given graph definition into the current default graph.
As a side comment, note that graph collections are being deprecated in TensorFlow 2.x.
Upvotes: 1