moono
moono

Reputation: 41

Tensorflow, use a tf.estimator trained model within another tf.estimator model_fn

Is there a way to use tf.estimator trained model A in another model B?

Here is situation, Let say I have a trained 'Model A' with model_a_fn(). 'Model A' gets images as input, and outputs some vector floating values similar to MNIST classifier. And there is another 'Model B' which is defined in model_b_fn(). It also gets images as input, and needs vector output of 'Model A' while training 'Model B'.

So basically I want to train 'Model B' that need inputs as images & prediction output of 'Model A'. (No need to train 'Model A' anymore, only to get prediction output while training 'Model B')

I've tried three cases:

  1. Use estimator object('Model A') inside model_b_fn()
  2. Exported 'Model A' with tf.estimator.export_savedmodel(), and create prediction function. Passed it to model_b_fn() with params dict.
  3. Same as 2, but restore 'Model A' inside model_b_fn()

But all cases shows errors:

  1. ... must be from the same graph as ...
  2. TypeError: can't pickle _thread.RLock objects
  3. TypeError: The value of a feed cannot be a tf.Tensor object.

And here is my code I used... only attaching important parts

train_model_a.py

def model_a_fn(features, labels, mode, params):
    # ...
    # ...
    # ...
    return

def main():
    # model checkpoint location
    model_a_dir = './model_a'

    # create estimator for Model A
    model_a = tf.estimator.Estimator(model_fn=model_a_fn, model_dir=model_a_dir)

    # train Model A
    model_a.train(input_fn=lambda : input_fn_a)
    # ...
    # ...
    # ...

    # export model a
    model_a.export_savedmodel(model_a_dir, serving_input_receiver_fn=serving_input_receiver_fn)
    # exported to ./model_a/123456789
    return

if __name__ == '__main__':
    main()

train_model_b_case_1.py

# follows model_a's input format
def bypass_input_fn(x):
    features = {
        'x': x,
    }
    return features

def model_b_fn(features, labels, mode, params):
    # parse input
    inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])

    # get Model A's response
    model_a = params['model_a']
    predictions = model_a.predict(
        input_fn=lambda: bypass_input_fn(inputs)
    )
    for results in predictions:
        # Error occurs!!!
        model_a_output = results['class_id']

    # build Model B
    layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
    layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)

    # ...
    # some layers added...
    # ...

    flatten = tf.layers.flatten(prev_layer)
    layern = tf.layers.dense(10)

    # let say layern's output shape and model_a_output's output shape is same
    add_layer = tf.add(flatten, model_a_output)

    # ...
    # do more... stuff
    # ...
    return

def main():
    # load pretrained model A
    model_a_dir = './model_a'
    model_a = tf.estimator.Estimator(model_fn=model_a_fn, model_dir=model_a_dir)

    # model checkpoint location
    model_b_dir = './model_b/'

    # create estimator for Model A
    model_b = tf.estimator.Estimator(
        model_fn=model_b_fn,
        model_dir=model_b_dir,
        params={
            'model_a': model_a,
        }
    )

    # train Model B
    model_b.train(input_fn=lambda : input_fn_b)
    return

if __name__ == '__main__':
    main()

train_model_b_case_2.py

def model_b_fn(features, labels, mode, params):
    # parse input
    inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])

    # get Model A's response
    model_a_predict_fn = params['model_a_predict_fn']
    model_a_prediction = model_a_predict_fn(
        {
            'x': inputs
        }
    )
    model_a_output = model_a_prediction['output']

    # build Model B
    layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
    layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)

    # ...
    # some layers added...
    # ...

    flatten = tf.layers.flatten(prev_layer)
    layern = tf.layers.dense(10)

    # let say layern's output shape and model_a_output's output shape is same
    add_layer = tf.add(flatten, model_a_output)

    # ...
    # do more... stuff
    # ...
    return

def main():
    # load pretrained model A
    model_a_dir = './model_a/123456789'
    model_a_predict_fn = tf.contrib.predictor.from_saved_model(export_dir=model_a_dir)

    # model checkpoint location
    model_b_dir = './model_b/'

    # create estimator for Model A
    # Error occurs!!!
    model_b = tf.estimator.Estimator(
        model_fn=model_b_fn,
        model_dir=model_b_dir,
        params={
            'model_a_predict_fn': model_a_predict_fn,
        }
    )

    # train Model B
    model_b.train(input_fn=lambda : input_fn_b)
    return

if __name__ == '__main__':
    main()

train_model_b_case_3.py

def model_b_fn(features, labels, mode, params):
    # parse input
    inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])

    # get Model A's response
    model_a_predict_fn = tf.contrib.predictor.from_saved_model(export_dir=params['model_a_dir'])
    # Error occurs!!!
    model_a_prediction = model_a_predict_fn(
        {
            'x': inputs
        }
    )
    model_a_output = model_a_prediction['output']

    # build Model B
    layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
    layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)

    # ...
    # some layers added...
    # ...

    flatten = tf.layers.flatten(prev_layer)
    layern = tf.layers.dense(10)

    # let say layern's output shape and model_a_output's output shape is same
    add_layer = tf.add(flatten, model_a_output)

    # ...
    # do more... stuff
    # ...
    return

def main():
    # load pretrained model A
    model_a_dir = './model_a/123456789'

    # model checkpoint location
    model_b_dir = './model_b/'

    # create estimator for Model A
    # Error occurs!!!
    model_b = tf.estimator.Estimator(
        model_fn=model_b_fn,
        model_dir=model_b_dir,
        params={
            'model_a_dir': model_a_dir,
        }
    )

    # train Model B
    model_b.train(input_fn=lambda : input_fn_b)
    return

if __name__ == '__main__':
    main()

So any idea on using trained custom tf.estimator in another tf.estimator please??

Upvotes: 3

Views: 695

Answers (1)

moono
moono

Reputation: 41

I've figured out one solution to this problem.

One can use this method if struggling with same problem.

  1. create a function which runs tensorflow.contrib.predictor.from_saved_model() -> call it 'pretrained_predictor()'
  2. inside Model B's model_fn(), call above predefined 'pretrained_predictor()' wrap it with tensorflow.py_func()

For example case, see https://github.com/moono/tf-cnn-mnist/blob/master/4_3_estimator_within_estimator.py for simple use case.

Upvotes: 1

Related Questions