Reputation: 762
I want to reset the weights of my convolutional neural network if a "nan" is detected.
Im not sure how to do it.
Im also confused if i should change the seed as well in this case.
if np.isnan(trainingLoss):
print "..Training Loss is NaN"
self.reset_network()
if np.isnan(validationLoss):
print "..Validation Loss is NaN"
self.reset_network()
How should i implement reset_network() ?
Upvotes: 1
Views: 245
Reputation: 110
I'm not sure this is the intended way of resetting network weights, but here's how I did it. In the following code network
is a reference to a CNN with 2 convolutional layers followed by max pooling layers. I believe it should work with other architectures as well.
The trick here is to update all trainable parameters of the network with initializer functions.
def reset_weights(network):
params = lasagne.layers.get_all_params(network, trainable=True)
for v in params:
val = v.get_value()
if(len(val.shape) < 2):
v.set_value(lasagne.init.Constant(0.0)(val.shape))
else:
v.set_value(lasagne.init.GlorotUniform()(val.shape))
I hope it helps!
Upvotes: 1