Mathew
Mathew

Reputation: 307

Usage of argmax from tf.nn.max_pool_with_argmax tensorflow

I am trying to use the argmax result of tf.nn.max_pool_with_argmax() to index another tensor. For simplicity, let's say I am trying to implement the following:

output, argmax = tf.nn.max_pool_with_argmax(input, ksize, strides, padding)
tf.assert_equal(input[argmax],output)

Now my question is how do I implement the necessary indexing operation input[argmax] to achieve the desired result? I am guessing this involves some usage of tf.gather_nd() and related calls, but I cannot figure it out. If necessary, we could assume that input has [BatchSize, Height, Width, Channel] dimensions.

Thx for your help!

Mat

Upvotes: 2

Views: 1846

Answers (3)

rishav09
rishav09

Reputation: 19

This small snippet works:

def get_results(data,other_tensor):
    pooled_data, indices = tf.nn.max_pool_with_argmax(data,ksize=[1,ksize,ksize,1],strides=[1,stride,stride,1],padding='VALID',include_batch_in_index=True)
    b,w,h,c = other_tensor.get_shape.as_list()
    other_tensor_pooled = tf.gather(tf.reshape(other_tensor,shape= [b*w*h*c,]),indices)
    return other_tensor_pooled

The above indices can be used to index the tensor. This function actually returns flattened indices and to use it with anything with batch_size > 1 you need to pass include_batch_in_index as True in-order to get proper results. I am assuming here that othertensor you has the same batch size as data.

Upvotes: 1

Tofik Ali
Tofik Ali

Reputation: 1

I am doing it in this way:

def max_pool(input, ksize, strides,padding):
    output, arg_max = tf.nn.max_pool_with_argmax(input=input,ksize=ksize,strides=strides,padding=padding)
    shape=tf.shape(output)
    output1=tf.reshape(tf.gather(tf.reshape(input,[-1]),arg_max),shape)

    err=tf.reduce_sum(tf.square(tf.subtract(output,output1)))
    return output1, err

Upvotes: 0

Mathew
Mathew

Reputation: 307

I found a solution using tf.gather_ndand it works, although it seems not so elegant. I used the function unravel_argmaxthat was posted here.

def unravel_argmax(argmax, shape):
    output_list = []
    output_list.append(argmax // (shape[2] * shape[3]))
    output_list.append(argmax % (shape[2] * shape[3]) // shape[3])
    return tf.stack(output_list)

def max_pool(input, ksize, strides,padding):
    output, arg_max = tf.nn.max_pool_with_argmax(input=input,ksize=ksize,strides=strides,padding=padding)
    shape = input.get_shape()
    arg_max = tf.cast(arg_max,tf.int32)
    unraveld = unravel_argmax(arg_max,shape)
    indices = tf.transpose(unraveld,(1,2,3,4,0))
    channels = shape[-1]
    bs = tf.shape(iv.m)[0]
    t1 = tf.range(channels,dtype=arg_max.dtype)[None, None, None, :, None]
    t2 = tf.tile(t1,multiples=(bs,) + tuple(indices.get_shape()[1:-2]) + (1,1))
    t3 = tf.concat((indices,t2),axis=-1)
    t4 = tf.range(tf.cast(bs, dtype=arg_max.dtype))
    t5 = tf.tile(t4[:,None,None,None,None],(1,) + tuple(indices.get_shape()[1:-2].as_list()) + (channels,1))
    t6 = tf.concat((t5, t3), -1)    
    return tf.gather_nd(input,t6) 

In case anyone has a more elegant solution, I'd still be curious to know.

Mat

Upvotes: 2

Related Questions