Reputation: 31
I am attempting to learn skorch by translating a simple pytorch model that predicts the 2 digits contained in a set of MNIST multi digit pictures. These pictures contain 2 overlapping digits which are the output lables (y
). I am getting the following error:
ValueError: Stratified CV requires explicitely passing a suitable y
I followed the "MNIST with SciKit-Learn and skorch" notebook AND applied the multiple output fixes outlined in "Multiple return values from forward" by creating a custom get_loss
function.
Data dimensions are:
(40000, 1, 4, 28)
(40000, 2)
Code:
class Flatten(nn.Module):
"""A custom layer that views an input as 1D."""
def forward(self, input):
return input.view(input.size(0), -1)
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3)
self.pool1 = nn.MaxPool2d((2, 2))
self.conv2 = nn.Conv2d(32, 64, 3)
self.pool2 = nn.MaxPool2d((2, 2))
self.flatten = Flatten()
self.fc1 = nn.Linear(2880, 64)
self.drop1 = nn.Dropout(p=0.5)
self.fc2 = nn.Linear(64, 10)
self.fc3 = nn.Linear(64, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.pool1(x)
x = F.relu(self.conv2(x))
x = self.pool2(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.drop1(x)
out_first_digit = self.fc2(x)
out_second_digit = self.fc3(x)
return out_first_digit, out_second_digit
torch.manual_seed(0)
class CNN_net(NeuralNetClassifier):
def get_loss(self, y_pred, y_true, *args, **kwargs):
loss1 = F.cross_entropy(y_pred[0], y_true[:,0])
loss2 = F.cross_entropy(y_pred[1], y_true[:,1])
return 0.5 * (loss1 + loss2)
net = CNN_net(
CNN,
max_epochs=5,
lr=0.1,
device=device,
)
net.fit(X_train, y_train);
Upvotes: 3
Views: 1048
Reputation: 57729
skorch's NeuralNetClassifier
applies a stratified cross-validation split by default to provide you with metrics such as validation accuracy during training. Of course this makes it necessary that your data can be split that way. Since you have two labels for each image there is no trivial way to do a stratified split (although there are ways).
Two solutions come to mind:
train_split=None
) and lose validation during trainingtrain_split=skorch.dataset.CVSplit(5, stratified=False)
Since I guess that you want validation metrics during training your final code should look like this:
net = CNN_net(
CNN,
max_epochs=5,
lr=0.1,
device=device,
train_split=skorch.dataset.CVSplit(5, stratified=False),
)
net.fit(X_train, y_train);
Upvotes: 4