Ricardo Cruz
Ricardo Cruz

Reputation: 3593

TensorFlow: slicing tensor with a list using placeholders

I have a batch of images and I would like to extract patches at different positions specified during session running. The size is always the same.

If I wanted to extract the same patch in all images, I could of course just use tf.slice(images, [px, py, 0], [size, size, 3]).

slice same position

But I want to slice at different positions, so I would like px and py to be vectors.

slice different positions

In Numpy, I am not sure how to do this without using cycle. I would have done something like this:

result = np.array([image[y:y+size, x:x+size] for image, x, y in zip(images, px, py)])

Inspired by that, the TensorFlow solution I came up with was also to re-implement tf.slice using a cycle so that begin is now begin_vector:

def my_slice(input_, begin_vector, size):
    def condition(i, _):
        return tf.less(i, tf.shape(input_)[0])
    def body(i, r):
        sliced = tf.slice(input_[i], begin_vector[i], size)
        sliced = tf.expand_dims(sliced, 0)
        return i+1, tf.concat((r, sliced), 0)

    i = tf.constant(0)
    empty_result = tf.zeros((0, *size), tf.float32)
    loop = tf.while_loop(
        condition, body, [i, empty_result],
        [i.get_shape(), tf.TensorShape([None, *size])])
    return loop[1]

Then, I can just run this using my positions vector, here called ix:

sess = tf.Session()
images = tf.placeholder(tf.float32, (None, 256, 256, 1))
ix = tf.placeholder(tf.int32, (None, 3))
res = sess.run(
  my_slice(images, ix, [10, 10, 1]),
  {images: np.random.uniform(size=(2, 256, 256, 1)), ix: [[40, 80, 0], [20, 10, 0]]})
print(res.shape)

I just wonder if there is a prettier way to do this.

PS: I am aware that people have asked similar things. For example, Slicing tensor with list - TensorFlow. But notice I want to do the slicing using placeholders, so none of the solutions I have seen work for me. Everything needs to be dynamic during training. I want to use placeholders to specify the slices. I cannot use Python's for. I also don't want to turn on eager execution.

Upvotes: 3

Views: 398

Answers (1)

javidcf
javidcf

Reputation: 59731

Here is a function to do that without a loop:

import tensorflow as tf

def extract_patches(images, px, py, w, h):
    s = tf.shape(images)
    ii, yy, xx = tf.meshgrid(tf.range(s[0]), tf.range(h), tf.range(w), indexing='ij')
    xx2 = xx + px[:, tf.newaxis, tf.newaxis]
    yy2 = yy + py[:, tf.newaxis, tf.newaxis]
    # Optional: ensure indices do not go out of bounds
    xx2 = tf.clip_by_value(xx2, 0, s[2] - 1)
    yy2 = tf.clip_by_value(yy2, 0, s[1] - 1)
    idx = tf.stack([ii, yy2, xx2], axis=-1)
    return tf.gather_nd(images, idx)

Here is an example:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    # Works for images with any size and number of channels
    images = tf.placeholder(tf.float32, (None, None, None, None))
    patch_xy = tf.placeholder(tf.int32, (None, 2))
    patch_size = tf.placeholder(tf.int32, (2,))
    px = patch_xy[:, 0]
    py = patch_xy[:, 1]
    w = patch_size[0]
    h = patch_size[1]
    patches = extract_patches(images, px, py, w, h)
    test = sess.run(patches, {
        images: [
            # Image 0
            [[[ 0], [ 1], [ 2], [ 3]],
             [[ 4], [ 5], [ 6], [ 7]],
             [[ 8], [ 9], [10], [11]]],
            # Image 0
            [[[50], [51], [52], [53]],
             [[54], [55], [56], [57]],
             [[58], [59], [60], [61]]]
        ],
        patch_xy: [[1, 0],
                   [0, 1]],
        patch_size: [3, 2]})
    print(test[..., 0])
    # [[[ 1.  2.  3.]
    #   [ 5.  6.  7.]]
    #
    #  [[54. 55. 56.]
    #   [58. 59. 60.]]]

Upvotes: 2

Related Questions