Reputation: 96
I have 2 Tensorflow models both having the same architecture (Unet-3d). My current flow is:
Pre-processing -> Prediction from Model 1 -> Some operations -> Prediction from Model 2 -> Post-processing
The operations in between can be done in TF. Can we combine both the models with the operations in between to 1 TF graph such that the flow would look something like this:
Pre-processing -> Model 1+2 -> Post-processing
Thanks.
Upvotes: 7
Views: 9522
Reputation: 2632
You can use the tf.keras
functional api to achieve this, here is a toy example.
import tensorflow as tf
print('TensorFlow:', tf.__version__)
def preprocessing(tensor):
# preform your operations
return tensor
def some_operations(model_1_prediction):
# preform your operations
# assuming your operations result in a tensor
# which has shape matching with model_2's input
tensor = model_1_prediction
return tensor
def post_processing(tensor):
# preform your operations
return tensor
def get_model(name):
inp = tf.keras.Input(shape=[256, 256, 3])
x = tf.keras.layers.Conv2D(64, 3, 1, 'same')(inp)
x = tf.keras.layers.Conv2D(256, 3, 1, 'same')(x)
x = tf.keras.layers.Conv2D(512, 3, 1, 'same')(x)
x = tf.keras.layers.Conv2D(64, 3, 1, 'same')(x)
x = tf.keras.layers.Conv2D(3, 3, 1, 'same')(x)
# num_filters is set to 3 to make sure model_1's output
# matches model_2's input.
output = tf.keras.layers.Activation('sigmoid')(x)
return tf.keras.Model(inp, output, name=name)
model_1 = get_model('model-1')
model_2 = get_model('model-2')
x = some_operations(model_1.output)
out = model_2(x)
model_1_2 = tf.keras.Model(model_1.input, out, name='model-1+2')
model_1_2.summary()
Output:
TensorFlow: 2.1.0-rc0
Model: "model-1+2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 256, 256, 3)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 256, 256, 64) 1792
_________________________________________________________________
conv2d_1 (Conv2D) (None, 256, 256, 256) 147712
_________________________________________________________________
conv2d_2 (Conv2D) (None, 256, 256, 512) 1180160
_________________________________________________________________
conv2d_3 (Conv2D) (None, 256, 256, 64) 294976
_________________________________________________________________
conv2d_4 (Conv2D) (None, 256, 256, 3) 1731
_________________________________________________________________
activation (Activation) (None, 256, 256, 3) 0
_________________________________________________________________
model-2 (Model) (None, 256, 256, 3) 1626371
=================================================================
Total params: 3,252,742
Trainable params: 3,252,742
Non-trainable params: 0
_________________________________________________________________
Upvotes: 14