GPhilo
GPhilo

Reputation: 19123

Freeze graph with different data format

I trained a small CNN on my GPU using NCHW data format, now I want to export a .pb file that I can then use to do inference in other applications.

I wrote a small helper function to call Tensorflow's freeze_graph function with default values, given a directory containing the checkpoint files and graph.pbtxt:

import os
import argparse
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import tensorflow as tf
dir(tf.contrib) #fix for tf.contrib undefined ops bug
from tensorflow.python.tools.freeze_graph import freeze_graph 

def my_freeze_graph_2(model_dir, output_node_names):
"""Extract the sub graph defined by the output nodes and convert 
  all its variables into constant 
  Args:
      model_dir: the root folder containing the checkpoint state file
      output_node_names: a string, containing all the output node's names, 
                          comma separated
"""
if not tf.gfile.Exists(model_dir):
    raise AssertionError(
        "Export directory doesn't exists. Please specify an export "
        "directory: %s" % model_dir)

if not output_node_names:
    print("You need to supply the name of a node to --output_node_names.")
    return -1

# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path

# We precise the file fullname of our freezed graph
absolute_model_dir = os.path.abspath(model_dir)
output_graph = os.path.join(absolute_model_dir, "frozen_model.pb")

freeze_graph(input_graph=os.path.join(model_dir, 'graph.pbtxt'),
             input_saver='',
             input_binary=False,
             input_checkpoint=input_checkpoint,
             output_node_names=output_node_names,
             restore_op_name="save/restore_all",
             filename_tensor_name="save/Const:0",
             output_graph=output_graph,
             clear_devices=True,
             initializer_nodes='')

I then have a small script that attempts to build the graph from frozen_model.pb to test that the freezing actually worked:

import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import argparse
import tensorflow as tf
from freeze_graph import load_graph
import cv2

if __name__ == '__main__':
    # Let's allow the user to pass the filename as an argument
    parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", default="model-multiple_starts/frozen_model.pb", type=str, help="Frozen model file to import")
args = parser.parse_args()

# We use our "load_graph" function
graph = load_graph(args.frozen_model_filename)

# We can verify that we can access the list of operations in the graph
for op in graph.get_operations():
    print(op.name)

# We access the input and output nodes 
x = graph.get_tensor_by_name('prefix/Reshape:0')
y = graph.get_tensor_by_name('prefix/softmax_tensor:0')

# We launch a Session
with tf.Session(graph=graph, config=tf.ConfigProto(log_device_placement=True)) as sess:
    # Note: we don't nee to initialize/restore anything
    # There is no Variables in this graph, only hardcoded constants 

    # Load an image to use as test
    im = cv2.imread('57_00000000.png', cv2.IMREAD_GRAYSCALE)
    im = im.T
    im = im / 255 - 0.5
    im = im[None,:,:,None]


    y_out = sess.run(y, feed_dict={
        x: im 
    })
    print(y_out)

If I try to run my test script, I get the following error:

InvalidArgumentError: CPU BiasOp only supports NHWC. [[Node: prefix/conv2d/BiasAdd = BiasAdd[T=DT_FLOAT, data_format="NCHW", _device="/job:localhost/replica:0/task:0/cpu:0"](prefix/conv2d/convolution, prefix/conv2d/bias/read)]]

I tried different configurations:

All of them raise the same error.

The problem lies in the fact that the checkpoint which I want to freeze has operations defined with data_format='NCHW'. How do I freeze the checkpoint with NHWC data format?

Update:

Poking around the files, I see that in graph.pbtxt for many operations data_format is hardcoded to NCHW. I guess, then, I'll need to make a new model with NHWC format, selectively load from the checkpoint the weights for the layers and use that graph to manually save out a .pb file... I'd assume there would be a process to do this already, but I can't find any documentation about this, nor examples.

Update 2:

After trying to import the .pb files in OpenCV's DNN module, I found out the following:

It seems, then, that checkpoints are not transferable between graphs with different data formats (even if no error or warning is raised during the freezing process).

Upvotes: 4

Views: 1790

Answers (1)

Eli Bixby
Eli Bixby

Reputation: 1178

Typically, you'll want to wrap graph construction up in functions, so that you can rebuild your graph conditionally for the prediction case, because usually quite a few pieces of the graph change from training to prediction. As you've discovered NCHW and NWHC versions of, for example the convolutional layers, are actually different Ops in the graph proto, and they are hardcoded this way because GPU optimizations are only possible for one of the formats.

Editing graph protos is very difficult to do correctly, which is why most TensorFlow code that does this follows the pattern I described above. At a very high level:

def build_graph(data_format='NCHW'):
   # Conditionally use proper ops based on data_format arg

training_graph = tf.Graph()
with training_graph.as_default():
   build_graph(data_format='NCHW')

with tf.Session() as sess:
   # train
   # checkpoint session

prediction_graph = tf.Graph()
with prediction_graph.as_default():
   build_graph(data_format='NHWC')
   # load checkpoint
   # freeze graph

Note that the tf.estimator.Estimator framework makes this relatively easy. You can use the mode argument in your model_fn to decide between data formats and then have two different input_fns for training and prediction, and the framework will do the rest. You can find an end to end example of this here: https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10_estimator/cifar10_main.py#L77 (I've linked to the relevant lines)

Upvotes: 1

Related Questions