Reputation: 19
I'm following the TensorFlow Keras tutorial for text generation. The training part works perfectly, but when I try to predict the next token, I get an error. Here's all the important code:
vocab = sorted(set(text))
char2index = { c:i for i, c in enumerate(vocab) }
index2char = np.array(vocab)
chars_to_int = np.array([char2index[c] for c in text])
char_dataset = tf.data.Dataset.from_tensor_slices(chars_to_int)
sequences = char_dataset.batch(seq_length + 1, drop_remainder=True)
def split_input_and_target(sequence):
input_ = sequence[:-1]
target_ = sequence[1:]
return input_, target_
dataset = sequences.map(split_input_and_target)
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
model = tf.keras.Sequential()
model.add(tf.keras.layers.Embedding(len(vocab), EMBEDDING_DIM,
batch_input_shape=[BATCH_SIZE, None]))
# here are a few more layers
model.compile(loss="sparse_categorical_crossentropy", optimizer="adam")
model.fit(dataset, epochs=EPOCHS)
num_tokens = 100
seed = "some text"
input_eval = [char2index[c] for c in seed]
input_eval = tf.expand_dims(input_eval, 0)
text_generated = []
model.reset_states()
for i in range(num_tokens):
predictions = model(input_eval)
predictions = tf.squeeze(predictions, 0)
# more stuff
Then, I first get a warning:
WARNING:tensorflow:Model was constructed with shape (64, None) for input Tensor("embedding_14_input:0", shape=(64, None), dtype=float32), but it was called on an input with incompatible shape (1, 9).
Then it gives me an error:
---->3 predictions = model(input_eval)
...
ValueError: Tensor's shape (9, 64, 256) is not compatible with supplied shape [9, 1, 256]
The second number, 64, is my batch size. If I change BATCH_SIZE to 1, everything works and all is fine, but this is obviously not the solution I am hoping for.
Upvotes: 0
Views: 2587
Reputation: 19
(I somehow managed to miss a step in the tutorial despite reading it several times over the past few hours.)
Here's the relevant passage:
To keep this prediction step simple, use a batch size of 1.
Because of the way the RNN state is passed from timestep to timestep, the model only accepts a fixed batch size once built.
To run the model with a different batch_size, we need to rebuild the model and restore the weights from the checkpoint.
tf.train.latest_checkpoint(checkpoint_dir)
model = build_model(vocab_size, embedding_dim, rnn_units, batch_size=1)
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
model.build(tf.TensorShape([1, None]))
I hope my silly mistake will help somebody to remember to reload the model in the future!
Upvotes: 1