xuhai
xuhai

Reputation: 77

Loop over a tensor and apply function to each element

I want to loop over a tensor which contains a list of Int, and apply a function to each of the elements. In the function every element will get the value from a dict of python. I have tried the easy way with tf.map_fn, which will work on add function, such as the following code:

import tensorflow as tf

def trans_1(x):
    return x+10

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))
# output: [11 12 13]

But the following code throw the KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32 exception:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))

My tensorflow version is 1.13.1. Thanks ahead.

Upvotes: 6

Views: 4407

Answers (2)

Addy
Addy

Reputation: 1450

There is a simple way to achieve, what you are trying.

The problem is that the function passed to map_fn must have tensors as its parameters and tensor as the return value. However, your function trans_2 takes plain python int as parameter and returns another python int. That's why your code doesn't work.

However, TensorFlow provides a simple way to wrap ordinary python functions, which is tf.py_func, you can use it in your case as follows:

import tensorflow as tf

kv_dict = {1:11, 2:12, 3:13}

def trans_2(x):
    return kv_dict[x]

def wrapper(x):
    return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)

a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:
    res = sess.run(b)
    print(str(res))

you can see I have added a wrapper function, which expects tensor parameter and returns a tensor, that's why it can be used in map_fn. The cast is used because python by default uses 64-bit integers, whereas TensorFlow uses 32-bit integers.

Upvotes: 1

javidcf
javidcf

Reputation: 59691

You cannot use a function like that, because the parameter x is a TensorFlow tensor, not a Python value. So, in order for that to work, you would have to turn your dictionary into a tensor as well, but it's not so simple because keys in the dictionary may not be sequential.

You can instead solve this problem without mapping, but instead doing something similar to what is proposed here for NumPy. In TensorFlow, you could implement it like this:

import tensorflow as tf

def replace_by_dict(x, d):
    # Get keys and values from dictionary
    keys, values = zip(*d.items())
    keys = tf.constant(keys, x.dtype)
    values = tf.constant(values, x.dtype)
    # Make a sequence for the range of values in the input
    v_min = tf.reduce_min(x)
    v_max = tf.reduce_max(x)
    r = tf.range(v_min, v_max + 1)
    r_shape = tf.shape(r)
    # Mask replacements that are out of the input range
    mask = (keys >= v_min) & (keys <= v_max)
    keys = tf.boolean_mask(keys, mask)
    values = tf.boolean_mask(values, mask)
    # Replace values in the sequence with the corresponding replacements
    scatter_idx = tf.expand_dims(keys, 1) - v_min
    replace_mask = tf.scatter_nd(
        scatter_idx, tf.ones_like(values, dtype=tf.bool), r_shape)
    replace_values = tf.scatter_nd(scatter_idx, values, r_shape)
    replacer = tf.where(replace_mask, replace_values, r)
    # Gather the replacement value or the same value if it was not modified
    return tf.gather(replacer, x - v_min)

# Test
kv_dict = {1: 11, 2: 12, 3: 13}
with tf.Graph().as_default(), tf.Session() as sess:
    a = tf.constant([1, 2, 3])
    print(sess.run(replace_by_dict(a, kv_dict)))
    # [11, 12, 13]

This will allow you to have values in the input tensor without replacements (left as they are), and also does not require to have all the replacement values in the tensor. It should be efficient unless the minimum and maximum values in your input are very far away.

Upvotes: 0

Related Questions