Moondra
Moondra

Reputation: 4511

tf.GraphKeys.TRAINABLE_VARIABLES on output_graph.pb resulting in empty list

I'm trying to extract all the weights/biases from a saved model output_graph.pb.

I read the model:

def create_graph(modelFullPath):
    """Creates a graph from saved GraphDef file and returns a saver."""
    # Creates graph from saved graph_def.pb.
    with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

GRAPH_DIR = r'C:\tmp\output_graph.pb'
create_graph(GRAPH_DIR)

And attempted this hoping I would be able to extract all weights/biases within each layer.

with tf.Session() as sess:
    all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    print (len(all_vars))

However, I'm getting a value of 0 as the len.

Final goal is to extract the weights and biases and save it to a text file/np.arrays.

Upvotes: 7

Views: 5879

Answers (1)

mrry
mrry

Reputation: 126154

The tf.import_graph_def() function doesn't have enough information to reconstruct the tf.GraphKeys.TRAINABLE_VARIABLES collection (for that, you would need a MetaGraphDef). However, if output.pb contains a "frozen" GraphDef, then all of the weights will be stored in tf.constant() nodes in the graph. To extract them, you can do something like the following:

create_graph(GRAPH_DIR)

constant_values = {}

with tf.Session() as sess:
  constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
  for constant_op in constant_ops:
    constant_values[constant_op.name] = sess.run(constant_op.outputs[0])

Note that constant_values will probably contain more values than just the weights, so you may need to filter further by op.name or some other criterion.

Upvotes: 7

Related Questions