Reputation: 51
I want to count parameters in a tensorflow model. It is similar to the existing question as follows.
How to count total number of trainable parameters in a tensorflow model?
But if the model is defined with a graph loaded from .pb file, all the proposed answers don't work. Basically I loaded the graph with the following function.
def load_graph(model_file):
graph = tf.Graph()
graph_def = tf.GraphDef()
with open(model_file, "rb") as f:
graph_def.ParseFromString(f.read())
with graph.as_default():
tf.import_graph_def(graph_def)
return graph
One example is to load a frozen_graph.pb file for retraining purpose in tensorflow-for-poets-2.
https://github.com/googlecodelabs/tensorflow-for-poets-2
Upvotes: 4
Views: 1740
Reputation: 5722
To my understanding, a GraphDef
doesn't have enough information to describe Variables
. As explained here, you will need MetaGraph
, which contain both GraphDef
and CollectionDef
which is a map that can describe Variables
. So the following code should give us the correct trainable variable count.
Export MetaGraph:
import tensorflow as tf
a = tf.get_variable('a', shape=[1])
b = tf.get_variable('b', shape=[1], trainable=False)
init = tf.global_variables_initializer()
saver = tf.train.Saver([a])
with tf.Session() as sess:
sess.run(init)
saver.save(sess, r'.\test')
Import MetaGraph and count total number of trainable parameters.
import tensorflow as tf
saver = tf.train.import_meta_graph('test.meta')
with tf.Session() as sess:
saver.restore(sess, 'test')
total_parameters = 0
for variable in tf.trainable_variables():
total_parameters += 1
print(total_parameters)
Upvotes: 1