Diep N.
Diep N.

Reputation: 23

Limit the number of elements in each row of matrix using Numpy Python

I have following numpy matrix:

import numpy as np
matrix = np.array([[1,2,3],[4,5,6]])

and a numpy vector:

vector = np.array([1,2])

where each element in the vector represents for each row of the matrix, the number of elements I want to retain. I would like to replace all other elements in the matrix with 0.

The final matrix should look like:

matrix_output = np.array([[1,0,0],[4,5,0]])

What is the quickest way ?

Upvotes: 1

Views: 299

Answers (2)

hpaulj
hpaulj

Reputation: 231665

try

mask = vector[:,None]<=np.arange(3)
matrix[mask] = 0

Upvotes: 1

flakes
flakes

Reputation: 23674

Could do something simple like this:

import numpy as np

matrix = np.array([[1,2,3],[4,5,6]])
vector = np.array([1,2])

for row, index in enumerate(vector):
    matrix[row, index:] = 0

print(matrix)
[[1 0 0]
 [4 5 0]]

Upvotes: 2

Related Questions