Reputation: 2168
Let's say I have a (None, 2)-shape tensor indices
, and (None,)-shape tensor values
. These actual row # and values will be determined at runtime.
I would like to set a 4x5 tensor t
that each element of indices has values of values. I found that I can use tf.scatter_nd like this:
t = tf.scatter_np(indices, values, [4, 5])
# E.g., indices = [[1,2],[2,3]], values = [100, 200]
# t[1,2] <-- 100; t[2,3] <-- 200
My problem is that: when indices has duplicates, the values will be accumulated.
# E.g., indices = [[1,2],[1,2]], values = [100, 200]
# t[1,2] <-- 300
I would like to assign only one, i,e, either ignorance (so, the first value) or overwriting (so, the last value).
I feel like I need to check duplicates in indices, or I need to use tensorflow loop. Could anyone please advise? (hopefully a minimal example code?)
Upvotes: 3
Views: 2539
Reputation: 21
Apply tf.scatter_nd also to a matrix of ones. That will give you the number of elements that are accumulated and you can just divide your result by that to get an average. (But watch out for zeros, for those you should divide by one).
counter = tf.ones(tf.shape(values))
t = tf.scatter_nd(indices,values,shape)
t_counter = tf.scatter_nd(indices,counter,shape)
Then divide t by t_counter (but only where t_counter is not zero).
Upvotes: 2
Reputation: 5813
This may not be the best solution, I used tf.unsorted_segment_max
to avoid accumulation
with tf.Session() as sess:
# #########
# Examples:
# ##########
width, height, depth = [3, 3, 2]
indices = tf.cast([[0, 1, 0], [0, 1, 0], [1, 1, 1]], tf.int32)
values = tf.cast([1, 2, 3], tf.int32)
# ########################
# Filter duplicated indices
# #########################
flatten = tf.matmul(indices, [[height * depth], [depth], [1]])
filtered, idx = tf.unique(tf.squeeze(flatten))
# #####################
# Obtain updated result
# #####################
def reverse(index):
"""Map from 1-D to 3-D """
x = index / (height * depth)
y = (index - x * height * depth) / depth
z = index - x * height * depth - y * depth
return tf.stack([x, y, z], -1)
# This will pick the maximum value instead of accumulating the result
updated_values = tf.unsorted_segment_max(values, idx, tf.shape(filtered_idx)[0])
updated_indices = tf.map_fn(fn=lambda i: reverse(i), elems=filtered)
# Now you can scatter_nd without accumulation
result = tf.scatter_nd(updated_indices,
updated_values,
tf.TensorShape([3, 3, 2]))
Upvotes: 1
Reputation: 27050
You can use tf.unique
: the only issue is that this op requires a 1D tensor.
Thus, to overcome this I decided to use the Cantor pairing function.
In short, it exists a bijective function that maps a tuple (in this case a pair of values, but it works for any N-dimensional tuple) to a single value.
Once the coordinates have been reduced to a 1-D tensor of scalar, then tf.unique
can be used to find the indices of the unique numbers.
The Cantor pairing function is invertible, thus now we know not only the indices of the non-repeated values within the 1-D tensor, but we can also go back to the 2-D space of the coordinates and use scatter_nd
to perform the update without the problem of the accumulator.
import tensorflow as tf
import numpy as np
# Dummy values
indices = np.array([[1, 2], [2, 3]])
values = np.array([100, 200])
# Placeholders
indices_ = tf.placeholder(tf.int32, shape=(2, 2))
values_ = tf.placeholder(tf.float32, shape=(2))
# Use the Cantor tuple to create a one-to-one correspondence between the coordinates
# and a single value
x = tf.cast(indices_[:, 0], tf.float32)
y = tf.cast(indices_[:, 1], tf.float32)
z = (x + y) * (x + y + 1) / 2 + y # shape = (2)
# Collect unique indices, treated as single values
# Drop the indices position into z because are useless
unique_cantor, _ = tf.unique(z)
# Go back from cantor numbers to pairs of values
w = tf.floor((tf.sqrt(8 * unique_cantor + 1) - 1) / 2)
t = (tf.pow(w, 2) + w) / 2
y = z - t
x = w - y
# Recreate a batch of coordinates that are uniques
unique_indices = tf.cast(tf.stack([x, y], axis=1), tf.int32)
# Update without accumulator
go = tf.scatter_nd(unique_indices, values_, [4, 5])
with tf.Session() as sess:
print(sess.run(go, feed_dict={indices_: indices, values_: values}))
Upvotes: 3