Atsushi Sakai
Atsushi Sakai

Reputation: 1014

Getting last elements along an axis in numpy array

I need a function to get last elements along an axis in numpy array.

For example, If I have a array,

a = np.array([1, 2, 3])

The function should work like

get_last_elements(a, axis=0)
>>> [3]
get_last_elements(a, axis=1)
>>> [1, 2, 3]

This function needs to work for multidimensional array too:

b = np.array([[1, 2],
              [3, 4]])

get_last_elements(b, axis=0)
>>> [[2],
     [4]]
get_last_elements(b, axis=1)
>>> [3, 4]

Does anyone have a good idea to achieve it?

Upvotes: 0

Views: 1874

Answers (1)

Ehsan
Ehsan

Reputation: 12407

You can use np.take to get that:

def get_last_elements(a, axis=0):
  shape = list(a.shape)
  shape[axis] = 1
  return np.take(a,-1,axis=axis).reshape(tuple(shape))

output:

print(get_last_elements(b, axis=0))
[[3 4]]

print(get_last_elements(b, axis=1))
[[2]
 [4]]

Upvotes: 1

Related Questions