Michael Moretti
Michael Moretti

Reputation: 257

ValueError in model subclassing with tensorflow 2

I'm trying to implement a WideResnet using Model subclassing in keras. I cannot understand what's wrong in my code:

class ResidualBlock(layers.Layer):
  def __init__(self, filters, kernel_size, dropout, dropout_percentage, strides=1, **kwargs):
    super(ResidualBlock, self).__init__(**kwargs)
          
    self.conv_1 = layers.Conv2D(filters, (1, 1), strides=strides)
    self.bn_1 = layers.BatchNormalization()
    self.rel_1 = layers.ReLU()
    self.conv_2 = layers.Conv2D(filters, kernel_size, padding="same", strides=strides)
    self.dropout = layers.Dropout(dropout_percentage)
    self.bn_2 = layers.BatchNormalization()
    self.rel_2 = layers.ReLU()
    self.conv_3 = layers.Conv2D(filters, kernel_size, padding="same")
    
    self.add = layers.Add()
    self.dropout = dropout
    self.strides = strides

  def call(self, inputs):
    x = inputs

    if self.strides > 1:
      x = self.conv_1(x)
    res_x = self.bn_1(x)
    res_x = self.rel_1(x)
    res_x = self.conv_2(x)
    if self.dropout:
      res_x = self.dropout(x)
    res_x = self.bn_2(x)
    res_x = self.rel_2(x)
    res_x = self.conv_3(x)
    inputs = self.add([x, res_x])
    return inputs

class WideResidualNetwork(models.Model):
  def __init__(self, input_shape, n_classes, d, k, kernel_size=(3, 3), dropout=False, dropout_percentage=0.3, strides=1, **kwargs):
    
    super(WideResidualNetwork, self).__init__(**kwargs)

    if (d-4)%6 != 0:
      raise ValueError('Please choose a correct depth!')

    self.rel_1 = layers.ReLU()
    self.conv_1 = layers.Conv2D(16, (3, 3), padding='same')
    self.conv_2 = layers.Conv2D(16*k, (1, 1))
    self.dense = layers.Dense(n_classes)

    self.dropout = dropout
    self.dropout_percentage = dropout_percentage
    self.N = int((d - 4) / 6)
    self.k = k
    self.d = d
    self.kernel_size = kernel_size

  def build(self, input_shape):
    self.bn_1 = layers.BatchNormalization(input_shape=input_shape)

  def call(self, inputs):
    x = self.bn_1(inputs)
    x = self.rel_1(x)
    x = self.conv_1(x)
    x = self.conv_2(x)

    for _ in range(self.N):
      x = ResidualBlock(16*self.k, self.kernel_size, self.dropout, self.dropout_percentage)(x)
    
    x = ResidualBlock( 32*self.k, self.kernel_size, self.dropout, self.dropout_percentage, strides=2)(x)

    for _ in range(self.N-1):
      x = ResidualBlock( 32*self.k, self.kernel_size, self.dropout, self.dropout_percentage)(x)

    x = ResidualBlock( 64*self.k, self.kernel_size, self.dropout, self.dropout_percentage, strides=2)(x)
    
    for _ in range(self.N-1):
      x = ResidualBlock( 64*self.k, self.kernel_size, self.dropout, self.dropout_percentage)(x)
    
    x = layers.GlobalAveragePooling2D()(x)
    x = self.dense(x)
    x = layers.Activation("softmax")(x)

    return x

When i try to fit the model in this way:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
model = WideResidualNetwork(x_train[0].shape, 10, 28, 1)
x_train, x_test = x_train/255. , x_test/255.
model = WideResidualNetwork(x_train[0].shape, 10, 28, 1)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

epochs = 40
batch_size = 64
validation_split = 0.2
h = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_split=validation_split)

I got the following error:

...
 <ipython-input-26-61c1bdb3546c>:31 call  *
        x = ResidualBlock(16*self.k, self.kernel_size, self.dropout, self.dropout_percentage)(x)
    <ipython-input-9-3fea1e77cb6e>:23 call  *
        res_x = self.bn_1(x)
...
ValueError: tf.function-decorated function tried to create variables on non-first call.

So I didn't understand where is the problem, I also tried to move the initialization into the build, but without results, the error persists. Probably I have some gaps in my knowledge Thank you in advance

Upvotes: 2

Views: 116

Answers (1)

Federico A.
Federico A.

Reputation: 266

You are initializing ResidualBlocks, GlobalAveragePooling2D, and Activation layers into the call method. Try to move them into the init, as you did for other layers, and it shouldn't give you that error.

Upvotes: 2

Related Questions