mr blick
mr blick

Reputation: 139

Index of list within a numpy array

So I'm writing a code that uses exact diagonalization to study the Lieb-Liniger model. The first step is building a numpy array containing lists that describe particle occupations. The array would look something like

array([[2, 0, 0],
   [1, 1, 0],
   [1, 0, 1],
   [0, 2, 0],
   [0, 1, 1],
   [0, 0, 2]])

for the case of 2 particles in 3 modes. My question is, is it possible to get the index of a particular list in this array, similar to how you would get an index in a regular list with the index function. For instance, with a list of lists, A, i was able to use A.index(some_list_in_A) to get the index of that list, but I have tried using numpy.where(HS=[2,0,0]) to get the index of [2,0,0] (and so on), but to no avail. For large numbers of particles and modes, I'm looking for an efficient way to obtain these indices, and I figured using numpy arrays were quite efficient, but I have just hit this block and have not found a solution to it. Any suggestions?

Upvotes: 1

Views: 529

Answers (2)

hpaulj
hpaulj

Reputation: 231665

Here are several ways of doing this lookup:

In [36]: A=np.array([[2,0,0],[1,1,0],[1,0,1],[0,2,0],[0,1,1],[0,0,2]])
In [37]: pattern = [0,2,0]

In [38]: np.where(np.all(pattern==A,1))  # Saullo's where
Out[38]: (array([3]),)

In [39]: A.tolist().index(pattern)  # your list find
Out[39]: 3

In [40]: D={tuple(a):i for i,a in enumerate(A.tolist())}  # dictionary
In [41]: D[tuple(pattern)]
Out[41]: 3

I am using tuples as the dictionary keys - a tuple is an immutable list.

For this small size, the dictionary approach is fastest, especially if the dictionary can be built once and used repeatedly. Even if constructed on the fly it is faster than the np.where. But you should test it with more realistic sizes.

Python dictionaries are tuned for speed, since they are fundamental to the language's operation.

The pieces in the np.where are all fast, using compiled code. But still, it has to compare all the elements of A with the pattern. There's a lot more work than the dictionary hash lookup.

Upvotes: 1

Saullo G. P. Castro
Saullo G. P. Castro

Reputation: 58985

You can use np.where() doing:

pattern = [2,0,0]
index = np.where(np.all(a==pattern, axis=1))[0]

Upvotes: 5

Related Questions