zplizzi
zplizzi

Reputation: 1006

Convert between NHWC and NCHW in TensorFlow

What is the best way to convert a tensor from NHWC format to NCHW format, and vice versa?

Is there an op specifically that does this, or will I need to use some combination of the split/concat type operations?

Upvotes: 50

Views: 59974

Answers (4)

John Pope
John Pope

Reputation: 1

trick shot

class PyConv2D(tf.keras.layers.Conv2D):
    def __init__(self, 
                 filters, 
                 kernel_size, 
                 strides=(1, 1), 
                 padding='valid', 
                 data_format='channels_first', 
                 dilation_rate=(1, 1), 
                 groups=1,
                 activation=None, 
                 use_bias=True, 
                 kernel_initializer='glorot_uniform', 
                 bias_initializer='zeros', 
                 kernel_regularizer=None, 
                 bias_regularizer=None, 
                 activity_regularizer=None, 
                 kernel_constraint=None, 
                 bias_constraint=None, 
                 **kwargs):
        super().__init__(
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            data_format=data_format,
            dilation_rate=dilation_rate,
            groups=groups,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs
        )

Upvotes: 0

Sovik Gupta
Sovik Gupta

Reputation: 157

For the latest TF2 models, we have a functionality in tf2onnx package. tf2onnx.convert.from_keras(input_as_nchw = [List]) is the latest function update which can be used while converting the model from .pb format to .onnx also it successfully converts the NHWC to NCHW. https://github.com/onnx/tensorflow-onnx/blob/e896723e410a59a600d1a73657f9965a3cbf2c3b/tf2onnx/convert.py#L408

Upvotes: 0

Olivier Moindrot
Olivier Moindrot

Reputation: 28208

All you need to do is a permutation of the dimensions from NHWC to NCHW (or the contrary).

The meaning of each letter might help understand:

  • N: number of images in the batch
  • H: height of the image
  • W: width of the image
  • C: number of channels of the image (ex: 3 for RGB, 1 for grayscale...)

From NHWC to NCHW

The image shape is (N, H, W, C) and we want the output to have shape (N, C, H, W). Therefore we need to apply tf.transpose with a well chosen permutation perm.

The returned tensor's dimension i will correspond to the input dimension perm[i]

perm[0] = 0  # output dimension 0 will be 'N', which was dimension 0 in the input
perm[1] = 3  # output dimension 1 will be 'C', which was dimension 3 in the input
perm[2] = 1  # output dimension 2 will be 'H', which was dimension 1 in the input
perm[3] = 2  # output dimension 3 will be 'W', which was dimension 2 in the input

In practice:

images_nhwc = tf.placeholder(tf.float32, [None, 200, 300, 3])  # input batch
out = tf.transpose(images_nhwc, [0, 3, 1, 2])
print(out.get_shape())  # the shape of out is [None, 3, 200, 300]

From NCHW to NHWC

The image shape is (N, C, H, W) and we want the output to have shape (N, H, W, C). Therefore we need to apply tf.transpose with a well chosen permutation perm.

The returned tensor's dimension i will correspond to the input dimension perm[i]

perm[0] = 0  # output dimension 0 will be 'N', which was dimension 0 in the input
perm[1] = 2  # output dimension 1 will be 'H', which was dimension 2 in the input
perm[2] = 3  # output dimension 2 will be 'W', which was dimension 3 in the input
perm[3] = 1  # output dimension 3 will be 'C', which was dimension 1 in the input

In practice:

images_nchw = tf.placeholder(tf.float32, [None, 3, 200, 300])  # input batch
out = tf.transpose(images_nchw, [0, 2, 3, 1])
print(out.get_shape())  # the shape of out is [None, 200, 300, 3]

Upvotes: 88

Rishi
Rishi

Reputation: 17

To convert 'NCHW' to 'NHWC'

from keras import backend
backend.set_image_data_format('channels_last') #channels_first for NCHW

Upvotes: 0

Related Questions