Pumpkin
Pumpkin

Reputation: 233

How can I use Tensorflows scatter_nd with slices?

I am trying to optimize only parts of a variable. I found this seemingly useful answer.

However my variable is an image and I want to change only parts of it so I am trying to extend the code to more dimensions. This seems to work fine:

import tensorflow as tf
import tensorflow.contrib.opt as opt

X = tf.Variable([[1.0, 2.0], [3.0, 4.0]])

# the next two lines need to change because
# manually specifying the values is not feasible
indexes = tf.constant([[0, 0], [1, 0]])
updates = [X[0, 0], X[1, 0]]

part_X = tf.scatter_nd(indexes, updates, [2, 2])
X_2 = part_X + tf.stop_gradient(-part_X + X)
Y = tf.constant([[2.5, -3.5], [5.5, -7.5]])
loss = tf.reduce_sum(tf.squared_difference(X_2, Y))
opt = opt.ScipyOptimizerInterface(loss, [X])

init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    opt.minimize(sess)
    print("X: {}".format(X.eval()))

However, since my image dimensions and the area I would like to select are much bigger manually specifying all the indices is not feasible. I would like to know how to use slices or range assignments to do so.

Upvotes: 2

Views: 694

Answers (1)

javidcf
javidcf

Reputation: 59731

You can do that like this:

import tensorflow as tf

# Input with size (50, 100)
X = tf.Variable([[0] * 100] * 50)
# Selected slice
row_start = 10
row_end = 30
col_start = 20
col_end = 50
# Make indices from meshgrid
indexes = tf.meshgrid(tf.range(row_start, row_end),
                      tf.range(col_start, col_end), indexing='ij')
indexes = tf.stack(indexes, axis=-1)
# Take slice
updates = X[row_start:row_end, col_start:col_end]
# Build tensor with "filtered" gradient
part_X = tf.scatter_nd(indexes, updates, tf.shape(X))
X_2 = part_X + tf.stop_gradient(-part_X + X)
# Continue as before...

Upvotes: 4

Related Questions