Reputation: 2897
I am trying to retrain the last layer of inception-resnet-v2. Here's what I came up with:
train_op
to minimise only these variables wrt lossAnd I implemented that as follows:
with slim.arg_scope(arg_scope):
logits = model(images_ph, is_training=True, reuse=None)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels_ph))
accuracy = tf.contrib.metrics.accuracy(tf.argmax(logits, 1), labels_ph)
train_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'InceptionResnetV2/Logits')
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
train_op = optimizer.minimize(loss, var_list=train_list)
# restore all variables whose names doesn't contain 'logits'
restore_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='^((?!Logits).)*$')
saver = tf.train.Saver(restore_list, write_version=tf.train.SaverDef.V2)
with tf.Session() as session:
init_op = tf.group(tf.local_variables_initializer(), tf.global_variables_initializer())
session.run(init_op)
saver.restore(session, '../models/inception_resnet_v2_2016_08_30.ckpt')
# followed by code for running train_op
This doesn't seem to work (training loss, error don't improve much from initial values). Is there a better/elegant way to do this? It would be good learning for me if you can also tell me what's going wrong here.
Upvotes: 9
Views: 2422
Reputation: 1143
There are several things:
is the scope correctly set? if you don't use L2 regularization and batch normalization of the gradients you might fall into a local minimum very soon and the network is unable to learn
from nets import inception_resnet_v2 as net
with net.inception_resnet_v2_arg_scope():
logits, end_points = net.inception_resnet_v2(images_ph, num_classes=num_classes,
is_training=True)
you should add the regularization variables to the loss (or at least the ones of the last layer):
regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
all_losses = [loss] + regularization_losses
total_loss = tf.add_n(all_losses, name='total_loss')
training only the full connected layer might not be a good idea, I would train all the network as the features you need for your class aren't necessarily defined in the last layer but few layers before and you need to change them.
double check the train_op runs after the loss:
with ops.name_scope('train_op'):
train_op = control_flow_ops.with_dependencies([train_op], total_loss)
Upvotes: 1