janbolle
janbolle

Reputation: 196

Autoencoder: Decoder has not same size as encoder

If I build the decoder as a mirror of encoder the output size of the last layer does not match.

This is the model summary:

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv_1_j (Conv2D)            (None, 28, 28, 64)        640       
_________________________________________________________________
batch_normalization_v2 (Batc (None, 28, 28, 64)        256       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64)        0         
_________________________________________________________________
conv_2_j (Conv2D)            (None, 14, 14, 64)        36928     
_________________________________________________________________
batch_normalization_v2_1 (Ba (None, 14, 14, 64)        256       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         
_________________________________________________________________
conv_3_j (Conv2D)            (None, 7, 7, 64)          36928     
_________________________________________________________________
batch_normalization_v2_2 (Ba (None, 7, 7, 64)          256       
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 3, 3, 64)          0         
_________________________________________________________________
conv_4_j (Conv2D)            (None, 3, 3, 64)          36928     
_________________________________________________________________
batch_normalization_v2_3 (Ba (None, 3, 3, 64)          256       
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 1, 1, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 64)                0         
_________________________________________________________________
dense_1_j (Dense)            (None, 64)                4160      
_________________________________________________________________
reshape_out (Lambda)         (None, 1, 1, 64)          0         
_________________________________________________________________
conv2d (Conv2D)              (None, 1, 1, 64)          36928     
_________________________________________________________________
batch_normalization_v2_4 (Ba (None, 1, 1, 64)          256       
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 2, 2, 64)          0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 2, 2, 64)          36928     
_________________________________________________________________
batch_normalization_v2_5 (Ba (None, 2, 2, 64)          256       
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 4, 4, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
batch_normalization_v2_6 (Ba (None, 4, 4, 64)          256       
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 8, 8, 64)          0         
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 8, 8, 64)          36928     
_________________________________________________________________
batch_normalization_v2_7 (Ba (None, 8, 8, 64)          256       
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 16, 16, 64)        0         
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 16, 16, 1)         577       
=================================================================
Total params: 265,921
Trainable params: 264,897
Non-trainable params: 1,024
_________________________________________________________________

Code to reproduce:

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from tensorflow.python.keras.layers import Lambda

from tensorflow.python.keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping


def resize(example):
    image = example['image']
    image = tf.image.resize(image, [28, 28])
    image = tf.image.rgb_to_grayscale(image, )
    image = image / 255

    example['image'] = image
    return example


def get_tupel(example):
    return example['image'], example['image']


def gen_dataset(dataset, batch_size):
    dataset = dataset.map(resize, num_parallel_calls=4)
    dataset = dataset.map(get_tupel, num_parallel_calls=4)
    dataset = dataset.shuffle(batch_size*50).repeat()  # infinite stream
    dataset = dataset.prefetch(10000)
    dataset = dataset.batch(batch_size)
    return dataset


def main():
    builder = tfds.builder("cifar10")
    builder.download_and_prepare()
    datasets = builder.as_dataset()
    train_dataset, test_dataset = datasets['train'], datasets['test']

    batch_size = 48

    train_dataset = gen_dataset(train_dataset, batch_size)
    test_dataset = gen_dataset(test_dataset, batch_size)

    device = '/cpu:0' if not tf.test.is_gpu_available() else tf.test.gpu_device_name()
    print(tf.test.gpu_device_name())
    with tf.device(device):
        filters = 64
        kernel = 3
        pooling = 2
        image_size = 28

        inp_layer = tf.keras.layers.Input(shape=(image_size, image_size, 1))
        cnn_embedding_out = cnn_encoder(inp_layer, filters, kernel, pooling)
        cnn_decoder_out = cnn_decoder(cnn_embedding_out, filters, kernel, pooling)

        model = tf.keras.Model(inputs=inp_layer, outputs=cnn_decoder_out)

        model.compile(optimizer=tf.optimizers.Adam(0.0001), loss='binary_crossentropy',
                      metrics=['accuracy'])

        print(model.summary())

        model.fit(train_dataset, validation_data=test_dataset,
                  steps_per_epoch=100,  # 1000
                  validation_steps=100,
                  epochs=150,)


def cnn_encoder(inp_layer, filters, kernel, pooling):
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_1_j')(inp_layer)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    max1 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_2_j')(max1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    max2 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_3_j')(max2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    max3 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_4_j')(max3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    max4 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn4)
    flat = tf.keras.layers.Flatten()(max4)
    fc = tf.keras.layers.Dense(64, name='dense_1_j')(flat)  # this is the encoder layer!

    return fc


def cnn_decoder(inp_layer, filters, kernel, pooling):
    res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
    decoded = tf.keras.layers.Conv2D(1, kernel,  padding="same", activation='sigmoid')(up4)
    return decoded


def reshape(dim, name="", complete=False):
    def func(x):
        if complete:
            ret = tf.reshape(x, dim)
        else:
            ret = tf.reshape(x, [-1, ] + dim)
        return ret
    return Lambda(func, name=name)


if __name__ == "__main__":
    main()

I tried to use Conv2dTranspose and different size of upsampling, but this doesn't feel right.

I would expect the output as the input (48,28,28,1)

What am I doing wrong?

Upvotes: 0

Views: 1777

Answers (2)

janbolle
janbolle

Reputation: 196

I changed the cnn_decoder function:

def cnn_decoder(inp_layer, filters, kernel, pooling):
    res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
    cnn5 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up4)
    bn5 = tf.keras.layers.BatchNormalization()(cnn5)
    up5 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn5)
    cnn6 = tf.keras.layers.Conv2D(filters, kernel, padding="valid",  activation='relu')(up5)
    bn5 = tf.keras.layers.BatchNormalization()(cnn6)
    decoded = tf.keras.layers.Conv2D(1, kernel, padding="valid",  activation='sigmoid')(bn5)
    return decoded

Does it seem correct to you?

Upvotes: 1

thetradingdogdj
thetradingdogdj

Reputation: 521

as you can see in reshape_out (Lambda) you have shape 1, so by doing operations like UpSampling2 you can just get to size 2, 4, 8, 16, 32. To do your way you may have to do UpSampling2 until size 32, then reshape by doing for instance two Conv2D of kernel=3 to get a result of shape (x, 28, 28, 64)

Upvotes: 2

Related Questions