Reputation: 1
I am working with Skorch and using GridSearchCV to perform a grid search. However, I have concerns about what would happen if an unexpected event, such as a system failure or interruption, were to occur during the search. In such cases, I would like to save the model's progression and resume the grid search from where I left off.
I have attempted to utilize the checkpoint callback in Skorch for this purpose. However, I am unsure of the correct approach to properly save and load the model's state in Skorch. Can anyone provide a comprehensive example or guide me on achieving this?
Upvotes: 0
Views: 166
Reputation: 33147
Have you checked/tried to use the Checkpoint callback to save and load the model's state during a grid search?
Here is a simple example since you did not provide any code:
from skorch.callbacks import Checkpoint
from skorch import NeuralNetClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import GridSearchCV
from torch import nn
# simple neural network classifier
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(20, 2)
def forward(self, x):
return self.fc(x)
# Skorch NeuralNetClassifier
net = NeuralNetClassifier(
Net,
max_epochs=10,
lr=0.1,
callbacks=[Checkpoint(monitor='valid_acc_best', f_params='best_model.pt')],
)
# fake data
X, y = make_classification(n_samples=100, n_features=20, random_state=42)
# grid search params
param_grid = {
'lr': [0.1, 0.01, 0.001],
'module__hidden_units': [10, 20, 30],
}
# Here we can use Checkpoint callback to monitor the search
gs = GridSearchCV(net, param_grid, scoring='accuracy', cv=3, refit=True)
gs.fit(X, y)
To load the saved model use:
best_model = Net()
best_model.load_state_dict(torch.load('best_model.pt'))
Upvotes: 0