Reputation: 3233
I need to apply smth like tf.map_fn to each channel for a batch of imgs (512, 32, 32, 3) where the last dim corresponds to a channel (rgb) e.g
x = tf.map_fn(lambda channel: func(y), x)
where func(y)
is a function applied to each channel-matrix e.g (512, 32, 32)
Is there some way of doing it?
Or maybe I can do smth like this
for ch in range(3):
cp = tf.copy(x[:,:,:,ch]) #TF does not have copy, it's just pseudo code
cp = tf.reshape(xp, [xp.shape[0], -1])
out = func(cp)
unsq = tf.reshape(out, [x.shape[0], 32, 32])
[:,:,:,ch] = unsq
e.g I need to apply func over flattened images for each channel. I am completely to tf so I have no idea how to accomplish it.
Thanks!
Upvotes: 0
Views: 360
Reputation: 27042
You can just transpose the dimensions before calling map_fn
:
x = tf.transpose(x, perm=[3,0,1,2]) # shape 3, 512, 32, 32
x = tf.map_fn(lambda channel: func(y), x)
Upvotes: 2