YohanRoth
YohanRoth

Reputation: 3233

How to apply tf.map_fn for each channel in Tensorflow

I need to apply smth like tf.map_fn to each channel for a batch of imgs (512, 32, 32, 3) where the last dim corresponds to a channel (rgb) e.g

x = tf.map_fn(lambda channel: func(y), x)

where func(y) is a function applied to each channel-matrix e.g (512, 32, 32)

Is there some way of doing it?

Or maybe I can do smth like this

for ch in range(3):
   cp = tf.copy(x[:,:,:,ch])   #TF does not have copy, it's just pseudo code
   cp = tf.reshape(xp, [xp.shape[0], -1])
   out = func(cp)
   unsq = tf.reshape(out, [x.shape[0], 32, 32])
   [:,:,:,ch] = unsq

e.g I need to apply func over flattened images for each channel. I am completely to tf so I have no idea how to accomplish it.

Thanks!

Upvotes: 0

Views: 360

Answers (1)

nessuno
nessuno

Reputation: 27042

You can just transpose the dimensions before calling map_fn:

x = tf.transpose(x, perm=[3,0,1,2]) # shape 3, 512, 32, 32
x = tf.map_fn(lambda channel: func(y), x)

Upvotes: 2

Related Questions