RobR
RobR

Reputation: 2190

tf.image.rot90() gives error when parameter k is a tensor

I am trying to introduce a random 90-degree rotation to images as part of a training data pipeline. However when I try to populate the k parameter of tf.image.rot90() with a scalar tensor I get the following error: TypeError: Fetch argument None has invalid type <class 'NoneType'>. The function works as expected when k is a python variable. The following demonstrates the problem:

import tensorflow as tf
import random
import numpy as np
from matplotlib import pyplot as plt

with tf.Session() as sess:

   image = np.reshape(np.arange(0., 4.), [2, 2, 1])
   print(image.shape)

   # this works
   k = random.randint(0, 3)
   print('k = ' + str(k))

   # this gives an error
   # k = random.randint(0, 3)
   # k = tf.convert_to_tensor(k, dtype=tf.int32)
   # k = tf.Print(k, [k], 'k = ')

   # this gives an error
   # k = tf.random_uniform([], minval=0, maxval=4, dtype=tf.int32)
   # k = tf.Print(k, [k], 'k = ')

   image2 = tf.image.rot90(image, k)
   img2 = sess.run(image2)
   plt.figure
   plt.subplot(121)
   plt.imshow(np.squeeze(image), interpolation='nearest')
   plt.subplot(122)
   plt.imshow(np.squeeze(img2), interpolation='nearest')
   plt.show()

Is there a way to set k to a random value as part of the training pipeline? Or is this a bug in tf.image.rot90()?

Upvotes: 1

Views: 936

Answers (1)

mrry
mrry

Reputation: 126184

The current implementation of tf.image.rot90() has a bug: if you pass a value that is not a Python integer, it will not return any value. I created an issue about this, and will get a fix in soon. In general, you should be able to draw a random integer scalar for k, but the current implementation isn't general enough to support that.

You could try using tf.case() to implement it yourself, but I intend to implement that in the fix, so it might be easier to wait :-).

Upvotes: 2

Related Questions