Reputation: 149
I am trying to continue training from a saved checkpoint using the colab setup for GPT-2-simple at:
https://colab.research.google.com/drive/1SvQne5O_7hSdmPvUXl5UzPeG5A6csvRA#scrollTo=aeXshJM-Cuaf
But I just cant get it to work. Loading the saved checkpoint from my googledrive works fine, and I can use it to generate text, but I cant continue training from that checkpoint. In the gpt2.finetune ()
I am entering restore.from='latest"
and overwrite=True
, and I have been trying to use both same run_name and different one, and using overwrite=True
, and not. I have also tried restarting the runtime in between, as was suggested, but it doesn´t help, I keep getting the following error:
"ValueError: Variable model/wpe already exists, disallowed. Did you mean to set reuse=True
or reuse=tf.AUTO_REUSE in VarScope?"
I asume that I need to run the gpt2.load_gpt2(sess, run_name='myRun')
before continue training, but whenever I have run this first, the gtp2.finetune()
throws this error
Upvotes: 2
Views: 4902
Reputation: 11
I've tryed the following and works fine:
tf.reset_default_graph()
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
steps=n,
dataset=file_name,
model_name='model',
print_every=z,
run_name= 'run_name',
restore_from='latest',
sample_every=x,
save_every=y
)
You must indicate the same 'run_name' as the model you want to resume training and hp restore_from = 'latest'
Upvotes: 1
Reputation: 1935
You don't need to (and can't) run load_gpt2()
before finetuning. You instead simply need to give run_name
to finetune()
. I agree that this is confusing as hell; I had the same trouble.
sess = gpt2.start_tf_sess()
gpt2.finetune(sess,
file_name,
model_name=model_name,
checkpoint_dir=checkpoint_dir,
run_name=run_name,
steps=25,
)
This will automatically grab the latest checkpoint from your checkpoint/run-name
folder, load its weights, and continue training where it left off. You can confirm this by checking the epoch number - it doesn't start again from 0. E.g., if you'd previously trained 25 epochs, it'll start at 26:
Training...
[26 | 7.48] loss=0.49 avg=0.49
Also note that to run finetuning multiple times (or to load another model) you normally have to restart the python runtime. You can instead run this before each finetine command:
tf.reset_default_graph()
Upvotes: 3