beridzeg45
beridzeg45

Reputation: 416

Image channel error while training CNN model

I get the following error when trying to train a CNN model:

InvalidArgumentError: Graph execution error:

Detected at node decode_image/DecodeImage defined at (most recent call last):
<stack traces unavailable>
Number of channels inherent in the image must be 1, 3 or 4, was 2
     [[{{node decode_image/DecodeImage}}]]
     [[IteratorGetNext]] [Op:__inference_train_function_1598]

The dataset I am working on is Cats and Dogs classification dataset from Kaggle. I defined the data like this:

path=r'C:\Users\berid\python\cats and dogs\PetImages'
data=tf.keras.utils.image_dataset_from_directory(path)

Any suggestion will be appreciated.

Upvotes: 1

Views: 106

Answers (1)

Fred Myers
Fred Myers

Reputation: 23

I had this exact same problem working on the same dataset. I downloaded my dataset from Kaggle and it seems there are some bad photos. The files have a jpg file extension but the format is BMP or None. Also, some photos have a weird number of channels. I used the code below to remove those files. It was only about 150 out of 25,000 so not a big deal IMO. Then I ws able to fit the model without issue. Here is my code:

cats_filenames = [os.path.join(data_dir_cats, filename) for filename in os.listdir(data_dir_cats)]
dogs_filenames = [os.path.join(data_dir_dogs, filename) for filename in os.listdir(data_dir_dogs)]

print('Validating cat files....')
for cat_image in cats_filenames:
    img = tf.keras.utils.load_img(cat_image)
    if img.format != 'JPEG' and img.format != 'jpg':
        print('Not jpeg.  removing...', img.format, cat_image)
        os.remove(cat_image)
    else:    
        img=mpimg.imread(cat_image)
        try:
            if img.shape[2] < 1 or img.shape[2] > 4 or img.shape[2] == 2:                   
                print(f'Removing...  {img.shape=} {cat_image}')
                os.remove(cat_image)
        except Exception as e: 
            print(e, cat_image)
print('Validating dog files....')
for dog_image in dogs_filenames:
    img = tf.keras.utils.load_img(dog_image)
    if img.format != 'JPEG' and img.format != 'jpg':
        print('Not jpeg.  removing...', img.format, dog_image)
        os.remove(dog_image)
    else:    
        img=mpimg.imread(dog_image)
        try:
            if img.shape[2] < 1 or img.shape[2] > 4 or img.shape[2] == 2:                   
                print(f'Removing...  {img.shape=} {dog_image}')
                os.remove(dog_image)
        except Exception as e: 
            print(e, dog_image)
print('Done Validating....')
print(f"There are {len(os.listdir(data_dir_dogs))} images of dogs.")
print(f"There are {len(os.listdir(data_dir_cats))} images of cats.")

# Get the filenames for cats and dogs images
cats_filenames = [os.path.join(data_dir_cats, filename) for filename in os.listdir(data_dir_cats)]
dogs_filenames = [os.path.join(data_dir_dogs, filename) for filename in os.listdir(data_dir_dogs)]

Upvotes: 0

Related Questions