Reputation: 3673
I have a tensor defined which then turns into a scalar and is used as a denumerator. I want to catch the error beforehand, that the denumerator gets 0.
Here is my code so far:
#Some stuff hapens before that gives the numbers for denumerator and numerator
#Set denumerator to 1 for example
def check(): return denumerator= tf.constant([1])
tf.cond(tf.equal(denumerator,tf.zeros([1]),check)
res = tf.divide(numerator,denumerator)
I looked at the documentation here: https://www.tensorflow.org/api_docs/python/tf/cond. But this only explains how to execute condition-based graphs, and not how I can set a specific value for a graph if a certain condition is met.
Upvotes: 0
Views: 1305
Reputation: 24581
The answer to your question
denumerator = tf.cond(tf.equal(denumerator, 0), lambda: tf.ones(()), lambda: denumerator)
res = tf.divide(numerator, denumerator)
The answer that you should probably use
Do not compare your divisor with zero -- it is not robust (you may still have problems for tiny values) and the result is not continuous as a function of denumerator
, which works usually poorly in an optimization setting.
Use a more proven technique. For example if you know that is it positive, you could max it with 1:
res = tf.divide(numerator, tf.maximum(denumerator, 1))
Upvotes: 1