Reputation: 4611
I try to train Neural Network model in PyTorch Lightning and training fails on validation step where it executes EarlyStopping callback.
The relevant part of the model is below. See, in particular, validation_step
which must log the metrics necessary for Early stopping.
class DialogActsLightningModel(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.model = ContextAwareDAC(
model_name=self.config['model_name'],
hidden_size=self.config['hidden_size'],
num_classes=self.config['num_classes'],
device=self.config['device']
)
self.tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
def forward(self, batch):
logits = self.model(batch)
return logits
def validation_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
logits = self(batch)
loss = F.cross_entropy(logits, targets)
acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
return {"val_loss": loss, "val_accuracy": torch.tensor([acc]), "val_f1": torch.tensor([f1]),
"val_precision": torch.tensor([precision]), "val_recall": torch.tensor([recall])}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_acc = torch.stack([x['val_accuracy'] for x in outputs]).mean()
avg_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()
avg_precision = torch.stack([x['val_precision'] for x in outputs]).mean()
avg_recall = torch.stack([x['val_recall'] for x in outputs]).mean()
wandb.log({"val_loss": avg_loss, "val_accuracy": avg_acc, "val_f1": avg_f1, "val_precision": avg_precision,
"val_recall": avg_recall})
return {"val_loss": avg_loss, "val_accuracy": avg_acc, "val_f1": avg_f1, "val_precision": avg_precision,
"val_recall": avg_recall}
When I run training in the following way:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
import os
from trainer import DialogActsLightningModel
import wandb
wandb.init()
logger = WandbLogger(
name="model_name",
entity='myname',
save_dir=config["save_dir"],
project=config["project"],
log_model=True,
)
early_stopping = EarlyStopping(
monitor="val_accuracy",
min_delta=0.001,
patience=5,
)
model = DialogActsLightningModel(config=config)
trainer = pl.Trainer(
logger=logger,
gpus=[0],
checkpoint_callback=True,
callbacks=[early_stopping],
default_root_dir=MODELS_DIRECTORY,
max_epochs=config["epochs"],
precision=config["precision"],
limit_train_batches=10, # run for only 10 batches, debug mode
limit_test_batches=10,
limit_val_batches=10
)
trainer.fit(model)
I've got an error, but the model should have calculated and logged the metric "val_accuracy" during validation step.
Epoch 0: 100%
20/20 [00:28<00:00, 1.41s/it, loss=1.95]
/opt/conda/lib/python3.9/site-packages/pytorch_lightning/callbacks/early_stopping.py in _validate_condition_metric(self, logs)
149 if monitor_val is None:
150 if self.strict:
--> 151 raise RuntimeError(error_msg)
152 if self.verbose > 0:
153 rank_zero_warn(error_msg, RuntimeWarning)
RuntimeError: Early stopping conditioned on metric `val_accuracy` which is not available. Pass in or modify your `EarlyStopping` callback to use any of the following: ``
What I am doing wrong? How to fix it?
Upvotes: 1
Views: 8348
Reputation: 193
if you use pytorch-lightning latest version you should want to log the val_accuracy or val_loss while you calling early stopping or similar functions. for more please check out the code below.i think this will definitely helpful for you...
def validation_step(self, batch, batch_idx):
input_ids, attention_mask, targets = batch['input_ids'], batch['attention_mask'], batch['label'].squeeze()
logits = self(batch)
loss = F.cross_entropy(logits, targets)
acc = accuracy_score(targets.cpu(), logits.argmax(dim=1).cpu())
f1 = f1_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
precision = precision_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
recall = recall_score(targets.cpu(), logits.argmax(dim=1).cpu(), average=self.config['average'])
##########################################################################
##########################################################################
self.log("val_accuracy", torch.tensor([acc]) # try this line
##########################################################################
##########################################################################
return {"val_loss": loss, "val_accuracy": torch.tensor([acc]), "val_f1": torch.tensor([f1]),
"val_precision": torch.tensor([precision]), "val_recall": torch.tensor([recall])}
If This Post is Useful Please Up vote
Upvotes: 2