Ink
Ink

Reputation: 963

How to store model in `.pb` file with Estimator in TensorFlow?

I trained my model with estimator of TensorFlow. It seems that export_savedmodel should be used to make .pb file, but I don't really know how to construct the serving_input_receiver_fn. Anybody any ideas? Example code is welcomed.

Extra questions:

  1. Is .pb the only file I need when I want to reload the model? Variable unnecessary?

  2. How much will .pb reduced the model file size compared with .ckpt with adam optimizer?

Upvotes: 2

Views: 5214

Answers (1)

Ghilas BELHADJ
Ghilas BELHADJ

Reputation: 14096

You can use freeze_graph.py to produce a .pb from .ckpt + .pbtxt if you're using tf.estimator.Estimator, then you'll find these two files in the model_dir

python freeze_graph.py \
    --input_graph=graph.pbtxt \
    --input_checkpoint=model.ckpt-308 \
    --output_graph=output_graph.pb
    --output_node_names=<output_node>
  1. Is .pb the only file I need when I want to reload the model? Variable unnecessary?

Yes, You'll have to know you're model's input nodes and output node names too. Then use import_graph_def to load the .pb file and get the input and output operations using get_operation_by_name

  1. How much will .pb reduced the model file size compared with .ckpt with adam optimizer?

A .pb file is not a compressed .ckpt file, so there is no "compression rate".

However, there is a way to optimize your .pb file for inference, and this optimization may reduce the file size as it removes parts of the graph that are training only operations (see the complete description here).

[comment] how can I get the input and output node names?

You can set the input and output node names using the op name parameter.

To list the node names in your .pbtxt file, use the following script.

import tensorflow as tf
from google.protobuf import text_format

with open('graph.pbtxt') as f:
    graph_def = text_format.Parse(f.read(), tf.GraphDef())

print [n.name for n in graph_def.node]

[comment] I found that there is a tf.estimator.Estimator.export_savedmodel(), is that the function to store model in .pb directly? And I'm struggling in it's parameter serving_input_receiver_fn. Any ideas?

export_savedmodel() generates a SavedModel which is a universal serialization format for TensorFlow models. It should contain everything's needed to fit with TensorFlow Serving APIs

serving_input_receiver_fn() is a part of those needed things you have to provide in order to generate a SavedModel, it determines the input signature of your model by adding placeholders to the graph.

From the doc

This function has the following purposes:

  • To add placeholders to the graph that the serving system will feed with inference requests.
  • To add any additional ops needed to convert data from the input format into the feature Tensors expected by the model.

If you're receiving your inference requests in the form of serialized tf.Examples (which is a typical pattern) then you can use the example provided in the doc.

feature_spec = {'foo': tf.FixedLenFeature(...),
                'bar': tf.VarLenFeature(...)}

def serving_input_receiver_fn():
  """An input receiver that expects a serialized tf.Example."""
  serialized_tf_example = tf.placeholder(dtype=tf.string,
                                         shape=[default_batch_size],
                                         name='input_example_tensor')
  receiver_tensors = {'examples': serialized_tf_example}
  features = tf.parse_example(serialized_tf_example, feature_spec)
  return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)

[comment] Any idea to list the node names in '.pb'?

It depends on how it was generated.

if it's a SavedModel the use:

import tensorflow as tf

with tf.Session() as sess:
    meta_graph_def = tf.saved_model.loader.load(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        './saved_models/1519232535')
    print [n.name for n in meta_graph_def.graph_def.node]

if it's a MetaGraph then use:

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    with gfile.FastGFile('model.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        print [n.name for n in graph_def.node]

Upvotes: 11

Related Questions