Rasoul
Rasoul

Reputation: 73

Limit the digits after decimal point (rounding) without breaking the gradient in TensorFlow

I am training a Tensorflow model and need high precision in the output. My output format is:

U = X1.Y1Y2Y3Y4Y5Y6
V = X1.Y1Y2Y3Y4Y5Y6

where X1 is the digit before the decimal point and Y1,.., Y6 are the digits after the decimal point. Obviously, the round operation cannot be used because it breaks the gradient. I came up with the following idea:

U = tf.cast(tf.cast(U,'float16'),'float32')
W = U+1e-4*V

In this way, the different digits could be controlled by different TensorFlow variables and the training should be more efficient. I was expecting to get an output as:

U= X1.Y1Y2Y3000

with Y4=Y5=Y6=0. However, the digits Y4, Y5, and Y6 got random values.

My questions:

  1. Is such a behavior expected in upconversion of flaot16 to float32?
  2. Can I modify the tf.cast behavior?

Python code:

x = tf.constant(1.222222222222222222222,'float32')
print(x.numpy())
x_ = tf.cast(tf.cast(x,'float16'),'float32')
print(x_.numpy())

Output:

1.2222222
1.2226562

Upvotes: 1

Views: 768

Answers (1)

Pascal Getreuer
Pascal Getreuer

Reputation: 3256

Converting to lower bit depth, like casting float32 to float16, is effectively rounding in base 2, replacing lower bits with zeros. This isn't the same as rounding in base 10; it won't necessarily replace lower base-10 decimal digits with zeros.

Assuming "base-2 rounding" is enough, TensorFlow's "fake_quant" ops are useful for this purpose, for instance tf.quantization.fake_quant_with_min_max_args. They simulate the effect of converting to lower bit-depth, yet are differentiable. The Post-training quantization guide might also be helpful.

Another thought: if you need to hack the gradient of something, check out the utilities tf.custom_gradient and tf.stop_gradient.

Upvotes: 1

Related Questions