Reputation: 963
I trained my model with estimator of TensorFlow. It seems that export_savedmodel
should be used to make .pb
file, but I don't really know how to construct the serving_input_receiver_fn
. Anybody any ideas?
Example code is welcomed.
Extra questions:
Is .pb
the only file I need when I want to reload the model? Variable
unnecessary?
How much will .pb
reduced the model file size compared with .ckpt
with adam optimizer?
Upvotes: 2
Views: 5214
Reputation: 14096
You can use freeze_graph.py
to produce a .pb
from .ckpt
+ .pbtxt
if you're using tf.estimator.Estimator
, then you'll find these two files in the model_dir
python freeze_graph.py \
--input_graph=graph.pbtxt \
--input_checkpoint=model.ckpt-308 \
--output_graph=output_graph.pb
--output_node_names=<output_node>
- Is .pb the only file I need when I want to reload the model? Variable unnecessary?
Yes, You'll have to know you're model's input nodes and output node names too. Then use import_graph_def
to load the .pb file and get the input and output operations using get_operation_by_name
- How much will .pb reduced the model file size compared with .ckpt with adam optimizer?
A .pb file is not a compressed .ckpt file, so there is no "compression rate".
However, there is a way to optimize your .pb file for inference, and this optimization may reduce the file size as it removes parts of the graph that are training only operations (see the complete description here).
[comment] how can I get the input and output node names?
You can set the input and output node names using the op name
parameter.
To list the node names in your .pbtxt
file, use the following script.
import tensorflow as tf
from google.protobuf import text_format
with open('graph.pbtxt') as f:
graph_def = text_format.Parse(f.read(), tf.GraphDef())
print [n.name for n in graph_def.node]
[comment] I found that there is a tf.estimator.Estimator.export_savedmodel(), is that the function to store model in .pb directly? And I'm struggling in it's parameter serving_input_receiver_fn. Any ideas?
export_savedmodel()
generates a SavedModel
which is a universal serialization format for TensorFlow models. It should contain everything's needed to fit with TensorFlow Serving APIs
serving_input_receiver_fn()
is a part of those needed things you have to provide in order to generate a SavedModel
, it determines the input signature of your model by adding placeholders to the graph.
From the doc
This function has the following purposes:
- To add placeholders to the graph that the serving system will feed with inference requests.
- To add any additional ops needed to convert data from the input format into the feature Tensors expected by the model.
If you're receiving your inference requests in the form of serialized tf.Examples
(which is a typical pattern) then you can use the example provided in the doc.
feature_spec = {'foo': tf.FixedLenFeature(...),
'bar': tf.VarLenFeature(...)}
def serving_input_receiver_fn():
"""An input receiver that expects a serialized tf.Example."""
serialized_tf_example = tf.placeholder(dtype=tf.string,
shape=[default_batch_size],
name='input_example_tensor')
receiver_tensors = {'examples': serialized_tf_example}
features = tf.parse_example(serialized_tf_example, feature_spec)
return tf.estimator.export.ServingInputReceiver(features, receiver_tensors)
[comment] Any idea to list the node names in '.pb'?
It depends on how it was generated.
if it's a SavedModel
the use:
import tensorflow as tf
with tf.Session() as sess:
meta_graph_def = tf.saved_model.loader.load(
sess,
[tf.saved_model.tag_constants.SERVING],
'./saved_models/1519232535')
print [n.name for n in meta_graph_def.graph_def.node]
if it's a MetaGraph
then use:
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
with gfile.FastGFile('model.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
print [n.name for n in graph_def.node]
Upvotes: 11