Reputation: 2019
In TensorFlow, I have a tensor x
with shape [batch_of_batches_size,batch_of_images_size,image_height,image_width,nchannels]
. It represents a batch of batches of images. The first index of x
, which I'll refer to as the "batch index", points to a specific batch, and the second index points to a specific image inside the batch. The three remaining indexes of x
represent the image itself, which has dimensions image_height
-by-image_width
and nchannels
channels.
I want to apply 2d convolutions to the images. The filters have height filter_height
and width filter_width
, and I want to use 'SAME'
padding and all strides equal to 1
. But for each batch index i
, I want to apply a different set of filters. The filters tensor is named w
and its shape is [batch_of_batches_size,filter_height,filter_width,nchannels,nfilters]
. For each batch index i
, what I want is to apply to the images x[i,:,:,:,:]
the filter w[i,:,:,:,:]
(similarly to how I would apply the function conv2d
). I want all the results to be held in the tensor y
with shape [batch_of_batches_size,batch_of_images_size,image_height,image_width,nfilters]
, such that the result for this i
would be in y[i,:,:,:,:]
.
Mathematically, what I want is:
y[i,j,k,l,m] = SUM_{a,b,u} x[i,j,a,b,u]*w[i,k+a,l+b,u,m]
This is identical to conv2d
, besides the first index i
.
I'd like to know if there is a way of doing this, using built-in functions in TensorFlow. I know I can use a for
loop over the first dimension (the batch index), and use conv2d
on rehspaed slices of x
and filters
on each iteration. But is there a simpler, more efficient or more elegant way, that does not require slicing and indexing into tensors?
Some ideas that I had but didn't lead me to a solution were (1) to reshape\transpose x
and w
, use conv2d
or depthwise_conv2d
, and then reshape\transpose again. (2) to use conv3d
on x
and on a padded version of w
.
Upvotes: 1
Views: 574
Reputation: 2019
I found a solution:
y = tf.map_fn(lambda u: tf.nn.conv2d(u[0],u[1],padding='SAME',strides=[1,1,1,1]),.elems=[x,w],dtype=tf.float32)
The map_fn
function lets applying the conv2d
operator to each first-dimension slice of x
and w
.
Upvotes: 1