Reputation: 1671
I have checkpoint from a pre-trained model A but not its graph-building code. Now I want to reload model A and add sub-graph B to it to get final model C. It's like that C = A + B. However, since model A is very well trained I don't want to train it in C. I only want to train sub-graph B in C. In other words, sub-graph A only participate in prediction phase(forward propagation) but not training phase(backward propagation), and sub-graph B participate both phase and is what I want to train. The aim is that with the help of sub-graph B, now model C will outperform model A.
How can this be achieved? I guess this may be related to saver/restore but I don't know how to get all things work together. Any code snippet will be greatly appreciated.
I'm using tensorflow 1.12
Upvotes: 0
Views: 299
Reputation: 182
Well, what you want to do is quite a typical machine learning use case, and it can be achieved in a few manners.
If you just have a saved checkpoint but not source code of the pre-trained model, after loading the model from checkpoint, you need to :
tf.GraphKeys.TRAINABLE_VARIABLES
collection to make variables
from ckpt non-trainable by the fine tuning modeltf.GraphKeys.GLOBAL_VARIABLES
collection to make variables
from ckpt not re-initialized by the fine tuning modelAnd then build and train the fine tuning model, as usual.
The code below is tested with TensorFlow 1.12
import os
import shutil
import tensorflow as tf
import numpy as np
from absl import logging, app, flags
flags.DEFINE_string('pre_train_model_checkpoint_path', '/tmp/pre_train_model', '')
FLAGS = flags.FLAGS
def train_and_save_pre_train_model():
# model: y = w*x + b
x = tf.placeholder(tf.float32, shape=(None), name='x')
y_true = tf.placeholder(tf.float32, shape=(None), name='y')
w = tf.get_variable('w', shape=())
b = tf.get_variable('b', shape=())
y = w*x + b
y_pred = tf.identity(y, 'y_pred')
loss = tf.losses.mean_squared_error(y_true, y_pred)
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss, global_step=tf.train.get_or_create_global_step())
sess = tf.Session()
sess.run(tf.global_variables_initializer())
for _ in range(2000):
x_val = np.random.rand(128)
y_val = x_val * 2.0 - 1.0 # to make w=2.0 and b=-1.0
sess.run(train_op, {x: x_val, y_true: y_val})
saver = tf.train.Saver()
shutil.rmtree(FLAGS.pre_train_model_checkpoint_path)
os.makedirs(FLAGS.pre_train_model_checkpoint_path)
save_path = os.path.join(FLAGS.pre_train_model_checkpoint_path, 'ckpt')
saver.save(sess, save_path, global_step=100)
return sess.run((w, b));
def main(_):
w_val, b_val = train_and_save_pre_train_model()
# to check if pre train model is trained as expected
logging.info('w_val={}, b_val={}'.format(w_val, b_val))
# load pre train model
tf.reset_default_graph()
meta_file = os.path.join(FLAGS.pre_train_model_checkpoint_path, 'ckpt-100.meta')
saver = tf.train.import_meta_graph(meta_file)
save_path = os.path.join(FLAGS.pre_train_model_checkpoint_path, 'ckpt-100')
sess = tf.Session()
saver.restore(sess, save_path)
x = tf.get_default_graph().get_tensor_by_name('x:0')
y_pred = tf.get_default_graph().get_tensor_by_name('y_pred:0')
w = tf.get_default_graph().get_tensor_by_name('w:0')
b = tf.get_default_graph().get_tensor_by_name('b:0')
# to make variable from ckpt non-trainable by fine tuning model
tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES).clear()
# to make variable from ckpt not re-initialized by fine tuning model
tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES).clear()
# build fine-tuning model: y2 = w2*y + b2
w2 = tf.get_variable('w2', shape=())
b2 = tf.get_variable('b2', shape=())
y2_pred = w2*y_pred + b2
y2_true = tf.placeholder(tf.float32, shape=(None), name='y2')
loss = tf.losses.mean_squared_error(y2_true, y2_pred)
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss, global_step=tf.train.get_or_create_global_step())
sess.run(tf.global_variables_initializer())
for _ in range(2000):
x_val = np.random.rand(128)
y2_val = (x_val * w_val + b_val) * 10.0 + 1.0 # to make w2=10.0 and b2=1.0
sess.run(train_op, {x: x_val, y2_true: y2_val})
w2_val, b2_val = sess.run((w2, b2))
logging.info('w2_val={}, b2_val={}'.format(w2_val, b2_val))
# assert w and b is not trained
w_val_after_fine_tuning, b_val_after_fine_tuning = sess.run((w, b))
logging.info('w_val_after_fine_tuning={}, b_val_after_fine_tuning={}'.format(w_val_after_fine_tuning, b_val_after_fine_tuning))
assert(w_val == w_val_after_fine_tuning)
assert(b_val == b_val_after_fine_tuning)
logging.info('all good')
if __name__ == '__main__':
app.run(main)
Upvotes: 1