Albert
Albert

Reputation: 399

Using tf.map_fn when the function has multiple outputs

I can easily use tf.map_fn when the function has one output:

import tensorflow as tf
tensaki=tf.constant([[1., 2., 3.], [4., 5., 6.]])

def my_fun(x):
    return x[0]

print(tf.map_fn(my_fun,tensaki))

output:

tf.Tensor([1. 4.], shape=(2,), dtype=float32)

But, when the function has two outputs:

def my_fun(x):
    return [x[0],x[1]]

print(tf.map_fn(my_fun,tensaki))

I get an error. Not sure what is going on. I read the information about tf.map_fn in here https://www.tensorflow.org/api_docs/python/tf/map_fn, but not sure how to fix this:

map_fn also supports functions with multi-arity inputs and outputs:

If elems is a tuple (or nested structure) of tensors, then those tensors must all have the same outer-dimension size (num_elems); and fn is used to transform each tuple (or structure) of corresponding slices from elems. E.g., if elems is a tuple (t1, t2, t3), then fn is used to transform each tuple of slices (t1[i], t2[i], t3[i]) (where 0 <= i < num_elems). If fn returns a tuple (or nested structure) of tensors, then the result is formed by stacking corresponding elements from those structures.

Output:

~Users\user2\AppData\Roaming\Python\Python37\site-packages\tensorflow_core\python\util\nest.py in assert_same_structure(nest1, nest2, check_types, expand_composites)
    317     _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types,
--> 318                                            expand_composites)
    319   except (ValueError, TypeError) as e:

ValueError: The two structures don't have the same nested structure.

First structure: type=DType str=<dtype: 'float32'>

Second structure: type=list str=[<tf.Tensor: id=203, shape=(), dtype=float32, numpy=1.0>, <tf.Tensor: id=207, shape=(), dtype=float32, numpy=2.0>]

More specifically: Substructure "type=list str=[<tf.Tensor: id=203, shape=(), dtype=float32, numpy=1.0>, <tf.Tensor: id=207, shape=(), dtype=float32, numpy=2.0>]" is a sequence, while substructure "type=DType str=<dtype: 'float32'>" is not

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
<ipython-input-36-5b11c7fef461> in <module>
      5     return [x[0],x[1]]
      6 
----> 7 print(tf.map_fn(my_fun,tensaki))

~Users\user2\AppData\Roaming\Python\Python37\site-packages\tensorflow_core\python\ops\map_fn.py in map_fn(fn, elems, dtype, parallel_iterations, back_prop, swap_memory, infer_shape, name)
    266         back_prop=back_prop,
    267         swap_memory=swap_memory,
--> 268         maximum_iterations=n)
    269     results_flat = [r.stack() for r in r_a]
    270 

~Users\user2\AppData\Roaming\Python\Python37\site-packages\tensorflow_core\python\ops\control_flow_ops.py in while_loop(cond, body, loop_vars, shape_invariants, parallel_iterations, back_prop, swap_memory, name, maximum_iterations, return_same_structure)
   2712                                               list(loop_vars))
   2713       while cond(*loop_vars):
-> 2714         loop_vars = body(*loop_vars)
   2715         if try_to_pack and not isinstance(loop_vars, (list, _basetuple)):
   2716           packed = True

~Users\user2\AppData\Roaming\Python\Python37\site-packages\tensorflow_core\python\ops\control_flow_ops.py in <lambda>(i, lv)
   2703         cond = lambda i, lv: (  # pylint: disable=g-long-lambda
   2704             math_ops.logical_and(i < maximum_iterations, orig_cond(*lv)))
-> 2705         body = lambda i, lv: (i + 1, orig_body(*lv))
   2706       try_to_pack = False
   2707 

~Users\user2\AppData\Roaming\Python\Python37\site-packages\tensorflow_core\python\ops\map_fn.py in compute(i, tas)
    256       packed_values = input_pack([elem_ta.read(i) for elem_ta in elems_ta])
    257       packed_fn_values = fn(packed_values)
--> 258       nest.assert_same_structure(dtype or elems, packed_fn_values)
    259       flat_fn_values = output_flatten(packed_fn_values)
    260       tas = [ta.write(i, value) for (ta, value) in zip(tas, flat_fn_values)]

~Users\user2\AppData\Roaming\Python\Python37\site-packages\tensorflow_core\python\util\nest.py in assert_same_structure(nest1, nest2, check_types, expand_composites)
    323                   "Entire first structure:\n%s\n"
    324                   "Entire second structure:\n%s"
--> 325                   % (str(e), str1, str2))
    326 
    327 

ValueError: The two structures don't have the same nested structure.

First structure: type=DType str=<dtype: 'float32'>

Second structure: type=list str=[<tf.Tensor: id=203, shape=(), dtype=float32, numpy=1.0>, <tf.Tensor: id=207, shape=(), dtype=float32, numpy=2.0>]

More specifically: Substructure "type=list str=[<tf.Tensor: id=203, shape=(), dtype=float32, numpy=1.0>, <tf.Tensor: id=207, shape=(), dtype=float32, numpy=2.0>]" is a sequence, while substructure "type=DType str=<dtype: 'float32'>" is not
Entire first structure:
.
Entire second structure:
[., .]```

Upvotes: 1

Views: 645

Answers (1)

AloneTogether
AloneTogether

Reputation: 26698

You should make sure you are returning a tensor. Maybe concatenate or stack the list of values:

import tensorflow as tf
tensaki=tf.constant([[1., 2., 3.], [4., 5., 6.]])

def my_fun(x):
    x = tf.stack([x[0], x[1]], axis=0)
    return x

print(tf.map_fn(my_fun,tensaki))
tf.Tensor(
[[1. 2.]
 [4. 5.]], shape=(2, 2), dtype=float32)

Of course, it all depends on the output you are expecting.

Upvotes: 1

Related Questions