Reputation: 368
Let's say I have a ResNet50 model and I wish to connect the output layer of this model to the input layer of a VGG model.
This is the ResNet model and the output tensor of ResNet50:
img_shape = (164, 164, 3)
resnet50_model = ResNet50(include_top=False, input_shape=img_shape, weights = None)
print(resnet50_model.output.shape)
I get the output:
TensorShape([Dimension(None), Dimension(6), Dimension(6), Dimension(2048)])
Now I want a new layer where I reshape this output tensor to (64,64,18)
Then I have a VGG16 model:
VGG_model = VGG_model = VGG16(include_top=False, weights=None)
I want the output of the ResNet50 to reshape into the desired tensor and fed in as an input to the VGG model. So essentially I want to concatenate two models. Can someone help me do that? Thank you!
Upvotes: 2
Views: 2205
Reputation: 3278
There are multiple ways you can do this. Here is one way of using Sequential model API to do it.
import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16
model = tf.keras.Sequential()
img_shape = (164, 164, 3)
model.add(ResNet50(include_top=False, input_shape=img_shape, weights = None))
model.add(tf.keras.layers.Reshape(target_shape=(64,64,18)))
model.add(tf.keras.layers.Conv2D(3,kernel_size=(3,3),name='Conv2d'))
VGG_model = VGG16(include_top=False, weights=None)
model.add(VGG_model)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
Model summary is as follows
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
resnet50 (Model) (None, 6, 6, 2048) 23587712
_________________________________________________________________
reshape (Reshape) (None, 64, 64, 18) 0
_________________________________________________________________
Conv2d (Conv2D) (None, 62, 62, 3) 489
_________________________________________________________________
vgg16 (Model) multiple 14714688
=================================================================
Total params: 38,302,889
Trainable params: 38,249,769
Non-trainable params: 53,120
_________________________________________________________________
Full code is here.
Upvotes: 2