Avedis
Avedis

Reputation: 447

Cleaner way to whiten each image in a batch using keras

I would like to whiten each image in a batch. The code I have to do so is this:

def whiten(self, x):
    shape = x.shape
    x = K.batch_flatten(x)
    mn = K.mean(x, 0)
    std = K.std(x, 0) + K.epsilon()
    r = (x - mn) / std
    r = K.reshape(x, (-1,shape[1],shape[2],shape[3]))
    return r
#

where x is (?, 320,320,1). I am not keen on the reshape function with a -1 arg. Is there a cleaner way to do this?

Upvotes: 0

Views: 75

Answers (1)

gorjan
gorjan

Reputation: 5555

Let's see what the -1 does. From the Tensorflow documentation (Because the documentation from Keras is scarce compared to the one from Tensorflow):

If one component of shape is the special value -1, the size of that dimension is computed so that the total size remains constant.

So what this means:

from keras import backend as K

X = tf.constant([1,2,3,4,5])
K.reshape(X, [-1, 5])
# Add one more dimension, the number of columns should be 5, and keep the number of elements to be constant
# [[1 2 3 4 5]]

X = tf.constant([1,2,3,4,5,6])
K.reshape(X, [-1, 3])
# Add one more dimension, the number of columns should be 3
# For the number of elements to be constant the number of rows should be 2
# [[1 2 3]
#  [4 5 6]]

I think it is simple enough. So what happens in your code:

# Let's assume we have 5 images, 320x320 with 3 channels
X = tf.ones((5, 320, 320, 3))
shape = X.shape

# Let's flat the tensor so we can perform the rest of the computation
flatten = K.batch_flatten(X)
# What this did is: Turn a nD tensor into a 2D tensor with same 0th dimension. (Taken from the documentation directly, let's see that below)
flatten.shape
# (5, 307200)
# So all the other elements were squeezed in 1 dimension while keeping the batch_size the same

# ...The rest of the stuff in your code is executed here...

# So we did all we wanted and now we want to revert the tensor in the shape it had previously
r = K.reshape(flatten, (-1, shape[1],shape[2],shape[3]))
r.shape
# (5, 320, 320, 3)

Besides, I can't think of a cleaner way to do what you want to do. If you ask me, your code is already clear enough.

Upvotes: 1

Related Questions