Programmer
Programmer

Reputation: 185

Why does dim=1 return row indices in torch.argmax?

I am working on argmax function of PyTorch which is defined as:

torch.argmax(input, dim=None, keepdim=False)

Consider an example

a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))

Here when I use dim=1 instead of searching column vectors, the function searches for row vectors as shown below.

print(a) :   
tensor([[-1.7739,  0.8073,  0.0472, -0.4084],  
        [ 0.6378,  0.6575, -1.2970, -0.0625],  
        [ 1.7970, -1.3463,  0.9011, -0.8704],  
        [ 1.5639,  0.7123,  0.0385,  1.8410]])  

print(torch.argmax(a, dim=1))  
tensor([1, 1, 0, 3])

As far as my assumption goes dim = 0 represents rows and dim =1 represent columns.

Upvotes: 15

Views: 12758

Answers (3)

Jon
Jon

Reputation: 794

    |
    v
  dim-0  ---> -----> dim-1 ------> -----> --------> dim-1
    |   [[-1.7739,  0.8073,  0.0472, -0.4084],
    v    [ 0.6378,  0.6575, -1.2970, -0.0625],
    |    [ 1.7970, -1.3463,  0.9011, -0.8704],
    v    [ 1.5639,  0.7123,  0.0385,  1.8410]]
    |
    v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])

Taking this one step further,for those who come in later: this devolves into 4 different comparisons:when DIM is 1

Comparison for ind=0: argmax([-1.7739,  0.8073,  0.0472, -0.4084]) = 1
Comparison for ind=1: argmax([ 0.6378,  0.6575, -1.2970, -0.0625]) = 1
Comparison for ind=2: argmax([ 1.7970, -1.3463,  0.9011, -0.8704]) = 0
Comparison for ind=3: argmax([ 1.5639,  0.7123,  0.0385,  1.8410]) = 3

This therefore answers for each row, get me what column has the max value

When Dim is 0:

Comparison for ind=0: argmax([-1.7739,0.6378,1.7970,1.5639]) = 2
Comparison for ind=1: argmax([0.8073,0.6575,-1.3463,0.7123]) = 0
Comparison for ind=2: argmax([0.0472,-1.2970,0.9011,0.0385]) = 2
Comparison for ind=3: argmax([-0.4084,-0.0625,-0.8704,1.8410]) = 3

Which answers the question for each column, get me what row has the maximum value

When you get multi-dimensional (i.e > 2 dims/axis) it will look across all remaining dimensions to find largest one.

Upvotes: 1

Ismet Sahin
Ismet Sahin

Reputation: 163

Dimensions are defined as shown in the above excellent answer. I have highlighted the way I understand dimensions in Torch and Numpy (dim and axis respectively) and hope that this will be helpful to others.

Notice that only the specified dimension’s index varies during the argmax operation, and the specified dimension’s index range reduces to a single index once the operation is completed. Let tensor A have M rows and N columns and consider the sum operation for simplicity. The shape of A is (M, N). If dim=0 is specified, then the vectors A[0,:], A[1,:], ..., A[M-1,:] are summed elementwise and the result is another tensor with 1 row and N columns. Notice that only the 0th dimension’s indices vary from 0 throughout M-1. Similarly, If dim=1 is specified, then the vectors A[:,0], A[:,1], ..., A[:,N-1] are summed elementwise and the result is another tensor with M rows and 1 column.

An example is given below:

>>> A = torch.tensor([[1,2,3], [4,5,6]])
>>> A
tensor([[1, 2, 3],
        [4, 5, 6]])
>>> S0 = torch.sum(A, dim = 0)
>>> S0
tensor([5, 7, 9])
>>> S1 = torch.sum(A, dim = 1)
>>> S1
tensor([ 6, 15])

In the above sample code, the first sum operation specifies dim=0, therefore A[0,:] and A[1,:], which are [1,2,3] and [4,5,6], are summed and resulted in [5, 7, 9]. When dim=1 was specified, the vectors A[:,0], A[:,1], and A[:2], which are the vectors [1, 4], [2, 5], and [3, 6], are elementwise added to find [6, 15].

Note also that the specified dimension collapses. Again let A have the shape (M, N). If dim=0, then the result will have the shape (1, N), where dimension 0 is reduced from M to 1. Similarly if dim=1, then the result would have the shape (M, 1), where N is reduced to 1. Note also that shapes (1, N) and (M,1) are represented by a single-dimensional tensor with N and M elements respectively.

Upvotes: 3

kmario23
kmario23

Reputation: 61325

It's time to correctly understand how the axis or dim argument work in PyTorch:

tensor dimension


The following example should make sense once you comprehend the above picture:

    |
    v
  dim-0  ---> -----> dim-1 ------> -----> --------> dim-1
    |   [[-1.7739,  0.8073,  0.0472, -0.4084],
    v    [ 0.6378,  0.6575, -1.2970, -0.0625],
    |    [ 1.7970, -1.3463,  0.9011, -0.8704],
    v    [ 1.5639,  0.7123,  0.0385,  1.8410]]
    |
    v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])

Note: dim (short for 'dimension') is the torch equivalent of 'axis' in NumPy.

Upvotes: 26

Related Questions