user3443033
user3443033

Reputation: 779

Keras multiple input, output, loss model

I am working on super-resolution GAN and having some doubts about the code I found on Github. In particular, I have multiple inputs, multiple outputs in the model. Also, I have two different loss functions.

In the following code will the mse loss be applied to img_hr and fake_features?

# Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = self.build_generator()

        # High res. and low res. images
        img_hr = Input(shape=self.hr_shape)
        img_lr = Input(shape=self.lr_shape)

        # Generate high res. version from low res.
        fake_hr = self.generator(img_lr)

        # Extract image features of the generated img
        fake_features = self.vgg(fake_hr)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

# Discriminator determines validity of generated high res. images
        validity = self.discriminator(fake_hr)

        self.combined = Model([img_lr, img_hr], [validity, fake_features])
        self.combined.compile(loss=['binary_crossentropy', 'mse'],
                              loss_weights=[1e-3, 1],
                              optimizer=optimizer)

Upvotes: 1

Views: 1207

Answers (2)

ben
ben

Reputation: 1390

In neural networks Loss is applied to the Outputs of a network in order to have a way of measurement of "How wrong is this output?" so you can take this value and minimize it via Gradient decent and backprop. Following this Intuition the Losses in keras are a List with the same length as the Outputs of your model. They are appied to the Output with the same index.

self.combined = Model([img_lr, img_hr], [validity, fake_features])

This gives you a model with 2 Inputs (img_lr, img_hr) and 2 outputs (validity, fake_features). So combined.compile(loss=['binary_crossentropy', 'mse']... uses binary_crossentropy loss for validity and Mean Squared Error for fake_features.

Upvotes: 0

Manoj Mohan
Manoj Mohan

Reputation: 6034

In the following code will the mse loss be applied to img_hr and fake_features?

From the documentation, https://keras.io/models/model/#compile

"If the model has multiple outputs, you can use a different loss on each output by passing a dictionary or a list of losses."

In this case, the mse loss will be applied to fake_features and the corresponding y_true passed as part of self.combined.fit().

Upvotes: 1

Related Questions