mirt
mirt

Reputation: 1493

How to combine tf.map_fn and tf.split

So the pseucode of thing i want is:

splitted_outputs = [tf.split(output, rate, axis=0) for output in outputs]

where outputs is Tensor of shape (512, ?, 128), and splitted_outputs is list of lists of Tensors or Tensor with 3 dimensions. So i can iterate such tensor tensorflow.

I've tried to use tf.map_fn:

splitted_outputs = tf.map_fn(
    lambda output: tf.split(output, rate, axis=0),
    outputs,
    dtype=list
)

but it's not possible cause list is not legal tf dtype.

Upvotes: 1

Views: 325

Answers (1)

javidcf
javidcf

Reputation: 59731

You can use tf.unstack on outputs to get a list of "subtensors", then use tf.split on each of those:

splitted_outputs = [tf.split(output, rate, axis=0) for output in tf.unstack(outputs, axis=0)]

Note that tf.unstack can only be used like that when the size of the given axis is known, or otherwise you would need to provide a num parameter.

Upvotes: 1

Related Questions