GoodDeeds
GoodDeeds

Reputation: 8507

How can I efficiently map each pixel of a three channel image to one channel?

I am writing a python program to preprocess images to be used as labels for a semantic segmentation task. The original images have three channels, where the vector of three values representing each pixel represent a class label for that pixel. For example, a pixel of [0,0,0] could be class 1, [0,0,255] could be class 2, and so on.

I need to convert these images into a single channel image, with pixel values starting from 0 and increasing serially to represent each class. Essentially, I need to convert [0,0,0] in the old image to 0 in the new image, [0,0,255] to 1, and so on for all the classes.

The images are fairly high resolution, with more than 2000 pixels width and height. I need to do this for hundreds of images. The current approach I have involves iterating over each pixel and replacing the 3-dimensional value with the corresponding scalar value.

filename="file.png"
label_list = [[0,0,0], [0,0,255]] # for example. there are more classes like this
image = imread(filename)
new_image = np.empty((image.shape[0], image.shape[1]))
for i in range(image.shape[0]):
    for j in range(image.shape[1]):
        for k, label in enumerate(label_list):
            if np.array_equal(image[i][j], label):
                new_image[i][j] = k
                break   
imsave("newname.png", new_image)

The problem is that the above program is very inefficient, and takes a few minutes to run for each image. This is too much to handle all my images, and hence I need to improve it.

Firstly, I think it might be possible to remove the innermost loop by converting label_list to a numpy array and using np.where. However, I am not sure how to do a np.where to find a 1-dimensional array inside a two-dimensional array, and whether it would improve anything.

From this thread, I tried to define a function and apply it directly on the image. However, I need to map every 3-dimensional label to a scalar. A dictionary cannot contain a list as a key. Would there be a better way to do this, and would it help?

Is there a way to improve (by a lot) the efficiency, or is there a better way, to do what the above program does?

Thank you.

Upvotes: 2

Views: 1280

Answers (1)

Divakar
Divakar

Reputation: 221534

Approach #1

Here's one approach with views and np.searchsorted -

# https://stackoverflow.com/a/45313353/ @Divakar
def view1D(a, b): # a, b are arrays
    a = np.ascontiguousarray(a)
    b = np.ascontiguousarray(b)
    void_dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1]))
    return a.view(void_dt).ravel(),  b.view(void_dt).ravel()

# Trace back a 2D array back to given labels
def labelrows(a2D, label_list):
    # Reduce array and labels to 1D
    a1D,b1D = view1D(a2D, label_list)

    # Use searchsorted to trace back label indices
    sidx = b1D.argsort()
    return sidx[np.searchsorted(b1D, a1D, sorter=sidx)]

Hence, to use it for a 3D image array, we need to reshape merging the height and width into one dimension and keeping the color channel dim as it is and use the labeling function.

Approach #2

Tuned for image elements that have [0,255] range, we could leverage matrix-multiplication for the dimensionality-reduction and hence boost up the performance further, like so -

def labelpixels(img3D, label_list):
    # scale array
    s = 256**np.arange(img.shape[-1])

    # Reduce image and labels to 1D
    img1D = img.reshape(-1,img.shape[-1]).dot(s)
    label1D = np.dot(label_list, s)

    # Use searchsorted to trace back label indices
    sidx = label1D.argsort()
    return sidx[np.searchsorted(label1D, img1D, sorter=sidx)]

Sample run on how to extend for image case and also verify -

In [194]: label_list = [[0,255,255], [0,0,0], [0,0,255], [255, 0, 255]]

In [195]: idx = [2,0,3,1,0,3,1,2] # We need to retrieve this back

In [196]: img = np.asarray(label_list)[idx].reshape(2,4,3)

In [197]: img
Out[197]: 
array([[[  0,   0, 255],
        [  0, 255, 255],
        [255,   0, 255],
        [  0,   0,   0]],

       [[  0, 255, 255],
        [255,   0, 255],
        [  0,   0,   0],
        [  0,   0, 255]]])

In [198]: labelrows(img.reshape(-1,img.shape[-1]), label_list)
Out[198]: array([2, 0, 3, 1, 0, 3, 1, 2])

In [217]: labelpixels(img, label_list)
Out[217]: array([2, 0, 3, 1, 0, 3, 1, 2])

Finally, the output would need a reshape back to 2D -

In [222]: labelpixels(img, label_list).reshape(img.shape[:-1])
Out[222]: 
array([[2, 0, 3, 1],
       [0, 3, 1, 2]])

Upvotes: 1

Related Questions