Raghavendra Sugeeth
Raghavendra Sugeeth

Reputation: 119

Gathering list of 2-d tensors from a 3-d tensor in Keras

I have a 3-d Tensor named main_decoder of shape (None,9,256)

I want to extract 9 tensors of shape (None,256)

I have tried using Keras gather and the following is mode code snippet:

for i in range(0,9):
    sub_decoder_input = Lambda(lambda main_decoder:gather(main_decoder,(i)), name='lambda'+str(i))(main_decoder)

the resultant is 9 lambda layers of shape (9,256)

How can I modify it so that I can get or gather 9 tensors of shape (None,256)

Thanks.

Upvotes: 2

Views: 523

Answers (1)

Yu-Yang
Yu-Yang

Reputation: 14619

You can slice the 3D tensor into 9 2D tensors and return a list of tensors from the Lambda layer.

main_decoder = Input(shape=(9, 256))
sub_decoder_input = Lambda(lambda x: [x[:, i, :] for i in range(9)])(main_decoder)

print(sub_decoder_input)
[<tf.Tensor 'lambda_1/strided_slice:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_1:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_2:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_3:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_4:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_5:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_6:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_7:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_8:0' shape=(?, 256) dtype=float32>]

Upvotes: 3

Related Questions