Python: A Fast way to erase equal/similar numpy arrays from a batch

I have a batch of 28x28 numpy arrays corresponding to grayscale images and I would like to erase similar ones (Despite it looks primitive I mean, that fulfills np.sum(np.abs(arrayA-arrayB))<50 or something like that). Does any one know something better than loops to do this.

It was easy to erase repeated images with np.unique, however this one is tough to me. Thanks a lot

Upvotes: 0

Views: 35

Answers (1)

Nils Werner
Nils Werner

Reputation: 36775

Given your dataset doesn't grow too large, one very helpful tool for this kind of task is broadcasting. It allows you to do these "outer" operations, like taking the difference of every element with every other element, creating a similarity matrix:

To get an intuition how this works, let's look at the 1D case:

import numpy as np

data = numpy.random.rand(12) * 256
# Make sure we have some similar elements in `data`
data[0] = data[3]
data[7] = data[10]

diff = np.abs(data[None, :] - data[:, None])
diff.shape
# (12, 12)

To get a feeling for what is happening, take a look at the output:

plt.imshow(diff)
plt.show()

Now that we understand how we can leverage broadcasting, let's adapt it to your 3D case:

data = np.random.rand(12, 28, 28) * 256
# Make sure we have some similar elements in `data`
data[0, ...] = data[3, ...]
data[7, ...] = data[10, ...]

diff = np.abs(data[None, ...] - data[:, None, ...])
diff.shape
# (12, 12, 28, 28)

As you can see, we got a tensor that contains the difference of each pixel with the same pixel in every other tile. To get the sum of this difference, do the summation over the last two axes

diff = np.sum(diff, axis=(-1, -2))
diff.shape
# (12, 12)

And again, take a look what is happening:

plt.imshow(diff)
plt.show()

To now find the duplicate elements, we can use your condition:

diff = diff < 5

But be aware, that the main diagonal will be all True values now (each tile is compared with itself, and the difference is obviously 0). So let's set those to False:

np.fill_diagonal(diff, False)

Just as a sanity check, now lets search for True values:

np.where(diff)
# (array([ 0,  3,  7, 10]), array([ 3,  0, 10,  7]))

Alright, those values seem reasonable.

To now get a boolean column or row mask from our diff array, let's search rowwise for any True values:

mask = np.any(diff, axis=0)
# array([ True, False, False,  True, False, False, False,  True, False, False,  True, False])

And use this mask to filter data. This will drop columns 0, 3, 7, 10

data = data[~mask, ...]
data.shape
# (8, 28, 28)

If instead you want to keep one of the duplicates, search for True only in the upper or lower triangle of diff, and keep the rest:

mask = np.any(np.triu(diff), axis=0)
# array([False, False, False,  True, False, False, False, False, False, False,  True, False])

This will drop columns 3, 10

data = data[~mask, ...]
data.shape
# (10, 28, 28)

Upvotes: 2

Related Questions