Reputation: 41
I keep on running into an AssertionError when trying to fit a model. I did some reading on when Python raises an AssertionError. The backlog is as follows
File "G:/test3/main.py", line 167, in <module>
model.fit([images, captions], next_words, batch_size=128, epochs=50)
File "C:\Users\Acer\Anaconda3\lib\site-packages\keras\engine\training.py", line 950, in fit
batch_size=batch_size)
File "C:\Users\Acer\Anaconda3\lib\site-packages\keras\engine\training.py", line 671, in _standardize_user_data
self._set_inputs(x)
File "C:\Users\Acer\Anaconda3\lib\site-packages\keras\engine\training.py", line 575, in _set_inputs
assert len(inputs) == 1
AssertionError
My code is as follows
model=Sequential()
model.add(Concatenate([image_model, language_model]))
model.add(LSTM(1000, return_sequences=False))
model.add(Dense(vocab_size))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer=Nadam(), metrics=['accuracy'])
model.fit([images, captions], next_words, batch_size=5, epochs=50)
model.summary()
model.save_weights("./models/vgg16_weights_tf_dim_ordering_tf_kernels.h5")
images has a shape of (18724,1000) and captions has a shape of (18724, 43)
Upvotes: 1
Views: 8683
Reputation: 7129
You are getting this error because you did not specify any inputs to your model, and Keras is trying to set them on calling model.fit()
. The assertion is there because each model wrapped in a Sequential
container should take only one input.
To implement what you want, you probably want to go for Keras' Functional API instead of the Sequential API. Something along these lines:
from keras.models import Model
from keras.layers import Concatenate, Input, Dense
# First model
first_input = Input((2, ))
first_dense = Dense(128)(first_input)
# Second model
second_input = Input((10, ))
second_dense = Dense(64)(second_input)
# Concatenate both
merged = Concatenate()([first_dense, second_dense])
output_layer = Dense(1)(merged)
model = Model(inputs=[first_input, second_input], outputs=output_layer)
model.compile(optimizer='sgd', loss='mse')
Upvotes: 1