Rocket Pingu
Rocket Pingu

Reputation: 621

Passing CNN outputs to LSTM in Tensorflow?

Given that the output of a CNN is of the shape [batch_size, height, width, number_of_channels] (assuming the format is channels_last), I have this way of turning CNN dimensions to RNN dimensions:

def collapse_to_rnn_dims(inputs):
    batch_size, height, width, num_channels = inputs.get_shape().as_list()
    if batch_size is None:
        batch_size = -1
    return tf.reshape(inputs, [batch_size, width, height * num_channels])

It does work. However, I just would like to ask if this is really the proper way of reshaping CNN outputs so that they can be passed to the LSTM layer.

Upvotes: 1

Views: 805

Answers (1)

Rocket Pingu
Rocket Pingu

Reputation: 621

I found an answer here that does exactly what I'm doing for handwritten text recognition though this one assumes that the number_of_time_steps (width) is dynamic and not the batch_size.

shape = cnn_net.get_shape().as_list()  # [batch, height, width, features]
transposed = tf.transpose(cnn_net, perm=[0, 2, 1, 3],
                          name='transposed')  # [batch, width, height, features]
conv_reshaped = tf.reshape(transposed, [shape[0], -1, shape[1] * shape[3]],
                           name='reshaped')  # [batch, width, height x features]

Upvotes: 2

Related Questions