Reputation: 41
Is there a way to use tf.estimator trained model A in another model B?
Here is situation, Let say I have a trained 'Model A' with model_a_fn(). 'Model A' gets images as input, and outputs some vector floating values similar to MNIST classifier. And there is another 'Model B' which is defined in model_b_fn(). It also gets images as input, and needs vector output of 'Model A' while training 'Model B'.
So basically I want to train 'Model B' that need inputs as images & prediction output of 'Model A'. (No need to train 'Model A' anymore, only to get prediction output while training 'Model B')
I've tried three cases:
But all cases shows errors:
And here is my code I used... only attaching important parts
def model_a_fn(features, labels, mode, params):
# ...
# ...
# ...
return
def main():
# model checkpoint location
model_a_dir = './model_a'
# create estimator for Model A
model_a = tf.estimator.Estimator(model_fn=model_a_fn, model_dir=model_a_dir)
# train Model A
model_a.train(input_fn=lambda : input_fn_a)
# ...
# ...
# ...
# export model a
model_a.export_savedmodel(model_a_dir, serving_input_receiver_fn=serving_input_receiver_fn)
# exported to ./model_a/123456789
return
if __name__ == '__main__':
main()
# follows model_a's input format
def bypass_input_fn(x):
features = {
'x': x,
}
return features
def model_b_fn(features, labels, mode, params):
# parse input
inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])
# get Model A's response
model_a = params['model_a']
predictions = model_a.predict(
input_fn=lambda: bypass_input_fn(inputs)
)
for results in predictions:
# Error occurs!!!
model_a_output = results['class_id']
# build Model B
layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)
# ...
# some layers added...
# ...
flatten = tf.layers.flatten(prev_layer)
layern = tf.layers.dense(10)
# let say layern's output shape and model_a_output's output shape is same
add_layer = tf.add(flatten, model_a_output)
# ...
# do more... stuff
# ...
return
def main():
# load pretrained model A
model_a_dir = './model_a'
model_a = tf.estimator.Estimator(model_fn=model_a_fn, model_dir=model_a_dir)
# model checkpoint location
model_b_dir = './model_b/'
# create estimator for Model A
model_b = tf.estimator.Estimator(
model_fn=model_b_fn,
model_dir=model_b_dir,
params={
'model_a': model_a,
}
)
# train Model B
model_b.train(input_fn=lambda : input_fn_b)
return
if __name__ == '__main__':
main()
def model_b_fn(features, labels, mode, params):
# parse input
inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])
# get Model A's response
model_a_predict_fn = params['model_a_predict_fn']
model_a_prediction = model_a_predict_fn(
{
'x': inputs
}
)
model_a_output = model_a_prediction['output']
# build Model B
layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)
# ...
# some layers added...
# ...
flatten = tf.layers.flatten(prev_layer)
layern = tf.layers.dense(10)
# let say layern's output shape and model_a_output's output shape is same
add_layer = tf.add(flatten, model_a_output)
# ...
# do more... stuff
# ...
return
def main():
# load pretrained model A
model_a_dir = './model_a/123456789'
model_a_predict_fn = tf.contrib.predictor.from_saved_model(export_dir=model_a_dir)
# model checkpoint location
model_b_dir = './model_b/'
# create estimator for Model A
# Error occurs!!!
model_b = tf.estimator.Estimator(
model_fn=model_b_fn,
model_dir=model_b_dir,
params={
'model_a_predict_fn': model_a_predict_fn,
}
)
# train Model B
model_b.train(input_fn=lambda : input_fn_b)
return
if __name__ == '__main__':
main()
def model_b_fn(features, labels, mode, params):
# parse input
inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])
# get Model A's response
model_a_predict_fn = tf.contrib.predictor.from_saved_model(export_dir=params['model_a_dir'])
# Error occurs!!!
model_a_prediction = model_a_predict_fn(
{
'x': inputs
}
)
model_a_output = model_a_prediction['output']
# build Model B
layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)
# ...
# some layers added...
# ...
flatten = tf.layers.flatten(prev_layer)
layern = tf.layers.dense(10)
# let say layern's output shape and model_a_output's output shape is same
add_layer = tf.add(flatten, model_a_output)
# ...
# do more... stuff
# ...
return
def main():
# load pretrained model A
model_a_dir = './model_a/123456789'
# model checkpoint location
model_b_dir = './model_b/'
# create estimator for Model A
# Error occurs!!!
model_b = tf.estimator.Estimator(
model_fn=model_b_fn,
model_dir=model_b_dir,
params={
'model_a_dir': model_a_dir,
}
)
# train Model B
model_b.train(input_fn=lambda : input_fn_b)
return
if __name__ == '__main__':
main()
So any idea on using trained custom tf.estimator in another tf.estimator please??
Upvotes: 3
Views: 695
Reputation: 41
I've figured out one solution to this problem.
One can use this method if struggling with same problem.
For example case, see https://github.com/moono/tf-cnn-mnist/blob/master/4_3_estimator_within_estimator.py for simple use case.
Upvotes: 1