Slowat_Kela
Slowat_Kela

Reputation: 1511

What do the output of pytorch lightning predicts mean?

Could someone explain to me where I can find out what the output of pytorch lightning prediction tensors mean?

I have this code:

#Predicting
path = analysis.best_checkpoint + '/' + "ray_ckpt"

model = GraphLevelGNN.load_from_checkpoint(path)
model.eval() 

trainer = pl.Trainer()
test_result = trainer.test(model, graph_test_loader, verbose=False)

print(test_result)
##[{'test_acc': 0.65625, 'test_f1': 0.7678904428904428, 'test_precision': 1.0, 'test_recall': 0.65625}]

predictions = trainer.predict(model, graph_test_loader)

print(predictions)

And it prints:

[(tensor(0.7582), tensor(0.5000), 0.6666666666666666, 1.0, 0.5), (tensor(0.4276), tensor(0.7500), 0.8571428571428571, 1.0, 0.75), (tensor(0.4436), tensor(0.7500), 0.8571428571428571, 1.0, 0.75), (tensor(0.2545), tensor(1.), 1.0, 1.0, 1.0), (tensor(1.0004), tensor(0.3750), 0.5454545454545454, 1.0, 0.375)]

But I can't seem to understand what these numbers mean? Can someone explain how to get more info?

Upvotes: 0

Views: 744

Answers (1)

Edwin Cheong
Edwin Cheong

Reputation: 989

Well in a simple summary its the forward pass that we can define with a prediction step

import pytorch_lightning as pl

class LitModel(pl.LightningModule):
   def forward(self, inputs):
       return self.base_model(inputs)
   
   # Overwrite the predict step
   def predict_step(self, batch, batch_idx):
      return self(batch)

model = LitModel()
trainer = pl.Trainer()
trainer.predict(model, data) # note data is a dataloader

for a deeper explanation read this: output prediction of pytorch lightning model

Upvotes: 1

Related Questions