Reputation: 1519
I understand that with multi-class, F1 (micro) is the same as Accuracy. I aim to test a binary classification in Torch Lightning but always get identical F1, and Accuracy.
To get more detail, I shared my code at GIST, where I used the MUTAG dataset. Below are some important parts I would like to bring up for discussion
The function where I compute Accuracy and F1 (line #28-40)
def evaluate(self, batch, stage=None):
y_hat = self(batch.x, batch.edge_index, batch.batch)
loss = self.criterion(y_hat, batch.y)
preds = torch.argmax(y_hat.softmax(dim=1), dim=1)
acc = accuracy(preds, batch.y)
f1_score = f1(preds, batch.y)
if stage:
self.log(f"{stage}_loss", loss, on_step=True, on_epoch=True, logger=True)
self.log(f"{stage}_acc", acc, on_step=True, on_epoch=True, logger=True)
self.log(f"{stage}_f1", f1_score, on_step=True, on_epoch=True, logger=True)
return loss
To inspect, I put a checkpoint at line #35, and got acc=0.5
, f1_score=0.5
, while prediction
and label
respectively are
preds = tensor([1, 1, 1, 0, 1, 1, 1, 1, 0, 0])
batch.y = tensor([1, 0, 1, 1, 0, 1, 0, 1, 1, 0])
Using these values, I run a notebook to double-check with scikit-learn
from sklearn.metrics import f1_score
y_hat = [1, 1, 1, 0, 1, 1, 1, 1, 0, 0]
y = [1, 0, 1, 1, 0, 1, 0, 1, 1, 0]
f1_score(y_hat, y, average='binary') # got 0.6153846153846153
accuracy_score(y_hat, y) # 0.5
I obtained a different result compared to evaluation's code. Besides, I verified again with torch
, interestingly, I got a correct result
from torchmetrics.functional import accuracy, f1
import torch
f1(torch.Tensor(y_hat), torch.LongTensor(y)) # tensor(0.6154)
accuracy(torch.Tensor(pred), torch.LongTensor(true)) # tensor(0.5000)
I guess somehow the torch-lightning
treats my calculation as a multiclass task. My question is how to correct its behavior?
Upvotes: 0
Views: 950
Reputation: 1240
You can pass multiclass=False
in case your dataset is binary.
This will give you the result which matches the Sklearn F1 score output where average="binary"
(default) is passed.
from sklearn.metrics import f1_score, accuracy_score
y_hat = [1, 1, 1, 0, 1, 1, 1, 1, 0, 0]
y = [1, 0, 1, 1, 0, 1, 0, 1, 1, 0]
print("Binary f1: ", f1_score(y, y_hat, average="binary")) # default
print("Micro f1:", f1_score(y, y_hat, average="micro")) # this is same as accuracy
print("accuracy_score", accuracy_score(y, y_hat))
>>> Binary f1: 0.6153846153846153
>>> Micro f1: 0.5
>>> accuracy_score: 0.5
import torchmetrics.functional as F
import torch
y_hat = [1, 1, 1, 0, 1, 1, 1, 1, 0, 0]
y = [1, 0, 1, 1, 0, 1, 0, 1, 1, 0]
# torchmetrics
print("Non-Multiclass f1: ", F.f1_score(torch.tensor(y), torch.tensor(y_hat), multiclass=False))
print("Multiclass f1:", F.f1_score(torch.tensor(y), torch.tensor(y_hat))) # same as accuracy
print("accuracy_score", F.accuracy(torch.tensor(y), torch.tensor(y_hat)))
>>> Non-Multiclass f1: tensor(0.6154)
>>> Multiclass f1: tensor(0.5000)
>>> accuracy_score tensor(0.5000)
Upvotes: 1