batuman
batuman

Reputation: 7304

Add new layers to Tensorflow freeze_graph?

These discussion talked (1,2) about adding new layers to Tensorflow graph and retrain the model.

And the following code shows to add in new layer to restored trainable model.

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

#Add more to the current graph
add_on_op = tf.multiply(op_to_restore,2)

print sess.run(add_on_op,feed_dict)
#This will print 120.

But I like to add in layers to restored frozen graph.

I have frozen model only for an application. I like to add in layers to the model and freeze again. Those layers are more for post processing and not necessary to train so not in the trained model.

The reason why is I am converting the freeze graph to TensorRT and I like to include those layers into Int8 engine.

Upvotes: 0

Views: 881

Answers (1)

Gaurav Joshi
Gaurav Joshi

Reputation: 71

I hope below will help you. I have a custom Op which was supposed to be added to my existing graph which i loaded from .pb file (freezed model file) With this i was able to append new nodes to my existing graph.

Source code below: 

import tensorflow as tf
from tensorflow.python.framework import load_library
from tensorflow.python.platform import resource_loader

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat


# Utility functions for Loading and Freezing graphs


def load_graph(frozen_graph_filename):

    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")

    return graph

def freeze_graph(sess, output_graph):

    output_node_names = [
        "custom_op_zero","custom_op_zero_1"
  output_node_names = ",".join(output_node_names)

    output_graph_def = tf.graph_util.convert_variables_to_constants(
        sess,
        tf.get_default_graph().as_graph_def(),
        output_node_names.split(",")
    )

    with tf.gfile.GFile(output_graph, "wb") as f:
        f.write(output_graph_def.SerializeToString())

    print("{} ops written to {}.".format(len(output_graph_def.node), output_graph))


## load custom Ops shared object file

zero_out_ops = load_library.load_op_library(
    resource_loader.get_path_to_datafile('my-op/tensorflow_zero_out/python/ops/_zero_out_ops.so'))
zero_out = zero_out_ops.zero_out

frozen_graph = load_graph("frozen_model.pb")
all_tensors = [tensor for op in frozen_graph.get_operations() for tensor in op.values()]
#print (all_tensors[29])

# Input to the new node is the output of last node

zero_out_custom = zero_out(all_tensors[-1],name="custom_op_zero")
zero_out_custom1 = zero_out(all_tensors[-1],name="custom_op_zero_1")
#print (new_op)

# save new freezed model file
with tf.Session(graph=frozen_graph) as persisted_sess:
  for op in persisted_sess.graph.get_operations():
     print(op)
  freeze_graph(persisted_sess,"new_model.pb")

Upvotes: 1

Related Questions