Reputation: 196
If I build the decoder as a mirror of encoder the output size of the last layer does not match.
This is the model summary:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
conv_1_j (Conv2D) (None, 28, 28, 64) 640
_________________________________________________________________
batch_normalization_v2 (Batc (None, 28, 28, 64) 256
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 64) 0
_________________________________________________________________
conv_2_j (Conv2D) (None, 14, 14, 64) 36928
_________________________________________________________________
batch_normalization_v2_1 (Ba (None, 14, 14, 64) 256
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 64) 0
_________________________________________________________________
conv_3_j (Conv2D) (None, 7, 7, 64) 36928
_________________________________________________________________
batch_normalization_v2_2 (Ba (None, 7, 7, 64) 256
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 3, 3, 64) 0
_________________________________________________________________
conv_4_j (Conv2D) (None, 3, 3, 64) 36928
_________________________________________________________________
batch_normalization_v2_3 (Ba (None, 3, 3, 64) 256
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 (None, 1, 1, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 64) 0
_________________________________________________________________
dense_1_j (Dense) (None, 64) 4160
_________________________________________________________________
reshape_out (Lambda) (None, 1, 1, 64) 0
_________________________________________________________________
conv2d (Conv2D) (None, 1, 1, 64) 36928
_________________________________________________________________
batch_normalization_v2_4 (Ba (None, 1, 1, 64) 256
_________________________________________________________________
up_sampling2d (UpSampling2D) (None, 2, 2, 64) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 2, 2, 64) 36928
_________________________________________________________________
batch_normalization_v2_5 (Ba (None, 2, 2, 64) 256
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 4, 4, 64) 0
_________________________________________________________________
conv2d_2 (Conv2D) (None, 4, 4, 64) 36928
_________________________________________________________________
batch_normalization_v2_6 (Ba (None, 4, 4, 64) 256
_________________________________________________________________
up_sampling2d_2 (UpSampling2 (None, 8, 8, 64) 0
_________________________________________________________________
conv2d_3 (Conv2D) (None, 8, 8, 64) 36928
_________________________________________________________________
batch_normalization_v2_7 (Ba (None, 8, 8, 64) 256
_________________________________________________________________
up_sampling2d_3 (UpSampling2 (None, 16, 16, 64) 0
_________________________________________________________________
conv2d_4 (Conv2D) (None, 16, 16, 1) 577
=================================================================
Total params: 265,921
Trainable params: 264,897
Non-trainable params: 1,024
_________________________________________________________________
Code to reproduce:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from tensorflow.python.keras.layers import Lambda
from tensorflow.python.keras.callbacks import TensorBoard, ModelCheckpoint, EarlyStopping
def resize(example):
image = example['image']
image = tf.image.resize(image, [28, 28])
image = tf.image.rgb_to_grayscale(image, )
image = image / 255
example['image'] = image
return example
def get_tupel(example):
return example['image'], example['image']
def gen_dataset(dataset, batch_size):
dataset = dataset.map(resize, num_parallel_calls=4)
dataset = dataset.map(get_tupel, num_parallel_calls=4)
dataset = dataset.shuffle(batch_size*50).repeat() # infinite stream
dataset = dataset.prefetch(10000)
dataset = dataset.batch(batch_size)
return dataset
def main():
builder = tfds.builder("cifar10")
builder.download_and_prepare()
datasets = builder.as_dataset()
train_dataset, test_dataset = datasets['train'], datasets['test']
batch_size = 48
train_dataset = gen_dataset(train_dataset, batch_size)
test_dataset = gen_dataset(test_dataset, batch_size)
device = '/cpu:0' if not tf.test.is_gpu_available() else tf.test.gpu_device_name()
print(tf.test.gpu_device_name())
with tf.device(device):
filters = 64
kernel = 3
pooling = 2
image_size = 28
inp_layer = tf.keras.layers.Input(shape=(image_size, image_size, 1))
cnn_embedding_out = cnn_encoder(inp_layer, filters, kernel, pooling)
cnn_decoder_out = cnn_decoder(cnn_embedding_out, filters, kernel, pooling)
model = tf.keras.Model(inputs=inp_layer, outputs=cnn_decoder_out)
model.compile(optimizer=tf.optimizers.Adam(0.0001), loss='binary_crossentropy',
metrics=['accuracy'])
print(model.summary())
model.fit(train_dataset, validation_data=test_dataset,
steps_per_epoch=100, # 1000
validation_steps=100,
epochs=150,)
def cnn_encoder(inp_layer, filters, kernel, pooling):
cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_1_j')(inp_layer)
bn1 = tf.keras.layers.BatchNormalization()(cnn1)
max1 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn1)
cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_2_j')(max1)
bn2 = tf.keras.layers.BatchNormalization()(cnn2)
max2 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn2)
cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_3_j')(max2)
bn3 = tf.keras.layers.BatchNormalization()(cnn3)
max3 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn3)
cnn4 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu', name='conv_4_j')(max3)
bn4 = tf.keras.layers.BatchNormalization()(cnn4)
max4 = tf.keras.layers.MaxPooling2D(pooling, pooling, padding="valid")(bn4)
flat = tf.keras.layers.Flatten()(max4)
fc = tf.keras.layers.Dense(64, name='dense_1_j')(flat) # this is the encoder layer!
return fc
def cnn_decoder(inp_layer, filters, kernel, pooling):
res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
bn1 = tf.keras.layers.BatchNormalization()(cnn1)
up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
bn2 = tf.keras.layers.BatchNormalization()(cnn2)
up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
bn3 = tf.keras.layers.BatchNormalization()(cnn3)
up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
cnn4 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up3)
bn4 = tf.keras.layers.BatchNormalization()(cnn4)
up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
decoded = tf.keras.layers.Conv2D(1, kernel, padding="same", activation='sigmoid')(up4)
return decoded
def reshape(dim, name="", complete=False):
def func(x):
if complete:
ret = tf.reshape(x, dim)
else:
ret = tf.reshape(x, [-1, ] + dim)
return ret
return Lambda(func, name=name)
if __name__ == "__main__":
main()
I tried to use Conv2dTranspose and different size of upsampling, but this doesn't feel right.
I would expect the output as the input (48,28,28,1)
What am I doing wrong?
Upvotes: 0
Views: 1777
Reputation: 196
I changed the cnn_decoder
function:
def cnn_decoder(inp_layer, filters, kernel, pooling):
res1 = reshape([1, 1, filters], name="reshape_out")(inp_layer)
cnn1 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(res1)
bn1 = tf.keras.layers.BatchNormalization()(cnn1)
up1 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn1)
cnn2 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up1)
bn2 = tf.keras.layers.BatchNormalization()(cnn2)
up2 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn2)
cnn3 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up2)
bn3 = tf.keras.layers.BatchNormalization()(cnn3)
up3 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn3)
cnn4 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up3)
bn4 = tf.keras.layers.BatchNormalization()(cnn4)
up4 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn4)
cnn5 = tf.keras.layers.Conv2D(filters, kernel, padding="same", activation='relu',)(up4)
bn5 = tf.keras.layers.BatchNormalization()(cnn5)
up5 = tf.keras.layers.UpSampling2D((pooling, pooling))(bn5)
cnn6 = tf.keras.layers.Conv2D(filters, kernel, padding="valid", activation='relu')(up5)
bn5 = tf.keras.layers.BatchNormalization()(cnn6)
decoded = tf.keras.layers.Conv2D(1, kernel, padding="valid", activation='sigmoid')(bn5)
return decoded
Does it seem correct to you?
Upvotes: 1
Reputation: 521
as you can see in reshape_out (Lambda)
you have shape 1, so by doing operations
like UpSampling2
you can just get to size 2, 4, 8, 16, 32.
To do your way you may have to do UpSampling2
until size 32
, then reshape by doing for instance two Conv2D
of kernel=3
to get a result of shape (x, 28, 28, 64)
Upvotes: 2