Reputation: 1511
I have a pytorch lightning module like this:
class GraphLevelGNN(pl.LightningModule):
def __init__(self,**model_kwargs):
super().__init__()
# Saving hyperparameters
self.save_hyperparameters()
self.model = GraphGNNModel(**model_kwargs)
self.loss_module = nn.BCEWithLogitsLoss()
def forward(self, data, mode="train"):
x, edge_index, batch_idx = data.x, data.edge_index, data.batch
x = self.model(x.cpu(), edge_index.cpu(), batch_idx.cpu())
x = x.squeeze(dim=-1)
if self.hparams.c_out == 1:
preds = (x > 0).float().cpu()
data.y = data.y.float().cpu()
else:
preds = x.argmax(dim=-1).cpu()
loss = self.loss_module(x.cpu(), data.y.cpu())
acc = (preds.cpu() == data.y.cpu()).sum().float() / preds.shape[0]
f1 = f1_score(preds.cpu(),data.y.cpu()) ##change f1/precision and recall was just testing
precision = precision_score(preds.cpu(),data.y.cpu())
recall = recall_score(preds.cpu(),data.y.cpu())
return loss, acc, f1,precision, recall,preds
def configure_optimizers(self):
optimizer = optim.SGD(self.parameters(),lr=0.1) # High lr because of small dataset and small model
return optimizer
def training_step(self, batch, batch_idx):
loss, acc, _,_,_,_ = self.forward(batch, mode="train")
self.log('train_loss', loss,on_epoch=True,logger=True)
self.log('train_acc', acc,on_epoch=True,logger=True)
return loss
def validation_step(self, batch, batch_idx):
loss, acc, _,_,_,_ = self.forward(batch, mode="val")
self.log('val_acc', acc,on_epoch=True,logger=True)
self.log('val_loss', loss,on_epoch=True,logger=True)
def test_step(self, batch, batch_idx):
loss,acc, f1,precision, recall,preds = self.forward(batch, mode="test")
self.log('test_acc', acc,on_epoch=True,logger=True)
self.log('test_f1', f1,on_epoch=True,logger=True)
self.log('test_precision', precision,on_epoch=True,logger=True)
self.log('test_recall', recall,on_epoch=True,logger=True)
When I run the code, I get a warning:
(train_fn pid=404034) /opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/data.py:99: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 127. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.
I'm not clear, which function am I meant to add the extra self.log
to?
Upvotes: 2
Views: 3719
Reputation: 40648
This warning means PyTorch Lightning has trouble inferring the batch size of your training perhaps because the batch contains different element types with varying amounts of elements inside them. To make sure it uses the correct batch_size
for loss and metric computation. You can specify it yourself as described on the warning message. By setting the batch_size
argument on each log
call, e.g.
self.log('train_acc', acc, on_epoch=True, logger=True, batch_size=batch_size)
Upvotes: 4