Bastiaan
Bastiaan

Reputation: 4672

How to loop through the elements of a tensor in TensorFlow?

I want to create a function batch_rot90(batch_of_images) using TensorFlow's tf.image.rot90(), the latter only takes one image at a time, the former should take a batch of n images at once (shape = [n,x,y,f]).

So naturally, one should just itterate through all images in the batch and rotate them one by one. In numpy this would look like:

def batch_rot90(batch):
  for i in range(batch.shape[0]):
    batch_of_images[i] = rot90(batch[i,:,:,:])
  return batch

How is this done in TensorFlow? using tf.while_loop I got his far:

batch = tf.placeholder(tf.float32, shape=[2, 256, 256, 4])    
def batch_rot90(batch, k, name=''):
      i = tf.constant(0)
      def cond(batch, i):
        return tf.less(i, tf.shape(batch)[0])
      def body(im, i):
        batch[i] = tf.image.rot90(batch[i], k)
        i = tf.add(i, 1)
        return batch, i  
      r = tf.while_loop(cond, body, [batch, i])
      return r

But the assignment to im[i] is not allowed, and I'm confused about what is returned with r.

I realize there might be a workaround for this particular case using tf.batch_to_space() but I believe it should be possible with a loop of some kind too.

Upvotes: 4

Views: 8654

Answers (2)

Da Tong
Da Tong

Reputation: 2026

Updated Answer:

x = tf.placeholder(tf.float32, shape=[2, 3])

def cond(batch, output, i):
    return tf.less(i, tf.shape(batch)[0])

def body(batch, output, i):
    output = output.write(i, tf.add(batch[i], 10))
    return batch, output, i + 1

# TensorArray is a data structure that support dynamic writing
output_ta = tf.TensorArray(dtype=tf.float32,
               size=0,
               dynamic_size=True,
               element_shape=(x.get_shape()[1],))
_, output_op, _  = tf.while_loop(cond, body, [x, output_ta, 0])
output_op = output_op.stack()

with tf.Session() as sess:
    print(sess.run(output_op, feed_dict={x: [[1, 2, 3], [0, 0, 0]]}))

I think you should consider using tf.scatter_update to update one image in the batch, instead of using batch[i] = .... Refer to this link for detail. In your case, I suggest change the first line of body to:

tf.scatter_update(batch, i, tf.image.rot90(batch[i], k))

Upvotes: 4

Bastiaan
Bastiaan

Reputation: 4672

There is a map function in tf, that will work:

def batch_rot90(batch, k, name=''):
  fun = lambda x: tf.images.rot90(x, k = 1)
  return = tf.map_fn(fun, batch)

Upvotes: 3

Related Questions