Ljackson
Ljackson

Reputation: 13

PyTorch Multi Class Classification using CrossEntropyLoss - not converging

I am trying to get a simple network to output the probability that a number is in one of three classes. These are, smaller than 1.1, between 1.1 and 1.5 and bigger than 1.5. I am using cross entropy loss with class labels of 0, 1 and 2, but cannot solve the problem.

Every time I train, the network outputs the maximum probability for class 2, regardless of input. The lowest loss I seem to be able to achieve is 0.9ish. Any advice on where I am going wrong would be greatly appreciated!! All code is below.

class gating_net(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(gating_net, self).__init__()
        self.linear1 = nn.Linear(input_dim, 32)
        self.linear2 = nn.Linear(32, output_dim)

    def forward(self, x):
        # The original input (action) is used as the residual.
        x = F.relu(self.linear1(x))
        x = F.sigmoid(self.linear2(x))
        return x

learning_rate = 0.01
batch_size = 64
epochs = 500
test = 1

gating_network = gating_net(1,3)

optimizer = torch.optim.SGD(gating_network.parameters(), lr=learning_rate, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20, verbose=True)

for epoch in range (epochs):
    input_ = []
    label_ = []

    for i in range (batch_size):
        scale = random.randint(10,20)/10

        input = scale
        if scale < 1.1:
            label = np.array([0])
        elif 1.1 < scale < 1.5:
            label = np.array([1])
        else:
            label = np.array([2])

        input_.append(np.array([input]))
        label_.append(label)

    optimizer.zero_grad()

    # get output from the model, given the inputs
    output = gating_network.forward(torch.FloatTensor(input_))
    old_label  = torch.FloatTensor(label_)

    # get loss for the predicted output
    loss = nn.CrossEntropyLoss()(output, old_label.squeeze().long())

    # get gradients w.r.t to parameters
    loss.backward()
    # update parameters
    optimizer.step()
    scheduler.step(loss)

    print('epoch {}, loss {}'.format(epoch, loss.item()))

    if loss.item() < 0.01:
        print("########## Solved! ##########")
        torch.save(mod_network.state_dict(), './supervised_learning/run_{}.pth'.format(test))
        break

    # save every 500 episodes
    if epoch % 100 == 0:
        torch.save(gating_network.state_dict(), './run_{}.pth'.format(test))

Upvotes: 1

Views: 2200

Answers (1)

ddoGas
ddoGas

Reputation: 871

  • Your code generates training data every epochs (which is also every batch in this case). This is very redundant, but it doesn't mean the code won't work. However one thing that does influence the training is the imbalance of training data between classes. With your code majority of the training data is always labeled 2. So intuitively, your network will always learn more about class 2. That is why with the very small epoch of 500, the network classified all the classes as 2, because it is a fast and easy way to lower the loss. However when the network can't get the loss much lower by applying the knowledge about label 2, it will learn about 1 and 0 too. So it is possible to train the network, though not so efficient.
  • Continuing from the previous issue, using ReduceLROnPlateau also isn't efficient because at the time the network starts to learn about label 0 and 1, the learning rate is already to small(intuitively speaking). Doesn't mean it is untrainable, but it will probably take a lot of time.
  • CrossEntropyLoss calculates LogSoftmax internally, so having Sigmoid at the end of the network means you have a Softmax layer right after Sigmoid layer, which is probably not what you want. I think the network isn't necessarily 'wrong', but it will be much harder to train.
  • actually scale 1.1 is being labeled 2 because you have <1.1 and >1.1.

TL;DR

Get rid of sigmoid and scheduler. I was able to get Solved! somewhere around 15000 epoch (with learning rate and batch size as same as your code).

Upvotes: 3

Related Questions