Darius
Darius

Reputation: 43

Keras GAN (generator) not training well despite accurate discriminator

I've tried sorting this out for a few days now, following many pieces of advice found on forums etc, and now would welcome any suggestions to what is wrong!

I'm attempting to get my first GAN training - a simple feedforward deep net - very similar to using MNIST dataset, but with spectrum power windows derived from the VCTK-Corpus (size(1, 513)).

You can see from the Tensorboard graphs below that the networks seem to be interacting, and there is some sort of training going on: Tensorboard graph overview. Tensorboard graph zoom.

However, results are poor and noisy: generated and validation comparison

The generator takes normal noise (usually 30 to 100 vectors) with a mean of 0 and stdev of 0.5.

def gan_generator(x_shape, frame_size):
    g_input = Input(shape=x_shape)
    H = BatchNormalization()(g_input)
    H = Dense(128)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(128)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(256)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(256)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(256)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    out = Dense(frame_size[1], activation='linear')(H)

    generator = Model(g_input, out)
    generator.summary()
    return generator

The discriminator determines a one-hot categorisation of generated frames: (not sure about batch normalisation here - I've read it shouldn't be used if you're mixing real and generated into one batch. However, the generator makes much more convincing results with it than without - despite having a higher loss.)

def gan_discriminator(input_shape):
    d_input = Input(shape=input_shape)
    H = Dropout(0.1)(d_input)
    H = Dense(256)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(128)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(100)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Dense(100)(H)
    H = Dropout(0.1)(H)
    H = LeakyReLU()(H)
    H = BatchNormalization()(H)
    H = Reshape((1, -1))(H)
    d_V = Dense(2, activation='softmax')(H)

    discriminator = Model(d_input,d_V)
    discriminator.summary()
    return discriminator

The GAN is easy:

def init_gan(generator, discriminator):
    x = Input(shape=generator.inputs[0].shape[1:])

    #Generator makes a prediction
    pred = generator(x)

    #Discriminator attempts to categorise prediction
    y = discriminator(pred)

    GAN = Model(x, y)
    return GAN

Some training variables:

The training loop:

#Pre-training Discriminator Network
#Load new batch of real frames
frames = load_data(data_dir)
frames_label = np.zeros((frames.shape[0], 1, 2))
frames_label[:, :, 0] = 1 #mark as real frames

#Generate Frames from noise vector
X_noise = noisegen((frames.shape[0], 1, n_noise))
generated_frames = generator.predict(X_noise)
generated_label = np.zeros((generated_frames.shape[0], 1, 2))
generated_label[:, :, 1] = 1 #mark as false frames

#Prep Data - concat real and false data
dis_batch_x = np.concatenate((frames, generated_frames), axis=0)
dis_batch_y = np.concatenate((frames_label, generated_label), axis=0)

#Make discriminator trainable and train for 8 epochs
make_trainable(discriminator, True)
discriminator.compile(optimizer=dis_optimizer, loss=dis_loss)
fit_model(discriminator, dis_batch_x, dis_batch_y, 8)

#Training Loop
for d in range(data_sets):
    print "Starting New Dataset: {0}/{1}".format(d+1, data_sets)

    """ Fit Discriminator """
    #Load new batch of real frames
    frames = load_data(data_dir)
    frames_label = np.zeros((frames.shape[0], 1, 2))
    frames_label[:, :, 0] = 1 #mark as real frames

    #Generate Frames from noise vector
    X_noise = noisegen((frames.shape[0], 1, n_noise))
    generated_frames = generator.predict(X_noise)
    generated_label = np.zeros((generated_frames.shape[0], 1, 2))
    generated_label[:, :, 1] = 1 #mark as false frames

    #Prep Data - concat real and false data
    dis_batch_x = np.concatenate((frames, generated_frames), axis=0)
    dis_batch_y = np.concatenate((frames_label, generated_label), axis=0)

    #Make discriminator trainable & fit
    make_trainable(discriminator, True)
    discriminator.compile(optimizer=dis_optimizer, loss=dis_loss)
    fit_model(discriminator, dis_batch_x, dis_batch_y)


    """ Fit Generator """
    #Prep Data
    X_noise = noisegen((frames.shape[0], 1, n_noise))
    generated_label = np.zeros((generated_frames.shape[0], 1, 2))
    generated_label[:, :, 1] = 1 #mark as false frames

    make_trainable(discriminator, False)
    GAN.layers[2].trainable = False #done twice just to be sure
    GAN.compile(optimizer=GAN_optimizer, loss=GAN_loss) 
    fit_model(GAN, X_noise, generated_label)

And finally a little bit of system info:

Many thanks in advance!

Upvotes: 1

Views: 2053

Answers (1)

Darius
Darius

Reputation: 43

What the solution actually was that I didn't swap my True/False class in Generator training (suggested https://github.com/soumith/ganhacks), which I think effectively makes it gradient ascent.

Clarification on this would be nice to have.

Upvotes: 1

Related Questions