Reputation: 3593
I have a batch of images and I would like to extract patches at different positions specified during session running. The size is always the same.
If I wanted to extract the same patch in all images, I could of course just use tf.slice(images, [px, py, 0], [size, size, 3])
.
But I want to slice at different positions, so I would like px
and py
to be vectors.
In Numpy, I am not sure how to do this without using cycle. I would have done something like this:
result = np.array([image[y:y+size, x:x+size] for image, x, y in zip(images, px, py)])
Inspired by that, the TensorFlow solution I came up with was also to re-implement tf.slice
using a cycle so that begin
is now begin_vector
:
def my_slice(input_, begin_vector, size):
def condition(i, _):
return tf.less(i, tf.shape(input_)[0])
def body(i, r):
sliced = tf.slice(input_[i], begin_vector[i], size)
sliced = tf.expand_dims(sliced, 0)
return i+1, tf.concat((r, sliced), 0)
i = tf.constant(0)
empty_result = tf.zeros((0, *size), tf.float32)
loop = tf.while_loop(
condition, body, [i, empty_result],
[i.get_shape(), tf.TensorShape([None, *size])])
return loop[1]
Then, I can just run this using my positions vector, here called ix
:
sess = tf.Session()
images = tf.placeholder(tf.float32, (None, 256, 256, 1))
ix = tf.placeholder(tf.int32, (None, 3))
res = sess.run(
my_slice(images, ix, [10, 10, 1]),
{images: np.random.uniform(size=(2, 256, 256, 1)), ix: [[40, 80, 0], [20, 10, 0]]})
print(res.shape)
I just wonder if there is a prettier way to do this.
PS: I am aware that people have asked similar things. For example, Slicing tensor with list - TensorFlow. But notice I want to do the slicing using placeholders, so none of the solutions I have seen work for me. Everything needs to be dynamic during training. I want to use placeholders to specify the slices. I cannot use Python's for
. I also don't want to turn on eager execution.
Upvotes: 3
Views: 398
Reputation: 59731
Here is a function to do that without a loop:
import tensorflow as tf
def extract_patches(images, px, py, w, h):
s = tf.shape(images)
ii, yy, xx = tf.meshgrid(tf.range(s[0]), tf.range(h), tf.range(w), indexing='ij')
xx2 = xx + px[:, tf.newaxis, tf.newaxis]
yy2 = yy + py[:, tf.newaxis, tf.newaxis]
# Optional: ensure indices do not go out of bounds
xx2 = tf.clip_by_value(xx2, 0, s[2] - 1)
yy2 = tf.clip_by_value(yy2, 0, s[1] - 1)
idx = tf.stack([ii, yy2, xx2], axis=-1)
return tf.gather_nd(images, idx)
Here is an example:
import tensorflow as tf
with tf.Graph().as_default(), tf.Session() as sess:
# Works for images with any size and number of channels
images = tf.placeholder(tf.float32, (None, None, None, None))
patch_xy = tf.placeholder(tf.int32, (None, 2))
patch_size = tf.placeholder(tf.int32, (2,))
px = patch_xy[:, 0]
py = patch_xy[:, 1]
w = patch_size[0]
h = patch_size[1]
patches = extract_patches(images, px, py, w, h)
test = sess.run(patches, {
images: [
# Image 0
[[[ 0], [ 1], [ 2], [ 3]],
[[ 4], [ 5], [ 6], [ 7]],
[[ 8], [ 9], [10], [11]]],
# Image 0
[[[50], [51], [52], [53]],
[[54], [55], [56], [57]],
[[58], [59], [60], [61]]]
],
patch_xy: [[1, 0],
[0, 1]],
patch_size: [3, 2]})
print(test[..., 0])
# [[[ 1. 2. 3.]
# [ 5. 6. 7.]]
#
# [[54. 55. 56.]
# [58. 59. 60.]]]
Upvotes: 2