Autoencoder: Decoder has not same size as encoder

If I build the decoder as a mirror of encoder the output size of the last layer does not match.

This is the model summary:

Model: "model"
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28, 1)]       0         
conv_1_j (Conv2D)            (None, 28, 28, 64)        640       
batch_normalization_v2 (Batc (None, 28, 28, 64)        256       
max_pooling2d (MaxPooling2D) (None, 14, 14, 64)        0         
conv_2_j (Conv2D)            (None, 14, 14, 64)        36928     
batch_normalization_v2_1 (Ba (None, 14, 14, 64)        256       
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64)          0         
conv_3_j (Conv2D)            (None, 7, 7, 64)          36928     
batch_normalization_v2_2 (Ba (None, 7, 7, 64)          256       
max_pooling2d_2 (MaxPooling2 (None, 3, 3, 64)          0         
conv_4_j (Conv2D)            (None, 3, 3, 64)          36928     
batch_normalization_v2_3 (Ba (None, 3, 3, 64)          256       
max_pooling2d_3 (MaxPooling2 (None, 1, 1, 64)          0         
flatten (Flatten)            (None, 64)                0         
dense_1_j (Dense)            (None, 64)                4160      
reshape_out (Lambda)         (None, 1, 1, 64)          0         
conv2d (Conv2D)              (None, 1, 1, 64)          36928     
batch_normalization_v2_4 (Ba (None, 1, 1, 64)          256       
up_sampling2d (UpSampling2D) (None, 2, 2, 64)          0         
conv2d_1 (Conv2D)            (None, 2, 2, 64)          36928     
batch_normalization_v2_5 (Ba (None, 2, 2, 64)          256       
up_sampling2d_1 (UpSampling2 (None, 4, 4, 64)          0         
conv2d_2 (Conv2D)            (None, 4, 4, 64)          36928     
batch_normalization_v2_6 (Ba (None, 4, 4, 64)          256       
up_sampling2d_2 (UpSampling2 (None, 8, 8, 64)          0         
conv2d_3 (Conv2D)            (None, 8, 8, 64)          36928     
batch_normalization_v2_7 (Ba (None, 8, 8, 64)          256       
up_sampling2d_3 (UpSampling2 (None, 16, 16, 64)        0         
conv2d_4 (Conv2D)            (None, 16, 16, 1)         577       
Total params: 265,921
Trainable params: 264,897
Non-trainable params: 1,024

Code to reproduce:

import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from tensorflow.python.keras.layers import Lambda

from tensorflow.python.keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping

def resize(example):
    image = example['image']
    image = tf.image.resize(image, [28, 28])
    image = tf.image.rgb_to_grayscale(image, )
    image = image / 255

    example['image'] = image
    return example

def get_tupel(example):
    return example['image'], example['image']

def gen_dataset(dataset, batch_size):
    dataset = dataset.map(resize, num_parallel_calls=4)
    dataset = dataset.map(get_tupel, num_parallel_calls=4)
    dataset = dataset.shuffle(batch_size*50).repeat()  # infinite stream
    dataset = dataset.prefetch(10000)
    dataset = dataset.batch(batch_size)
    return dataset

def main():
    builder = tfds.builder("cifar10")
    datasets = builder.as_dataset()
    train_dataset, test_dataset = datasets['train'], datasets['test']

    batch_size = 48

    train_dataset = gen_dataset(train_dataset, batch_size)
    test_dataset = gen_dataset(test_dataset, batch_size)

    device = '/cpu:0' if not tf.test.is_gpu_available() else tf.test.gpu_device_name()
    with tf.device(device):
        filters = 64
        kernel = 3
        pooling = 2
        image_size = 28

        inp_layer = tf.keras.layers.Input(shape=(image_size, image_size, 1))
        cnn_embedding_out = cnn_encoder(inp_layer, filters, kernel, pooling)
        cnn_decoder_out = cnn_decoder(cnn_embedding_out, filters, kernel, pooling)

        model = tf.keras.Model(inputs=inp_layer, outputs=cnn_decoder_out)

        model.compile(optimizer=tf.optimizers.Adam(0.0001), loss='binary_crossentropy',


        model.fit(train_dataset, validation_data=test_dataset,
                  steps_per_epoch=100,  # 1000

def cnn_encoder(inp_layer, filters, kernel, pooling):
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_1_j')(inp_layer)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    max1 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_2_j')(max1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    max2 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_3_j')(max2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    max3 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_4_j')(max3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    max4 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn4)
    flat = tf.keras.layers.Flatten()(max4)
    fc = tf.keras.layers.Dense(64, name='dense_1_j')(flat)  # this is the encoder layer!

    return fc

def cnn_decoder(inp_layer, filters, kernel, pooling):
    res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
    decoded = tf.keras.layers.Conv2D(1, kernel,  padding="same", activation='sigmoid')(up4)
    return decoded

def reshape(dim, name="", complete=False):
    def func(x):
        if complete:
            ret = tf.reshape(x, dim)
            ret = tf.reshape(x, [-1, ] + dim)
        return ret
    return Lambda(func, name=name)

if __name__ == "__main__":

I tried to use Conv2dTranspose and different size of upsampling, but this doesn't feel right.

I would expect the output as the input (48,28,28,1)

What am I doing wrong?

I changed the cnn_decoder function:

def cnn_decoder(inp_layer, filters, kernel, pooling):
    res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
    cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
    bn1 = tf.keras.layers.BatchNormalization()(cnn1)
    up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
    cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
    bn2 = tf.keras.layers.BatchNormalization()(cnn2)
    up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
    cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
    bn3 = tf.keras.layers.BatchNormalization()(cnn3)
    up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
    cnn4 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up3)
    bn4 = tf.keras.layers.BatchNormalization()(cnn4)
    up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
    cnn5 = tf.keras.layers.Conv2D(filters, kernel,  padding="same", activation='relu',)(up4)
    bn5 = tf.keras.layers.BatchNormalization()(cnn5)
    up5 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn5)
    cnn6 = tf.keras.layers.Conv2D(filters, kernel, padding="valid",  activation='relu')(up5)
    bn5 = tf.keras.layers.BatchNormalization()(cnn6)
    decoded = tf.keras.layers.Conv2D(1, kernel, padding="valid",  activation='sigmoid')(bn5)
    return decoded

Does it seem correct to you?

Upvotes: 1


as you can see in reshape_out (Lambda) you have shape 1, so by doing operations like UpSampling2 you can just get to size 2, 4, 8, 16, 32. To do your way you may have to do UpSampling2 until size 32, then reshape by doing for instance two Conv2D of kernel=3 to get a result of shape (x, 28, 28, 64)

Upvotes: 2

