Amir
Amir

Reputation: 1428

restoring weights of an already saved Tensorflow .pb model

I have seen many posts about restoring an already saved TF models here, but none could answer my question. Using TF 1.0.0

Specifically, I am interested in seeing the weights for inceptionv3 model which is publicly available in .pb file here. I managed to restore it back using a small chunk of Python code and can access the graphs high-level view in tensorboard:

from tensorflow.python.platform import gfile

INCEPTION_LOG_DIR = '/tmp/inception_v3_log'

if not os.path.exists(INCEPTION_LOG_DIR):
    os.makedirs(INCEPTION_LOG_DIR)
with tf.Session() as sess:
    model_filename = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb'
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        _= tf.import_graph_def(graph_def,name='')
    writer = tf.train.SummaryWriter(INCEPTION_LOG_DIR, graph_def)
    writer=tf.summary.FileWriter(INCEPTION_LOG_DIR, graph_def)
    writer.close()

However, I failed to access any layers' weights.

tensors= tf.import_graph_def(graph_def,name='')

returns empty, even if I add the arbitrary return_elements=. Does it have any weights at all? If yes, what is the appropriate procedure here? Thanks.

Upvotes: 2

Views: 4767

Answers (4)

mrgloom
mrgloom

Reputation: 21612

Just small utils to print .pb model weights:

import argparse

import tensorflow as tf
from tensorflow.python.framework import tensor_util


def print_pb_weights(pb_filepath):
    graph_def = tf.GraphDef()
    with tf.gfile.GFile(pb_filepath, "rb") as f:
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')

    for node in graph_def.node:
        if node.op == 'Const':
            print('-' * 60)
            print('op:', node.op)
            print('name:', node.name)
            arr = tensor_util.MakeNdarray(node.attr['value'].tensor)
            print('shape:', list(arr.shape))
            print(arr)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('pb_filepath')
    args = parser.parse_args()

    print_pb_weights(args.pb_filepath)

Upvotes: 0

Krist
Krist

Reputation: 477

There is a difference between restoring weights and printing them. The former one denotes that one would like to import the weight values from already saved ckpt files for retraining or inference while the latter may be for inspection. Also .pb file encodes model parameters as tf.constant() ops. As a result, the model parameters would not appear in tf.trainable_variables(), hence you can't use .pb directly to restore the weights. From your question I take that you just want to 'see' the weights for inspection.

Let us first load the graph from .pb file.

import tensorflow as tf
from tensorflow.python.platform import gfile

GRAPH_PB_PATH = './model/tensorflow_inception_v3_stripped_optimized_quantized.pb' #path to your .pb file
with tf.Session(config=config) as sess:
  print("load graph")
  with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')
    graph_nodes=[n for n in graph_def.node]

Now when you freeze a graph to .pb file your variables are converted to Const type and the weights which were trainabe variables would also be stored as Const in .pb file. graph_nodes contains all the nodes in graph. But we are interested in all the Const type nodes.

wts = [n for n in graph_nodes if n.op=='Const']

Each element of wts is of NodeDef type. It has several atributes such as name, op etc. The values can be extracted as follows -

from tensorflow.python.framework import tensor_util

for n in wts:
    print "Name of the node - %s" % n.name
    print "Value - " 
    print tensor_util.MakeNdarray(n.attr['value'].tensor)

Hope this solves your concern.

Upvotes: 3

vahid mahmoodi
vahid mahmoodi

Reputation: 118

use this code to print your tensor's value :

with tf.Session() as sess:
    print sess.run('your_tensor_name')

you can use this code to retrieve tensor names:

    op = sess.graph.get_operations()
    for m in op : 
    print(m.values())

Upvotes: 4

Beta
Beta

Reputation: 1746

You can use this code to get the names of tensor.

[tensor.name for tensor in tf.get_default_graph().as_graph_def().node]

Upvotes: 1

Related Questions