Slowat_Kela
Slowat_Kela

Reputation: 1511

UserWarning: Trying to infer the `batch_size` from an ambiguous collection

I have a pytorch lightning module like this:

class GraphLevelGNN(pl.LightningModule):


    def __init__(self,**model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()

        self.model = GraphGNNModel(**model_kwargs)
        self.loss_module = nn.BCEWithLogitsLoss() 
        

    def forward(self, data, mode="train"):
        x, edge_index, batch_idx = data.x, data.edge_index, data.batch
        x = self.model(x.cpu(), edge_index.cpu(), batch_idx.cpu())
        x = x.squeeze(dim=-1)
        
        if self.hparams.c_out == 1:
            preds = (x > 0).float().cpu()
            data.y = data.y.float().cpu()
        else:
            preds = x.argmax(dim=-1).cpu()

        loss = self.loss_module(x.cpu(), data.y.cpu())
        acc = (preds.cpu() == data.y.cpu()).sum().float() / preds.shape[0]
        f1 = f1_score(preds.cpu(),data.y.cpu())  ##change f1/precision and recall was just testing
        precision = precision_score(preds.cpu(),data.y.cpu())
        recall = recall_score(preds.cpu(),data.y.cpu())

        return loss, acc, f1,precision, recall,preds

    def configure_optimizers(self):
        optimizer = optim.SGD(self.parameters(),lr=0.1) # High lr because of small dataset and small model
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc, _,_,_,_ = self.forward(batch, mode="train")
        self.log('train_loss', loss,on_epoch=True,logger=True)
        self.log('train_acc', acc,on_epoch=True,logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc, _,_,_,_ = self.forward(batch, mode="val")
        self.log('val_acc', acc,on_epoch=True,logger=True)
        self.log('val_loss', loss,on_epoch=True,logger=True)

    def test_step(self, batch, batch_idx):
        loss,acc, f1,precision, recall,preds = self.forward(batch, mode="test")
        self.log('test_acc', acc,on_epoch=True,logger=True)
        self.log('test_f1', f1,on_epoch=True,logger=True)
        self.log('test_precision', precision,on_epoch=True,logger=True)       
        self.log('test_recall', recall,on_epoch=True,logger=True)

When I run the code, I get a warning:

(train_fn pid=404034) /opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/data.py:99: UserWarning: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 127. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.

I'm not clear, which function am I meant to add the extra self.log to?

Upvotes: 2

Views: 3719

Answers (1)

Ivan
Ivan

Reputation: 40648

This warning means PyTorch Lightning has trouble inferring the batch size of your training perhaps because the batch contains different element types with varying amounts of elements inside them. To make sure it uses the correct batch_size for loss and metric computation. You can specify it yourself as described on the warning message. By setting the batch_size argument on each log call, e.g.

self.log('train_acc', acc, on_epoch=True, logger=True, batch_size=batch_size)

Upvotes: 4

Related Questions