schil
schil

Reputation: 322

Re-train a frozen *.pb model in TensorFlow

How do I import a frozen protobuf to enable it for re-training?

All the methods i've found online expect checkpoints. Is there a way to read a protobuf such that kernel and bias constants are converted to variables?


Edit 1: This is similar to the following question: How to retrain model in graph (.pb)?

I looked at DeepSpeech, which was recommended in the answers to that question. They seem to have removed support for initialize_from_frozen_model. I couldn't find the reason.


Edit 2: I tried creating a new GraphDef object where I replace the kernels and biases with Variables:

probable_variables = [...] # kernels and biases of Conv2D and MatMul

new_graph_def = tf.GraphDef()

with tf.Session(graph=graph) as sess:
    for n in sess.graph_def.node:

        if n.name in probable_variables:
            # create variable op
            nn = new_graph_def.node.add()
            nn.name = n.name
            nn.op = 'VariableV2'
            nn.attr['dtype'].CopyFrom(attr_value_pb2.AttrValue(type=dtype))
            nn.attr['shape'].CopyFrom(attr_value_pb2.AttrValue(shape=shape))

        else:
            nn = new_model.node.add()
            nn.CopyFrom(n)

Not sure if I am on the right path. Don't know how to set trainable=True in a NodeDef object.

Upvotes: 12

Views: 6304

Answers (5)

Someshwar Kale
Someshwar Kale

Reputation: 1

def protobuf_to_checkpoint_conversion(pb_model, ckpt_dir):
    graph = tf.Graph()
    with graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(pb_model, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def,name='')
    image_tensor = graph.get_tensor_by_name('image_tensor:0')
    dummy = np.random.random((1, 512, 512, 3))
    with graph.as_default():
        config = tf.ConfigProto()
        with tf.Session(graph=graph, config=config) as sess:
            constant_ops = [op for op in graph.get_operations() if op.type == "Const"]
            vars_dict = {}
            ass = []
            for constant_op in constant_ops:
                name = constant_op.name
                const = constant_op.outputs[0]
                shape = const.shape
                var = tf.get_variable(name, shape, dtype=const.dtype, initializer=tf.zeros_initializer())
                vars_dict[name] = var
            print('INFO:Initializing variables')
            init = tf.global_variables_initializer()
            sess.run(init)
            print('INFO: Loading vars')
            for constant_op in tqdm(constant_ops):
                name = constant_op.name
                if 'FeatureExtractor' in name or 'BoxPredictor' in name:
                    const = constant_op.outputs[0]
                    shape = const.shape
                    var = vars_dict[name]
                    var.load(sess.run(const, feed_dict={image_tensor:dummy}), sess)
            saver = tf.train.Saver(var_list=vars_dict)
            ckpt_path = os.path.join(ckpt_dir, 'model.ckpt')
            saver.save(sess, ckpt_path)

reference: https://github.com/yeephycho/tensorflow-face-detection/issues/42#issuecomment-455325984

Upvotes: 0

Someshwar Kale
Someshwar Kale

Reputation: 1

def protobuf_to_checkpoint_conversion(pb_model, ckpt_dir):

graph = tf.Graph()
with graph.as_default():
    od_graph_def = tf.GraphDef()
    with tf.gfile.GFile(pb_model, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def,name='')

image_tensor = graph.get_tensor_by_name('image_tensor:0')
dummy = np.random.random((1, 512, 512, 3))

with graph.as_default():
    config = tf.ConfigProto()
    with tf.Session(graph=graph, config=config) as sess:
        constant_ops = [op for op in graph.get_operations() if op.type == "Const"]
        vars_dict = {}
        ass = []
        for constant_op in constant_ops:
            name = constant_op.name
            const = constant_op.outputs[0]
            shape = const.shape
            var = tf.get_variable(name, shape, dtype=const.dtype, initializer=tf.zeros_initializer())
            vars_dict[name] = var

        print('INFO:Initializing variables')
        init = tf.global_variables_initializer()
        sess.run(init)

        print('INFO: Loading vars')
        for constant_op in tqdm(constant_ops):
            name = constant_op.name
            if 'FeatureExtractor' in name or 'BoxPredictor' in name:
                const = constant_op.outputs[0]
                shape = const.shape
                var = vars_dict[name]
                var.load(sess.run(const, feed_dict={image_tensor:dummy}), sess)

        saver = tf.train.Saver(var_list=vars_dict)
        ckpt_path = os.path.join(ckpt_dir, 'model.ckpt')
        saver.save(sess, ckpt_path)

reference: https://github.com/yeephycho/tensorflow-face-detection/issues/42#issuecomment-455325984

Upvotes: 0

chunlei chen
chunlei chen

Reputation: 31

Thanks for @FalconUA and @Max wu. I added a way to quickly get the variables' names.

import tensorflow as tf


# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph


tf_graph = load_pb('mobilenet_v1_1.0_224_frozen_ccl.pb')
variables = [op.name for op in tf_graph.get_operations() if op.type == "Const"]
print(variables)

Upvotes: 3

Max Wu
Max Wu

Reputation: 146

I have verified @FalconUA's solution with tested code. Slight modifications were needed (notably, I use the initializer option in get_variable to properly initialize the Variables). Here it is!

Assuming your frozen model is stored in frozen_graph.pb:

probable_variables = [...] # kernels and biases of Conv2D and MatMul
tf_graph = load_pb('frozen_graph.pb')

const_var_name_pairs = []
with tf_graph.as_default() as g:

    for name in probable_variables:
        tensor = g.get_tensor_by_name('{}:0'.format(name))
        with tf.Session() as sess:
            tensor_as_numpy_array = sess.run(tensor)
        var_shape = tensor.get_shape()
        # Give each variable a name that doesn't already exist in the graph
        var_name = '{}_turned_var'.format(name)
        # Create TensorFlow variable initialized by values of original const.
        var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape, \  
                      initializer=tf.constant_initializer(tensor_as_numpy_array))
        # We want to keep track of our variables names for later.
        const_var_name_pairs.append((name, var_name))

    # At this point, we added a bunch of tf.Variables to the graph, but they're
    # not connected to anything.

    # The magic: we use TF Graph Editor to swap the Constant nodes' outputs with
    # the outputs of our newly created Variables.

    for const_name, var_name in const_var_name_pairs:
        const_op = g.get_operation_by_name(const_name)
        var_reader_op = g.get_operation_by_name(var_name + '/read')
        ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

Note: if you save the converted model and view it in Tensorboard or Netron, you will see that Variables have taken the Constants' places. You will also see a bunch of dangling Constants, which you can optionally remove.

I have verified that the weight values are the same between the frozen and unfrozen versions.

Here is the load_pb function:

import tensorflow as tf
# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

Upvotes: 4

Chan Kha Vu
Chan Kha Vu

Reputation: 10400

You were actually in the right direction with the snippet you provided :)


Step 1: get the name of previously trainable variables

The most tricky part is to get the names of previously trainable variables. Hopefully the model was created with some high-level frameworks, like keras or tf.slim - they wraps their variables nicely in something like conv2d_1/kernel, dense_1/bias, batch_normalization/gamma, etc.

If you're not sure, the most useful thing to do is to visualize the graph...

# read graph definition
with tf.gfile.GFile('frozen.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

# now build the graph in the memory and visualize it
with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name="prefix")
    writer = tf.summary.FileWriter('out', graph)
    writer.close()

... with tensorboard:

$ tensorboard --logdir out/

and see for yourself what the graph looks like and what the naming is.


Step 2: replace constants with variables (the fun part :D)

All you need is the magical library called tf.contrib.graph_editor. Now let's say you've stored the names of previously trainable ops (that previously were variables but now they are Const) in probable_variables (as in your Edit 2).

Note: remember the difference between ops, tensors, and variables. Ops are elements of the graph, tensor is a buffer that contains results of ops, and variables are wrappers around tensors, with 3 ops: assign (to be called when you initialize the variable), read (called by other ops, e.g. conv2d), and ref tensor (which holds the values).

Note 2: graph_editor can only be run outside a session – you cannot make any graph modification online!

import numpy as np
import tensorflow.contrib.graph_editor as ge

# load the graphdef into memory, just as in Step 1
graph = load_graph('frozen.pb')

# create a variable for each constant, beware the naming
const_var_name_pairs = []
for name in probable_variables:
    var_shape = graph.get_tensor_by_name('{}:0'.format(name)).get_shape()
    var_name = '{}_a'.format(name)
    var = tf.get_variable(name=var_name, shape=var_shape, dtype='float32')
    const_var_name_pairs.append((name, var_name))

# from now we're going to work with GraphDef
name_to_op = dict([(n.name, n) for n in graph.as_graph_def().node])

# magic: now we swap the outputs of const and created variable
for const_name, var_name in const_var_name_pairs:
    const_op = name_to_op[const_name]
    var_reader_op = name_to_op[var_name + '/read']
    ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

# Now we can safely create a session and copy the values
sess = tf.Session(graph=graph)
for const_name, var_name in const_var_name_pairs:
    ts = graph.get_tensor_by_name('{}:0'.format(const_name))
    var = tf.get_variable(var_name)
    var.load(ts.eval(sess))

# All done! Now you can make sure everything is correct by visualizing
# and calculate outputs for some inputs.

PS: this code was not tested; however, i've been using graph_editor and performing network surgery quite often lately, so I think it should mostly be correct :)

Upvotes: 8

Related Questions