Sarthak
Sarthak

Reputation: 43

Python: Extract the indices of repeated rows corresponding to the non-zero unique rows in a matrix

For this matrix K=

 [[1.  2.  3.]
 [ 0.  0.  0.]
 [ 4.  5.  6.]
 [ 0.  0.  0.]
 [ 4.  5.  6.]
 [ 0.  0.  0.]]

How to store the list/array of indices of repeated rows corresponding to the non-zero unique rows in a matrix.

In this example:[0,2] are the indices of non-zero unique rows.

Question: How to store this information in a dictionary:

   corresponding value for key 0: [0]
   corresponding value for key 2: [2,4]

Thanks!

Upvotes: 0

Views: 83

Answers (2)

jpp
jpp

Reputation: 164823

Here is one method via collections.defaultdict. It iterates via a for loop with enumerate and uses set to track seen items.

You can easily remove (0, 0, 0) from the dictionary at the end, and rename keys if necessary. The method is O(n) in complexity.

from collections import defaultdict

A = np.array([[ 1,  2,  3],
              [ 0,  0,  0],
              [ 4,  5,  6],
              [ 0,  0,  0],
              [ 4,  5,  6],
              [ 0,  0,  0]])

seen = {(0, 0, 0)}
d = defaultdict(list)

for idx, row in enumerate(map(tuple, A)):
    d[row].append(idx)

Result:

print(d)

defaultdict(list, {(0, 0, 0): [1, 3, 5],
                   (1, 2, 3): [0],
                   (4, 5, 6): [2, 4]})

Upvotes: 2

gahooa
gahooa

Reputation: 137552

Given that your data is in a list of tuples.

data = [
  (1, 2, 3),
  (0, 0, 0),
  (4, 5, 6),
  (0, 0, 0),
  (4, 5, 6),
  (0, 0, 0),
  ]

Edit in response to comments:

Invert the data into a default dict, appending the index onto the list that is attached to each key in the dict.

import collections
output = collections.defaultdict(list)
for i,v in enumerate(data):
  if v == (0,0,0):
    continue
  output[v].append(i)
print(output.values())

Output is:

[[0], [2, 4]]

Original

A simple loop will do. This will

  • ignore (0,0,0)
  • record the index of the first instance of any sequential set of duplicates

It stores the indexes in a set() for performance, but sorts them at the end.

output = set()
lastval = None
lasti = None

for i, val in enumerate(data):
  if val == (0,0,0):
    continue

  if val != lastval:
    lastval = val
    lasti = i

  if lasti not in output:
    output.add(lasti)

print(sorted(output))

Output is

[0, 2]

Upvotes: 1

Related Questions