Bryan
Bryan

Reputation: 1519

How to correct when Accuracy equals F1 in Torch Lightning for binary classification?

I understand that with multi-class, F1 (micro) is the same as Accuracy. I aim to test a binary classification in Torch Lightning but always get identical F1, and Accuracy.

To get more detail, I shared my code at GIST, where I used the MUTAG dataset. Below are some important parts I would like to bring up for discussion

The function where I compute Accuracy and F1 (line #28-40)

def evaluate(self, batch, stage=None):
        y_hat = self(batch.x, batch.edge_index, batch.batch)
        loss = self.criterion(y_hat, batch.y)
        preds = torch.argmax(y_hat.softmax(dim=1), dim=1)
        acc = accuracy(preds, batch.y)
        f1_score = f1(preds, batch.y)

        if stage:
            self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True, logger=True)
            self.log(f"{stage}_acc", acc, on_step=True, on_epoch=True, logger=True)
            self.log(f"{stage}_f1", f1_score, on_step=True, on_epoch=True, logger=True)

        return loss

To inspect, I put a checkpoint at line #35, and got acc=0.5, f1_score=0.5, while prediction and label respectively are

preds = tensor([1, 1, 1, 0, 1, 1, 1, 1, 0, 0])
batch.y = tensor([1, 0, 1, 1, 0, 1, 0, 1, 1, 0])

Using these values, I run a notebook to double-check with scikit-learn

from sklearn.metrics import f1_score
y_hat = [1, 1, 1, 0, 1, 1, 1, 1, 0, 0]
y = [1, 0, 1, 1, 0, 1, 0, 1, 1, 0]
f1_score(y_hat, y, average='binary') # got 0.6153846153846153
accuracy_score(y_hat, y) # 0.5

I obtained a different result compared to evaluation's code. Besides, I verified again with torch, interestingly, I got a correct result

from torchmetrics.functional import accuracy, f1
import torch
f1(torch.Tensor(y_hat), torch.LongTensor(y)) # tensor(0.6154)
accuracy(torch.Tensor(pred), torch.LongTensor(true)) # tensor(0.5000)

I guess somehow the torch-lightning treats my calculation as a multiclass task. My question is how to correct its behavior?

Upvotes: 0

Views: 950

Answers (1)

Rafay Khan
Rafay Khan

Reputation: 1240

You can pass multiclass=False in case your dataset is binary.

This will give you the result which matches the Sklearn F1 score output where average="binary" (default) is passed.

We can set multiclass=False to treat the inputs as binary - which is the same as converting the predictions to float beforehand.

Sklearn results

from sklearn.metrics import f1_score, accuracy_score

y_hat = [1, 1, 1, 0, 1, 1, 1, 1, 0, 0]
y =     [1, 0, 1, 1, 0, 1, 0, 1, 1, 0]


print("Binary f1: ", f1_score(y, y_hat, average="binary"))  # default
print("Micro f1:", f1_score(y, y_hat, average="micro"))  # this is same as accuracy
print("accuracy_score", accuracy_score(y, y_hat))

>>> Binary f1: 0.6153846153846153
>>> Micro f1: 0.5
>>> accuracy_score: 0.5

Pytorch-Lightning Results

import torchmetrics.functional as F
import torch

y_hat = [1, 1, 1, 0, 1, 1, 1, 1, 0, 0]
y =     [1, 0, 1, 1, 0, 1, 0, 1, 1, 0]

# torchmetrics
print("Non-Multiclass f1: ", F.f1_score(torch.tensor(y), torch.tensor(y_hat), multiclass=False))
print("Multiclass f1:", F.f1_score(torch.tensor(y), torch.tensor(y_hat))) # same as accuracy
print("accuracy_score", F.accuracy(torch.tensor(y), torch.tensor(y_hat)))

>>> Non-Multiclass f1: tensor(0.6154)
>>> Multiclass f1: tensor(0.5000)
>>> accuracy_score tensor(0.5000)

Upvotes: 1

Related Questions