Reputation: 889
I am setting trainable=False
in all my layers, implemented through the Model
API, but I want to verify whether that is working. model.count_params()
returns the total number of parameters, but is there any way in which I can get the total number of trainable parameters, other than looking at the last few lines of model.summary()
?
Upvotes: 27
Views: 27634
Reputation: 1119
A compact solution:
trainable_params = sum(prod(w.shape) for w in model.trainable_weights)
(requires from math import prod
).
Upvotes: 0
Reputation: 4906
from keras import backend as K
trainable_count = int(
np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
non_trainable_count = int(
np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
The above snippet can be discovered in the end of layer_utils.print_summary()
definition, which summary()
is calling.
Edit: more recent version of Keras has a helper function count_params()
for this purpose:
from keras.utils.layer_utils import count_params
trainable_count = count_params(model.trainable_weights)
non_trainable_count = count_params(model.non_trainable_weights)
Upvotes: 56
Reputation: 2316
For TensorFlow 2.0:
import tensorflow.keras.backend as K
trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])
print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))
Upvotes: 19
Reputation: 124
For tensorflow.keras this works for me. Its from the tensorflow github code for the function print_layer_summary_with_connections() in layer_utils.py
import numpy as np
from tensorflow.python.util import object_identity
def count_params(weights):
return int(sum(np.prod(p.shape.as_list())
for p in object_identity.ObjectIdentitySet(weights)))
if hasattr(model, '_collected_trainable_weights'):
trainable_count = count_params(model._collected_trainable_weights)
else:
trainable_count = count_params(model.trainable_weights)
print (trainable_count)
Upvotes: 1