f9786e94
f9786e94

Reputation: 11

Group Normalization and Weight Standardization in Keras

I am implementing weight standardization and Group normalization in tensorflow using keras on a resnet 50 following the original paper https://arxiv.org/pdf/1903.10520v1.pdf.

While weight standardization works on all convolutional layers, there seems to be an issue with group normalization after some Conv2d layers. In those cases no loss reduction or accuracy gain is achieved. I train the returned model with model.fit on CIFAR10 with different batch sizes (16-512).The positions of group norm which cause issues are marked. Can anyone provide me with some insight why this might be a problem?

!pip install -q -U tensorflow-addons

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, Activation, BatchNormalization, Dense, Add, Concatenate
from tensorflow.keras.layers import ZeroPadding2D, Input, MaxPooling2D, AveragePooling2D, Flatten
from tensorflow.keras import activations
from tensorflow.keras import Model
from tensorflow_addons.layers import GroupNormalization

def ws_reg(kernel):
    kernel_mean = tf.math.reduce_mean(kernel, axis=[0, 1, 2], keepdims=True, name='kernel_mean')
    kernel = kernel - kernel_mean
    kernel_std = tf.keras.backend.std(kernel, axis=[0, 1, 2], keepdims=True)
    kernel = kernel / (kernel_std + 1e-5)
    #return kernel
    
def res_identity(x, filters): 
  #renet block where dimension doesnot change.
  #The skip connection is just simple identity conncection
  #we will have 3 blocks and then input will be added

  x_skip = x # this will be used for addition with the residual block 
  f1, f2 = filters

  #first block 
  x = Conv2D(f1, kernel_size=(1, 1), strides=(1, 1), padding='valid', kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x)
  #x = BatchNormalization()(x)
  x = Activation(activations.relu)(x)

  #second block # bottleneck (but size kept same with padding)
  x = Conv2D(f1, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x)
  #x = BatchNormalization()(x)
  x = Activation(activations.relu)(x)

  # third block activation used after adding the input
  x = Conv2D(f2, kernel_size=(1, 1), strides=(1, 1), padding='valid', kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x) ###ISSUE
  #x = BatchNormalization()(x)

  # add the input 
  x = Add()([x, x_skip])
  x = Activation(activations.relu)(x)

  return x


def res_conv(x, s, filters):
  '''
  here the input size changes''' 
  x_skip = x
  f1, f2 = filters

  # first block
  x = Conv2D(f1, kernel_size=(1, 1), strides=(s, s), padding='valid', kernel_regularizer=ws_reg)(x)
  # when s = 2 then it is like downsizing the feature map
  x = GroupNormalization(groups=16, axis=-1)(x)
  #x = BatchNormalization()(x)
  x = Activation(activations.relu)(x)

  # second block
  x = Conv2D(f1, kernel_size=(3, 3), strides=(1, 1), padding='same', kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x)
  #x = BatchNormalization()(x)
  x = Activation(activations.relu)(x)

  #third block
  x = Conv2D(f2, kernel_size=(1, 1), strides=(1, 1), padding='valid', kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x) ###ISSUE
  #x = BatchNormalization()(x)

  # shortcut 
  x_skip = Conv2D(f2, kernel_size=(1, 1), strides=(s, s), padding='valid', kernel_regularizer=ws_reg)(x_skip)
  x_skip = GroupNormalization(groups=16, axis=-1)(x_skip) ###ISSUE
  #x_skip = BatchNormalization()(x_skip)

  # add 
  x = Add()([x, x_skip])
  x = Activation(activations.relu)(x)

  return x

  
def resnet50(train_im):

  input_im = Input(shape=(train_im[0], train_im[1], train_im[2])) # cifar 10 images size
  x = ZeroPadding2D(padding=(3, 3))(input_im)

  # 1st stage
  # here we perform maxpooling, see the figure above

  x = Conv2D(64, kernel_size=(7, 7), strides=(2, 2), kernel_regularizer=ws_reg)(x)
  x = GroupNormalization(groups=16, axis=-1)(x) ###ISSUE
  #x = BatchNormalization()(x)
  x = Activation(activations.relu)(x)
  x = MaxPooling2D((3, 3), strides=(2, 2))(x)

  #2nd stage 
  # from here on only conv block and identity block, no pooling

  x = res_conv(x, s=1, filters=(64, 256))
  x = res_identity(x, filters=(64, 256))
  x = res_identity(x, filters=(64, 256))

  # 3rd stage

  x = res_conv(x, s=2, filters=(128, 512))
  x = res_identity(x, filters=(128, 512))
  x = res_identity(x, filters=(128, 512))
  x = res_identity(x, filters=(128, 512))

  # 4th stage

  x = res_conv(x, s=2, filters=(256, 1024))
  x = res_identity(x, filters=(256, 1024))
  x = res_identity(x, filters=(256, 1024))
  x = res_identity(x, filters=(256, 1024))
  x = res_identity(x, filters=(256, 1024))
  x = res_identity(x, filters=(256, 1024))

  # 5th stage

  x = res_conv(x, s=2, filters=(512, 2048))
  x = res_identity(x, filters=(512, 2048))
  x = res_identity(x, filters=(512, 2048))

  # ends with average pooling and dense connection

  x = AveragePooling2D((2, 2), padding='same')(x)

  x = Flatten()(x)
  x = Dense(10, activation='softmax', kernel_initializer='he_normal')(x) #multi-class

  # define the model 

  model = Model(inputs=input_im, outputs=x, name='Resnet50')

  return model

Upvotes: 1

Views: 1669

Answers (1)

Jamil
Jamil

Reputation: 11

It seems to me, that passing ws_reg function to kernel_regularizer in Conv2D is incorrect way of doing WeightStandartization.

According to TF and Keras docs output of kernel_regularizer is added to loss, not applied to kernel. Proper way can look like that:

class WSConv2D(tf.keras.layers.Conv2D):
    def __init__(self, *args, **kwargs):
        super(WSConv2D, self).__init__(kernel_initializer="he_normal", *args, **kwargs)

    def standardize_weight(self, weight, eps):

        mean = tf.math.reduce_mean(weight, axis=(0, 1, 2), keepdims=True)
        var = tf.math.reduce_variance(weight, axis=(0, 1, 2), keepdims=True)
        fan_in = np.prod(weight.shape[:-1])
        gain = self.add_weight(
            name="gain",
            shape=(weight.shape[-1],),
            initializer="ones",
            trainable=True,
            dtype=self.dtype,
        )
        scale = (
            tf.math.rsqrt(
                tf.math.maximum(var * fan_in, tf.convert_to_tensor(eps, dtype=self.dtype))
            )
            * gain
        )
        return weight * scale - (mean * scale)

    def call(self, inputs, eps=1e-4):
        self.kernel.assign(self.standardize_weight(self.kernel, eps))
        return super().call(inputs)

Upvotes: 1

Related Questions