Jason Hong
Jason Hong

Reputation: 31

I followed the tensorflow image segmentation tutorial, but the predicted mask is blank

I'd like to try image segmentation with my grayscale tif images (the shape of original images are (512,512) and the value of each pixel is between 0-2 or NaN which is in float32 type and the mask images have 0, 1, or NaN also in float32 type). I followed Google Colab and tensorflow tutorial to create the following code:

from glob import glob
from PIL import Image
from tensorflow import keras
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models
from tensorflow.python.keras import backend as K

#get the path of my data
img = sorted(glob('train_sub_5/*.tif'))
mask = sorted(glob('train_mask_sub_5/*.tif'))

#split into train and test data
img, img_val, mask, mask_val = train_test_split(img, mask, test_size=0.2, random_state=42)

#load image as array and append to a list
train_image = []
for m in img:
    img= Image.open(m)
    img_arr = np.array(img)
    stacked_img = np.stack((img_arr,)*1, axis=-1)
    train_image.append(stacked_img)

train_mask = []
for n in mask:
    mask= Image.open(n)
    mask_arr= np.array(mask)
    stacked_mask = np.stack((mask_arr,)*1, axis=-1)
    train_mask.append(stacked_mask)

test_img = []
for o in img_val:
    img= Image.open(o)
    img_arr = np.array(img)
    stacked_img = np.stack((img_arr,)*1, axis=-1)
    test_img.append(stacked_img)

test_mask = []
for p in mask_val:
    mask= Image.open(p)
    mask_arr = np.array(mask)
    stacked_mask = np.stack((mask_arr,)*1, axis=-1)
    test_mask.append(stacked_mask)

#create TensorSliceDataset
for i, j in zip(train_image, train_mask):
    train= tf.data.Dataset.from_tensor_slices(([i], [j])) 

for k, l in zip(test_img, test_mask):
    test= tf.data.Dataset.from_tensor_slices(([k], [l]))

#for visualization
def display(display_list):
    plt.figure(figsize=(15, 15))
    title = ['Input Image', 'True Mask', 'Predicted Mask']
    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    plt.show()

for img, mask in train.take(1):
    sample_image = img.numpy()[:,:,0]
    sample_mask = mask.numpy()[:,:,0]

display([sample_image, sample_mask])

The output of the visualization looks normal like below: out put of the visualization

#build the model
train_length = len(train_image)
img_shape = (512,512,1)
batch_size = 8
buffer_size = 5
epochs = 5

train_dataset = train.cache().shuffle(train_length).batch(batch_size).repeat()
train_dataset = train_dataset.prefetch(buffer_size)
test_dataset = test.batch(batch_size).repeat()

def conv_block(input_tensor, num_filters):
    encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
    encoder = layers.BatchNormalization()(encoder)
    encoder = layers.Activation('relu')(encoder)
    encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
    encoder = layers.BatchNormalization()(encoder)
    encoder = layers.Activation('relu')(encoder)
    return encoder

def encoder_block(input_tensor, num_filters):
    encoder = conv_block(input_tensor, num_filters)
    encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)

    return encoder_pool, encoder

def decoder_block(input_tensor, concat_tensor, num_filters):
    decoder = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
    decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
    decoder = layers.BatchNormalization()(decoder)
    decoder = layers.Activation('relu')(decoder)
    decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
    decoder = layers.BatchNormalization()(decoder)
    decoder = layers.Activation('relu')(decoder)
    decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
    decoder = layers.BatchNormalization()(decoder)
    decoder = layers.Activation('relu')(decoder)
    return decoder

inputs = layers.Input(shape=img_shape)
# 256

encoder0_pool, encoder0 = encoder_block(inputs, 32)
# 128

encoder1_pool, encoder1 = encoder_block(encoder0_pool, 64)
# 64

encoder2_pool, encoder2 = encoder_block(encoder1_pool, 128)
# 32

encoder3_pool, encoder3 = encoder_block(encoder2_pool, 256)
# 16

encoder4_pool, encoder4 = encoder_block(encoder3_pool, 512)
# 8

center = conv_block(encoder4_pool, 1024)
# center

decoder4 = decoder_block(center, encoder4, 512)
# 16

decoder3 = decoder_block(decoder4, encoder3, 256)
# 32

decoder2 = decoder_block(decoder3, encoder2, 128)
# 64

decoder1 = decoder_block(decoder2, encoder1, 64)
# 128

decoder0 = decoder_block(decoder1, encoder0, 32)
# 256

outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(decoder0)

model = models.Model(inputs=[inputs], outputs=[outputs])

def dice_coeff(y_true, y_pred):
    smooth = 1.
    # Flatten
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return score

def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss

def bce_dice_loss(y_true, y_pred):
    loss = losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss

model.compile(optimizer='adam', loss=bce_dice_loss, metrics=[dice_loss])

model.summary()

#save model
save_model_path = 'tmp/weights.hdf5'
cp = tf.keras.callbacks.ModelCheckpoint(filepath=save_model_path, monitor='val_dice_loss', mode='max', save_best_only=True)

#start training
history = model.fit(train_dataset, 
                   steps_per_epoch=int(np.ceil(train_length / float(batch_size))),
                   epochs=epochs,
                   validation_data=test_dataset,
                   validation_steps=int(np.ceil(len(test_img) / float(batch_size))),
                   callbacks=[cp])

#training process visualization
dice = history.history['dice_loss']
val_dice = history.history['val_dice_loss']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, dice, label='Training Dice Loss')
plt.plot(epochs_range, val_dice, label='Validation Dice Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Dice Loss')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()

The output of the training process visualization looks like below: The output of the training process visualization The model seems functioning.

#make prediction
def show_predictions(dataset=None, num=1):
    for image, mask in dataset.take(num):
        pred_mask = model.predict(image)
        display([image[0,:,:,0], mask[0,:,:,0], create_mask(pred_mask)])

def create_mask(pred_mask):
    pred_mask = tf.argmax(pred_mask, axis=-1)
    pred_mask = pred_mask[..., tf.newaxis]
    return pred_mask[0,:,:,0]

show_predictions(test_dataset, 3)

The output of the prediction is below: The output of predictions

I tried to inspect the variables test and test_dataset using:

for img, mask in test:
    print(img,mask)

But I only got one image array and one mask array. Does it mean that there's only one image array and one mask array in the dataset? What's wrong with my code creating train and test TensorSliceDataset?

The Second question is why I got the predicted mask blank? Is it because some of my patches have nan? As you can see in output, the white part of the input image and the true mask, the sea is represented by NaN. If this is the problem, how do I set the value for NaN if I hope the model can ignore sea?

Thank you for your help.

Upvotes: 1

Views: 2129

Answers (1)

abir hasan
abir hasan

Reputation: 1

def display(display_list):
    fig = plt.figure(figsize=(15, 15))
    title = ['Input Image', 'True Mask', 'Predicted Mask']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i + 1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img
        (display_list[i]))
        plt.axis('off')
    plt.show()

def show_predictions(dataset=None, num=1):
    for image, mask in dataset.take(num):
        pred_mask = model.predict(image)
        pred_mask *= 255.0
        print(pred_mask.min())
        print(pred_mask.max())
        print(np.unique(pred_mask, return_counts=True))
        display([image[0], mask[0], pred_mask[0]])

show_predictions(test_dataset, 3)

Upvotes: 0

Related Questions