some_name.py
some_name.py

Reputation: 827

Set .trainable variable in GAN implementation in tf.keras properly

I am confused with the .trainable statement of tf.keras.model in the implementation of a GAN.

Given following code snipped (taken from this repo):

class GAN():

    def __init__(self):

        ...

        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)

        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

    def build_generator(self):

        ...

        return Model(noise, img)

    def build_discriminator(self):

        ...

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):

        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

during the definition of the model self.combined the weights of the discriminator are set to self.discriminator.trainable = False but never turned back on.

Still, during the training loop the weights of the discriminator will change for the lines:

# Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

and will stay constant during:

# Train the generator (to have the discriminator label samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)

which I wouldn't expect.

Of course this the correct (iterative) way to train a GAN, but I don't understand why we don't have to pass self.discriminator.trainable = True before we can do some training on the discriminator.

Would be nice If someone has a explanation for that, I guess that is a crucial point to understand.

Upvotes: 4

Views: 882

Answers (1)

o-90
o-90

Reputation: 17603

Its usually a good idea to check the issues (both open and closed) when you have a question about code in a github repo. This issue explains why the flag is set to False. It says,

Since self.discriminator.trainable = False is set after the discriminator is compiled, it will not affect the training of the discriminator. However since it is set before the combined model is compiled the discriminator layers will be frozen when the combined model is trained.

And also talks about freezing keras layers.

Upvotes: 7

Related Questions