Reputation: 1
I've seen several posts about adding simple constraints (i.e. non-negativity) to TensorFlow weight variables, but none about how to prevent a weight from changing signs. For example, if I have
W = tf.get_variable('W1', [512, 256], initializer=initializer, dtype=tf.float32)
how do I add a constraint such that after initialization W[i,j]
cannot change sign? I don't see a clear way to use the "constraint" option in tf.get_variable().
Upvotes: 0
Views: 402
Reputation: 1994
My approach to solve this problem will be as following.
For each weight, you store the initial sign. This can be done using the following code
w1 = tf.get_variable('W1', [512, 256], initializer=initializer, dtype=tf.float32)
w1_sign = tf.zeros_like(w1)
store_sign = tf.assign(w1_sign, tf.sign(w1))
You can use the following code to make the weights 0 whenever they break the sign constraints.
constraint_op = tf.assign(w1, tf.where(w1_sign * w1 >= 0, w1, 0))
Now you can run the above code as follows
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(store_sign)
for _ in range(train_itr):
sess.run(some_train_op)
sess.run(constraint_op)
Note that in the above code, you are running the op store_sign
only once and you are running the op constraint_op
after each run of train_op
.
The same idea can be applied with constraints
argument of tf.get_variable
.
Upvotes: 1