Reputation: 185
I am working on argmax
function of PyTorch which is defined as:
torch.argmax(input, dim=None, keepdim=False)
Consider an example
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
Here when I use dim=1 instead of searching column vectors, the function searches for row vectors as shown below.
print(a) :
tensor([[-1.7739, 0.8073, 0.0472, -0.4084],
[ 0.6378, 0.6575, -1.2970, -0.0625],
[ 1.7970, -1.3463, 0.9011, -0.8704],
[ 1.5639, 0.7123, 0.0385, 1.8410]])
print(torch.argmax(a, dim=1))
tensor([1, 1, 0, 3])
As far as my assumption goes dim = 0 represents rows and dim =1 represent columns.
Upvotes: 15
Views: 12758
Reputation: 794
|
v
dim-0 ---> -----> dim-1 ------> -----> --------> dim-1
| [[-1.7739, 0.8073, 0.0472, -0.4084],
v [ 0.6378, 0.6575, -1.2970, -0.0625],
| [ 1.7970, -1.3463, 0.9011, -0.8704],
v [ 1.5639, 0.7123, 0.0385, 1.8410]]
|
v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])
Taking this one step further,for those who come in later:
this devolves into 4 different comparisons:when DIM is 1
Comparison for ind=0: argmax([-1.7739, 0.8073, 0.0472, -0.4084]) = 1
Comparison for ind=1: argmax([ 0.6378, 0.6575, -1.2970, -0.0625]) = 1
Comparison for ind=2: argmax([ 1.7970, -1.3463, 0.9011, -0.8704]) = 0
Comparison for ind=3: argmax([ 1.5639, 0.7123, 0.0385, 1.8410]) = 3
This therefore answers for each row, get me what column has the max value
When Dim is 0
:
Comparison for ind=0: argmax([-1.7739,0.6378,1.7970,1.5639]) = 2
Comparison for ind=1: argmax([0.8073,0.6575,-1.3463,0.7123]) = 0
Comparison for ind=2: argmax([0.0472,-1.2970,0.9011,0.0385]) = 2
Comparison for ind=3: argmax([-0.4084,-0.0625,-0.8704,1.8410]) = 3
Which answers the question for each column, get me what row has the maximum value
When you get multi-dimensional (i.e > 2 dims/axis) it will look across all remaining dimensions to find largest one.
Upvotes: 1
Reputation: 163
Dimensions are defined as shown in the above excellent answer. I have highlighted the way I understand dimensions in Torch and Numpy (dim and axis respectively) and hope that this will be helpful to others.
Notice that only the specified dimension’s index varies during the argmax operation, and the specified dimension’s index range reduces to a single index once the operation is completed. Let tensor A have M rows and N columns and consider the sum operation for simplicity. The shape of A is (M, N). If dim=0 is specified, then the vectors A[0,:]
, A[1,:]
, ..., A[M-1,:]
are summed elementwise and the result is another tensor with 1 row and N columns. Notice that only the 0th dimension’s indices vary from 0 throughout M-1. Similarly, If dim=1 is specified, then the vectors A[:,0]
, A[:,1]
, ..., A[:,N-1]
are summed elementwise and the result is another tensor with M rows and 1 column.
An example is given below:
>>> A = torch.tensor([[1,2,3], [4,5,6]])
>>> A
tensor([[1, 2, 3],
[4, 5, 6]])
>>> S0 = torch.sum(A, dim = 0)
>>> S0
tensor([5, 7, 9])
>>> S1 = torch.sum(A, dim = 1)
>>> S1
tensor([ 6, 15])
In the above sample code, the first sum operation specifies dim=0, therefore A[0,:]
and A[1,:]
, which are [1,2,3]
and [4,5,6]
, are summed and resulted in [5, 7, 9]
. When dim=1 was specified, the vectors A[:,0]
, A[:,1]
, and A[:2]
, which are the vectors [1, 4]
, [2, 5]
, and [3, 6]
, are elementwise added to find [6, 15]
.
Note also that the specified dimension collapses. Again let A have the shape (M, N)
. If dim=0, then the result will have the shape (1, N)
, where dimension 0 is reduced from M to 1. Similarly if dim=1, then the result would have the shape (M, 1)
, where N is reduced to 1. Note also that shapes (1, N) and (M,1) are represented by a single-dimensional tensor with N and M elements respectively.
Upvotes: 3
Reputation: 61325
It's time to correctly understand how the axis
or dim
argument work in PyTorch:
The following example should make sense once you comprehend the above picture:
| v dim-0 ---> -----> dim-1 ------> -----> --------> dim-1 | [[-1.7739, 0.8073, 0.0472, -0.4084], v [ 0.6378, 0.6575, -1.2970, -0.0625], | [ 1.7970, -1.3463, 0.9011, -0.8704], v [ 1.5639, 0.7123, 0.0385, 1.8410]] | v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])
Note: dim
(short for 'dimension') is the torch equivalent of 'axis' in NumPy.
Upvotes: 26