Lior
Lior

Reputation: 2019

TensorFlow: Applying a batch of batches of filters to a batch of batches of images

In TensorFlow, I have a tensor x with shape [batch_of_batches_size,batch_of_images_size,image_height,image_width,nchannels]. It represents a batch of batches of images. The first index of x, which I'll refer to as the "batch index", points to a specific batch, and the second index points to a specific image inside the batch. The three remaining indexes of x represent the image itself, which has dimensions image_height-by-image_width and nchannels channels. I want to apply 2d convolutions to the images. The filters have height filter_height and width filter_width, and I want to use 'SAME' padding and all strides equal to 1. But for each batch index i, I want to apply a different set of filters. The filters tensor is named w and its shape is [batch_of_batches_size,filter_height,filter_width,nchannels,nfilters]. For each batch index i, what I want is to apply to the images x[i,:,:,:,:] the filter w[i,:,:,:,:] (similarly to how I would apply the function conv2d). I want all the results to be held in the tensor y with shape [batch_of_batches_size,batch_of_images_size,image_height,image_width,nfilters], such that the result for this i would be in y[i,:,:,:,:].

Mathematically, what I want is:

y[i,j,k,l,m] = SUM_{a,b,u} x[i,j,a,b,u]*w[i,k+a,l+b,u,m]

This is identical to conv2d, besides the first index i.

I'd like to know if there is a way of doing this, using built-in functions in TensorFlow. I know I can use a for loop over the first dimension (the batch index), and use conv2d on rehspaed slices of x and filters on each iteration. But is there a simpler, more efficient or more elegant way, that does not require slicing and indexing into tensors?

Some ideas that I had but didn't lead me to a solution were (1) to reshape\transpose x and w, use conv2d or depthwise_conv2d, and then reshape\transpose again. (2) to use conv3d on x and on a padded version of w.

Upvotes: 1

Views: 574

Answers (1)

Lior
Lior

Reputation: 2019

I found a solution:

y = tf.map_fn(lambda u: tf.nn.conv2d(u[0],u[1],padding='SAME',strides=[1,1,1,1]),.elems=[x,w],dtype=tf.float32)

The map_fn function lets applying the conv2d operator to each first-dimension slice of x and w.

Upvotes: 1

Related Questions