Michael Lempart
Michael Lempart

Reputation: 51

Time distributed layer keras

Iam trying to understand the time distributed layer in keras/tensorflow. As far as I have understood it is a kind of wrapper, making it possible to in example process a sequence of images.

Now Iam wondering how would design a time distributed network without using the time distributed layer.

In example if I would have a sequence of 3 images, each having 1 channel and a pixel dimension of 256x256px, that should first be processed by a CNN and then by LSTM cells. My input to the time distributed layer would then be (N,3,256,256,1), where N is the batch size.

The CNN would then have 3 outputs, which are fed to the LSTM cell.

Now, without using the time distributed layers, would it be possible to accomplish the same by setting up a network with 3 different inputs and 3 similar CNNs? The outputs of the 3 CNNs could then be flattened and concatenated.

Is that any different from the time distributed approach?

Thanks in advance,

M

Upvotes: 0

Views: 1357

Answers (1)

Bashir Kazimi
Bashir Kazimi

Reputation: 1377

I created a prototype for you. I used the least number of layers and arbitrary units/kernels/filters, change them as you like. It creates a cnn model first that takes inputs of size (256,256,1). It uses the same cnn model 3 times (for your three images in the sequence) to extract features. It stacks all the features using Lambda layer to put it back in a sequence. The sequence then goes through LSTM layer. I have chosen for the LSTM to return a single feature vector per example, but if you want the output to be a sequence as well, you could change it to say return_sequences=True. You could also add final additional layers to adapt it to your needs.

from tensorflow.keras.layers import Input, LSTM, Conv2D, Flatten, Lambda
from tensorflow.keras import Model
import tensorflow.keras.backend as K

def create_cnn_model():
  inp = Input(shape=(256,256,1))
  x = Conv2D(filters=16, kernel_size=5, strides=2)(inp)
  x = Flatten()(x)
  model = Model(inputs=inp, outputs=x, name='cnn_Model')
  return model


def combined_model():
  cnn_model = create_cnn_model()
  inp_1 = Input(shape=(256,256,1))
  inp_2 = Input(shape=(256,256,1))
  inp_3 = Input(shape=(256,256,1))

  out_1 = cnn_model(inp_1)
  out_2 = cnn_model(inp_2)
  out_3 = cnn_model(inp_3)

  lstm_inp = [out_1, out_2, out_3]
  lstm_inp = Lambda(lambda x: K.stack(x, axis=-2))(lstm_inp)
  x = LSTM(units=32, return_sequences=False)(lstm_inp)

  model = Model(inputs=[inp_1, inp_2, inp_3], outputs=x)
  return model

Now create the model as such:

model = combined_model()

Check the summary:

model.summary()

which will print:

Model: "model_14"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_53 (InputLayer)           [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
input_54 (InputLayer)           [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
input_55 (InputLayer)           [(None, 256, 256, 1) 0                                            
__________________________________________________________________________________________________
cnn_Model (Model)               (None, 254016)       416         input_53[0][0]                   
                                                                 input_54[0][0]                   
                                                                 input_55[0][0]                   
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, 3, 254016)    0           cnn_Model[1][0]                  
                                                                 cnn_Model[2][0]                  
                                                                 cnn_Model[3][0]                  
__________________________________________________________________________________________________
lstm_13 (LSTM)                  (None, 32)           32518272    lambda_3[0][0]                   
==================================================================================================
Total params: 32,518,688
Trainable params: 32,518,688
Non-trainable params: 0

The inner cnn model summary could be printed:

model.get_layer('cnn_Model').summary()

which currently prints:

Model: "cnn_Model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_52 (InputLayer)        [(None, 256, 256, 1)]     0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 126, 126, 16)      416       
_________________________________________________________________
flatten_6 (Flatten)          (None, 254016)            0         
=================================================================
Total params: 416
Trainable params: 416
Non-trainable params: 0
_________________________

Your model expects a list as input. The list should have a length of 3 (since there are 3 images in a sequence). Each element of the list should be a numpy array of shape (batch_size, 256, 256, 1). I have worked a dummy example below with a batch size of 1:

import numpy as np

a = np.zeros((256,256,1)) # first image filled with zeros
b = np.zeros((256,256,1)) # second image filled with zeros
c = np.zeros((256,256,1)) # third image filled with zeros

a = np.expand_dims(a, 0) # adding batch dimension to make it (1, 256, 256, 1)
b = np.expand_dims(b, 0) # same here
c = np.expand_dims(c, 0) # same here


model.compile(loss='mse', optimizer='adam')
# train your model with model.fit(....)

e = model.predict([a,b,c]) # a,b and c have shape of (1, 256, 256, 1) where the first 1 is the batch size

Upvotes: 1

Related Questions