Reputation: 31
I'd like to try image segmentation with my grayscale tif images (the shape of original images are (512,512) and the value of each pixel is between 0-2 or NaN which is in float32 type and the mask images have 0, 1, or NaN also in float32 type). I followed Google Colab and tensorflow tutorial to create the following code:
from glob import glob
from PIL import Image
from tensorflow import keras
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import layers
from tensorflow.python.keras import losses
from tensorflow.python.keras import models
from tensorflow.python.keras import backend as K
#get the path of my data
img = sorted(glob('train_sub_5/*.tif'))
mask = sorted(glob('train_mask_sub_5/*.tif'))
#split into train and test data
img, img_val, mask, mask_val = train_test_split(img, mask, test_size=0.2, random_state=42)
#load image as array and append to a list
train_image = []
for m in img:
img= Image.open(m)
img_arr = np.array(img)
stacked_img = np.stack((img_arr,)*1, axis=-1)
train_image.append(stacked_img)
train_mask = []
for n in mask:
mask= Image.open(n)
mask_arr= np.array(mask)
stacked_mask = np.stack((mask_arr,)*1, axis=-1)
train_mask.append(stacked_mask)
test_img = []
for o in img_val:
img= Image.open(o)
img_arr = np.array(img)
stacked_img = np.stack((img_arr,)*1, axis=-1)
test_img.append(stacked_img)
test_mask = []
for p in mask_val:
mask= Image.open(p)
mask_arr = np.array(mask)
stacked_mask = np.stack((mask_arr,)*1, axis=-1)
test_mask.append(stacked_mask)
#create TensorSliceDataset
for i, j in zip(train_image, train_mask):
train= tf.data.Dataset.from_tensor_slices(([i], [j]))
for k, l in zip(test_img, test_mask):
test= tf.data.Dataset.from_tensor_slices(([k], [l]))
#for visualization
def display(display_list):
plt.figure(figsize=(15, 15))
title = ['Input Image', 'True Mask', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i+1)
plt.title(title[i])
plt.imshow(display_list[i])
plt.axis('off')
plt.show()
for img, mask in train.take(1):
sample_image = img.numpy()[:,:,0]
sample_mask = mask.numpy()[:,:,0]
display([sample_image, sample_mask])
The output of the visualization looks normal like below: out put of the visualization
#build the model
train_length = len(train_image)
img_shape = (512,512,1)
batch_size = 8
buffer_size = 5
epochs = 5
train_dataset = train.cache().shuffle(train_length).batch(batch_size).repeat()
train_dataset = train_dataset.prefetch(buffer_size)
test_dataset = test.batch(batch_size).repeat()
def conv_block(input_tensor, num_filters):
encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(input_tensor)
encoder = layers.BatchNormalization()(encoder)
encoder = layers.Activation('relu')(encoder)
encoder = layers.Conv2D(num_filters, (3, 3), padding='same')(encoder)
encoder = layers.BatchNormalization()(encoder)
encoder = layers.Activation('relu')(encoder)
return encoder
def encoder_block(input_tensor, num_filters):
encoder = conv_block(input_tensor, num_filters)
encoder_pool = layers.MaxPooling2D((2, 2), strides=(2, 2))(encoder)
return encoder_pool, encoder
def decoder_block(input_tensor, concat_tensor, num_filters):
decoder = layers.Conv2DTranspose(num_filters, (2, 2), strides=(2, 2), padding='same')(input_tensor)
decoder = layers.concatenate([concat_tensor, decoder], axis=-1)
decoder = layers.BatchNormalization()(decoder)
decoder = layers.Activation('relu')(decoder)
decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
decoder = layers.BatchNormalization()(decoder)
decoder = layers.Activation('relu')(decoder)
decoder = layers.Conv2D(num_filters, (3, 3), padding='same')(decoder)
decoder = layers.BatchNormalization()(decoder)
decoder = layers.Activation('relu')(decoder)
return decoder
inputs = layers.Input(shape=img_shape)
# 256
encoder0_pool, encoder0 = encoder_block(inputs, 32)
# 128
encoder1_pool, encoder1 = encoder_block(encoder0_pool, 64)
# 64
encoder2_pool, encoder2 = encoder_block(encoder1_pool, 128)
# 32
encoder3_pool, encoder3 = encoder_block(encoder2_pool, 256)
# 16
encoder4_pool, encoder4 = encoder_block(encoder3_pool, 512)
# 8
center = conv_block(encoder4_pool, 1024)
# center
decoder4 = decoder_block(center, encoder4, 512)
# 16
decoder3 = decoder_block(decoder4, encoder3, 256)
# 32
decoder2 = decoder_block(decoder3, encoder2, 128)
# 64
decoder1 = decoder_block(decoder2, encoder1, 64)
# 128
decoder0 = decoder_block(decoder1, encoder0, 32)
# 256
outputs = layers.Conv2D(1, (1, 1), activation='sigmoid')(decoder0)
model = models.Model(inputs=[inputs], outputs=[outputs])
def dice_coeff(y_true, y_pred):
smooth = 1.
# Flatten
y_true_f = tf.reshape(y_true, [-1])
y_pred_f = tf.reshape(y_pred, [-1])
intersection = tf.reduce_sum(y_true_f * y_pred_f)
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
return score
def dice_loss(y_true, y_pred):
loss = 1 - dice_coeff(y_true, y_pred)
return loss
def bce_dice_loss(y_true, y_pred):
loss = losses.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
return loss
model.compile(optimizer='adam', loss=bce_dice_loss, metrics=[dice_loss])
model.summary()
#save model
save_model_path = 'tmp/weights.hdf5'
cp = tf.keras.callbacks.ModelCheckpoint(filepath=save_model_path, monitor='val_dice_loss', mode='max', save_best_only=True)
#start training
history = model.fit(train_dataset,
steps_per_epoch=int(np.ceil(train_length / float(batch_size))),
epochs=epochs,
validation_data=test_dataset,
validation_steps=int(np.ceil(len(test_img) / float(batch_size))),
callbacks=[cp])
#training process visualization
dice = history.history['dice_loss']
val_dice = history.history['val_dice_loss']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, dice, label='Training Dice Loss')
plt.plot(epochs_range, val_dice, label='Validation Dice Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Dice Loss')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
The output of the training process visualization looks like below: The output of the training process visualization The model seems functioning.
#make prediction
def show_predictions(dataset=None, num=1):
for image, mask in dataset.take(num):
pred_mask = model.predict(image)
display([image[0,:,:,0], mask[0,:,:,0], create_mask(pred_mask)])
def create_mask(pred_mask):
pred_mask = tf.argmax(pred_mask, axis=-1)
pred_mask = pred_mask[..., tf.newaxis]
return pred_mask[0,:,:,0]
show_predictions(test_dataset, 3)
The output of the prediction is below: The output of predictions
I tried to inspect the variables test and test_dataset using:
for img, mask in test:
print(img,mask)
But I only got one image array and one mask array. Does it mean that there's only one image array and one mask array in the dataset? What's wrong with my code creating train and test TensorSliceDataset?
The Second question is why I got the predicted mask blank? Is it because some of my patches have nan? As you can see in output, the white part of the input image and the true mask, the sea is represented by NaN. If this is the problem, how do I set the value for NaN if I hope the model can ignore sea?
Thank you for your help.
Upvotes: 1
Views: 2129
Reputation: 1
def display(display_list):
fig = plt.figure(figsize=(15, 15))
title = ['Input Image', 'True Mask', 'Predicted Mask']
for i in range(len(display_list)):
plt.subplot(1, len(display_list), i + 1)
plt.title(title[i])
plt.imshow(tf.keras.preprocessing.image.array_to_img
(display_list[i]))
plt.axis('off')
plt.show()
def show_predictions(dataset=None, num=1):
for image, mask in dataset.take(num):
pred_mask = model.predict(image)
pred_mask *= 255.0
print(pred_mask.min())
print(pred_mask.max())
print(np.unique(pred_mask, return_counts=True))
display([image[0], mask[0], pred_mask[0]])
show_predictions(test_dataset, 3)
Upvotes: 0