Reputation: 733
I have a 3d tensor A
of shape [n, ?, m]
having one non-zero element along the third axis. For example
A[0,0,:] = [0,0,0,1,0,0]
A[1,0,:] = [0,0,1,0,0,0]
A[0,1,:] = [0,1,0,0,0,0]
A[1,1,:] = [1,0,0,0,0,0]
I have a weight tensor w
of shape (1,)
.
I want to dilate the tensor A
by the weight w
such that I can transform the tensor A
as below
A[0,0,:] = [0,0,w,1,w,0]
A[1,0,:] = [0,w,1,w,0,0]
A[0,1,:] = [w,1,w,0,0,0]
A[1,1,:] = [1,w,0,0,0,w]
Please note that the weight w
is added adjacent to the nonzero element 1
and if it is at the border then we wrap the indexes around.
How can I do that using tensorflow
in python.
Upvotes: 3
Views: 176
Reputation: 59681
EDIT:
Here is a more general version that works for a padding vector with more than one element:
import tensorflow as tf
def surround_nonzero(a, w):
# Find non-zero positions
idx = tf.where(tf.not_equal(a, 0))
# A vector to shift the last value in the indices
w_len = tf.shape(w, out_type=tf.int64)[0]
shift1 = tf.concat([tf.zeros(tf.shape(idx)[-1] - 1, dtype=tf.int64), [1]], axis=0)
shift_len = shift1 * tf.expand_dims(tf.range(1, w_len + 1), 1)
# Shift last value of indices using module to wrap around
a_shape = tf.shape(a, out_type=tf.int64)
d = a_shape[-1]
idx_exp = tf.expand_dims(idx, 1)
idx_prev_exp = (idx_exp - shift_len) % d
idx_next_exp = (idx_exp + shift_len) % d
# Reshape shifted indices
a_rank = tf.rank(a)
idx_prev = tf.reshape(idx_prev_exp, [-1, a_rank])
idx_next = tf.reshape(idx_next_exp, [-1, a_rank])
# Take non-zero values
nonzero = tf.gather_nd(a, idx)
# Tile wrapping value twice the number of non-zero values
n = tf.shape(nonzero)[0]
w2n = tf.tile(w, [2 * n])
# Make full index and values for scattering with non-zero values and wrapping value
idx_full = tf.concat([idx, idx_prev, idx_next], axis=0)
values_full = tf.concat([nonzero, w2n], axis=0)
# Make output tensor with scattering
return tf.scatter_nd(idx_full, values_full, a_shape)
# Test
with tf.Graph().as_default():
A = tf.constant([[[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0]]],
dtype=tf.int32)
w = tf.constant([2, 3, 4], dtype=tf.int32)
out = surround_nonzero(A, w)
with tf.Session() as sess:
print(sess.run(out))
Output:
[[[4 0 4 3 2 1 2 3]
[3 2 1 2 3 4 0 4]]
[[0 4 3 2 1 2 3 4]
[1 2 3 4 0 4 3 2]]]
As before, this assumes that the padding always "fits", and the behavior in cases where padding values would overlap is not guaranteed.
Here is a way to do that using tf.scatter_nd
:
import tensorflow as tf
def surround_nonzero(a, w):
# Find non-zero positions
idx = tf.where(tf.not_equal(a, 0))
# A vector to shift the last value in the indices by one
shift1 = tf.concat([tf.zeros(tf.shape(idx)[-1] - 1, dtype=tf.int64), [1]], axis=0)
# Shift last value of indices using module to wrap around
a_shape = tf.shape(a, out_type=tf.int64)
d = a_shape[-1]
idx_prev = (idx - shift1) % d
idx_next = (idx + shift1) % d
# Take non-zero values
nonzero = tf.gather_nd(a, idx)
# Tile wrapping value twice the number of non-zero values
n = tf.shape(nonzero)[0]
w2n = tf.tile(w, [2 * n])
# Make full index and values for scattering with non-zero values and wrapping value
idx_full = tf.concat([idx, idx_prev, idx_next], axis=0)
values_full = tf.concat([nonzero, w2n], axis=0)
# Make output tensor with scattering
return tf.scatter_nd(idx_full, values_full, a_shape)
# Test
with tf.Graph().as_default():
A = tf.constant([[[0, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 0, 0]],
[[0, 0, 1, 0, 0, 0],
[1, 0, 0, 0, 0, 0]]],
dtype=tf.int32)
w = tf.constant([2], dtype=tf.int32)
out = surround_nonzero(A, w)
with tf.Session() as sess:
print(sess.run(out))
Output:
[[[0 0 2 1 2 0]
[2 1 2 0 0 0]]
[[0 2 1 2 0 0]
[1 2 0 0 0 2]]]
Note this assumes each non-zero value is surrounded by zeros (as is your case). Otherwise, the scatter operation would find duplicate indices and the output would not be deterministic.
Upvotes: 1