Reputation: 5681
I am using this code:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Dense, LSTM, Input, Conv2D, Lambda
from tensorflow.keras import Model
def reshape_n(x):
x = tf.compat.v1.placeholder_with_default(
[None, 121, 240, 2])
return x
input_shape = (121, 240, 1)
inputs = Input(shape=input_shape)
x = Conv2D(1, 1)(inputs)
x = LSTM(2, return_sequences=True)(x[0, :, :, :])
x = Lambda(reshape_n, (121, 240,2))(x[None, :, :, :])
x = Conv2D(1, 1)(x)
output = Dense(3, activation='softmax')(x)
model = Model(inputs, output)
train_x = np.random.randint(0, 30, size=(10, 121, 240))
train_y = np.random.randint(0, 3, size=(10, 121, 240))
train_y = tf.one_hot(tf.cast(train_y, 'int32'), depth=3), train_y, epochs=2)
and I receive:
logits and labels must be broadcastable: logits_size=[29040,3] labels_size=[290400,3]
If I just omit the LSTM layer:
x = Conv2D(1, 1)(inputs)
x = Conv2D(1, 1)(x)
output = Dense(3, activation='softmax')(x)
then the code runs without any problem!
Upvotes: 1
Views: 125
Reputation: 2921
Using tensorflow-gpu==2.3.0
and numpy==1.19.5
, when I run your code, I observe no errors, exit code is 0
. My python version is Python 3.8.6
, in case that matters as well.
The displayed model summary is
Model: "functional_1"
Layer (type) Output Shape Param #
input_1 (InputLayer) [(None, 121, 240, 1)] 0
conv2d (Conv2D) (None, 121, 240, 1) 2
tf_op_layer_strided_slice (T [(121, 240, 1)] 0
lstm (LSTM) (121, 240, 2) 32
tf_op_layer_strided_slice_1 [(1, 121, 240, 2)] 0
lambda (Lambda) (None, 121, 240, 2) 0
conv2d_1 (Conv2D) (None, 121, 240, 1) 3
dense (Dense) (None, 121, 240, 3) 6
Total params: 43
Trainable params: 43
Non-trainable params: 0
The training phase:
Epoch 1/2
2021-07-14 13:42:20.645002: I tensorflow/stream_executor/platform/default/] Successfully opened dynamic library
2021-07-14 13:42:20.793137: I tensorflow/stream_executor/platform/default/] Successfully opened dynamic library
1/1 [==============================] - 0s 824us/step - loss: 1.1036 - accuracy: 0.3336
Epoch 2/2
1/1 [==============================] - 0s 2ms/step - loss: 1.1033 - accuracy: 0.3336
Upvotes: 1