Reputation: 37474
Following Tensorflow's best practices for performance, I am using NCHW data format, but I am not sure about the filter shape to be used in tensorflow.nn.conv2d.
The doc says to use [filter_height, filter_width, in_channels, out_channels]
for NHWC format, but is not clear about what to do with NCHW.
Should the same shape be used ?
Upvotes: 5
Views: 2615
Reputation: 19
Using the same filter shape should work. The only change to the function arguments is the stride. As an example let's say you wanted your architecture to work with both formats, which is also recommended:
# input -> Tensor in NCHW format
if use_nchw:
result = tf.nn.conv2d(
input=input,
filter=filter,
strides=[1, 1, stride, stride],
data_format='NCHW')
else:
input_t = tf.transpose(input, [0, 2, 3, 1]) # NCHW to NHWC
result = tf.nn.conv2d(
input=input_t,
filter=filter,
strides=[1, stride, stride, 1])
result = tf.transpose(result, [0, 3, 1, 2]) # NHWC to NCHW
Upvotes: 0