Maryam Sadeghi
Maryam Sadeghi

Reputation: 159

How to make this loop faster?

I want my image to have only 10 specific colors, specified in color_list. So I loop through every pixel and if the color of that pixel is not included in the color list, I assign the color of the neighboring region. But since the images are 2k by 2k pixels. This loop takes 3minutes or so. I'm sure my way of doing this is not optimal. How can I optimize my way of doing this?

atlas_img_marked, atlas_img_cleaned = clean_img_pixels(atlas_img, color_list)

def clean_img_pixels(atlas_img, color_list):
    dd = 3
    for ii in range(atlas_img.shape[0]-1):
        for jj in range(atlas_img.shape[1]-1):
            pixelcolor = (atlas_img[ii,jj,0],atlas_img[ii,jj,1],atlas_img[ii,jj,2])
            if pixelcolor not in color_list:
                pixel2color = (atlas_img[ii-dd,jj,0],atlas_img[ii-dd,jj,1],atlas_img[ii-dd,jj,2])
                if (pixel2color == (0,0,0)) | (pixel2color not in color_list):
                    pixel2color = (atlas_img[ii+dd,jj,0],atlas_img[ii+dd,jj,1],atlas_img[ii+dd,jj,2])
                    if (pixel2color == (0,0,0)) | (pixel2color not in color_list):
                        pixel2color = (atlas_img[ii+5,jj,0],atlas_img[ii+5,jj,1],atlas_img[ii+5,jj,2])
                atlas_img_cleaned[ii,jj] = pixel2color
    return atlas_img_cleaned

To be more precise, here is the part which takes the longest:

out_colors = []
for ii in range(atlas_img.shape[0]-1):
    for jj in range(atlas_img.shape[1]-1):
        pixelcolor = (atlas_img[ii,jj,0],atlas_img[ii,jj,1],atlas_img[ii,jj,2])
        if pixelcolor not in color_list:
            out_colors.append((ii,jj))

takes 177 seconds

Tried it in this way:

out_colors = [(ii,jj) for (ii,jj) in itertools.product(range(atlas_img.shape[0]), range(atlas_img.shape[1])) if (atlas_img[ii,jj,0],atlas_img[ii,jj,1],atlas_img[ii,jj,2]) not in color_list]

But doesn't make much of a difference. takes 173 seconds

This is the color list:

color_list = [(52, 26, 75), (9, 165, 216), (245, 34, 208), (146, 185, 85), (251, 6, 217), (223, 144, 239), (190, 224, 121), (252, 26, 157), (150, 130, 142), (51, 129, 172), (97, 85, 204), (1, 108, 233), (138, 201, 180), (210, 63, 175), (26, 138, 43), (216, 141, 61), (38, 89, 118), (0, 0, 0)]

Here is an example image enter image description here

Upvotes: 5

Views: 125

Answers (2)

Jeremy Whitcher
Jeremy Whitcher

Reputation: 721

Here's what I came up with based on the problem statement, Thomas Jungblut's answer and the answer here.

The algorithm performs a single pixel lookup and results in an image limited to colors in COLOR_LIST.

from PIL import Image
from datetime import datetime
from math import sqrt

COLOR_LIST = {(52, 26, 75), (9, 165, 216), (245, 34, 208), (146, 185, 85), (251, 6, 217), (223, 144, 239),
              (190, 224, 121), (252, 26, 157), (150, 130, 142), (51, 129, 172), (97, 85, 204), (1, 108, 233),
              (138, 201, 180), (210, 63, 175), (26, 138, 43), (216, 141, 61), (38, 89, 118), (0, 0, 0)}
COLOR_CACHE = {}

def closest_color(rgb, color_list):
    if rgb not in COLOR_CACHE:
        r, g, b = rgb
        color_diffs = []
        for color in color_list:
            cr, cg, cb = color
            color_diff = sqrt(abs(r - cr)**2 + abs(g - cg)**2 + abs(b - cb)**2)
            color_diffs.append((color_diff, color))
        COLOR_CACHE[rgb] = min(color_diffs)[1]
    return COLOR_CACHE[rgb]

def clean_img_pixels(atlas_img, color_list):
    atlas_img_cleaned = atlas_img.copy()
    pixels = atlas_img_cleaned.load()
    for ii in range(atlas_img.size[0] - 1):
        for jj in range(atlas_img.size[1] - 1):
            pixel = atlas_img.getpixel((ii, jj))
            if pixel not in color_list:
                pixels[ii, jj] = closest_color(pixel, color_list)
    return atlas_img_cleaned

im = Image.open('7y1JG.png')
im = im.convert('RGB')
start_time = datetime.now()
om = clean_img_pixels(im, COLOR_LIST)
print('Time elapsed (hh:mm:ss.ms) {}'.format(datetime.now() - start_time))
om.save('7y1JG-clean.png', "PNG")

# Time elapsed (hh:mm:ss.ms) 0:00:02.932316

Upvotes: 0

Thomas Jungblut
Thomas Jungblut

Reputation: 20969

If you ditch numpy altogether and directly operate with Pillow arrays and use the tuple set instead of a list, it's much faster (for me this executes in 5s on your example picture):

from PIL import Image
from datetime import datetime

im = Image.open('7y1JG.png')
im = im.convert('RGB')

color_list = {(52, 26, 75), (9, 165, 216), (245, 34, 208), (146, 185, 85), (251, 6, 217), (223, 144, 239),
              (190, 224, 121), (252, 26, 157), (150, 130, 142), (51, 129, 172), (97, 85, 204), (1, 108, 233),
              (138, 201, 180), (210, 63, 175), (26, 138, 43), (216, 141, 61), (38, 89, 118), (0, 0, 0)}


def clean_img_pixels(atlas_img, color_list):
    atlas_img_cleaned = atlas_img.copy().load()
    dd = 3
    for ii in range(atlas_img.size[0] - 1):
        for jj in range(atlas_img.size[1] - 1):
            if atlas_img.getpixel((ii, jj)) not in color_list:
                pixel2_color = atlas_img.getpixel((ii - dd, jj))
                if (pixel2_color == (0, 0, 0)) | (pixel2_color not in color_list):
                    pixel2_color = atlas_img.getpixel((ii + dd, jj))
                    if (pixel2_color == (0, 0, 0)) | (pixel2_color not in color_list):
                        pixel2_color = atlas_img.getpixel((ii + 5, jj))
                atlas_img_cleaned[ii, jj] = pixel2_color
    return atlas_img_cleaned


start_time = datetime.now()

out_image = clean_img_pixels(im, color_list)
time_elapsed = datetime.now() - start_time
print('Time elapsed (hh:mm:ss.ms) {}'.format(time_elapsed))

I'd still advise you to do some additional boundary checking, it just so happens to run because of the way your image is laid out.

Upvotes: 2

Related Questions