T. Holmström
T. Holmström

Reputation: 163

How to use tf.gather_nd for multi-dimensional tensor

I don't fully understand how I should use tf.gather_nd() to pick up elements along some axis if I have multi-dimensional tensor. Let's take a small example (if I get answer for this simple example, it solves also my more complex original problem). Let's say that I have rgb image and I am trying to pick the smallest pixel value along channels (last dimension if data order is (B,H,W,C)). I know that this can be done with tf.recude_min(x, axis=-1) but I would like to know that is it also possible to do the same thing with tf.argmin() and tf.gather_nd()?

from skimage import data
import tensorflow as tf
import numpy as np

# Load RGB image from skimage, cast it to float32 and put it in order (B,H,W,C)
image = data.astronaut()
image = tf.cast(image, tf.float32)
image = tf.expand_dims(image, axis=0)

# Take minimum pixel value of each channel in a way number 1
min_along_channels_1 = tf.reduce_min(image, axis=-1)

# Take minimum pixel value of each channel in a way number 2
# The goal is that min_along_channels_1 is equal to min_along_channels_2
idxs = tf.argmin(image, axis=-1)
min_along_channels_2 = tf.gather_nd(image, idxs) # This line gives error :(

Upvotes: 1

Views: 525

Answers (1)

AloneTogether
AloneTogether

Reputation: 26718

You will have to use tf.meshgrid, which will create a rectangular grid of two one-dimensional arrays representing the tensor indexing of the first and second dimension, since tf.gather_nd needs to know exactly where to extract values across the dimensions. Here is a simplified example:

import tensorflow as tf

image = tf.random.normal((1, 4, 4, 3))
image = tf.squeeze(image, axis=0)
idx = tf.argmin(image, axis=-1)

ij = tf.stack(tf.meshgrid(
    tf.range(image.shape[0], dtype=tf.int64), 
    tf.range(image.shape[1], dtype=tf.int64),
                              indexing='ij'), axis=-1)

gather_indices = tf.concat([ij, tf.expand_dims(idx, axis=-1)], axis=-1)
result = tf.gather_nd(image, gather_indices)

print('First option -->', tf.reduce_min(image, axis=-1))
print('Second option -->', result)
First option --> tf.Tensor(
[[-0.53245485 -0.29117298 -0.64434254 -0.8209638 ]
 [-0.9386176  -0.5993224  -0.597746   -1.5392851 ]
 [-0.5478666  -1.5280861  -1.0344954  -1.920418  ]
 [-0.5580688  -1.425873   -1.9276617  -1.0668412 ]], shape=(4, 4), dtype=float32)
Second option --> tf.Tensor(
[[-0.53245485 -0.29117298 -0.64434254 -0.8209638 ]
 [-0.9386176  -0.5993224  -0.597746   -1.5392851 ]
 [-0.5478666  -1.5280861  -1.0344954  -1.920418  ]
 [-0.5580688  -1.425873   -1.9276617  -1.0668412 ]], shape=(4, 4), dtype=float32)

Or with your example:

from skimage import data
import tensorflow as tf
import numpy as np

image = data.astronaut()
image = tf.cast(image, tf.float32)
image = tf.expand_dims(image, axis=0)

min_along_channels_1 = tf.reduce_min(image, axis=-1)

image = tf.squeeze(image, axis=0)
idx = tf.argmin(image, axis=-1)

ij = tf.stack(tf.meshgrid(
    tf.range(image.shape[0], dtype=tf.int64), 
    tf.range(image.shape[1], dtype=tf.int64),
                              indexing='ij'), axis=-1)

gather_indices = tf.concat([ij, tf.expand_dims(idx, axis=-1)], axis=-1)
min_along_channels_2 = tf.gather_nd(image, gather_indices)

print(tf.equal(min_along_channels_1, min_along_channels_2))
tf.Tensor(
[[[ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  ...
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]
  [ True  True  True ...  True  True  True]]], shape=(1, 512, 512), dtype=bool)

Upvotes: 1

Related Questions