Reputation: 3616
I have a dataset of rgb and grayscale images. While iterating over the dataset, I want to detect if the image is a grayscale image such that I can convert it to rgb. I wanted to use tf.shape(image)
to detect the dimensions of the image. For a rgb image I get something like [1, 100, 100, 3]
. For grayscale images the function returns for example [1, 100, 100]
. I wanted to use len(tf.shape(image))
to detect if it is of length 4 (=rgb) or length 3 (=grayscale). That did not work.
This is my code so far which did not work:
def process_image(image):
# Convert numpy array to tensor
image = tf.convert_to_tensor(image, dtype=tf.uint8)
# Take care of grayscale images
dims = len(tf.shape(image))
if dims == 3:
image = np.expand_dims(image, axis=3)
image = tf.image.grayscale_to_rgb(image)
return image
Is there an alternative way to convert grayscale images to rgb?
Upvotes: 4
Views: 7232
Reputation: 786
I had a very similar problem, I wanted to load rgb and greyscale images in one go. Tensorflow supports setting the channel number when reading in the images. So if the images have different numbers of channels, this might be what you are looking for:
# to get greyscale:
tf.io.decode_image(raw_img, expand_animations = False, dtype=tf.float32, channels=1)
# to get rgb:
tf.io.decode_image(raw_img, expand_animations = False, dtype=tf.float32, channels=3)
-> You can even do both on the same image and inside tf.data.Dataset
mappings!
You now have to set the channels
variable to match the shape you need, so all the loaded images will be of that shape. Than you could reshape without a condition.
This also allows you to directly load a grayscale image to RGB in Tensorflow. Here an example:
>> a = Image.open(r"Path/to/rgb_img.JPG")
>> np.array(a).shape
(666, 1050, 3)
>> a = a.convert('L')
>> np.array(a).shape
(666, 1050)
>> b = np.array(a)
>> im = Image.fromarray(b)
>> im.save(r"Path/to/now_it_is_greyscale.jpg")
>> raw_img = tf.io.read_file(r"Path/to/now_it_is_greyscale.jpg")
>> img = tf.io.decode_image(raw_img, dtype=tf.float32, channels=3)
>> img.shape
TensorShape([666, 1050, 3])
>> img = tf.io.decode_image(raw_img, dtype=tf.float32, channels=1)
>> img.shape
TensorShape([666, 1050, 1])
Use expand_animations = False
if you get ValueError: 'images' contains no shape.
! See: https://stackoverflow.com/a/59944421/9621080
Upvotes: 4
Reputation: 59681
You can use a function like this for that:
import tensorflow as tf
def process_image(image):
image = tf.convert_to_tensor(image, dtype=tf.uint8)
image_rgb = tf.cond(tf.rank(image) < 4,
lambda: tf.image.grayscale_to_rgb(tf.expand_dims(image, -1)),
lambda: tf.identity(image))
# Add shape information
s = image.shape
image_rgb.set_shape(s)
if s.ndims is not None and s.ndims < 4:
image_rgb.set_shape(s.concatenate(3))
return image_rgb
Upvotes: 4