Reputation: 197
Currently trying to make this repo works.
I'm trying to save the trained model in the local machine so can be applied later. I read in tensorflow's doc, seems pretty intuitive to save the model, by calling tf.save_model.save(object)
. But I'm not sure how to apply.
Original code is here: model.py Following is my changes:
import tensorflow as tf
class ICON(tf.Module): # make it a tensorflow modul
def __init__(self, config, embeddingMatrix, session=None):
def _build_inputs(self):
def _build_vars(self):
def _convolution(self, input_to_conv):
def _inference(self):
def batch_fit(self, queries, ownHistory, otherHistory, labels):
feed_dict = {self._input_queries: queries, self._own_histories: ownHistory, self._other_histories: otherHistory,
self._labels: labels}
loss, _ = self._sess.run([self.loss_op, self.train_op], feed_dict=feed_dict)
return loss
def predict(self, queries, ownHistory, otherHistory, ):
feed_dict = {self._input_queries: queries, self._own_histories: ownHistory, self._other_histories: otherHistory}
return self._sess.run(self.predict_op, feed_dict=feed_dict)
def save(self): # attempt to save the model
tf.saved_model.save(
self, './output/model')
The code above produces ValueError as following:
ValueError: Tensor("ICON/CNN/embedding_matrix:0", shape=(16832, 300), dtype=float32_ref) must be from the same graph as Tensor("saver_filename:0", shape=(), dtype=string).
Upvotes: 1
Views: 1434
Reputation: 313
I believe you can use the tf.train.Saver class for this
def save(self): # attempt to save the model
saver = tf.train.Saver()
saver.save(self._sess, './output/model')
You can then restore the model this way
saver = tf.train.import_meta_graph('./output/model.meta')
with tf.Session() as sess:
saver.restore(sess, tf.train.latest_checkpoint('./output'))
You might also find this tutorial helpful in understanding this more.
Edit: if you want to use SavedModel
def save(self):
inputs = {'input_queries': self._input_queries, 'own_histories': self._own_histories, 'other_histories': self._other_histories}
outputs = {'output': self.predict_op}
tf.saved_model.simple_save(self._sess, './output/model', inputs, outputs)
You can then use tf.contrib.predictor.from_saved_model to load and serve using the SavedModel
from tensorflow.contrib.predictor import from_saved_model
predictor = from_saved_model('./output/model')
predictions = predictor({'input_queries': input_queries, 'own_histories': own_histories, 'other_histories': other_histories})
Upvotes: 1