Gael
Gael

Reputation: 13

How to process different row in tensor based on the first column value in tensorflow

let's say I have a 4 by 3 tensor:

sample = [[10, 15, 25], [1, 2, 3], [4, 4, 10], [5, 9, 8]]

I would like to return another tensor of shape 4: [r1,r2,r3,r4] where r is either equal to tf.reduce_sum(row) if row[0] is less than 5, or r is equal to tf.reduce_mean(row) if row[0] is greater or equal to 5. output:

output = [16.67, 6, 18, 7.33]

I'm not an adept to tensorflow, please do assist me on how to achieve the above in python 3 without a for loop. thank you

UPDATES:

So I've tried to adapt the answer given by @Onyambu to include two samples in the functions but it gave me an error in all instances. here is the answer for the first case:

def f(x):
    c = tf.constant(5,tf.float32)
    def fun1():
        return tf.reduce_sum(x)
    def fun2():
        return tf.reduce_mean(x)
    return tf.cond(tf.less(x[0],c),fun1,fun2)
a = tf.map_fn(f,tf.constant(sample,tf.float32))

The above works well.

The for two samples:

sample1 = [[10, 15, 25], [1, 2, 3], [4, 4, 10], [5, 9, 8]]
sample2 = [[0, 15, 25], [1, 2, 3], [0, 4, 10], [1, 9, 8]]

def f2(x1,x2):
    c = tf.constant(1,tf.float32)
    def fun1():
        return tf.reduce_sum(x1[:,0] - x2[:,0])
    def fun2():
        return tf.reduce_mean(x1 - x2)
    return tf.cond(tf.less(x2[0],c),fun1,fun2)
a = tf.map_fn(f2,tf.constant(sample1,tf.float32), tf.constant(sample2,tf.float32))

The adaptation does give errors, but the principle is simple:

Thank you for your assistance in advance!

Upvotes: 1

Views: 162

Answers (1)

Onyambu
Onyambu

Reputation: 79218

import tensorflow as tf
def f(x):
    y = tf.constant(5,tf.float32)
    def fun1():
        return tf.reduce_sum(x)
    def fun2():
        return tf.reduce_mean(x)
    return tf.cond(tf.less(x[0],y),fun1,fun2)

a = tf.map_fn(f,tf.constant(sample,tf.float32))

with tf.Session() as sess: print(sess.run(a))

 [16.666666   6.        18.         7.3333335]

If you want to shorten it:

y = tf.constant(5,tf.float32)
f=lambda x: tf.cond(tf.less(x[0], y), lambda: tf.reduce_sum(x),lambda: tf.reduce_mean(x))

a = tf.map_fn(f,tf.constant(sample,tf.float32))
with tf.Session() as sess: print(sess.run(a))

Upvotes: 1

Related Questions