Manish Sharma
Manish Sharma

Reputation: 190

How to use ConvLSTM2D followed by Conv2D in Keras python

I am trying to use the following model in Keras, where ConvLSTM2D output is followed by Conv2D to generate segmentation-like output. Input and output should be time series of the size (2*WINDOW_H+1, 2*WINDOW_W+1) each

model = Sequential()
model.add(ConvLSTM2D(3, kernel_size=3, padding = "same", batch_input_shape=(1, None, 2*WINDOW_H+1, 2*WINDOW_W+1, 1), return_sequences=True, stateful=True))
model.add(Conv2D(1, kernel_size=3, padding = "same"))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

However, this gives the following error (when adding Conv2D):

Input 0 is incompatible with layer conv2d_1: expected ndim=4, found ndim=5

Any pointers on where I might be wrong really appreciated. Thanks!

Upvotes: 0

Views: 1367

Answers (3)

In the ConvLSTM2D layer before Conv2D, you use return_sequence is False. After you run the program, you can get the result ndim=4.

Upvotes: 0

Pelonomi Moiloa
Pelonomi Moiloa

Reputation: 582

I think you would need to do a time distributed Conv2D layer so that the dimensions match. Like this maybe:

model = Sequential()
model.add(ConvLSTM2D(3, kernel_size=3, padding = "same", batch_input_shape=(1, None, 2*WINDOW_H+1, 2*WINDOW_W+1, 1), return_sequences=True, stateful=True))
model.add(TimeDistributed((Conv2D(1, kernel_size=3, padding = "same")))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

Upvotes: 1

johngull
johngull

Reputation: 819

The problem in your model that you try to use the sequence for the regular convolutional layer. The only thing that you need is to remove return_sequences=True in ConvLSTM2D.

So this line:

model.add(ConvLSTM2D(3, kernel_size=3, padding = "same", batch_input_shape=(1, None, 2*WINDOW_H+1, 2*WINDOW_W+1, 1), return_sequences=True, stateful=True))

should be like this:

model.add(ConvLSTM2D(3, kernel_size=3, padding = "same", batch_input_shape=(1, None, 2*WINDOW_H+1, 2*WINDOW_W+1, 1), stateful=True))

Upvotes: 0

Related Questions