random40154443
random40154443

Reputation: 1140

Batch training in Tensorflow Slim

I am looking at TF Slim introductory document and from what I understand, it only takes in one batch of image data at each run(32 images). Obviously, one wants to loop through this and train for many different batches. The intro does not cover this. How can this be done properly. I imagine there should be some way to specify a load batch function which should be called automatically when starting a batch training event, but I can't seem to find a simple example for this on the intro.

# Note that this may take several minutes.

import os

from datasets import flowers
from nets import inception
from preprocessing import inception_preprocessing

slim = tf.contrib.slim
image_size = inception.inception_v1.default_image_size


def get_init_fn():
    """Returns a function run by the chief worker to warm-start the training."""
    checkpoint_exclude_scopes=["InceptionV1/Logits", "InceptionV1/AuxLogits"]

    exclusions = [scope.strip() for scope in checkpoint_exclude_scopes]

    variables_to_restore = []
    for var in slim.get_model_variables():
        excluded = False
        for exclusion in exclusions:
            if var.op.name.startswith(exclusion):
                excluded = True
                break
        if not excluded:
            variables_to_restore.append(var)

    return slim.assign_from_checkpoint_fn(
      os.path.join(checkpoints_dir, 'inception_v1.ckpt'),
      variables_to_restore)


train_dir = '/tmp/inception_finetuned/'

with tf.Graph().as_default():
    tf.logging.set_verbosity(tf.logging.INFO)

    dataset = flowers.get_split('train', flowers_data_dir)
    images, _, labels = load_batch(dataset, height=image_size, width=image_size)

    # Create the model, use the default arg scope to configure the batch norm parameters.
    with slim.arg_scope(inception.inception_v1_arg_scope()):
        logits, _ = inception.inception_v1(images, num_classes=dataset.num_classes, is_training=True)

    # Specify the loss function:
    one_hot_labels = slim.one_hot_encoding(labels, dataset.num_classes)
    slim.losses.softmax_cross_entropy(logits, one_hot_labels)
    total_loss = slim.losses.get_total_loss()

    # Create some summaries to visualize the training process:
    tf.scalar_summary('losses/Total Loss', total_loss)

    # Specify the optimizer and create the train op:
    optimizer = tf.train.AdamOptimizer(learning_rate=0.01)
    train_op = slim.learning.create_train_op(total_loss, optimizer)

    # Run the training:
    final_loss = slim.learning.train(
        train_op,
        logdir=train_dir,
        init_fn=get_init_fn(),
        number_of_steps=2)


print('Finished training. Last batch loss %f' % final_loss)

Upvotes: 2

Views: 3451

Answers (1)

T.K. Bartel
T.K. Bartel

Reputation: 1365

The slim.learning.train function contains a training loop, so the code you've given does in fact train on multiple batches of images.

See here in the source code, where train_step_fn is called within a while loop. train_step (the default value of train_step_fn) contains the line sess.run([train_op, global_step]...), which actually runs the training operation on a single batch of images.

Upvotes: 1

Related Questions