Priyatham
Priyatham

Reputation: 2897

Retraining the last layer of Inception-ResNet-v2

I am trying to retrain the last layer of inception-resnet-v2. Here's what I came up with:

  1. Get names of variables in the final layer
  2. Create a train_op to minimise only these variables wrt loss
  3. Restore the whole graph except the final layer while initialising only the last layer randomly.

And I implemented that as follows:

with slim.arg_scope(arg_scope):
    logits = model(images_ph, is_training=True, reuse=None)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels_ph))
accuracy = tf.contrib.metrics.accuracy(tf.argmax(logits, 1), labels_ph)

train_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'InceptionResnetV2/Logits')
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)

train_op = optimizer.minimize(loss, var_list=train_list)

# restore all variables whose names doesn't contain 'logits'
restore_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='^((?!Logits).)*$')

saver = tf.train.Saver(restore_list, write_version=tf.train.SaverDef.V2)

with tf.Session() as session:


    init_op = tf.group(tf.local_variables_initializer(), tf.global_variables_initializer())

    session.run(init_op)
    saver.restore(session, '../models/inception_resnet_v2_2016_08_30.ckpt')


# followed by code for running train_op

This doesn't seem to work (training loss, error don't improve much from initial values). Is there a better/elegant way to do this? It would be good learning for me if you can also tell me what's going wrong here.

Upvotes: 9

Views: 2422

Answers (1)

jorgemf
jorgemf

Reputation: 1143

There are several things:

  • how is the learning rate? a too high value can mess with everything (probably not the reason)
  • try to use stochastic gradient descent, you should have less problems
  • is the scope correctly set? if you don't use L2 regularization and batch normalization of the gradients you might fall into a local minimum very soon and the network is unable to learn

    from nets import inception_resnet_v2 as net
    with net.inception_resnet_v2_arg_scope():
        logits, end_points = net.inception_resnet_v2(images_ph, num_classes=num_classes,
                                                     is_training=True)
    
  • you should add the regularization variables to the loss (or at least the ones of the last layer):

    regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    all_losses = [loss] + regularization_losses
    total_loss = tf.add_n(all_losses, name='total_loss')
    
  • training only the full connected layer might not be a good idea, I would train all the network as the features you need for your class aren't necessarily defined in the last layer but few layers before and you need to change them.

  • double check the train_op runs after the loss:

    with ops.name_scope('train_op'):
        train_op = control_flow_ops.with_dependencies([train_op], total_loss)
    

Upvotes: 1

Related Questions