robertspierre
robertspierre

Reputation: 4341

Implement UNet in Tensorflow

I'm trying to implement UNet for image segmentation in TensorFlow 2 using the Keras API, but I am not sure how to implement the Concatenate layer. Here is what I have tried:

def create_model_myunet(depth, start_f, output_channels, encoder_kernel_size):

    # Encoder
    model = tf.keras.Sequential()

    for i in range(0, depth):

        if i == 0:
            print("Specifying an input shape")
            input_shape = [config.img_h, config.img_w, 3]
        else:
            input_shape = [None]

        model.add(tf.keras.layers.Conv2D(filters=2**(start_f+i), 
                                         kernel_size=(encoder_kernel_size, encoder_kernel_size),
                                         strides=(1, 1),
                                         padding='same',
                                         input_shape=input_shape,
                                         name = "enc_conv2d_" + str(i)))
        model.add(tf.keras.layers.ReLU(name = "enc_relu_" + str(i)))
        model.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2), name="enc_maxpool2d_" + str(i)))

    # Decoder
    initializer = tf.random_normal_initializer(0., 0.02)

    for i in range(depth, 1, -1):

        model.add(
        tf.keras.layers.Conv2DTranspose(2**(start_f+i),
                                        encoder_kernel_size,
                                        strides=2,
                                        padding='same',
                                        kernel_initializer=initializer,
                                        use_bias=False)
        )

        model.add(tf.keras.layers.BatchNormalization())

        model.add(tf.keras.layers.ReLU(name="dec_relu_"+str(i)))

        model.add(tf.keras.layers.Concatenate([
            model.get_layer(name="dec_relu_"+str(i)).output,  
            model.get_layer(name="enc_relu_"+str(i-1)).output
        ] ))
        pass

    last = tf.keras.layers.Conv2DTranspose(
      output_channels, 3, strides=2,
      padding='same', activation='softmax')  #64x64 -> 128x128

    model.add(last)

    return model

It gives me the following error:

ValueError: A Concatenate layer should be called on a list of at least 2 inputs

Upvotes: 0

Views: 1495

Answers (1)

Kaushik Roy
Kaushik Roy

Reputation: 1685

You need to change

model.add(tf.keras.layers.Concatenate([
    model.get_layer(name="dec_relu_"+str(i)).output,  
    model.get_layer(name="enc_relu_"+str(i-1)).output
] ))

to

model.add(tf.keras.layers.Concatenate()([  # Sequential api
    model.get_layer(name="dec_relu_"+str(i)).output,  
    model.get_layer(name="enc_relu_"+str(i-1)).output
] ))

or

model.add(tf.keras.layers.concatenate([  # Functional api
    model.get_layer(name="dec_relu_"+str(i)).output,  
    model.get_layer(name="enc_relu_"+str(i-1)).output
] ))

Upvotes: 3

Related Questions