Reputation: 790
I have trained a big graph in tensorflow and saved them in a checkpoint by the following function,
def save_model(sess, saver, param_folder, saved_ckpt):
print("Saving model to disk...")
address = os.path.join(param_folder, 'model')
if not os.path.isdir(address):
os.makedirs(address)
address = os.path.join(address, saved_ckpt)
save_path = saver.save(sess, address)
saver.export_meta_graph(filename=address+'.meta')
print("Model saved in file: %s" % save_path)
Now, to load the graph, I used the following function.
def load_model(sess, saver, param_folder, saved_ckpt):
print("loding model from disk...")
address = os.path.join(param_folder, 'model')
if not os.path.isdir(address):
os.makedirs(address)
address = os.path.join(address, saved_ckpt)
print("meta graph address :", address)
saver = tf.train.import_meta_graph(address+'.meta')
saver.restore(sess, address)
It's a great feature of TensorFlow that it automatically assigns the saved weights to the desired graph from the checkpoint. But the problem occurs when I want to load the graph (the graph that is saved in the checkpoint) in a slightly different/extended graph than the graph I saved. Like, assume I have added an additional neural network to the previous graph and want to load the weights from the previous checkpoint so that I don't have to train the model from the beginning.
So in short, my question is, how to load a previously saved sub-graph to a larger (or you can say the parent) graph?
Upvotes: 1
Views: 1569
Reputation: 721
I also encountered this issue, and I used @rvinas comment. So just to make it easier for the next readers.
When your are loading the saved variables you can add/remove/edit them in the restore_dict as shown below:
def load_model(sess, saver, param_folder, saved_ckpt):
print("loding model from disk...")
address = os.path.join(param_folder, 'model')
if not os.path.isdir(address):
os.makedirs(address)
address = os.path.join(address, saved_ckpt)
print("meta graph address :", address)
# remove the next two lines
# saver = tf.train.import_meta_graph(address+'.meta')
# saver.restore(sess, address)
# instead put this block:
reader = tf.train.NewCheckpointReader(address)
restore_dict = dict()
for v in tf.trainable_variables():
tensor_name = v.name.split(':')[0]
if reader.has_tensor(tensor_name):
print('has tensor ', tensor_name)
restore_dict[tensor_name] = v
# put the logic of the new/modified variable here and assign to the restore_dict, i.e.
# restore_dict['my_var_scope/my_var'] = get_my_variable()
Hope that helps.
Upvotes: 2