Arash Rafiee
Arash Rafiee

Reputation: 1

Saving and Resuming Skorch GridSearchCV in the Event of Interruption

I am working with Skorch and using GridSearchCV to perform a grid search. However, I have concerns about what would happen if an unexpected event, such as a system failure or interruption, were to occur during the search. In such cases, I would like to save the model's progression and resume the grid search from where I left off.

I have attempted to utilize the checkpoint callback in Skorch for this purpose. However, I am unsure of the correct approach to properly save and load the model's state in Skorch. Can anyone provide a comprehensive example or guide me on achieving this?

Upvotes: 0

Views: 166

Answers (1)

seralouk
seralouk

Reputation: 33147

Have you checked/tried to use the Checkpoint callback to save and load the model's state during a grid search?

Here is a simple example since you did not provide any code:

from skorch.callbacks import Checkpoint
from skorch import NeuralNetClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from torch import nn

# simple neural network classifier
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc = nn.Linear(20, 2)

    def forward(self, x):
        return self.fc(x)

# Skorch NeuralNetClassifier
net = NeuralNetClassifier(
    Net,
    max_epochs=10,
    lr=0.1,
    callbacks=[Checkpoint(monitor='valid_acc_best', f_params='best_model.pt')],
)

# fake data
X, y = make_classification(n_samples=100, n_features=20, random_state=42)

# grid search params
param_grid = {
    'lr': [0.1, 0.01, 0.001],
    'module__hidden_units': [10, 20, 30],
}

# Here we can use Checkpoint callback to monitor the search
gs = GridSearchCV(net, param_grid, scoring='accuracy', cv=3, refit=True)
gs.fit(X, y)

To load the saved model use:

best_model = Net()
best_model.load_state_dict(torch.load('best_model.pt'))

Upvotes: 0

Related Questions