Sheetal Nagar
Sheetal Nagar

Reputation: 65

Does following code give recall for multiclass classification in Keras?

Does following code give recall for multiclass classification in Keras? even though I am not passing y_true and y_pred while calling recall function in model.compile, it showed me result of the recall.

def recall(y_true, y_pred):
    y_true = K.ones_like(y_true) 
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    all_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    
    recall = true_positives / (all_positives + K.epsilon())
    return recall

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=[recall])

Upvotes: 2

Views: 150

Answers (1)

ClaudiaR
ClaudiaR

Reputation: 3424

Yes, it works because recall is called multiple times under the hood inside model.fit() specifying those values.

It works in a way similar (more complex and optimized) to this:

accuracy = tf.keras.metrics.CategoricalAccuracy()
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for step, (x, y) in enumerate(dataset):
    with tf.GradientTape() as tape:
        logits = model(x)
        # Compute the loss value for this batch.
        loss_value = loss_fn(y, logits)

    # Update the state of the `accuracy` metric.
    accuracy.update_state(y, logits)

    # Update the weights of the model to minimize the loss value.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

    # Logging the current accuracy value so far.
    if step % 100 == 0:
        print('Step:', step)        
        print('Total running accuracy so far: %.3f' % accuracy.result())

This is called a Gradient Tape, and it can be used to perform a customized train loop. Basically it exposes the gradients computed on the trainable tensors of your model. It lets you update the weights of the model manually, so it is really useful for personalization. All this stuff is also done automatically inside model.fit(). You don't need this, it is just to explain how things work.

As you can see, at every batch of the dataset are computed the predictions, that is, the logits. The logits and the ground truth, that is the correct y values, are given as arguments to accuracy.update_state, just as it is done without you seeing it inside model.fit(). Even the order is the same, y_true and y are both the ground truth, and y_pred and logits the predictions.

I hope this has made things clearer.

Upvotes: 1

Related Questions