Reputation: 670
For the convenience of discussion, the following models have been simplified.
Let's say there are around 40,000 512x512 images in my training set. I am trying to implement pre-training, and my plan is the following:
1.Train a neural network (lets call it net_1) that takes in 256x256 images, and save the trained model in tensorflow checkpoint file format.
net_1: input -> 3 conv2d -> maxpool2d -> 2 conv2d -> rmspool -> flatten -> dense
let's call this structure net_1_kernel:
net_1_kernel: 3 conv2d -> maxpool2d -> 3 conv2d
and call the remaining part other_layers:
other_layers: rmspool -> flatten -> dense
Thus we can represent net_1 in the following form:
net_1: input -> net_1_kernel -> other_layers
2.Insert several layers to the structure of net_1, and now call it net_2. It should look like this:
net_2: input -> net_1_kernel -> maxpool2d -> 3 conv2d -> other_layers
net_2 will take 512x512 images as input.
When I train net_2, I would like to use the saved weights and biases in the checkpoint file of net_1 to initialize the net_1_kernel part in net_2. How can I do this?
I know that I can load checkpoints to make predictions of test data. But in that case it will load everything (net_1_kernel and other_layers). What I want is to load net_1_kernel only and use it for the weight/bias initialization in net_2.
I also know that I can print contents in checkpoint files to txt, and copy & paste to manually initialize the weights and biases. However, there are so many numbers in those weights and biases, and this would be my last choice.
Upvotes: 2
Views: 743
Reputation: 61
First of all, you can use the following code to check the list of all checkpoints in the ckpt file you saved.
from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(file_name="file.ckpt", tensor_name="xxx", all_tensors=False, all_tensor_names=True)
Remember when you restore a checkpoint file, it will restore all variables in the checkpoint file. If you have to save and restore specific variables, you can do so as follows:
tf.trainable_variables()
var = [v for v in tf.trainable_variables() if "net_1_kernel" in v.name]
saverAndRestore = tf.train.Saver(var)
saverAndRestore.save(sess_1,"net_1.ckpt")
This will only save variables in the list var to net_1.ckpt.
saverAndRestore.restore(sess_1,"net_1.ckpt")
This will only restore variables in the list var from net_1.ckpt.
Apart from above, all you have to do is name/scope your variables such that you can easily do step 1 above.
Upvotes: 2