Mina Gabriel
Mina Gabriel

Reputation: 25090

Tensorflow Keras Conv2D multiple filters

I don't really understand Keras Conv2D output if I have a 1X2X3X3 input (I am using channel first) and weights 2X2X2X2 as in the following image, can someone help me to understand the output feature map, how do the filters convolve over the input to get the output?

enter image description here

Here is my code:

import os

import tensorflow as to
import tensorflow.python.util.deprecation as deprecation
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv1D, Conv2D


data = tf.range(3 * 3 * 2)
print(data)
data = tf.reshape(data, (1, 2, 3, 3))
print(data)

print('-------')
e = tf.range(2 * 2 * 2 * 2)
print(e)
e = tf.reshape(e, (2, 2, 2, 2))
print(e)
print('-------')

model = Sequential()
model.add(Conv2D(2, (2, 2), input_shape=(2, 3, 3), data_format='channels_first'))

weights = [e, tf.constant([0.0,0.0])]
model.set_weights(weights)

print(model.get_weights())

yhat = model.predict(data)
print(yhat.shape)
print(yhat)

enter image description here

Upvotes: 2

Views: 838

Answers (1)

javidcf
javidcf

Reputation: 59701

It is easier to understand if you change the perspective when looking at each operator. You have an input with shape 1x2x3x3. Since you are using data_format='channels_first', that means you have 1 image with 2 channels and size 3x3. You can visualize that image like this:

| [ 0  9] [ 1 10] [ 2 11] |
| [ 3 12] [ 4 13] [ 5 14] |
| [ 6 15] [ 7 16] [ 8 17] |

This is your 3x3 image where each "pixel" has two channels. The filters shape is 2x2x2x2, meaning a 2x2 filter going from 2 channels to 2 channels. This can be represented like this:

|  0  1 |  |  4   5 |
|  2  3 |  |  6   7 |

|  8  9 |  | 12  13 |
| 10 11 |  | 14  15 |

This is your 2x2 filter, where each filter position contains a 2x2 matrix. The result, with shape 1x2x2x2, is 1 image with 2 channels and size 2x2:

| [456 508] [512 571] |
| [624 700] [680 764] |

To understand how the operation works, I will walk through the computation of the first "pixel" of the output, [456 508]. This output is computed from the first 2x2 window in the input image:

| [ 0  9] [ 1 10] |
| [ 3 12] [ 4 13] |

What you have to do is take each of the "pixels" (the two-element vectors) and multiply them by the matrix in the corresponding position in the filter:

# Top-left
          |  0  1 |
[ 0  9] x |       | = [18 27]
          |  2  3 |
# Top-right
          |  4  5 |
[ 1 10] x |       | = [64 75]
          |  6  7 |
# Bottom-left
          |  8  9 |
[ 3 12] x |       | = [144 159]
          | 10 11 |
# Bottom-right
          | 12 13 |
[ 4 13] x |       | = [230 247]
          | 14 15 |

Then, you simply add all the resulting vectors:

[18 27] + [64 75] + [144 159] + [230 247] = [456 508]

The rest of outputs are computed in the same way, for example the output [512 571] would be computed by applying the filters to the next image window:

| [ 1 10] [ 2 11] |
| [ 4 13] [ 5 14] |

And so on.

Upvotes: 3

Related Questions