Munichong
Munichong

Reputation: 4041

Calculate values in a vector based on the elements in another vector in Tensorflow

I have two vectors: time and event. If one event is 1, the time at the same index should be assigned to func_for_event1. Otherwise, it goes to func_for_event0.

import tensorflow as tf

def func_for_event1(t):
    return t + 1

def func_for_event0(t):
    return t - 1

time = tf.placeholder(tf.float32, shape=[None])  # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None])  # [0, 1, 1, 0, 1]

# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.

How should I implement this logic in Tensorflow? Say tf.cond or tf.where?

Upvotes: 0

Views: 201

Answers (1)

Peter Szoldan
Peter Szoldan

Reputation: 4868

This is exactly what tf.where() is for. This code (tested):

import tensorflow as tf
import numpy as np

def func_for_event1(t):
    return t + 1

def func_for_event0(t):
    return t - 1

time = tf.placeholder(tf.float32, shape=[None])  # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None])  # [0, 1, 1, 0, 1]

result = tf.where( tf.equal( 1, event ), func_for_event1( time ), func_for_event0( time ) )
# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.

with tf.Session() as sess:
    res = sess.run( result, feed_dict = {
        time : np.array( [3.2, 4.2, 1.0, 1.05, 1.8] ),
        event : np.array( [0, 1, 1, 0, 1] )
    } )
    print ( res )

outputs:

[2.2 5.2 2. 0.04999995 2.8 ]

as desired.

Upvotes: 1

Related Questions