Noha Atef
Noha Atef

Reputation: 11

why was the discriminator in the GAN model set as nontrainable?

So I am training a CycleGAN for a face aging task, I am following digital sreeni cycle gan tutorial, he set the generator as trainable, and the corresponding discriminator and other generator are set as nontrainable, so why did we set it as nontrainable? aren't we trying to train the model, like aren't the generator and the discriminator supposed to compete agaist each other and the discriminator should give feedback to the generator then the weights are updated training the model? How will the discriminator learn if it is set as nontranable? this is the composite model where it is set as nontrainable. the passed generator would be the first or second depending on whether we are training the AtoB generator or BtoA one.

def define_composite_model(g_model_1, d_model, g_model_2, image_shape):
    # Make the generator of interest trainable as we will be updating these weights.
    #by keeping other models constant.
    #Remember that we use this same function to train both generators,
    #one generator at a time. 
    g_model_1.trainable = True
    # mark discriminator and second generator as non-trainable
    d_model.trainable = False
    g_model_2.trainable = False
    
    # adversarial loss
    input_gen = Input(shape=image_shape)
    gen1_out = g_model_1(input_gen)
    output_d = d_model(gen1_out)
    # identity loss
    input_id = Input(shape=image_shape)
    output_id = g_model_1(input_id)
    # cycle loss - forward
    output_f = g_model_2(gen1_out)
    # cycle loss - backward
    gen2_out = g_model_2(input_id)
    output_b = g_model_1(gen2_out)
    
    # define model graph
    model = Model([input_gen, input_id], [output_d, output_id, output_f, output_b])
    
    # define the optimizer
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    # compile model with weighting of least squares loss and L1 loss
    model.compile(loss=['mse', 'mae', 'mae', 'mae'], 
                loss_weights=[1, 5, 10, 10], optimizer=opt)
    return model

I tried a pix2pix model both generator and discriminator were trainable. Like what is the difference? These are the discriminators and generators loss in the first iteration: Iteration>1, dA[5.195,3.602] dB[2.547,2.116] g[18.355,19.853] like is that good or bad? also it is training realy slow, should it be that slow like more than 10 seconds each iteration?

Upvotes: 1

Views: 47

Answers (0)

Related Questions