Reputation: 325
I am trying to use keras' functional API to build a recurrent neural network, but met some problems about the output shape, any help will be appreciated.
my code:
import tensorflow as tf
from tensorflow.python.keras.datasets import mnist
from tensorflow.python.keras.layers import Dense, CuDNNLSTM, Dropout
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.utils import normalize
from tensorflow.python.keras.utils import np_utils
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = normalize(x_train, axis=1), normalize(x_test, axis=1)
y_train = np_utils.to_categorical(y_train, 10)
y_test = np_utils.to_categorical(y_test, 10)
feature_input = tf.keras.layers.Input(shape=(28, 28))
x = tf.keras.layers.CuDNNLSTM(128, kernel_regularizer=tf.keras.regularizers.l2(l=0.0004), return_sequences=True)(feature_input)
y = tf.keras.layers.Dense(10, activation='softmax')(x)
model = tf.keras.Model(inputs=feature_input, outputs=y)
opt = tf.keras.optimizers.Adam(lr=1e-3, decay=1e-5)
model.compile(optimizer=opt, loss="sparse_categorical_crossentropy", metrics=['accuracy'])
model.fit(x_train, y_train, epochs=3, validation_data=(x_test, y_test))
ERROR:
ValueError: Error when checking target: expected dense to have 3 dimensions, but got array with shape (60000, 10)
Upvotes: 0
Views: 526
Reputation: 86600
You data (targets) has shape (60000, 10)
.
Your model's output ('dense') has shape (None, length, 10)
.
Where None
is the batch size (variable), length
is the middle dimension, which mean "time steps" for an LSTM, and 10 is the units of the Dense
layer.
Now, you don't have any sequence with time steps to process in an LSTM, it doesn't make sense. It is interpreting "image rows" as sequential time steps and "image columns" as independent features. (If this was not your intention, you simply got lucky that it didn't give you an error for trying to put an image into an LSTM)
Anyway, you can fix this error with return_sequences=False
(discard the length
of the sequences). Which does not mean this model is optimal for this case.
Upvotes: 2