guillaumefrd
guillaumefrd

Reputation: 127

Memory leak with Keras Lambda layer

I need to split the channels of a Tensor to apply different normalizations for each split. To do so, I use the Lambda layer from Keras:

# split the channels in two (first part for IN, second for BN)
x_in = Lambda(lambda x: x[:, :, :, :split_index])(x)
x_bn = Lambda(lambda x: x[:, :, :, split_index:])(x)

# apply IN and BN on their respective group of channels
x_in = InstanceNormalization(axis=3)(x_in)
x_bn = BatchNormalization(axis=3)(x_bn)

# concatenate outputs of IN and BN
x = Concatenate(axis=3)([x_in, x_bn])

Everything works as expected (see model.summary() bellow) but the RAM keeps increasing at each iteration, indicating a memory leak.

Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, 832, 832, 1)  0
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 832, 832, 32) 320         input_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 832, 832, 16) 0           conv1[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 832, 832, 16) 0           conv1[0][0]
__________________________________________________________________________________________________
instance_normalization_1 (Insta (None, 832, 832, 16) 32          lambda_1[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 832, 832, 16) 64          lambda_2[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 832, 832, 32) 0           instance_normalization_1[0][0]
                                                                 batch_normalization_1[0][0]
__________________________________________________________________________________________________

I am sure the leak comes from the Lambda layer as I tried another strategy where I don't split but apply the two normalizations independently on all the channels and then add the features together. I didn't experience any memory leak with this code:

# apply IN and BN on the input tensor independently
x_in = InstanceNormalization(axis=3)(x)
x_bn = BatchNormalization(axis=3)(x)

# addition of the feature maps outputed by IN and BN
x = Add()([x_in, x_bn])

Any idea to resolve this memory leak ? I am using Keras 2.2.4 with Tensorflow 1.15.3, and I can't upgrade to TF 2 or tf.keras for now.

Upvotes: 3

Views: 329

Answers (2)

guillaumefrd
guillaumefrd

Reputation: 127

Thibault Bacqueyrisses answer was right, the memory leak disappeared with a custom layer!

Here is my implementation:

class Crop(keras.layers.Layer):
    def __init__(self, dim, start, end, **kwargs):
        """
        Slice the tensor on the last dimension, keeping what is between start
        and end.
        Args
            dim   (int)   : dimension of the tensor (including the batch dim)
            start (int)   : index of where to start the cropping
            end   (int)   : index of where to stop the cropping
        """
        super(Crop, self).__init__(**kwargs)
        self.dimension = dim
        self.start = start
        self.end = end

    def call(self, inputs):
        if self.dimension == 0:
            return inputs[self.start:self.end]
        if self.dimension == 1:
            return inputs[:, self.start:self.end]
        if self.dimension == 2:
            return inputs[:, :, self.start:self.end]
        if self.dimension == 3:
            return inputs[:, :, :, self.start:self.end]
        if self.dimension == 4:
            return inputs[:, :, :, :, self.start:self.end]

    def compute_output_shape(self, input_shape):
        return (input_shape[:-1] + (self.end - self.start,))

    def get_config(self):
        config = {
            'dim': self.dimension,
            'start': self.start,
            'end': self.end,
        }
        base_config = super(Crop, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Upvotes: 4

Thibault Bacqueyrisses
Thibault Bacqueyrisses

Reputation: 2331

You may want to consider using a custom layer instead of a lambda layer.
It's possible that the keras lambda layer have some failures.

Upvotes: 1

Related Questions