Reputation: 77
I want to loop over a tensor which contains a list of Int
, and apply a function to each of the elements.
In the function every element will get the value from a dict of python.
I have tried the easy way with tf.map_fn
, which will work on add
function, such as the following code:
import tensorflow as tf
def trans_1(x):
return x+10
a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_1, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
# output: [11 12 13]
But the following code throw the KeyError: tf.Tensor'map_8/while/TensorArrayReadV3:0' shape=() dtype=int32
exception:
import tensorflow as tf
kv_dict = {1:11, 2:12, 3:13}
def trans_2(x):
return kv_dict[x]
a = tf.constant([1, 2, 3])
b = tf.map_fn(trans_2, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
My tensorflow version is 1.13.1
. Thanks ahead.
Upvotes: 6
Views: 4407
Reputation: 1450
There is a simple way to achieve, what you are trying.
The problem is that the function passed to map_fn
must have tensors as its parameters and tensor as the return value. However, your function trans_2
takes plain python int
as parameter and returns another python int
. That's why your code doesn't work.
However, TensorFlow provides a simple way to wrap ordinary python functions, which is tf.py_func
, you can use it in your case as follows:
import tensorflow as tf
kv_dict = {1:11, 2:12, 3:13}
def trans_2(x):
return kv_dict[x]
def wrapper(x):
return tf.cast(tf.py_func(trans_2, [x], tf.int64), tf.int32)
a = tf.constant([1, 2, 3])
b = tf.map_fn(wrapper, a)
with tf.Session() as sess:
res = sess.run(b)
print(str(res))
you can see I have added a wrapper function, which expects tensor parameter and returns a tensor, that's why it can be used in map_fn. The cast is used because python by default uses 64-bit integers, whereas TensorFlow uses 32-bit integers.
Upvotes: 1
Reputation: 59691
You cannot use a function like that, because the parameter x
is a TensorFlow tensor, not a Python value. So, in order for that to work, you would have to turn your dictionary into a tensor as well, but it's not so simple because keys in the dictionary may not be sequential.
You can instead solve this problem without mapping, but instead doing something similar to what is proposed here for NumPy. In TensorFlow, you could implement it like this:
import tensorflow as tf
def replace_by_dict(x, d):
# Get keys and values from dictionary
keys, values = zip(*d.items())
keys = tf.constant(keys, x.dtype)
values = tf.constant(values, x.dtype)
# Make a sequence for the range of values in the input
v_min = tf.reduce_min(x)
v_max = tf.reduce_max(x)
r = tf.range(v_min, v_max + 1)
r_shape = tf.shape(r)
# Mask replacements that are out of the input range
mask = (keys >= v_min) & (keys <= v_max)
keys = tf.boolean_mask(keys, mask)
values = tf.boolean_mask(values, mask)
# Replace values in the sequence with the corresponding replacements
scatter_idx = tf.expand_dims(keys, 1) - v_min
replace_mask = tf.scatter_nd(
scatter_idx, tf.ones_like(values, dtype=tf.bool), r_shape)
replace_values = tf.scatter_nd(scatter_idx, values, r_shape)
replacer = tf.where(replace_mask, replace_values, r)
# Gather the replacement value or the same value if it was not modified
return tf.gather(replacer, x - v_min)
# Test
kv_dict = {1: 11, 2: 12, 3: 13}
with tf.Graph().as_default(), tf.Session() as sess:
a = tf.constant([1, 2, 3])
print(sess.run(replace_by_dict(a, kv_dict)))
# [11, 12, 13]
This will allow you to have values in the input tensor without replacements (left as they are), and also does not require to have all the replacement values in the tensor. It should be efficient unless the minimum and maximum values in your input are very far away.
Upvotes: 0