sutradhar
sutradhar

Reputation: 93

Keras freeze specific weights with mask

I am new in Keras. I want to implement a layer where not all the weights will update. For example, in the following code, I want the dilation layer will update in a way that some center weights are never updated. For say, the shape of each feature matrix (out of 1024) in the dilation layer is 448, 448 and a block of 8x8 at the center of all feature matrices will never be updated, i.e. the 8x8 block is a (non-trainable) mask to the feature matrices.

input_layer=Input(shape=(896,896,3))
new_layer = Conv2D(32, kernel_size=(3,3), padding="same", activation='relu', kernel_initializer='he_normal')(input_layer)
new_layer = MaxPooling2D(pool_size=(2, 2), strides=(2,2), padding='same', data_format=None)(new_layer)
new_layer = Conv2D(64, kernel_size=(3,3), padding='same', activation='relu', kernel_initializer='he_normal')(new_layer)
new_layer = Conv2D(1024, kernel_size=(7,7), dilation_rate=8, padding="same", activation='relu', kernel_initializer='he_normal', name='dialation')(new_layer)
new_layer = Conv2D(32, kernel_size=(1,1), padding="same", activation='relu', kernel_initializer='he_normal')(new_layer)
new_layer = Conv2D(32, kernel_size=(1,1), padding="same", activation='relu', kernel_initializer='he_normal')(new_layer)

model = Model(input_layer, new_layer)

I was trying with the Keras's custom layer [link], but it was difficult for me to understand. Anyone would please help.

UPDATE: I added the following figure for a better understanding. The dilation layer contains 1024 features. I want the middle region of each feature to be non-trainable (static).

image of dilation layer

Upvotes: 3

Views: 1228

Answers (2)

Priyanka Srs
Priyanka Srs

Reputation: 21

I am working on clustering weights and then freezing specific clusters and train the network.

I am trying to freeze the specific weights in this network using the above example. But I'm not sure how to set the shape of the mask and custom layer in run_certain_weights().

This is the code I'm using:

from keras.layers import Dense, Flatten, Lambda
from keras.utils import to_categorical
from keras.models import Sequential, load_model
from keras.datasets import mnist
from keras.losses import categorical_crossentropy
from keras.backend import stop_gradient
import numpy as np

def load_data():
    
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    x_train /= 255
    x_test /= 255
    y_train = to_categorical(y_train, num_classes=10)
    y_test = to_categorical(y_test, num_classes=10)
    
    return x_train, y_train, x_test, y_test

def run():
    x_train, y_train, x_test, y_test = load_data()

    model=Sequential(Flatten(input_shape=(28, 28)))

    layer = Dense(300, name='dense1', activation='relu')
    layer.trainable=True
    model.add(layer)
    layer2 = Dense(100, name='dense2', activation='relu')
    layer2.trainable=False
    model.add(layer2)
    layer3 = Dense(10, name='dense3', activation='softmax')
    model.add(layer3)

    model.compile(loss=categorical_crossentropy, optimizer='Adam',metrics ['accuracy'])


    print(model.summary())
    print("x_train.shape():",x_train.shape)
    print("y_train.shape()",y_train.shape)

    model.fit(x_train, y_train, epochs=5, verbose=2)
    print(model.evaluate(x_test, y_test))
    
    return model

def stopBackprop(x):
    
    stopped=stop_gradient(x)
    return x*(1-mask) + stopped*mask

def run_certain_weights():
    
    x_train, y_train, x_test, y_test = load_data()
    model=Sequential(Flatten(input_shape=(28, 28)))
    mask = np.zeros((300,))
    print(mask.shape)
    mask[220:228,] = 1

    layer = Dense(300, name='dense1', activation='relu')
    layer.trainable=False
    model.add(layer)

    #before the dense2 layer
    new_layer=Lambda(stopBackprop)(layer)
    model.add(new_layer)

    layer2 = Dense(300, name='dense2', activation='relu')
    layer2.trainable=True
    model.add(layer2)
    layer3 = Dense(10, name='dense3', activation='softmax')
    model.add(layer3)

    model.compile(loss=categorical_crossentropy, optimizer='Adam',metrics = ['accuracy'])


    print(model.summary())
    print("x_train.shape():",x_train.shape)
    print("y_train.shape()",y_train.shape)

    model.fit(x_train, y_train, epochs=5, verbose=2)
    print(model.evaluate(x_test, y_test))
    return model


def freeze(model):
    x_train, y_train, x_test, y_test = load_data()

    name = 'dense2'

    weightsAndBias = model.get_layer(name=name).get_weights()

    # freeze the weights of this layer
    model.get_layer(name=name).trainable = False

    # record the weights before retrain
    weights_before = weightsAndBias[0]
    # retrain
    print("x_train.shape():",x_train.shape)
    print("y_train.shape()",y_train.shape)

    model.fit(x_train, y_train, verbose=2, epochs=1)
    weights_after = model.get_layer(name=name).get_weights()[0]

    if (weights_before == weights_after).all():
        print('the weights did not change!!!')
    else:
        print('the weights changed!!!!')

if __name__ == '__main__':
    
    model = run()
    freeze(model)

    model = run_certain_weights()
    freeze(model)

Upvotes: 0

Daniel Möller
Daniel Möller

Reputation: 86600

Use this mask for both cases:

mask = np.zeros((1,448,448,1))
mask[:,220:228,220:228] = 1

Replacing part of the feature

If you replace part of the feature with constant values, this means the feature will be static, but it will still participate in backpropagation (because weights will still be multiplied and summed for this part of the image and there is a connection)

constant = 0 (will annulate kernel, but not bias) 

def replace(x):
    return x*(1-mask) + constant*mask

#before the dilation layer
new_layer=Lambda(replace)(new_layer) 

Keeping the feature value, but stopping backpropagation

Here, the weights of the dilation layer and further will be updated normally, but the weights before the dilation layer will not receive the influence of the central region.

def stopBackprop(x):
    stopped=K.stop_gradients(x)
    return x*(1-mask) + stopped*mask

#before the dilation layer
new_layer=Lambda(stopBackprop)(new_layer) 

Upvotes: 1

Related Questions