Reputation: 1493
So the pseucode of thing i want is:
splitted_outputs = [tf.split(output, rate, axis=0) for output in outputs]
where outputs is Tensor of shape (512, ?, 128), and splitted_outputs is list of lists of Tensors or Tensor with 3 dimensions. So i can iterate such tensor tensorflow.
I've tried to use tf.map_fn
:
splitted_outputs = tf.map_fn(
lambda output: tf.split(output, rate, axis=0),
outputs,
dtype=list
)
but it's not possible cause list
is not legal tf dtype
.
Upvotes: 1
Views: 325
Reputation: 59731
You can use tf.unstack
on outputs
to get a list of "subtensors", then use tf.split
on each of those:
splitted_outputs = [tf.split(output, rate, axis=0) for output in tf.unstack(outputs, axis=0)]
Note that tf.unstack
can only be used like that when the size of the given axis
is known, or otherwise you would need to provide a num
parameter.
Upvotes: 1