Reputation: 61
I am trying to implement custom regularizer function for distributed learning to implement the penalty function as in the equation
I implemented the above functions as a layer wise regularizer, however it throws me error. Looking forward for the help from the community
@tf.keras.utils.register_keras_serializable(package='Custom', name='esgd')
def esgd(w, wt, mu):
delta = tf.math.square(tf.norm(w-wt))
rl = (mu/2)*delta
return rl
def model(w, wt, mu):
model = tf.keras.models.Sequential([tf.keras.layers.Conv2D(32,(3,3), padding='same', activation='relu',input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Conv2D(64,(3,3), padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(2,2),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128,activation='relu', kernel_initializer='ones',kernel_regularizer=esgd(w[0][7],wt[0][7],mu)
),
tf.keras.layers.Dropout(0.25),
tf.keras.layers.Dense(10, activation='softmax')
])
return model
----- Error -------
---> 59 model = init_model(w, wt, mu)
60
61 # model.set_weights(wei[0])
<ipython-input-5-e0796dd9fa55> in init_model(w, wt, mu)
11 tf.keras.layers.Dropout(0.25),
12 tf.keras.layers.Flatten(),
---> 13 tf.keras.layers.Dense(128,activation='relu', kernel_initializer='ones',kernel_regularizer=esgd(w[0][7],wt[0][7],mu)
14 ),
15 tf.keras.layers.Dropout(0.25),
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/core.py in __init__(self, units, activation, use_bias, kernel_initializer, bias_initializer, kernel_regularizer, bias_regularizer, activity_regularizer, kernel_constraint, bias_constraint, **kwargs)
1137 self.kernel_initializer = initializers.get(kernel_initializer)
1138 self.bias_initializer = initializers.get(bias_initializer)
-> 1139 self.kernel_regularizer = regularizers.get(kernel_regularizer)
1140 self.bias_regularizer = regularizers.get(bias_regularizer)
1141 self.kernel_constraint = constraints.get(kernel_constraint)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/regularizers.py in get(identifier)
313 return identifier
314 else:
--> 315 raise ValueError('Could not interpret regularizer identifier:', identifier)
ValueError: ('Could not interpret regularizer identifier:', <tf.Tensor: shape=(), dtype=float32, numpy=0.00068962533>)
Upvotes: 2
Views: 928
Reputation: 13458
According to Layer weight regularizers, you must subclass tensorflow.keras.regularizers.Regularizer if you want your regularizer to take additional parameters beyond the layer's weight tensor.
And it also looks like you are trying to support serialization, so don't forget to add the get_config
method.
from tensorflow.keras import regularizers
@tf.keras.utils.register_keras_serializable(package='Custom', name='esgd')
class ESGD(regularizers.Regularizer):
def __init__(self, mu):
self.mu = mu
def __call__(self, w):
return (mu/2) * tf.math.square(tf.norm(w - tf.transpose(w)))
def get_config(self):
return {'mu': self.mu}
and then you can use it with
tf.keras.layers.Dense(128, activation='relu', kernel_initializer='ones', kernel_regularizer=ESGD(mu=mu))
Upvotes: 1