user10253771
user10253771

Reputation: 670

How to implement pre-training in Tensorflow? How to partially use saved weights from checkpoint file?

For the convenience of discussion, the following models have been simplified.

Let's say there are around 40,000 512x512 images in my training set. I am trying to implement pre-training, and my plan is the following:

1.Train a neural network (lets call it net_1) that takes in 256x256 images, and save the trained model in tensorflow checkpoint file format.

net_1: input -> 3 conv2d -> maxpool2d -> 2 conv2d -> rmspool -> flatten -> dense

let's call this structure net_1_kernel:

net_1_kernel: 3 conv2d -> maxpool2d -> 3 conv2d

and call the remaining part other_layers:

other_layers: rmspool -> flatten -> dense

Thus we can represent net_1 in the following form:

net_1: input -> net_1_kernel -> other_layers

2.Insert several layers to the structure of net_1, and now call it net_2. It should look like this:

net_2: input -> net_1_kernel -> maxpool2d -> 3 conv2d -> other_layers

net_2 will take 512x512 images as input.

When I train net_2, I would like to use the saved weights and biases in the checkpoint file of net_1 to initialize the net_1_kernel part in net_2. How can I do this?

I know that I can load checkpoints to make predictions of test data. But in that case it will load everything (net_1_kernel and other_layers). What I want is to load net_1_kernel only and use it for the weight/bias initialization in net_2.

I also know that I can print contents in checkpoint files to txt, and copy & paste to manually initialize the weights and biases. However, there are so many numbers in those weights and biases, and this would be my last choice.

Upvotes: 2

Views: 743

Answers (1)

tg018
tg018

Reputation: 61

First of all, you can use the following code to check the list of all checkpoints in the ckpt file you saved.

from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(file_name="file.ckpt", tensor_name="xxx", all_tensors=False, all_tensor_names=True)

Remember when you restore a checkpoint file, it will restore all variables in the checkpoint file. If you have to save and restore specific variables, you can do so as follows:

  1. Make a list of all variables you want to save from tf.trainable_variables()

var = [v for v in tf.trainable_variables() if "net_1_kernel" in v.name]

saverAndRestore = tf.train.Saver(var)

  1. Now you can easily save or restore all the variables in var list as follows:

saverAndRestore.save(sess_1,"net_1.ckpt")

This will only save variables in the list var to net_1.ckpt.

saverAndRestore.restore(sess_1,"net_1.ckpt")

This will only restore variables in the list var from net_1.ckpt.

Apart from above, all you have to do is name/scope your variables such that you can easily do step 1 above.

Upvotes: 2

Related Questions