Yanjun
Yanjun

Reputation: 51

How to count total number of trainable parameters in a tensorflow model defined with graph loaded from .pb file?

I want to count parameters in a tensorflow model. It is similar to the existing question as follows.

How to count total number of trainable parameters in a tensorflow model?

But if the model is defined with a graph loaded from .pb file, all the proposed answers don't work. Basically I loaded the graph with the following function.

def load_graph(model_file):

  graph = tf.Graph()
  graph_def = tf.GraphDef()

  with open(model_file, "rb") as f:
    graph_def.ParseFromString(f.read())

  with graph.as_default():
    tf.import_graph_def(graph_def)

  return graph

One example is to load a frozen_graph.pb file for retraining purpose in tensorflow-for-poets-2.

https://github.com/googlecodelabs/tensorflow-for-poets-2

Upvotes: 4

Views: 1740

Answers (1)

Y. Luo
Y. Luo

Reputation: 5722

To my understanding, a GraphDef doesn't have enough information to describe Variables. As explained here, you will need MetaGraph, which contain both GraphDef and CollectionDef which is a map that can describe Variables. So the following code should give us the correct trainable variable count.

Export MetaGraph:

import tensorflow as tf

a = tf.get_variable('a', shape=[1])
b = tf.get_variable('b', shape=[1], trainable=False)
init = tf.global_variables_initializer()
saver = tf.train.Saver([a])

with tf.Session() as sess:
    sess.run(init)
    saver.save(sess, r'.\test')

Import MetaGraph and count total number of trainable parameters.

import tensorflow as tf

saver = tf.train.import_meta_graph('test.meta')

with tf.Session() as sess:
    saver.restore(sess, 'test')

total_parameters = 0
for variable in tf.trainable_variables():
    total_parameters += 1
print(total_parameters)

Upvotes: 1

Related Questions