Reputation: 1904
How can I implement the same operation with tf 1.15?
import torch
B, T, N, K = 2,3,4,2
# a is a counter table where T is the number of groups
a = torch.zeros(T, N, dtype=torch.long)
# x is a batch of data where K elements are selected for each group
x = torch.randint(0, N, (B, T, K))
# the counter should record all data within the batch (per group)
y = x.permute(1,2,0).reshape(T, -1)
# update counter
a.scatter_add(index=y, dim=-1, src=torch.ones_like(y))
Upvotes: 0
Views: 48
Reputation: 46
As mentioned by @mhenning,
You can use tf.scatter_add()
for Tensorflow 1.15
. But, the functions below are deprecated in the latest Tensorflow 2.x
and are only available with Tensorflow 1.x
:
tf.scatter_add()
tf.scatter_div()
tf.scatter_min()
tf.scatter_max()
tf.scatter_mul()
Upvotes: 0