Reputation: 3636
I would like to be able to reset the weights of my entire Keras model so that I do not have to compile it again. Compiling the model is currently the main bottleneck of my code. Here is an example of what I mean:
import tensorflow as tf
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(16, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
data = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = data.load_data()
model.fit(x=x_train, y=y_train, epochs=10)
# Reset all weights of model here
# model.reset_all_weights() <----- something like that
model.fit(x=x_train, y=y_train, epochs=10)
Upvotes: 9
Views: 8181
Reputation: 33
In case anyone has nested models, the answer from Moritz requires a slight modification as follows:
def reinitialize(model):
for l in model.layers:
if isinstance(l, tf.keras.Model):
reinitialize(l)
continue
if hasattr(l,"kernel_initializer"):
l.kernel.assign(l.kernel_initializer(tf.shape(l.kernel)))
if hasattr(l,"bias_initializer"):
l.bias.assign(l.bias_initializer(tf.shape(l.bias)))
if hasattr(l,"recurrent_initializer"):
l.recurrent_kernel.assign(l.recurrent_initializer(tf.shape(l.recurrent_kernel)))
Upvotes: 1
Reputation: 101
I wrote a function that reinitializes weights in tensorflow 2.
def reinitialize(model):
for l in model.layers:
if hasattr(l,"kernel_initializer"):
l.kernel.assign(l.kernel_initializer(tf.shape(l.kernel)))
if hasattr(l,"bias_initializer"):
l.bias.assign(l.bias_initializer(tf.shape(l.bias)))
if hasattr(l,"recurrent_initializer"):
l.recurrent_kernel.assign(l.recurrent_initializer(tf.shape(l.recurrent_kernel)))
It took me way longer than it should have to come up with this and i tried many things that failed in my specific use case. IMO this should be a standard TF feature.
Upvotes: 10
Reputation: 36704
You can use this loop:
for ix, layer in enumerate(model.layers):
if hasattr(model.layers[ix], 'kernel_initializer') and \
hasattr(model.layers[ix], 'bias_initializer'):
weight_initializer = model.layers[ix].kernel_initializer
bias_initializer = model.layers[ix].bias_initializer
old_weights, old_biases = model.layers[ix].get_weights()
model.layers[ix].set_weights([
weight_initializer(shape=old_weights.shape),
bias_initializer(shape=len(old_biases))])
Original weights:
model.layers[1].get_weights()[0][0]
array([ 0.4450057 , -0.13564804, 0.35884023, 0.41411972, 0.24866664,
0.07641453, 0.45726687, -0.04410008, 0.33194816, -0.1965386 ,
-0.38438258, -0.13263905, -0.23807487, 0.40130925, -0.07339832,
0.20535922], dtype=float32)
New weights:
model.layers[1].get_weights()[0][0]
array([-0.4607593 , -0.13104361, -0.0372932 , -0.34242013, 0.12066692,
-0.39146423, 0.3247317 , 0.2635846 , -0.10496247, -0.40134245,
0.19276887, 0.2652442 , -0.18802321, -0.18488845, 0.0826562 ,
-0.23322225], dtype=float32)
Upvotes: 5