Kev1n91
Kev1n91

Reputation: 3673

Check if a tensor (scalar) is 0 and if so change its value

I have a tensor defined which then turns into a scalar and is used as a denumerator. I want to catch the error beforehand, that the denumerator gets 0.

Here is my code so far:

#Some stuff hapens before that gives the numbers for denumerator and numerator
#Set denumerator to 1 for example
def check(): return denumerator= tf.constant([1]) 
tf.cond(tf.equal(denumerator,tf.zeros([1]),check)

res = tf.divide(numerator,denumerator)

I looked at the documentation here: https://www.tensorflow.org/api_docs/python/tf/cond. But this only explains how to execute condition-based graphs, and not how I can set a specific value for a graph if a certain condition is met.

Upvotes: 0

Views: 1305

Answers (1)

P-Gn
P-Gn

Reputation: 24581

The answer to your question

denumerator = tf.cond(tf.equal(denumerator, 0), lambda: tf.ones(()), lambda: denumerator)
res = tf.divide(numerator, denumerator)

The answer that you should probably use

Do not compare your divisor with zero -- it is not robust (you may still have problems for tiny values) and the result is not continuous as a function of denumerator, which works usually poorly in an optimization setting.

Use a more proven technique. For example if you know that is it positive, you could max it with 1:

res = tf.divide(numerator, tf.maximum(denumerator, 1))

Upvotes: 1

Related Questions