Achilles
Achilles

Reputation: 1129

How to replace a value within a tensor by indices?

The below code add something to a specific location within a tensor by indices (thanks to @mrry's answer here).

indices = [[1, 1]]  # A list of coordinates to update.
values = [1.0]  # A list of values corresponding to the respective
            # coordinate in indices.
shape = [3, 3]  # The shape of the corresponding dense tensor, same as `c`.
delta = tf.SparseTensor(indices, values, shape)

For example, given this -

c = tf.constant([[0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0]])

It'll add 1 at [1, 1], resulting in

[[0.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 0.0]])

Question - Is it possible to replace the value at a specific location instead of adding at that location? If it's not possible in tensorflow, is it possible in any other similar libraries?

For example,

Given this -

[[4.0, 43.1.0, 45.0],
[2.0, 22.0, 6664.0],
[-4543.0, 0.0, 43.0]])

Is there a way to replace the 22 at [1, 1] with (say) 45, resulting in the below?

[[4.0, 43.1.0, 45.0],
[2.0, 45.0, 6664.0],
[-4543.0, 0.0, 43.0]])

Upvotes: 7

Views: 7177

Answers (2)

AloneTogether
AloneTogether

Reputation: 26718

A simple option to updating a Tensor based on its own values or indices is using tf.where and tf.tensor_scatter_nd_update:

import tensorflow as tf

x = tf.constant([[4.0, 43.0, 45.0],
                 [2.0, 22.0, 6664.0],
                 [-4543.0, 0.0, 43.0]])
value = 45.0
indices = [1, 1]

by_indices = tf.tensor_scatter_nd_update(x, [indices], [value])
tf.print('Using indices\n', by_indices, '\n')

by_value = tf.where(tf.equal(x, 22.0), value, x)
tf.print('Using value\n', by_value)
Using indices
 [[4 43 45]
 [2 45 6664]
 [-4543 0 43]] 

Using value
 [[4 43 45]
 [2 45 6664]
 [-4543 0 43]]

Upvotes: 4

Nico
Nico

Reputation: 121

This is clunky, but it does replace values in a tensor. It's based on this answer that you mentioned.

# inputs
inputs = tf.placeholder(shape = [None, None], dtype = tf.float32)  # tensor with values to replace
indices = tf.placeholder(shape = [None, 2], dtype = tf.int64)  # coordinates to be updated
values = tf.placeholder(shape = [None], dtype = tf.float32)  # values corresponding to respective coordinates in "indices"

# set elements in "indices" to 0's
maskValues = tf.tile([0.0], [tf.shape(indices)[0]])  # one 0 for each element in "indices"
mask = tf.SparseTensor(indices, maskValues, tf.shape(inputs, out_type = tf.int64))
maskedInput = tf.multiply(inputs, tf.sparse_tensor_to_dense(mask, default_value = 1.0))  # set values in coordinates in "indices" to 0's, leave everything else intact

# replace elements in "indices" with "values"
delta = tf.SparseTensor(indices, values, tf.shape(inputs, out_type = tf.int64))
outputs = tf.add(maskedInput, tf.sparse_tensor_to_dense(delta))  # add "values" to elements in "indices" (which are 0's so far)

What it does:

  1. Set input elements in positions that need to be replaced to 0's.
  2. Add desired values to these 0's (this is straight from here).

Check by running:

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    ins = np.array([[4.0, 43.0, 45.0], [2.0, 22.0, 6664.0], [-4543.0, 0.0, 43.0]])
    ind = [[1, 1]]
    vals = [45]
    outs = sess.run(outputs, feed_dict = { inputs: ins, indices: ind, values: vals })
    print(outs)

Output:

[[ 4.000e+00  4.300e+01  4.500e+01]
 [ 2.000e+00  4.500e+01  6.664e+03]
 [-4.543e+03  0.000e+00  4.300e+01]]

Unlike many otherwise great answers out there, this one works beyond tf.Variable()s.

Upvotes: 1

Related Questions