Gilfoyle
Gilfoyle

Reputation: 3616

Tensorflow: Convert image to rgb if grayscale

I have a dataset of rgb and grayscale images. While iterating over the dataset, I want to detect if the image is a grayscale image such that I can convert it to rgb. I wanted to use tf.shape(image) to detect the dimensions of the image. For a rgb image I get something like [1, 100, 100, 3]. For grayscale images the function returns for example [1, 100, 100]. I wanted to use len(tf.shape(image)) to detect if it is of length 4 (=rgb) or length 3 (=grayscale). That did not work.

This is my code so far which did not work:

def process_image(image):
    # Convert numpy array to tensor
    image = tf.convert_to_tensor(image, dtype=tf.uint8)
    # Take care of grayscale images
    dims = len(tf.shape(image))
    if dims == 3:
        image = np.expand_dims(image, axis=3)
        image = tf.image.grayscale_to_rgb(image)
    return image

Is there an alternative way to convert grayscale images to rgb?

Upvotes: 4

Views: 7232

Answers (2)

LC117
LC117

Reputation: 786

I had a very similar problem, I wanted to load rgb and greyscale images in one go. Tensorflow supports setting the channel number when reading in the images. So if the images have different numbers of channels, this might be what you are looking for:

# to get greyscale:
tf.io.decode_image(raw_img, expand_animations = False, dtype=tf.float32, channels=1)

# to get rgb:
tf.io.decode_image(raw_img, expand_animations = False, dtype=tf.float32, channels=3)

-> You can even do both on the same image and inside tf.data.Dataset mappings!

You now have to set the channels variable to match the shape you need, so all the loaded images will be of that shape. Than you could reshape without a condition.

This also allows you to directly load a grayscale image to RGB in Tensorflow. Here an example:

    >> a = Image.open(r"Path/to/rgb_img.JPG")
    >> np.array(a).shape
    (666, 1050, 3)
    >> a = a.convert('L')
    >> np.array(a).shape
    (666, 1050)
    >> b = np.array(a)
    >> im = Image.fromarray(b) 
    >> im.save(r"Path/to/now_it_is_greyscale.jpg")
    >> raw_img = tf.io.read_file(r"Path/to/now_it_is_greyscale.jpg")
    >> img = tf.io.decode_image(raw_img, dtype=tf.float32, channels=3)
    >> img.shape
    TensorShape([666, 1050, 3])
    >> img = tf.io.decode_image(raw_img, dtype=tf.float32, channels=1)
    >> img.shape
    TensorShape([666, 1050, 1])

Use expand_animations = False if you get ValueError: 'images' contains no shape.! See: https://stackoverflow.com/a/59944421/9621080

Upvotes: 4

javidcf
javidcf

Reputation: 59681

You can use a function like this for that:

import tensorflow as tf

def process_image(image):
    image = tf.convert_to_tensor(image, dtype=tf.uint8)
    image_rgb =  tf.cond(tf.rank(image) < 4,
                         lambda: tf.image.grayscale_to_rgb(tf.expand_dims(image, -1)),
                         lambda: tf.identity(image))
    # Add shape information
    s = image.shape
    image_rgb.set_shape(s)
    if s.ndims is not None and s.ndims < 4:
        image_rgb.set_shape(s.concatenate(3))
    return image_rgb

Upvotes: 4

Related Questions