TedM
TedM

Reputation: 1

How to constrain weights in TensorFlow from changing sign?

I've seen several posts about adding simple constraints (i.e. non-negativity) to TensorFlow weight variables, but none about how to prevent a weight from changing signs. For example, if I have W = tf.get_variable('W1', [512, 256], initializer=initializer, dtype=tf.float32) how do I add a constraint such that after initialization W[i,j] cannot change sign? I don't see a clear way to use the "constraint" option in tf.get_variable().

Upvotes: 0

Views: 402

Answers (1)

Abhishek Mishra
Abhishek Mishra

Reputation: 1994

My approach to solve this problem will be as following.

For each weight, you store the initial sign. This can be done using the following code

w1 = tf.get_variable('W1', [512, 256], initializer=initializer, dtype=tf.float32)
w1_sign = tf.zeros_like(w1)
store_sign = tf.assign(w1_sign, tf.sign(w1))

You can use the following code to make the weights 0 whenever they break the sign constraints.

constraint_op = tf.assign(w1, tf.where(w1_sign * w1 >= 0, w1, 0))

Now you can run the above code as follows

sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(store_sign)
for _ in range(train_itr):
    sess.run(some_train_op)
    sess.run(constraint_op)

Note that in the above code, you are running the op store_sign only once and you are running the op constraint_op after each run of train_op.

The same idea can be applied with constraints argument of tf.get_variable.

Upvotes: 1

Related Questions