Reputation: 437
I'm trying to code the following function:
I've already implemented it in python using numpy but need to recode it into tensorflow:
input_array = np.array([0.2, 2.1, 4.5, 6.7, 8.1, 10.0])
def f_funct(input_array):
arr = np.ones(len(input_array))
delta_vec = np.multiply(arr,6.0/29.0)
res = np.where(input_array>np.power(delta_vec,3), np.cbrt(input_array),
input_array/(3*np.power(delta_vec,2)) + 4.0/29.0)
return res
print(f_funct(input_array))
I've tried this:
# convert input from np.array to tensor
def np_array_to_tensor(input_array):
input_array = tf.convert_to_tensor(input_array, dtype=tf.float64)
return input_array
input_tensor = np_array_to_tensor(tf.constant(input_array))
arr = tf.ones([tf.size(input_tensor)], tf.float64)
delta_vec = tf.math.multiply(arr,tf.math.divide(6.0,29.0))
res1 = tf.math.divide(4.0,29.0)
res2 = tf.divide(input_tensor,(3*tf.pow(delta_vec,2)))
res = tf.cond(input_tensor>tf.pow(delta_vec,3),
lambda: tf.pow(input_tensor,1.0/3.0),
lambda: tf.add(res2,res1))
with tf.Session() as sess:
print(input_tensor.eval())
print(arr.eval())
print(delta_vec.eval())
print(res1)
print(res2.eval())
print(res.eval())
But this throws an error because the condition must use scalars only.
ValueError: Shape must be rank 0 but is rank 1 for 'cond_13/Switch' (op: 'Switch') with input shapes: [6], [6].
As my input array is rather large, I would prefer sticking to arrays/tensors. Would it be possible to do this using a tensor based on a condition?
I'd appreciate any suggestions. Thank you.
Upvotes: 1
Views: 924
Reputation: 11651
What you want is not tf.cond
, but tf.where(condition, x, y)
, that returns the elements, either from x
or y
, depending on the condition
exactly as in numpy. (The broadcasting rules of tf.where
in TF 1.x are a bit different from numpy, so you might prefer tf.where_v2
if it's available in your tensorflow 1.x version).
res = tf.where(input_tensor>tf.pow(delta_vec,3),
tf.pow(input_tensor,1.0/3.0),
tf.add(res2,res1))
Running inside a session with your input_tensor:
>>> sess.run(res)
array([0.58480355, 1.28057916, 1.65096362, 1.88520363, 2.00829885,
2.15443469])
Upvotes: 1