walkerlala
walkerlala

Reputation: 1680

Tensorflow get dimension info when using tf.map_fn() to iterate into a tensor

Say I have a tensor ts of shape [s1, s2, s3], and I want to iterate into it with tf.map_fn as:

tf.map_fn(lambda dim1:
    tf.map_fn(lambda dim2:
        do_sth(dim, idx1, idx2)
    ,dim)
,ts)

The idx1 and idx2 above are index of dimension 0 and dimension 1 of ts that do_sth() is currently in. How can I get those? I want to get that as if I am doing something like:

for idx1 in range(s1):
    for idx2 in range(s2):
        tensor = ts[idx1][idx2]
        do_sth(tensor, idx1, idx2)

The reason why I cannot do it in that way is that most of the time s1, s2, s3 are unknown (i.e., ts is of shape (?, ?, t3) or similar)

Is that possible?

Upvotes: 0

Views: 363

Answers (1)

Peter Szoldan
Peter Szoldan

Reputation: 4868

I'd suggest you add an extra dimension at the end and fill it with the indices before running your command.

This code (tested) adds the indices to the values multiplied by 10 and 100 for X and Y indices respectively:

from __future__ import print_function
import tensorflow as tf
import numpy as np

r = 3

a = tf.reshape( tf.constant( range( r * r ) ), ( r, r ) )
x = tf.tile( tf.cast( tf.lin_space( 0.0, r - 1, r )[ None, : ], tf.int32 ), [ r, 1 ] )
y = tf.tile( tf.cast( tf.lin_space( 0.0, r - 1, r )[ :, None ], tf.int32 ), [ 1, r ] )

b = tf.stack( [ a, x, y ], axis = -1 )

c = tf.map_fn( lambda y: tf.map_fn( lambda x: 
        x[ 0 ] + 10 * x[ 1 ] + 100 * x[ 2 ]
    , y ), b )

with tf.Session() as sess:
    res = sess.run( [ c ] )
    for x in res:
        print()
        print( x )

Outputs:

[[ 0 11 22]
[103 114 125]
[206 217 228]]

Upvotes: 1

Related Questions