Reputation: 651
I'm implementing WGAN and need to clip weight variables.
I'm currently using Tensorflow with Keras as high-level API. Thus building layers with Keras to avoid manually creation and initialization of variables.
The problem is WGAN need to clip weight varibales, This can be done using tf.clip_by_value(x, v0, v1)
once I got those weight variable tensors, but I don't know to how to get them safely.
One possible solution maybe using tf.get_collection()
to get all trainable variables. But I don't know how to get only weight variable without bias variables.
Another solution is layer.get_weights()
, but it get numpy
arrays, although I can clip them with numpy
APIs and set them using layer.set_weights()
, but this may need CPU-GPU corporation, and may not be a good choice since clip operation needs to be performed on each train step.
The only way I know is access them directly using exact variable names which I can get from TF lower level APIs or TensorBoard, but this is may not be safe since naming rule of Keras is not guaranteed to be stable.
Is there any clean way to perform clip_by_value
only on those W
s with Tensorflow and Keras?
Upvotes: 0
Views: 2970
Reputation: 9099
You can use constraints(here) class to implement new constraints on parameters.
Here is how you can easily implement clip on weights and use it in your model.
from keras.constraints import Constraint
from keras import backend as K
class WeightClip(Constraint):
'''Clips the weights incident to each hidden unit to be inside a range
'''
def __init__(self, c=2):
self.c = c
def __call__(self, p):
return K.clip(p, -self.c, self.c)
def get_config(self):
return {'name': self.__class__.__name__,
'c': self.c}
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
model = Sequential()
model.add(Dense(30, input_dim=100, W_constraint = WeightClip(2)))
model.add(Dense(1))
model.compile(loss='mse', optimizer='rmsprop')
X = np.random.random((1000,100))
Y = np.random.random((1000,1))
model.fit(X,Y)
I have tested the running of the above code, but not the validity of the constraints. You can do so by getting the model weights after training using model.get_weights()
or model.layers[idx].get_weights()
and checking whether its abiding the constraints.
Note: The constrain is not added to all the model weights .. but just to the weights of the specific layer its used and also W_constraint
adds constrain to W
param and b_constraint to b
(bias) param
Upvotes: 4