djvaroli
djvaroli

Reputation: 1423

Tensorflow 1.15 / Keras 2.3.1 Model.train_on_batch() returns more values than there are outputs/loss functions

I am trying to train a model that has more than one output and as a result, also has more than one loss function attached to it when I compile it.

I haven't done something similar in the past (not from scratch at least).

Here's some code I am using to figure out how this works.

from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model

batch_size = 50
input_size = 10

i = Input(shape=(input_size,))
x = Dense(100)(i)
x_1 = Dense(output_size)(x)
x_2 = Dense(output_size)(x)

model = Model(i, [x_1, x_2])

model.compile(optimizer = 'adam', loss = ["mse", "mse"])

# Data creation
x = np.random.random_sample([batch_size, input_size]).astype('float32')
y = np.random.random_sample([batch_size, output_size]).astype('float32')

loss = model.train_on_batch(x, [y,y])

print(loss) # sample output [0.8311912, 0.3519104, 0.47928077]

I would expect the variable loss to have two entries (one for each loss function), however, I get back three. I thought maybe one of them is the weighted average but that does not look to be the case.

Could anyone explain how passing in multiple loss functions works, because obviously, I am misunderstanding something.

Upvotes: 3

Views: 1532

Answers (2)

Dr. Snoopy
Dr. Snoopy

Reputation: 56417

Your assumption that there should be two losses in incorrect. You have a model with two outputs, and you specified one loss for each output, but the model has to be trained on a single loss, so Keras trains the model on a new loss that is the sum of the per-output losses.

You can control how these losses are mixed using the loss_weights parameter in model.compile. I think by default it takes weights values equal to 1.0.

So in the end what train_on_batch returns is the loss, output one mse, and output two mse. That is why you get three values.

Upvotes: 1

josephkibe
josephkibe

Reputation: 1343

I believe the three outputs are the sum of all the losses, followed by the individual losses on each output.

For example, if you look at the sample output you've printed there:

0.3519104 + 0.47928077 = 0.83119117 ≈ 0.8311912

Upvotes: 1

Related Questions