muon
muon

Reputation: 14037

pytorch predictions stability

This is my predict function. is there anything wrong with this? Predictions are not stable, everytime I run on same data, I get different predictions.


def predict(model, device, inputs, batch_size=1024):
    model = model.to(device)
    dataset = torch.utils.data.TensorDataset(*inputs)
    loader = torch.utils.data.DataLoader(
                    dataset, 
                    batch_size=batch_size,
                    pin_memory=False
                )

    predictions = []

    for i, batch in enumerate(loader):
        with torch.no_grad():
            pred = model(*(item.to(device) for item in batch))
            pred = pred.detach().cpu().numpy()
        predictions.append(pred)
    return np.concatenate(predictions)

Upvotes: 1

Views: 1403

Answers (2)

Shai
Shai

Reputation: 114796

As Usman Ali suggested, you need to set your model to eval mode by calling

model.eval()

before your prediction function.

What eval mode does:

Sets the module in evaluation mode.

This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected, e.g. Dropout, BatchNorm, etc.

When you finish your prediction and wish t continue training, don't forget to reset your model to training mode by calling

model.train()

There are several layers in models that may introduce randomness into the forward pass of the net. One such example is the dropout layers. A dropout layer "drops" p percent of its neurons at random to increase model's generalization.
Additionally, BatchNorm (and possibly other adaptive normalization layers) keeps track of the statistics of the data and therefore has a different "behavior" in train mode or in eval mode.

Upvotes: 2

justSomeKid
justSomeKid

Reputation: 1

You have defined the function, but you haven't trained the model. The model randomizes predictions before it is trained, which is why yours are inconsistent. If you set up an optimizer with a loss function, and run over multiple epochs the predictions will stabilize. This link may help: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html. Look at sections 3 and 4

Upvotes: 0

Related Questions