user270700
user270700

Reputation: 759

How to translate or shift batches of Tensors randomly in Tensorflow

I want to make my input image (tensor) to shift up/down or right/left randomly in every batch.

For example, I have a batch of grayscale images with size [10, 48, 64, 1].

If there is one image, I know I can use tf.pad and tf.slice(or other built-in functions)

But I want to apply random shift to 10 different images with one operation.

Is it possible? or should I use loop such as tf.scan?

Upvotes: 0

Views: 3092

Answers (2)

b3nk4n
b3nk4n

Reputation: 1131

As an alternative, you could also use tf.contrib.image.transform() and use the parameters a2 and b2 to translate the image:

import numpy as np
import tensorflow as tf

image1 = np.array([[[.1], [.1], [.1], [.1]],
                  [[.2], [.2], [.2], [.2]],
                  [[.3], [.3], [.3], [.3]],
                  [[.4], [.4], [.4], [.4]]])
image2 = np.array([[[.1], [.2], [.3], [.4]],
                  [[.1], [.2], [.3], [.4]],
                  [[.1], [.2], [.3], [.4]],
                  [[.1], [.2], [.3], [.4]]])
images = np.stack([image1, image2])
images_ = tf.convert_to_tensor(images, dtype=tf.float32)

shift1_x = 1
shift1_y = 2
shift2_x = -1
shift2_y = 0
transforms_ = tf.convert_to_tensor([[1, 0, -shift1_x, 0, 1, -shift1_y, 0, 0],
                                   [1, 0, -shift2_x, 0, 1, -shift2_y, 0, 0]],
                                   tf.float32)
shifted_ = tf.contrib.image.transform(images=images_,
                                      transforms=transforms_)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    shifted = sess.run([shifted_])
    print(shifted)

The transforms projection matrix can also be a tensor of size N x 8, so it is possible to shift every image of a batch differently. This can be easily extended by tf.random_uniform() to include some randomness to the x/y shift of each image.

Edit: To use random shifts for every image of the batch:

...
images_ = tf.convert_to_tensor(images, dtype=tf.float32)

num_imgs = images.shape[0]
base_ = tf.convert_to_tensor(np.tile([1, 0, 0, 0, 1, 0, 0, 0], [num_imgs, 1]), dtype=tf.float32)
mask_ = tf.convert_to_tensor(np.tile([0, 0, 1, 0, 0, 1, 0, 0], [num_imgs, 1]), dtype=tf.float32)
random_shift_ = tf.random_uniform([num_imgs, 8], minval=-2.49, maxval=2.49, dtype=tf.float32)
transforms_ = base_ + random_shift_ * mask_

shifted_ = tf.contrib.image.transform(images=images_,
                                      transforms=transforms_)
...

Edit 2: For the sake of completion, here just another helper function with applies random rotation and shift to each single image of a batch:

def augment_data(input_data, angle, shift):
    num_images_ = tf.shape(input_data)[0]
    # random rotate
    processed_data = tf.contrib.image.rotate(input_data,
                                             tf.random_uniform([num_images_],
                                                               maxval=math.pi / 180 * angle,
                                                               minval=math.pi / 180 * -angle))
    # random shift
    base_row = tf.constant([1, 0, 0, 0, 1, 0, 0, 0], shape=[1, 8], dtype=tf.float32)
    base_ = tf.tile(base_row, [num_images_, 1])
    mask_row = tf.constant([0, 0, 1, 0, 0, 1, 0, 0], shape=[1, 8], dtype=tf.float32)
    mask_ = tf.tile(mask_row, [num_images_, 1])
    random_shift_ = tf.random_uniform([num_images_, 8], minval=-shift, maxval=shift, dtype=tf.float32)
    transforms_ = base_ + random_shift_ * mask_

    processed_data = tf.contrib.image.transform(images=processed_data,
                                                transforms=transforms_)
    return processed_data

Upvotes: 5

soloice
soloice

Reputation: 1040

Are you looking for tf.random_crop and tf.pad?

Well, when using tf.random_crop, a random shift will be applied to all images in the batch. The shift inside a batch is the same, but can be different for different batches.

If you want to use different shift within a batch, I think it's better to use a queue/input pipeline. See https://www.tensorflow.org/programmers_guide/reading_data for more. Here's an example code from part of my own project. self.image_names is a Python list which contains paths to all training images. In an input pipeline, the data flow like a stream: you just need to deal with only one image, and the queue automatically takes care of scheduling things (some threads read the data, some process the data, some group single images into batches, others feed the data to GPU, etc., to keep the whole pipeline busy). In the code below, images and labels are queues. That is to say, when you process this variable (as I do in self.data_augmentation), you can think it contains only one image, but actually the queue processes every item in it (It's like an implicit loop), then tf.train.shuffle_batch will shuffle training data in the queue and group them into batches.

def data_augmentation(images):
    if FLAGS.random_flip_up_down:
        images = tf.image.random_flip_up_down(images)
    if FLAGS.random_brightness:
        images = tf.image.random_brightness(images, max_delta=0.3)
    if FLAGS.random_contrast:
        images = tf.image.random_contrast(images, 0.8, 1.2)
    return images

def input_pipeline(self, batch_size, num_epochs=None, aug=False):
    images_tensor = tf.convert_to_tensor(self.image_names, dtype=tf.string)
    labels_tensor = tf.convert_to_tensor(self.labels, dtype=tf.int64)
    input_queue = tf.train.slice_input_producer([images_tensor, labels_tensor], num_epochs=num_epochs)

    labels = input_queue[1]
    images_content = tf.read_file(input_queue[0])
    images = tf.image.convert_image_dtype(tf.image.decode_png(images_content, channels=1), tf.float32)
    if aug:
        images = self.data_augmentation(images)
    new_size = tf.constant([FLAGS.image_size, FLAGS.image_size], dtype=tf.int32)
    images = tf.image.resize_images(images, new_size)
    image_batch, label_batch = tf.train.shuffle_batch([images, labels], batch_size=batch_size, capacity=50000,
                                                      min_after_dequeue=10000)
    # print 'image_batch', image_batch.get_shape()
    return image_batch, label_batch

Upvotes: 1

Related Questions