Freeze specific filters of a layer for pruning in Keras

During the pruning step, I zero out some filters of a Depthwise convolution. After doing this, I need to retrain the network but those weights that have been zeroed out (I have the list of indexes) should not be updated during the training, their values need to remain equal to zero. So if I have 150 filters in the Depthwise layer (I'm not counting the bias), is there any way to freeze just a subset of them?

For example, the weights of the filters are x:

x=model.layers[4].get_weights()[0]

And x is a ndarray of 150 numbers. Ideally, if I have the list of zeroed out indices pruned_filters, I would like to do something like:

x[pruned_filters].trainable = False # I know this is wrong, it's just an example

Or move them to the non_trainable_weights

Upvotes: 0

Views: 361

Answers (1)

Marco Cerliani
Marco Cerliani

Reputation: 22031

you can't freeze only specific filters. what you can do, if you retain it valuable, is to set them to 0. but all the others remain not trainable. here an example:

inp = Input((10,10,3))
c = Conv2D(32, kernel_size=(3, 3),
           activation='relu')
f = Flatten()
d = Dense(10, activation='softmax')

x = c(inp)
x = f(x)
out = d(x)
model = Model(inp, out)
print(model.summary())

# model.fit(.....)

pruned_filters = [1,5,9]
w,b = c.get_weights()
w[:,:,:,pruned_filters] = 0
c.set_weights([w,b])

model.layers[1].trainable = False

# model.fit(.....)

otherwise, you can apply a mask... the mask not consider the value with a specific value to compute backpropagation... in your case, this maintains the zero filters unmodified

inp = Input((10,10,3))
c = Conv2D(32, kernel_size=(3, 3),
           activation='relu')
f = Flatten()
d = Dense(10)

x = c(inp)
x = f(x)
out = d(x)
model1 = Model(inp, out)
model1.compile('adam', 'mse')
model1.fit(np.random.uniform(0,1, (5,10,10,3)), np.random.uniform(0,1, (5,10)))

pruned_filters = [1,5,9]
w,b = c.get_weights()
w[:,:,:,pruned_filters] = 0
c.set_weights([w,b])
print(w)

mask = Masking(mask_value=0)
x = c(inp)
x = mask(x)
x = f(x)
out = d(x)
model2 = Model(inp, out)
model2.compile('adam', 'mse')
model2.fit(np.random.uniform(0,1, (5,10,10,3)), np.random.uniform(0,1, (5,10)))

w,b = c.get_weights()
print(w)

Upvotes: 1

Related Questions