Jame
Jame

Reputation: 3854

variables_to_train flag in Tf-slim

I am fine-tuning my model from a pretrained model using TF-Slim. When I used the create_train_op, I found that it has a parameter that is variables_to_train. In some tutorial, it used the flag as follows:

   all_trainable = [v for v in tf.trainable_variables()]
   trainable     = [v for v in all_trainable]
   train_op      = slim.learning.create_train_op(
        opt,
        global_step=global_step,
        variables_to_train=trainable,
        summarize_gradients=True)

But in the official TF-Slim, it does not use

   all_trainable = [v for v in tf.trainable_variables()]
   trainable     = [v for v in all_trainable]
   train_op      = slim.learning.create_train_op(
        opt,
        global_step=global_step,            
        summarize_gradients=True)

So, what is different between with and without using variables_to_train?

Upvotes: 1

Views: 388

Answers (1)

mckay
mckay

Reputation: 111

Your two example both do the same thing. You train all trainable variables that occur in your graph. With the parameter variables_to_train you can define which variables should be updated during your training.

A use case for this is when you have pre-trained stuff like word embedding that you don't want to train in your model. With

train_vars = [v for v in tf.trainable_variables() if "embeddings" not in v.name]
train_op      = slim.learning.create_train_op(
    opt,
    global_step=global_step,
    variables_to_train=train_vars,
    summarize_gradients=True)

you can exclude all variables from training that contain "embeddings" in their name. If you simply want to train all variables, you don't have to define train_vars and you can create the train op without the parameter variables_to_train.

Upvotes: 2

Related Questions