James Nguyen
James Nguyen

Reputation: 1171

How does NumPy Sum (with axis) work?

I've taken it upon myself to learn how NumPy works for my own curiosity.

It seems that the simplest function is the hardest to translate to code (I understand by code). It's easy to hard code each axis for each case but I want to find a dynamic algorithm that can sum in any axis with n-dimensions. The documentation on the official website is not helpful (It only shows the result not the process) and it's hard to navigate through Python/C code.

Note: I did figure out that when an array is summed, the axis specified is "removed", i.e. Sum of an array with a shape of (4, 3, 2) with axis 1 yields an answer of an array with a shape of (4, 2)

Upvotes: 25

Views: 39747

Answers (3)

Faitus Joseph
Faitus Joseph

Reputation: 59

Assume that our array has 2 rows and 3 columns

import numpy as np
a = np.array([[1,2,3],[3,4,6]])

print(a.shape)
#prints:(2, 3) This array has 2 rows and 3 columns

Below are the 3 different possibilities:

print(np.sum(a)) #computes sum of all the elements; prints: 19
print(np.sum(a, axis= 0)) #computes sum of all the column; prints: [4 6 9]
print(np.sum(a, axis= 1)) #computes sum of all the rows; prints: [6 13]

Upvotes: 0

tinyhare
tinyhare

Reputation: 2401

I use a nested loop operation to explain it.

import numpy as np

n = np.array(
[[[1, 2, 3],
 [4, 5, 6],
 [7, 8, 9]],

 [[2, 4, 6],
 [8, 10, 12],
 [14, 16, 18]],

 [[1, 3, 5],
 [7, 9, 11],
 [13, 15, 17]]])

print(n)

print("============ sum axis=None=============")

sum = 0
for i in range(3):
  for j in range(3): 
    for k in range(3):
      sum += n[k][i][j]
print(sum) # 216

print('------------------')
print(np.sum(n))  # 216
print("============ sum axis=0 =============") 
for i in range(3):
  for j in range(3):
    sum = 0
    for axis in range(3):
      sum += n[axis][i][j]
    print(sum,end=' ')
  print()

print('------------------')
print("sum[0][0] = %d" % (n[0][0][0] + n[1][0][0] + n[2][0][0]))
print("sum[1][1] = %d" % (n[0][1][1] + n[1][1][1] + n[2][1][1]))
print("sum[2][2] = %d" % (n[0][2][2] + n[1][2][2] + n[2][2][2]))
print('------------------')
print(np.sum(n, axis=0)) 
print("============ sum axis=1 =============") 
for i in range(3):
  for j in range(3):
    sum = 0
    for axis in range(3):
      sum += n[i][axis][j]
    print(sum,end=' ')
  print()
print('------------------')
print("sum[0][0] = %d" % (n[0][0][0] + n[0][1][0] + n[0][2][0]))
print("sum[1][1] = %d" % (n[1][0][1] + n[1][1][1] + n[1][2][1]))
print("sum[2][2] = %d" % (n[2][0][2] + n[2][1][2] + n[2][2][2]))
print('------------------')
print(np.sum(n, axis=1))  
print("============ sum axis=2 =============") 
for i in range(3):
  for j in range(3):
    sum = 0
    for axis in range(3):
      sum += n[i][j][axis]
    print(sum,end=' ')
  print()
print('------------------')
print("sum[0][0] = %d" % (n[0][0][0] + n[0][0][1] + n[0][0][2]))
print("sum[1][1] = %d" % (n[1][1][0] + n[1][1][1] + n[1][1][2]))
print("sum[2][2] = %d" % (n[2][2][0] + n[2][2][1] + n[2][2][2]))
print('------------------')
print(np.sum(n, axis=2))
print("============ sum axis=(0,1)) =============") 
for i in range(3):
  sum = 0
  for axis1 in range(3):   
    for axis2 in range(3):
      sum += n[axis1][axis2][i]
  print(sum,end=' ')

print()
print('------------------')
print("sum[1] = %d" % (n[0][0][1] + n[0][1][1] + n[0][2][1] +
              n[1][0][1] + n[1][1][1] + n[1][2][1] +
              n[2][0][1] + n[2][1][1] + n[2][2][1] ))
print('------------------')
print(np.sum(n, axis=(0,1)))

result:

[[[ 1  2  3]
  [ 4  5  6]
  [ 7  8  9]]

 [[ 2  4  6]
  [ 8 10 12]
  [14 16 18]]

 [[ 1  3  5]
  [ 7  9 11]
  [13 15 17]]]
============ sum axis=None=============
216
------------------
216
============ sum axis=0 =============
4 9 14 
19 24 29 
34 39 44 
------------------
sum[0][0] = 4
sum[1][1] = 24
sum[2][2] = 44
------------------
[[ 4  9 14]
 [19 24 29]
 [34 39 44]]
============ sum axis=1 =============
12 15 18 
24 30 36 
21 27 33 
------------------
sum[0][0] = 12
sum[1][1] = 30
sum[2][2] = 33
------------------
[[12 15 18]
 [24 30 36]
 [21 27 33]]
============ sum axis=2 =============
6 15 24 
12 30 48 
9 27 45 
------------------
sum[0][0] = 6
sum[1][1] = 30
sum[2][2] = 45
------------------
[[ 6 15 24]
 [12 30 48]
 [ 9 27 45]]
============ sum axis=(0,1)) =============
57 72 87 
------------------
sum[1] = 72
------------------
[57 72 87]

Upvotes: 4

piRSquared
piRSquared

Reputation: 294228

Setup

consider the numpy array a

a = np.arange(30).reshape(2, 3, 5)
print(a)

[[[ 0  1  2  3  4]
  [ 5  6  7  8  9]
  [10 11 12 13 14]]

 [[15 16 17 18 19]
  [20 21 22 23 24]
  [25 26 27 28 29]]]

Where are the dimensions?

The dimensions and positions are highlighted by the following

            p  p  p  p  p
            o  o  o  o  o
            s  s  s  s  s

     dim 2  0  1  2  3  4

            |  |  |  |  |
  dim 0     ↓  ↓  ↓  ↓  ↓
  ----> [[[ 0  1  2  3  4]   <---- dim 1, pos 0
  pos 0   [ 5  6  7  8  9]   <---- dim 1, pos 1
          [10 11 12 13 14]]  <---- dim 1, pos 2
  dim 0
  ---->  [[15 16 17 18 19]   <---- dim 1, pos 0
  pos 1   [20 21 22 23 24]   <---- dim 1, pos 1
          [25 26 27 28 29]]] <---- dim 1, pos 2
            ↑  ↑  ↑  ↑  ↑
            |  |  |  |  |

     dim 2  p  p  p  p  p
            o  o  o  o  o
            s  s  s  s  s

            0  1  2  3  4

Dimension examples:

This becomes more clear with a few examples

a[0, :, :] # dim 0, pos 0

[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]]

a[:, 1, :] # dim 1, pos 1

[[ 5  6  7  8  9]
 [20 21 22 23 24]]

a[:, :, 3] # dim 2, pos 3

[[ 3  8 13]
 [18 23 28]]

sum

explanation of sum and axis
a.sum(0) is the sum of all slices along dim 0

a.sum(0)

[[15 17 19 21 23]
 [25 27 29 31 33]
 [35 37 39 41 43]]

same as

a[0, :, :] + \
a[1, :, :]

[[15 17 19 21 23]
 [25 27 29 31 33]
 [35 37 39 41 43]]

a.sum(1) is the sum of all slices along dim 1

a.sum(1)

[[15 18 21 24 27]
 [60 63 66 69 72]]

same as

a[:, 0, :] + \
a[:, 1, :] + \
a[:, 2, :]

[[15 18 21 24 27]
 [60 63 66 69 72]]

a.sum(2) is the sum of all slices along dim 2

a.sum(2)

[[ 10  35  60]
 [ 85 110 135]]

same as

a[:, :, 0] + \
a[:, :, 1] + \
a[:, :, 2] + \
a[:, :, 3] + \
a[:, :, 4]

[[ 10  35  60]
 [ 85 110 135]]

default axis is -1
this means all axes. or sum all numbers.

a.sum()

435

Upvotes: 94

Related Questions