Reputation: 4511
I'm trying to extract all the weights/biases from a saved model output_graph.pb
.
I read the model:
def create_graph(modelFullPath):
"""Creates a graph from saved GraphDef file and returns a saver."""
# Creates graph from saved graph_def.pb.
with tf.gfile.FastGFile(modelFullPath, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
GRAPH_DIR = r'C:\tmp\output_graph.pb'
create_graph(GRAPH_DIR)
And attempted this hoping I would be able to extract all weights/biases within each layer.
with tf.Session() as sess:
all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print (len(all_vars))
However, I'm getting a value of 0 as the len.
Final goal is to extract the weights and biases and save it to a text file/np.arrays.
Upvotes: 7
Views: 5879
Reputation: 126154
The tf.import_graph_def()
function doesn't have enough information to reconstruct the tf.GraphKeys.TRAINABLE_VARIABLES
collection (for that, you would need a MetaGraphDef
). However, if output.pb
contains a "frozen" GraphDef
, then all of the weights will be stored in tf.constant()
nodes in the graph. To extract them, you can do something like the following:
create_graph(GRAPH_DIR)
constant_values = {}
with tf.Session() as sess:
constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"]
for constant_op in constant_ops:
constant_values[constant_op.name] = sess.run(constant_op.outputs[0])
Note that constant_values
will probably contain more values than just the weights, so you may need to filter further by op.name
or some other criterion.
Upvotes: 7