Reputation: 149
I have implemented a version of U-NET in tensorflow, trying to identify buildings from satellite images. The implementation is working and is giving promising results regarding the classification. All the metrics seems to be working correctly except mean_iou. Regardless of the different hyperparameters and the images chosen from the dataset the mean_iou is always the same. The value is similar to 15 decimal points after each epoch.
The precision and recall values are considerable higher compared to mean_iou and what should be expected, so it seems that something is not working as intended.
As I am relatively new to tensorflow so the error might be something totally different, but I am here to learn. All feedback will be greatly appriciated.
Here is the relevant code and printout from the training of the model.
import numpy as np
import tensorflow as tf
from unet_model import build_unet
from data import load_dataset, tf_dataset
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, CSVLogger, EarlyStopping
model_types = ['segnet-master', 'unet-master', 'simpler', 'even-simpler']
if __name__ == "__main__":
""" Hyperparamaters """
dataset_path = "building-segmentation"
input_shape = (64, 64, 3)
batch_size = 20
model = 3
epochs = 5
res = 64
lr = 1e-3
model_path = f"unet_models/unet_{epochs}_epochs_{res}.h5"
csv_path = f"csv/data_unet_{epochs}_{res}.csv"
""" Load the dataset """
(train_images, train_masks), (val_images, val_masks) = load_dataset(dataset_path)
train_dataset = tf_dataset(train_images, train_masks, batch=batch_size)
val_dataset = tf_dataset(val_images, val_masks, batch=batch_size)
model = build_unet(input_shape)
model.compile(
loss="binary_crossentropy",
optimizer=tf.keras.optimizers.Adam(lr),
metrics=[
tf.keras.metrics.MeanIoU(num_classes=2),
tf.keras.metrics.IoU(num_classes=2, target_class_ids=[0]),
tf.keras.metrics.Recall(),
tf.keras.metrics.Precision()
]
)
callbacks = [
ModelCheckpoint(model_path, monitor="val_loss", verbose=1),
ReduceLROnPlateau(monitor="val_loss", patience=10, factor=0.1, verbose=1),
CSVLogger(csv_path),
EarlyStopping(monitor="val_loss", patience=10)
]
train_steps = len(train_images)//batch_size
if len(train_images) % batch_size != 0:
train_steps += 1
test_steps = len(val_images)//batch_size
if len(val_images) % batch_size != 0:
test_steps += 1
model.fit(
train_dataset,
validation_data=val_dataset,
epochs=epochs,
steps_per_epoch=train_steps,
validation_steps=test_steps,
callbacks=callbacks
)
epoch | loss | lr | mean_io_u | precision | recall | val_loss | val_mean_io_u | val_precision | val_recall |
---|---|---|---|---|---|---|---|---|---|
0 | 0.41137945652008057 | 0.001 | 0.37184661626815796 | 0.695444643497467 | 0.5243006944656372 | 0.87176513671875 | 0.37157535552978516 | 0.38247567415237427 | 0.9118495583534241 |
1 | 0.3461640477180481 | 0.001 | 0.37182655930519104 | 0.7579150795936584 | 0.6075601577758789 | 0.3907579183578491 | 0.37157535552978516 | 0.8406943082809448 | 0.5024654865264893 |
2 | 0.3203786611557007 | 0.001 | 0.37182655930519104 | 0.7694798707962036 | 0.6599727272987366 | 0.3412915766239166 | 0.37157535552978516 | 0.6986522674560547 | 0.7543279528617859 |
3 | 0.2999393939971924 | 0.001 | 0.37182655930519104 | 0.7859976887702942 | 0.6890525221824646 | 0.40518054366111755 | 0.37157535552978516 | 0.6738141775131226 | 0.6654454469680786 |
4 | 0.28737708926200867 | 0.001 | 0.37182655930519104 | 0.793653130531311 | 0.7092126607894897 | 0.37544798851013184 | 0.37157535552978516 | 0.621263325214386 | 0.768422544002533 |
5 | 0.27629318833351135 | 0.001 | 0.37182655930519104 | 0.8028419613838196 | 0.72260981798172 | 0.4055494964122772 | 0.37157535552978516 | 0.8477562665939331 | 0.5473824143409729 |
6 | 0.2665417492389679 | 0.001 | 0.37182655930519104 | 0.809609055519104 | 0.7353982329368591 | 0.33294594287872314 | 0.37157535552978516 | 0.7307689785957336 | 0.6933897733688354 |
7 | 0.25887876749038696 | 0.001 | 0.37182655930519104 | 0.8132126927375793 | 0.744954526424408 | 0.28797024488449097 | 0.37157535552978516 | 0.7534120082855225 | 0.7735632061958313 |
8 | 0.25271594524383545 | 0.001 | 0.37182655930519104 | 0.8179733753204346 | 0.7538670897483826 | 0.30249008536338806 | 0.37157535552978516 | 0.8644329905509949 | 0.6237345337867737 |
9 | 0.24556593596935272 | 0.001 | 0.37182655930519104 | 0.8207928538322449 | 0.7622584104537964 | 0.3576349914073944 | 0.37157535552978516 | 0.6576451063156128 | 0.8346141576766968 |
10 | 0.23954670131206512 | 0.001 | 0.37182655930519104 | 0.8256030082702637 | 0.769091010093689 | 0.2541409134864807 | 0.37157535552978516 | 0.8100516200065613 | 0.7633218765258789 |
11 | 0.2349284589290619 | 0.001 | 0.37182655930519104 | 0.8274455070495605 | 0.7762861847877502 | 0.24383187294006348 | 0.37157535552978516 | 0.795067310333252 | 0.8124401569366455 |
12 | 0.22480393946170807 | 0.001 | 0.37182655930519104 | 0.8354562520980835 | 0.787416398525238 | 0.3778316378593445 | 0.37157535552978516 | 0.6533672213554382 | 0.8588836789131165 |
13 | 0.22573505342006683 | 0.001 | 0.37182655930519104 | 0.8342418670654297 | 0.7852107882499695 | 0.3342073857784271 | 0.37157535552978516 | 0.6768029928207397 | 0.7917631268501282 |
14 | 0.21639415621757507 | 0.001 | 0.37182655930519104 | 0.8411555886268616 | 0.7972605228424072 | 0.2792396545410156 | 0.37157535552978516 | 0.7611830234527588 | 0.7955203652381897 |
15 | 0.21154287457466125 | 0.001 | 0.37182655930519104 | 0.8441442251205444 | 0.8019176125526428 | 0.27426305413246155 | 0.37157535552978516 | 0.8764772415161133 | 0.6708933115005493 |
16 | 0.20740143954753876 | 0.001 | 0.37182655930519104 | 0.8469985127449036 | 0.8068550825119019 | 0.367437481880188 | 0.37157535552978516 | 0.646026611328125 | 0.8527452945709229 |
17 | 0.2005360722541809 | 0.001 | 0.37182655930519104 | 0.8522992134094238 | 0.8129924535751343 | 0.22591133415699005 | 0.37157535552978516 | 0.8203750252723694 | 0.8089460730552673 |
18 | 0.1976771354675293 | 0.001 | 0.37182655930519104 | 0.853760302066803 | 0.8163849115371704 | 0.2331937551498413 | 0.37157535552978516 | 0.807687520980835 | 0.8157453536987305 |
19 | 0.19583451747894287 | 0.001 | 0.37182655930519104 | 0.8560215830802917 | 0.8190248012542725 | 0.2519392669200897 | 0.37157535552978516 | 0.7935053110122681 | 0.8000433444976807 |
20 | 0.1872621327638626 | 0.001 | 0.37182655930519104 | 0.8615736365318298 | 0.8263705372810364 | 0.22855037450790405 | 0.37157535552978516 | 0.7948822975158691 | 0.8500961065292358 |
21 | 0.1852150857448578 | 0.001 | 0.37182655930519104 | 0.8620718717575073 | 0.8289932012557983 | 0.2352440059185028 | 0.37157535552978516 | 0.7972174286842346 | 0.8323403000831604 |
22 | 0.17845036089420319 | 0.001 | 0.37182655930519104 | 0.8677510023117065 | 0.8351714611053467 | 0.21090157330036163 | 0.37157535552978516 | 0.8470866084098816 | 0.8098670244216919 |
23 | 0.1732502579689026 | 0.001 | 0.37182655930519104 | 0.8711428046226501 | 0.8414102792739868 | 0.32612740993499756 | 0.37157535552978516 | 0.8412857055664062 | 0.695543646812439 |
24 | 0.17396509647369385 | 0.001 | 0.37182655930519104 | 0.8704758882522583 | 0.840953528881073 | 0.2149643898010254 | 0.37157535552978516 | 0.8315027952194214 | 0.8180400729179382 |
25 | 0.1740695685148239 | 0.001 | 0.37182655930519104 | 0.8702647089958191 | 0.8410759568214417 | 0.2138184905052185 | 0.37157535552978516 | 0.8604387044906616 | 0.7878146171569824 |
26 | 0.16104143857955933 | 0.001 | 0.37182655930519104 | 0.8794053196907043 | 0.8530260324478149 | 0.23256370425224304 | 0.37157535552978516 | 0.8179659843444824 | 0.8145195841789246 |
27 | 0.15866029262542725 | 0.001 | 0.37182655930519104 | 0.8813797831535339 | 0.8556373119354248 | 0.21111807227134705 | 0.37157535552978516 | 0.8566364049911499 | 0.805817723274231 |
28 | 0.15867507457733154 | 0.001 | 0.37182655930519104 | 0.8811318874359131 | 0.8551875352859497 | 0.2091868668794632 | 0.37157535552978516 | 0.8498891592025757 | 0.8088852763175964 |
29 | 0.15372247993946075 | 0.001 | 0.37182655930519104 | 0.884833574295044 | 0.8602938055992126 | 0.2100905030965805 | 0.37157535552978516 | 0.8543928265571594 | 0.8121073246002197 |
30 | 0.1550114005804062 | 0.001 | 0.37182655930519104 | 0.8840479850769043 | 0.85946124792099 | 0.21207265555858612 | 0.37157535552978516 | 0.8512551784515381 | 0.814805269241333 |
31 | 0.14192143082618713 | 0.001 | 0.37182655930519104 | 0.8927850127220154 | 0.8717316389083862 | 0.21726688742637634 | 0.37157535552978516 | 0.8147332072257996 | 0.8602878451347351 |
32 | 0.1401694267988205 | 0.001 | 0.37182655930519104 | 0.8940809965133667 | 0.8732201457023621 | 0.21714988350868225 | 0.37157535552978516 | 0.8370103240013123 | 0.8307888507843018 |
33 | 0.13880570232868195 | 0.001 | 0.37182655930519104 | 0.8950505256652832 | 0.8743049502372742 | 0.23316830396652222 | 0.37157535552978516 | 0.8291308283805847 | 0.8264546990394592 |
34 | 0.14308543503284454 | 0.001 | 0.37182655930519104 | 0.892676830291748 | 0.8704872131347656 | 0.2735193967819214 | 0.37157535552978516 | 0.7545790076255798 | 0.8698106408119202 |
35 | 0.14015090465545654 | 0.001 | 0.37182655930519104 | 0.8939213752746582 | 0.8743175864219666 | 0.20235474407672882 | 0.37157535552978516 | 0.8535885810852051 | 0.8286886215209961 |
36 | 0.1288939267396927 | 0.001 | 0.37182655930519104 | 0.9015076756477356 | 0.8844809532165527 | 0.22387968003749847 | 0.37157535552978516 | 0.8760555982589722 | 0.7937673926353455 |
37 | 0.12568938732147217 | 0.001 | 0.37182655930519104 | 0.9041174054145813 | 0.8872519731521606 | 0.21494744718074799 | 0.37157535552978516 | 0.8468613028526306 | 0.8249993324279785 |
38 | 0.12176792323589325 | 0.001 | 0.37182655930519104 | 0.9065613746643066 | 0.8911336064338684 | 0.23827765882015228 | 0.37157535552978516 | 0.8391880989074707 | 0.8176671862602234 |
39 | 0.11993639171123505 | 0.001 | 0.37182655930519104 | 0.9084023237228394 | 0.8925207257270813 | 0.22297391295433044 | 0.37157535552978516 | 0.8404833674430847 | 0.8346469402313232 |
40 | 0.11878598481416702 | 0.001 | 0.37182655930519104 | 0.9090615510940552 | 0.8941413164138794 | 0.22415445744991302 | 0.37157535552978516 | 0.8580552339553833 | 0.8152300715446472 |
41 | 0.1256236732006073 | 0.001 | 0.37182655930519104 | 0.9046309590339661 | 0.8880045413970947 | 0.20100584626197815 | 0.37157535552978516 | 0.8520526885986328 | 0.8423823714256287 |
42 | 0.10843898355960846 | 0.001 | 0.37182655930519104 | 0.9163806438446045 | 0.903978168964386 | 0.21887923777103424 | 0.37157535552978516 | 0.86836838722229 | 0.8237167596817017 |
43 | 0.10670299828052521 | 0.001 | 0.37182655930519104 | 0.9178842902183533 | 0.9054436683654785 | 0.21005834639072418 | 0.37157535552978516 | 0.8679876327514648 | 0.8253417611122131 |
44 | 0.10276217758655548 | 0.001 | 0.37182655930519104 | 0.9207708239555359 | 0.909300684928894 | 0.2151617556810379 | 0.37157535552978516 | 0.8735089302062988 | 0.8225894570350647 |
45 | 0.10141195356845856 | 0.001 | 0.3718271255493164 | 0.9218501448631287 | 0.9108821749687195 | 0.22106514871120453 | 0.37157535552978516 | 0.8555923700332642 | 0.8328163623809814 |
46 | 0.09918847680091858 | 0.001 | 0.37182655930519104 | 0.9235833883285522 | 0.9129346609115601 | 0.23230132460594177 | 0.37157535552978516 | 0.8555824756622314 | 0.8224022388458252 |
47 | 0.10588783025741577 | 0.001 | 0.37182655930519104 | 0.9191931486129761 | 0.9068878293037415 | 0.22423967719078064 | 0.37157535552978516 | 0.8427634239196777 | 0.825032114982605 |
48 | 0.103585384786129 | 0.001 | 0.37182655930519104 | 0.9209527969360352 | 0.9087461233139038 | 0.2110774666070938 | 0.37157535552978516 | 0.8639764785766602 | 0.8252225518226624 |
49 | 0.09157560020685196 | 0.001 | 0.37182655930519104 | 0.9292182922363281 | 0.9203035831451416 | 0.22161123156547546 | 0.37157535552978516 | 0.8649827837944031 | 0.8406093120574951 |
50 | 0.08616402745246887 | 0.001 | 0.37182655930519104 | 0.9334553480148315 | 0.9252204298973083 | 0.2387685328722 | 0.37157535552978516 | 0.8806527256965637 | 0.811405599117279 |
51 | 0.0846954956650734 | 0.001 | 0.37182655930519104 | 0.9345796704292297 | 0.9265674352645874 | 0.22581790387630463 | 0.37157535552978516 | 0.8756505846977234 | 0.8313769698143005 |
Upvotes: 2
Views: 1146
Reputation: 58
for Binary problems, there is another IOU named tf.keras.metrics.BinaryIoU(name='IoU'). This might solve the issue.
I had the same issue with a multi class segmentation that was resolved after moving from tf.keras.metrics.MeanIoU to tf.keras.metrics.OneHotMeanIoU as I am using one hot encoded lables.
Upvotes: 3