TamTam
TamTam

Reputation: 567

Unable to load a trained checkpoint

I'm following the code from here to learn the text summarization task with transformer model, It can be found here

But the code doesn't provide a way to load a model after training, so It's an inconvenience and I decided to write that function

Here is my code:

model = Transformer(
num_layers, 
d_model, 
num_heads, 
dff,
encoder_vocab_size, 
decoder_vocab_size, 
pe_input=max_len_news,
pe_target=max_len_summary,
)

model.load_weights('checkpoints/ckpt-5.data-00000-of-00001') 

It throws an error:

ValueError: Unable to load weights saved in HDF5 format into a subclassed Model which has not created its variables yet. Call the Model first, then load the weights.

I'm totally new to machine learning and TensorFlow. I know what It's trying to say, but I just don't know how to fix this issue, please help.

Upvotes: 1

Views: 1516

Answers (1)

Andrey
Andrey

Reputation: 6377

You have to call a model with dummy input before loading weights.

Try this:

model = Transformer(
num_layers, 
d_model, 
num_heads, 
dff,
encoder_vocab_size, 
decoder_vocab_size, 
pe_input=max_len_news,
pe_target=max_len_summary,
)

input = tf.random.uniform([1, 12], 0, 100, dtype=tf.int32) #create dummy input
enc_padding_mask, look_ahead_mask, dec_padding_mask = create_masks(input, input) # create masks
a = model(input, input, enc_padding_mask, look_ahead_mask, dec_padding_mask) # call the model before loading weights

model.load_weights('checkpoints/ckpt-5.data-00000-of-00001')

Upvotes: 1

Related Questions