Reputation: 4041
I have two vectors: time and event. If one event is 1, the time at the same index should be assigned to func_for_event1
. Otherwise, it goes to func_for_event0
.
import tensorflow as tf
def func_for_event1(t):
return t + 1
def func_for_event0(t):
return t - 1
time = tf.placeholder(tf.float32, shape=[None]) # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None]) # [0, 1, 1, 0, 1]
# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.
How should I implement this logic in Tensorflow? Say tf.cond
or tf.where
?
Upvotes: 0
Views: 201
Reputation: 4868
This is exactly what tf.where()
is for. This code (tested):
import tensorflow as tf
import numpy as np
def func_for_event1(t):
return t + 1
def func_for_event0(t):
return t - 1
time = tf.placeholder(tf.float32, shape=[None]) # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None]) # [0, 1, 1, 0, 1]
result = tf.where( tf.equal( 1, event ), func_for_event1( time ), func_for_event0( time ) )
# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.
with tf.Session() as sess:
res = sess.run( result, feed_dict = {
time : np.array( [3.2, 4.2, 1.0, 1.05, 1.8] ),
event : np.array( [0, 1, 1, 0, 1] )
} )
print ( res )
outputs:
[2.2 5.2 2. 0.04999995 2.8 ]
as desired.
Upvotes: 1