mCalado
mCalado

Reputation: 131

Can you change the input shape of a trained model in Tensorflow?

I trained a model with the input shape of (224, 224, 3) and I'm trying to change it to (300, 300, 3). For instance:

resnet50 = tf.keras.models.load_model(path_to_model)

model = tf.keras.models.Model([Input(shape=(300, 300, 3))], [resnet50.output])
# or
resnet50.inputs[0].set_shape([None, 300, 300, 3])

doesn't work.

I saw that the pretained model allows for different input shapes but adjusts the hole network architecture, for example, the size of the convolutional channels. I was wondering if I needed to do something similar or if for a trained model it is impossibel to change the input shape.

Upvotes: 0

Views: 4515

Answers (2)

user19676560
user19676560

Reputation:

tf.keras.applications.ResNet50(include_top=False, input_shape=[300,300,3])

input_shape: optional shape tuple, only to be specified if include_top is False (otherwise the input shape has to be (224, 224, 3) (with 'channels_last' data format) or (3, 224, 224) (with 'channels_first' data format). It should have 3 input channels, and the width and height should be no smaller than 32. E.g. (200, 200, 3) would be one valid value.

Upvotes: 0

Frightera
Frightera

Reputation: 5079

This would only work for convolutional layers as they do not care about input_shape because they are just sliding filters. However, if your model is trained on RGB images then also new_input shape should have 3 as channels.

Example:

first_model = VGG16(weights = None, input_shape=(224,224,3), include_top=False)
first_model.summary()

>>   input_6 (InputLayer)         [(None, 224, 224, 3)]     0         

And second model:

new_input = tf.keras.Input((300,300,3))
x = first_model.layers[1](new_input) # First conv. layer

for new_layer in first_model.layers[2:]:
    x = new_layer(x) # loop through layers using Functional API
second_model = tf.keras.Model(inputs=new_input, outputs=x)

second_model.summary()

>> 

Layer (type)                 Output Shape              Param #   
=================================================================
input_9 (InputLayer)         [(None, 300, 300, 3)]     0         

Upvotes: 1

Related Questions