Yingchao Xiong
Yingchao Xiong

Reputation: 255

How to quantize the values of tf.Variables in Tensorflow

I have a training model like

Y = w * X + b

where Y and X are output and input placeholder, w and b are the vectors
I already know the value of w can only be 0 or 1, while b is still tf.float32.

How could I quantize the range of variable w when I define it?
or
Can I have two different learning rates? The rate for w is 1 or -1 and the rate for b is 0.0001 as usual.

Upvotes: 2

Views: 2135

Answers (2)

Tony Liechty
Tony Liechty

Reputation: 19

One method I have used to limit variables to a particular range is to add a constraint to my loss equation. If the variable goes outside of the desired range, then the loss will get bigger and the optimizer will push it back within the desired range.

For example:

#initialize variable to be between 0 and 1
variable = tf.Variable(tf.random_uniform([self.numOutputs], 0, 1))

#Clip the variable to force the result to be between 0 and 1 during training
variableClipped = tf.clip_by_value(variable, 0, 1)

#Set the loss to be the difference between the clipped variable and actual variable.
#Anytime it goes outside the variable range the loss will increase,
#and the optimizer will push it back within the desired range.
loss =  originalLossEquation + tf.reduce_sum((variable - variableClipped)**2)

Upvotes: 1

Salvador Dali
Salvador Dali

Reputation: 222511

There is no way to limit your variable during the activation. But what you can do is to limit it after each iteration. Here is one way to do this with tf.where():

import tensorflow as tf

a = tf.random_uniform(shape=(3, 3))

b = tf.where(
    tf.less(a, tf.zeros_like(a) + 0.5),
    tf.zeros_like(a),
    tf.ones_like(a)
)

with tf.Session() as sess:
    A, B = sess.run([a, b])
    print A, '\n'
    print B

Which will convert everything above 0.5 to 1 and everything else to 0:

[[ 0.2068541   0.12682056  0.73839438]
 [ 0.00512838  0.43465161  0.98486936]
 [ 0.32126224  0.29998791  0.31065524]] 

[[ 0.  0.  1.]
 [ 0.  0.  1.]
 [ 0.  0.  0.]]

Upvotes: 6

Related Questions