namespace-Pt
namespace-Pt

Reputation: 1904

Tensorflow equivalent of torch.scatter_add

How can I implement the same operation with tf 1.15?

import torch

B, T, N, K = 2,3,4,2

# a is a counter table where T is the number of groups
a = torch.zeros(T, N, dtype=torch.long)

# x is a batch of data where K elements are selected for each group
x = torch.randint(0, N, (B, T, K))

# the counter should record all data within the batch (per group)
y = x.permute(1,2,0).reshape(T, -1)

# update counter
a.scatter_add(index=y, dim=-1, src=torch.ones_like(y))

Upvotes: 0

Views: 48

Answers (1)

Mark Z
Mark Z

Reputation: 46

As mentioned by @mhenning,

You can use tf.scatter_add() for Tensorflow 1.15. But, the functions below are deprecated in the latest Tensorflow 2.x and are only available with Tensorflow 1.x:

  • tf.scatter_add()
  • tf.scatter_div()
  • tf.scatter_min()
  • tf.scatter_max()
  • tf.scatter_mul()

Upvotes: 0

Related Questions