P1t_
P1t_

Reputation: 165

Export Tensorflow Estimator

I'm trying to build a CNN with Tensorflow (r1.4) based on the API tf.estimator. It's a canned model. The idea is to train and evaluate the network with estimator in python and use the prediction in C++ without estimator by loading a pb file generated after the training.

My first question is, is it possible?

If yes, the training part works and the prediction part works too (with pb file generated without estimator) but it doesn't work when I load a pb file from estimator.

I got this error : "Data loss: Can't parse saved_model.pb as binary proto" My pyhon code to export my model :

feature_spec = {'input_image': parsing_ops.FixedLenFeature(dtype=dtypes.float32, shape=[1, 48 * 48])}
export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)

input_fn = tf.estimator.inputs.numpy_input_fn(self.eval_features,
                                              self.eval_label,
                                              shuffle=False,
                                              num_epochs=1)
eval_result = self.model.evaluate(input_fn=input_fn, name='eval')
exporter = tf.estimator.FinalExporter('save_model', export_input_fn)
exporter.export(estimator=self.model, export_path=MODEL_DIR,
                checkpoint_path=self.model.latest_checkpoint(),
                eval_result=eval_result,
                is_the_final_export=True)

It doesn't work neither with tf.estimator.Estimator.export_savedmodel()

If one of you knows an explicit tutorial on estimator with canned model and how to export it, I'm interested

Upvotes: 4

Views: 1338

Answers (1)

gdelab
gdelab

Reputation: 6220

Please look at this issue on github, it looks like you have the same problem. Apparently (at least when using estimator.export_savedmodel) you should load the graph with LoadSavedModel instead of ReadBinaryProto, because it's not saved as a graphdef file.

You'll find here a bit more instructions about how to use it:

 const string export_dir = ...
SavedModelBundle bundle;
...
LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain},
               &bundle);

I can't seem to find the SavedModelBundle documentation for c++ to use it afterwards, but it's likely close to the same class in Java, in which case it basically contains the session and the graph you'll be using.

Upvotes: 2

Related Questions