Steven Aronson
Steven Aronson

Reputation: 31

Skorch: Help constructing classifier for multiple outputs

I am attempting to learn skorch by translating a simple pytorch model that predicts the 2 digits contained in a set of MNIST multi digit pictures. These pictures contain 2 overlapping digits which are the output lables (y). I am getting the following error:

ValueError: Stratified CV requires explicitely passing a suitable y

I followed the "MNIST with SciKit-Learn and skorch" notebook AND applied the multiple output fixes outlined in "Multiple return values from forward" by creating a custom get_loss function.

Data dimensions are:

Code:

class Flatten(nn.Module):
    """A custom layer that views an input as 1D."""

    def forward(self, input):
        return input.view(input.size(0), -1)


class CNN(nn.Module):

    def __init__(self):
        super(CNN, self).__init__()

        self.conv1 = nn.Conv2d(1, 32, 3)
        self.pool1 = nn.MaxPool2d((2, 2))
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool2 = nn.MaxPool2d((2, 2))
        self.flatten = Flatten()
        self.fc1 = nn.Linear(2880, 64)
        self.drop1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(64, 10)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.drop1(x)
        out_first_digit = self.fc2(x)
        out_second_digit = self.fc3(x)

        return out_first_digit, out_second_digit


torch.manual_seed(0)

class CNN_net(NeuralNetClassifier):
    def get_loss(self, y_pred, y_true, *args, **kwargs):

        loss1 = F.cross_entropy(y_pred[0], y_true[:,0])
        loss2 = F.cross_entropy(y_pred[1], y_true[:,1])

        return 0.5 * (loss1 + loss2)

net = CNN_net(
    CNN,
    max_epochs=5,
    lr=0.1,
    device=device,
)

net.fit(X_train, y_train);
  1. Do I need to modify the format of y?
  2. Do I need to construct additional custom functions (predict)?
  3. Any other suggestions?

Upvotes: 3

Views: 1048

Answers (1)

nemo
nemo

Reputation: 57729

skorch's NeuralNetClassifier applies a stratified cross-validation split by default to provide you with metrics such as validation accuracy during training. Of course this makes it necessary that your data can be split that way. Since you have two labels for each image there is no trivial way to do a stratified split (although there are ways).

Two solutions come to mind:

  1. disable the train split altogether (pass train_split=None) and lose validation during training
  2. change the train split to non-stratified by passing train_split=skorch.dataset.CVSplit(5, stratified=False)

Since I guess that you want validation metrics during training your final code should look like this:

net = CNN_net(
    CNN,
    max_epochs=5,
    lr=0.1,
    device=device,
    train_split=skorch.dataset.CVSplit(5, stratified=False),
)

net.fit(X_train, y_train);

Upvotes: 4

Related Questions