FF123456
FF123456

Reputation: 83

How to modify my code to handle RGBX (4-channel) images for semantic segmentation?

I'm new to this field and have been following a U-Net tutorial using 3-channel RGB images for semantic segmentation https://www.youtube.com/watch?v=68HR_eyzk00&list=PLZsOBAyNTZwbR08R959iCvYT3qzhxvGOE&index=2&ab_channel=DigitalSreeni, and it worked fine for me. However, I now need to extend the pipeline to support 4-channel RGBX images (i.e., RGB + an other channel), but I’m not sure how to modify the code to accommodate the additional channel, especially for the preprocessing and the ImageDataGenerator parts (I think that the ImageDataGenerator doesn’t support 4-channel images).

This is the code (after patchifying the image to (256 * 256 * 4) and the masks to (256*256)):

import os
import cv2
import numpy as np
import glob
from matplotlib import pyplot as plt
import tensorflow as tf
import splitfolders
import segmentation_models as sm
from tensorflow.keras.metrics import MeanIoU
from sklearn.preprocessing import MinMaxScaler
from keras.utils import to_categorical


input_folder='path folder to my images and masks '
output_folder='path to output folder'
#split with a ratio
splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.75,.25),group_prefix=None) 

#Rearange the folder structure for keras augmentation


seed=24
batch_size=16 
n_classes=2 


scaler=MinMaxScaler()


BACKBONE='resnet34'  
preprocess_input=sm.get_preprocessing(BACKBONE)

def preprocess_data(img, mask, num_class):
    #Scale images
    img=scaler.fit_transform(img.reshape(-1, img.shape[-1])).reshape(img.shape)
    img=preprocess_input(img)  #Preprocess based on the pretrained backbone
    mask=to_categorical(mask, num_class)
    return (img,mask)

from tensorflow.keras.preprocessing.image import ImageDataGenerator
def trainGenerator(train_img_path, train_mask_path, num_class):
    img_data_gen_args=dict(horizontal_flip=True, vertical_flip=True, fill_mode='reflect') #Data augmentation
    
    image_datagen=ImageDataGenerator(**img_data_gen_args)
    mask_datagen=ImageDataGenerator(**img_data_gen_args)
    
    image_generator=image_datagen.flow_from_directory(train_img_path, class_mode=None, batch_size=batch_size, seed=seed)
    mask_generator=image_datagen.flow_from_directory(train_mask_path, class_mode=None, color_mode='grayscale', batch_size=batch_size, seed=seed)
    
    train_generator=zip(image_generator, mask_generator)
    
    for (img, mask) in train_generator:
        img, mask= preprocess_data(img, mask, num_class)
        yield (img, mask)

train_img_path='path for training images'
train_mask_path='path for training masks'
train_img_gen=trainGenerator(train_img_path, train_mask_path, num_class=2)

val_img_path='path for validation images'
val_mask_path='path for validation masks'
val_img_gen=trainGenerator(val_img_path, val_mask_path, num_class=2)


x, y=train_img_gen.__next__()

for i in range(0,3):
    image=x[i]
    mask=np.argmax(y[i], axis=2)
    plt.subplot(1,2,1)
    plt.imshow(image)
    plt.subplot(1,2,2)
    plt.imshow(mask, cmap='gray')
    plt.show()


num_train_imgs=len(os.listdir('path for training images'))
num_val_images=len(os.listdir('path for validation image'))
steps_per_epochs=num_train_imgs//batch_size
val_steps_per_epoch=num_val_images//batch_size

IMG_HEIGHT=x.shape[1]
IMG_WIDTH=x.shape[2]
IMG_CHANNELS=x.shape[3]

n_classes=2

model=sm.Unet('resnet34', encoder_weights='None', input_shape=(IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS), classes=n_classes,activation='softmax')
model.compile('Adam', loss=sm.losses.binary_crossentropy, metrics=[sm.metrics.iou_score, sm.metrics.FScore()])

history=model.fit(train_img_gen, steps_per_epoch=steps_per_epochs, epochs=100, verbose=1, validation_data=val_img_gen, validation_steps=val_steps_per_epoch)


Upvotes: 0

Views: 74

Answers (1)

Haley Schuhl
Haley Schuhl

Reputation: 21

You could drop the 4th band of data, typically an alpha channel, while reading it in with OpenCV like this.

import cv2

img = cv2.imread(filename)

and if the workflow requires an image path instead of a numpy object, then I might run a pre-processing workflow that copies 3-channel images to a new directory train_img_path.

Upvotes: 0

Related Questions