Maruf
Maruf

Reputation: 790

How to add some new variables to a loaded checkpoints in tensorflow?

I have trained a big graph in tensorflow and saved them in a checkpoint by the following function,

def save_model(sess, saver, param_folder, saved_ckpt):
    print("Saving model to disk...")
    address = os.path.join(param_folder, 'model')
    if not os.path.isdir(address):
        os.makedirs(address)
    address = os.path.join(address, saved_ckpt)
    save_path = saver.save(sess, address)
    saver.export_meta_graph(filename=address+'.meta')
    print("Model saved in file: %s" % save_path)

Now, to load the graph, I used the following function.

def load_model(sess, saver, param_folder, saved_ckpt):
    print("loding model from disk...")
    address = os.path.join(param_folder, 'model')
    if not os.path.isdir(address):
        os.makedirs(address)
    address = os.path.join(address, saved_ckpt)
    print("meta graph address :", address)
    saver = tf.train.import_meta_graph(address+'.meta')
    saver.restore(sess, address)

It's a great feature of TensorFlow that it automatically assigns the saved weights to the desired graph from the checkpoint. But the problem occurs when I want to load the graph (the graph that is saved in the checkpoint) in a slightly different/extended graph than the graph I saved. Like, assume I have added an additional neural network to the previous graph and want to load the weights from the previous checkpoint so that I don't have to train the model from the beginning.

So in short, my question is, how to load a previously saved sub-graph to a larger (or you can say the parent) graph?

Upvotes: 1

Views: 1569

Answers (1)

Idan Azuri
Idan Azuri

Reputation: 721

I also encountered this issue, and I used @rvinas comment. So just to make it easier for the next readers.

When your are loading the saved variables you can add/remove/edit them in the restore_dict as shown below:

def load_model(sess, saver, param_folder, saved_ckpt):
    print("loding model from disk...")
    address = os.path.join(param_folder, 'model')
    if not os.path.isdir(address):
        os.makedirs(address)
    address = os.path.join(address, saved_ckpt)
    print("meta graph address :", address)
    # remove the next two lines
    # saver = tf.train.import_meta_graph(address+'.meta')
    # saver.restore(sess, address)
    # instead put this block:

    reader = tf.train.NewCheckpointReader(address)
    restore_dict = dict()
    for v in tf.trainable_variables():
      tensor_name = v.name.split(':')[0]
      if reader.has_tensor(tensor_name):
        print('has tensor ', tensor_name)
        restore_dict[tensor_name] = v
        # put the logic of the new/modified variable here and assign to the restore_dict, i.e. 
        # restore_dict['my_var_scope/my_var'] = get_my_variable()

Hope that helps.

Upvotes: 2

Related Questions