Reputation: 1511
Could someone explain to me where I can find out what the output of pytorch lightning prediction tensors mean?
I have this code:
#Predicting
path = analysis.best_checkpoint + '/' + "ray_ckpt"
model = GraphLevelGNN.load_from_checkpoint(path)
model.eval()
trainer = pl.Trainer()
test_result = trainer.test(model, graph_test_loader, verbose=False)
print(test_result)
##[{'test_acc': 0.65625, 'test_f1': 0.7678904428904428, 'test_precision': 1.0, 'test_recall': 0.65625}]
predictions = trainer.predict(model, graph_test_loader)
print(predictions)
And it prints:
[(tensor(0.7582), tensor(0.5000), 0.6666666666666666, 1.0, 0.5), (tensor(0.4276), tensor(0.7500), 0.8571428571428571, 1.0, 0.75), (tensor(0.4436), tensor(0.7500), 0.8571428571428571, 1.0, 0.75), (tensor(0.2545), tensor(1.), 1.0, 1.0, 1.0), (tensor(1.0004), tensor(0.3750), 0.5454545454545454, 1.0, 0.375)]
But I can't seem to understand what these numbers mean? Can someone explain how to get more info?
Upvotes: 0
Views: 744
Reputation: 989
Well in a simple summary its the forward pass that we can define with a prediction step
import pytorch_lightning as pl
class LitModel(pl.LightningModule):
def forward(self, inputs):
return self.base_model(inputs)
# Overwrite the predict step
def predict_step(self, batch, batch_idx):
return self(batch)
model = LitModel()
trainer = pl.Trainer()
trainer.predict(model, data) # note data is a dataloader
for a deeper explanation read this: output prediction of pytorch lightning model
Upvotes: 1