carpediem
carpediem

Reputation: 437

Using tf.cond with a condition comparing two tensors

I'm trying to code the following function:

enter image description here

I've already implemented it in python using numpy but need to recode it into tensorflow:

input_array = np.array([0.2, 2.1, 4.5, 6.7, 8.1, 10.0])

def f_funct(input_array):
    arr = np.ones(len(input_array))
    delta_vec = np.multiply(arr,6.0/29.0)
    res = np.where(input_array>np.power(delta_vec,3), np.cbrt(input_array), 
             input_array/(3*np.power(delta_vec,2)) + 4.0/29.0)
    return res

print(f_funct(input_array))

I've tried this:

# convert input from np.array to tensor
def np_array_to_tensor(input_array):
    input_array = tf.convert_to_tensor(input_array, dtype=tf.float64)
    return input_array

input_tensor = np_array_to_tensor(tf.constant(input_array)) 
arr = tf.ones([tf.size(input_tensor)], tf.float64)
delta_vec = tf.math.multiply(arr,tf.math.divide(6.0,29.0))

res1 = tf.math.divide(4.0,29.0) 
res2 = tf.divide(input_tensor,(3*tf.pow(delta_vec,2)))
    
res = tf.cond(input_tensor>tf.pow(delta_vec,3), 
             lambda: tf.pow(input_tensor,1.0/3.0), 
             lambda: tf.add(res2,res1))


with tf.Session() as sess:  
    print(input_tensor.eval())
    print(arr.eval())
    print(delta_vec.eval())
    print(res1)
    print(res2.eval())
    print(res.eval())

But this throws an error because the condition must use scalars only.

ValueError: Shape must be rank 0 but is rank 1 for 'cond_13/Switch' (op: 'Switch') with input shapes: [6], [6].

As my input array is rather large, I would prefer sticking to arrays/tensors. Would it be possible to do this using a tensor based on a condition?

I'd appreciate any suggestions. Thank you.

Upvotes: 1

Views: 924

Answers (1)

Lescurel
Lescurel

Reputation: 11651

What you want is not tf.cond, but tf.where(condition, x, y), that returns the elements, either from x or y, depending on the condition exactly as in . (The broadcasting rules of tf.where in TF 1.x are a bit different from numpy, so you might prefer tf.where_v2 if it's available in your tensorflow 1.x version).

res = tf.where(input_tensor>tf.pow(delta_vec,3),
               tf.pow(input_tensor,1.0/3.0),
               tf.add(res2,res1))

Running inside a session with your input_tensor:

>>> sess.run(res)
array([0.58480355, 1.28057916, 1.65096362, 1.88520363, 2.00829885,
       2.15443469])

Upvotes: 1

Related Questions