Reputation: 1325
Autoencoder
and tried to implement a simple one.mnist
digits dataset and plot the digits before the Autoencoder
and after it.What am I missing ? Why the examples on the web shows good results and when I test it, I'm getting different results ?
import tensorflow as tf
from tensorflow.python.keras.layers import Input, Dense
from tensorflow.python.keras.models import Model, Sequential
from tensorflow.python.keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
# Build models
hiden_size = 784 # After It didn't work for 32 , I have tried 784 which didn't improve results
input_layer = Input(shape=(784,))
decoder_input_layer = Input(shape=(hiden_size,))
hidden_layer = Dense(hiden_size, activation="relu", name="hidden1")
autoencoder_output_layer = Dense(784, activation="sigmoid", name="output")
autoencoder = Sequential()
autoencoder.add(input_layer)
autoencoder.add(hidden_layer)
autoencoder.add(autoencoder_output_layer)
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy')
encoder = Sequential()
encoder.add(input_layer)
encoder.add(hidden_layer)
decoder = Sequential()
decoder.add(decoder_input_layer)
decoder.add(autoencoder_output_layer)
#
# Prepare Input
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
#
# Fit & Predict
autoencoder.fit(x_train, x_train,
epochs=50,
batch_size=256,
validation_data=(x_test, x_test),
verbose=1)
encoded_imgs = encoder.predict(x_test)
decoded_imgs = decoder.predict(encoded_imgs)
#
# Show results
n = 10 # how many digits we will display
plt.figure(figsize=(20, 4))
for i in range(n):
# display original
ax = plt.subplot(2, n, i + 1)
plt.imshow(x_test[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# display reconstruction
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
Upvotes: 1
Views: 317