Reputation: 2129
I am trying to build a very basic self-driving car steering angle prediction problem. I have around a video taken at 30fps and I've converted it into images of around 50000(i.e roughly 25 mins of footage) and I'm trying to predict target which is steering angle.
I have tried to use few basic convolutional layers but the MSE that I've got is very bad. So, I'm trying to use CNN + RNN to improve my model. As it make sense to use RNN's because my data is time distributed.
I don't know how to use TimeDistributed Layer
along with LSTM
for this. I'm basically using something like this below, I have each image of shape width, height, channel = (200, 66, 3)
img_height = 66
img_width = 200
channels = 3
input_l = Input(shape=(img_height, img_width, channels))
x = layers.Conv2D(128, kernel_size=(5, 5))(input_l)
x = layers.Conv2D(256, kernel_size=(5, 5))(x)
x = Flatten()(x)
out = Dense(1)(x)
model = Model(inputs=input_l, outputs=out)
model.summary()
As far I've learned that TimeDistributedLayer takes 4 dimensions in order to work but my each image is of shape (200, 66, 3) and how can I convert my each image into four dimension. I can't figure out on how exactly to use this. I've read through few articles but none of them talked about this.
How do I incorporate Time-distributed Layer along with LSTM into this architecture? Can anyone provide a sample code on how to achieve it.
Upvotes: 2
Views: 2011
Reputation: 809
The TimeDistributed layer assumes that dimension 1 is the time dimension (timesteps), so you need to add a time dimension to your image data. Something like:
from keras import layers
from keras import models
time_steps = 10
img_height = 66
img_width = 200
channels = 3
input_l = layers.Input(shape=(time_steps, img_height, img_width, channels))
x = layers.TimeDistributed( layers.Conv2D( 32, kernel_size=(5, 5) ) ) (input_l)
x = layers.TimeDistributed( layers.Conv2D(256, kernel_size=(5, 5)) ) (x)
x = layers.Flatten()(x)
out = layers.Dense(1)(x)
model = models.Model(inputs=input_l, outputs=out)
model.summary()
I hope this helps.
Upvotes: 3