Harsh Vardhan
Harsh Vardhan

Reputation: 31

How can I do indexwise Convolution type operation in Tensorflow using Conv2D?

So I have 64 Images/Feature maps of dimension 512x512 making it a cube of (512x512x64), I want to convolve each image with 64 kernels INDEXWISE.

Pictorial Representation

Example -
1st Image ------> 1st Kernel
2nd Image ------> 2nd Kernel
3rd Image ------> 3rd Kernel
.
.
.
64th Image -------> 64th Kernel

I want to do this with Conv2D in tensorflow, as far as i know Conv2D will take single image and convolve with each kernel,
1st Image --> all 64 kernels
2nd Image --> all 64 kernels
I dont want to do this

Upvotes: 0

Views: 46

Answers (1)

Susmit Agrawal
Susmit Agrawal

Reputation: 3764

One (inefficient but relatively simple) way to do this would be to use a custom layer:

class IndexConv(Layer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Define kernel size, strides, padding etc.

    def build(self, input_shape):
        # store the input kernels as a static value. This will save having to get the number of kernels at runtime.
        self.num_kernels = input_shape[-1]

        # If the input has n channels, you need n separate kernels
        # Since each kernel convolves a single channel, input and output channels for each kernel will be 1
        self.kernels = [self.add_weight(f'k{i}', (kernel_h, kernel_w, 1, 1), other_params) for i in range(input_shape[-1])]

    def call(self, inputs, **kwargs):
        # Split input vector into separate vectors, one vector per channel
        inputs = tf.unstack(inputs, axis=-1)
        
        # Convolve each input channel with corresponding kernel
        # This is the "inefficient" part I mentioned
        # Complex but more efficient versions can make use of tf.map_fn
        outputs = [
            tf.nn.conv2d(channel[i][:, :, :, None], self.kernels[i], other_params)
            for i in range(self.num_kernels)
        ]

        # return concatenated output
        return tf.concat(outputs, axis=-1)

Upvotes: 1

Related Questions