Boom
Boom

Reputation: 1325

Autoencoder give wrong results (Not as shown in basic examples)

What am I missing ? Why the examples on the web shows good results and when I test it, I'm getting different results ?

import tensorflow as tf
from tensorflow.python.keras.layers import Input, Dense
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt

#   Build models
hiden_size = 784 # After It didn't work for 32 , I have tried 784 which didn't improve results
input_layer = Input(shape=(784,))
decoder_input_layer = Input(shape=(hiden_size,))
hidden_layer = Dense(hiden_size, activation="relu", name="hidden1")
autoencoder_output_layer = Dense(784, activation="sigmoid", name="output")

autoencoder = Sequential()
autoencoder.add(input_layer)
autoencoder.add(hidden_layer)
autoencoder.add(autoencoder_output_layer)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')

encoder = Sequential()
encoder.add(input_layer)
encoder.add(hidden_layer)

decoder = Sequential()
decoder.add(decoder_input_layer)
decoder.add(autoencoder_output_layer)

#
#   Prepare Input
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))

#
# Fit & Predict
autoencoder.fit(x_train, x_train,
                epochs=50,
                batch_size=256,
                validation_data=(x_test, x_test),
                verbose=1)

encoded_imgs = encoder.predict(x_test)
decoded_imgs = decoder.predict(encoded_imgs)

#
# Show results
n = 10  # how many digits we will display
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(x_test[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(decoded_imgs[i].reshape(28, 28))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

Results: enter image description here

Upvotes: 1

Views: 317

Answers (1)

user3668129
user3668129

Reputation: 4820

Try to change the optimizer. I changed it to adam and got:

enter image description here

Upvotes: 1

Related Questions