Madara
Madara

Reputation: 368

How do I connect two keras models into one model?

Let's say I have a ResNet50 model and I wish to connect the output layer of this model to the input layer of a VGG model.

This is the ResNet model and the output tensor of ResNet50:

img_shape = (164, 164, 3)
resnet50_model = ResNet50(include_top=False, input_shape=img_shape, weights = None)

print(resnet50_model.output.shape)

I get the output:

TensorShape([Dimension(None), Dimension(6), Dimension(6), Dimension(2048)])

Now I want a new layer where I reshape this output tensor to (64,64,18)

Then I have a VGG16 model:

VGG_model = VGG_model = VGG16(include_top=False, weights=None)

I want the output of the ResNet50 to reshape into the desired tensor and fed in as an input to the VGG model. So essentially I want to concatenate two models. Can someone help me do that? Thank you!

Upvotes: 2

Views: 2205

Answers (1)

Vishnuvardhan Janapati
Vishnuvardhan Janapati

Reputation: 3278

There are multiple ways you can do this. Here is one way of using Sequential model API to do it.

import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16

model = tf.keras.Sequential()
img_shape = (164, 164, 3)
model.add(ResNet50(include_top=False, input_shape=img_shape, weights = None))

model.add(tf.keras.layers.Reshape(target_shape=(64,64,18)))
model.add(tf.keras.layers.Conv2D(3,kernel_size=(3,3),name='Conv2d'))

VGG_model = VGG16(include_top=False, weights=None)
model.add(VGG_model)

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.summary()

Model summary is as follows

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
resnet50 (Model)             (None, 6, 6, 2048)        23587712  
_________________________________________________________________
reshape (Reshape)            (None, 64, 64, 18)        0         
_________________________________________________________________
Conv2d (Conv2D)              (None, 62, 62, 3)         489       
_________________________________________________________________
vgg16 (Model)                multiple                  14714688  
=================================================================
Total params: 38,302,889
Trainable params: 38,249,769
Non-trainable params: 53,120
_________________________________________________________________

Full code is here.

Upvotes: 2

Related Questions