Hypnoz
Hypnoz

Reputation: 1136

How to get batched reduce_sum for different range of a big matrix?

import tensorflow as tf

tf.enable_eager_execution()

emb = tf.ones([100,16])
start_pos = tf.constant([1,2])
end_pos = tf.constant([11,31])

By providing a big matrix emb, and start position start_pos and end position end_pos. How to get the reduce_sim of different range of emb (e.g. the result should be in shape (2, 16) which first row is the sum of from 1th row to 11th row of emb and the second row is sum of from 2th row to 31th row of emb)?

Note I tried to use GPU (tf.py_func works but it is on CPU)

Updated: I have one solution but it is not based on Matrix. I use tf.while_loop to loop over each pos in start_pos / end_pos to calculate.

Upvotes: 1

Views: 63

Answers (1)

javidcf
javidcf

Reputation: 59731

EDIT:

Actually it is not so hard to do it better in a vectorized way. It takes more memory, but it should be much faster:

import tensorflow as tf

tf.enable_eager_execution()

emb = tf.ones([100,16])
start_pos = tf.constant([1,2])
end_pos = tf.constant([11,31])
# Select indices of each slice
r = tf.range(tf.shape(emb)[0])[tf.newaxis]
m = (r >= start_pos[:, tf.newaxis]) & (r <= end_pos[:, tf.newaxis])
# Broadcast and tile original matrix
s = tf.cast(m, emb.dtype)[..., tf.newaxis] * emb[tf.newaxis]
# Compute sums
result = tf.reduce_sum(s, axis=1)

Unfortunately I don't think there is any way to extract multiple slices to be summed in a single operation. If the number of slices is fixed one possibility would be to do it in a regular Python loop

import tensorflow as tf

tf.enable_eager_execution()

emb = tf.ones([100,16])
start_pos = tf.constant([1,2])
end_pos = tf.constant([11,31])
batch_size = 2

result = []
for i in range(batch_size):
    result.append(tf.reduce_sum(emb[start_pos[i]:start_pos[i] + 1], axis=0))
result = tf.stack(result, axis=0)

If the number of slices is only know at graph execution time, or if it is too big and you do not want to have that many nodes in the graph, you can use tf.while_loop:

import tensorflow as tf

tf.enable_eager_execution()

emb = tf.ones([100,16])
start_pos = tf.constant([1,2])
end_pos = tf.constant([11,31])
batch_size = 2

result = tf.TensorArray(emb.dtype, batch_size)
_, result = tf.while_loop(lambda i, _: i < batch_size,
                          lambda i, result: (i + 1, result.write(i, tf.reduce_sum(emb[start_pos[i]:start_pos[i] + 1], axis=0))),
                          [0, result])
result = result.stack()
result = tf.stack(result, axis=0)

Upvotes: 2

Related Questions