Reputation: 4341
I'm trying to implement UNet for image segmentation in TensorFlow 2 using the Keras API, but I am not sure how to implement the Concatenate layer. Here is what I have tried:
def create_model_myunet(depth, start_f, output_channels, encoder_kernel_size):
# Encoder
model = tf.keras.Sequential()
for i in range(0, depth):
if i == 0:
print("Specifying an input shape")
input_shape = [config.img_h, config.img_w, 3]
else:
input_shape = [None]
model.add(tf.keras.layers.Conv2D(filters=2**(start_f+i),
kernel_size=(encoder_kernel_size, encoder_kernel_size),
strides=(1, 1),
padding='same',
input_shape=input_shape,
name = "enc_conv2d_" + str(i)))
model.add(tf.keras.layers.ReLU(name = "enc_relu_" + str(i)))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2), name="enc_maxpool2d_" + str(i)))
# Decoder
initializer = tf.random_normal_initializer(0., 0.02)
for i in range(depth, 1, -1):
model.add(
tf.keras.layers.Conv2DTranspose(2**(start_f+i),
encoder_kernel_size,
strides=2,
padding='same',
kernel_initializer=initializer,
use_bias=False)
)
model.add(tf.keras.layers.BatchNormalization())
model.add(tf.keras.layers.ReLU(name="dec_relu_"+str(i)))
model.add(tf.keras.layers.Concatenate([
model.get_layer(name="dec_relu_"+str(i)).output,
model.get_layer(name="enc_relu_"+str(i-1)).output
] ))
pass
last = tf.keras.layers.Conv2DTranspose(
output_channels, 3, strides=2,
padding='same', activation='softmax') #64x64 -> 128x128
model.add(last)
return model
It gives me the following error:
ValueError: A
Concatenate
layer should be called on a list of at least 2 inputs
Upvotes: 0
Views: 1495
Reputation: 1685
You need to change
model.add(tf.keras.layers.Concatenate([
model.get_layer(name="dec_relu_"+str(i)).output,
model.get_layer(name="enc_relu_"+str(i-1)).output
] ))
to
model.add(tf.keras.layers.Concatenate()([ # Sequential api
model.get_layer(name="dec_relu_"+str(i)).output,
model.get_layer(name="enc_relu_"+str(i-1)).output
] ))
or
model.add(tf.keras.layers.concatenate([ # Functional api
model.get_layer(name="dec_relu_"+str(i)).output,
model.get_layer(name="enc_relu_"+str(i-1)).output
] ))
Upvotes: 3