Reputation: 13
I am trying to get a simple network to output the probability that a number is in one of three classes. These are, smaller than 1.1, between 1.1 and 1.5 and bigger than 1.5. I am using cross entropy loss with class labels of 0, 1 and 2, but cannot solve the problem.
Every time I train, the network outputs the maximum probability for class 2, regardless of input. The lowest loss I seem to be able to achieve is 0.9ish. Any advice on where I am going wrong would be greatly appreciated!! All code is below.
class gating_net(nn.Module):
def __init__(self, input_dim, output_dim):
super(gating_net, self).__init__()
self.linear1 = nn.Linear(input_dim, 32)
self.linear2 = nn.Linear(32, output_dim)
def forward(self, x):
# The original input (action) is used as the residual.
x = F.relu(self.linear1(x))
x = F.sigmoid(self.linear2(x))
return x
learning_rate = 0.01
batch_size = 64
epochs = 500
test = 1
gating_network = gating_net(1,3)
optimizer = torch.optim.SGD(gating_network.parameters(), lr=learning_rate, momentum=0.9)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20, verbose=True)
for epoch in range (epochs):
input_ = []
label_ = []
for i in range (batch_size):
scale = random.randint(10,20)/10
input = scale
if scale < 1.1:
label = np.array([0])
elif 1.1 < scale < 1.5:
label = np.array([1])
else:
label = np.array([2])
input_.append(np.array([input]))
label_.append(label)
optimizer.zero_grad()
# get output from the model, given the inputs
output = gating_network.forward(torch.FloatTensor(input_))
old_label = torch.FloatTensor(label_)
# get loss for the predicted output
loss = nn.CrossEntropyLoss()(output, old_label.squeeze().long())
# get gradients w.r.t to parameters
loss.backward()
# update parameters
optimizer.step()
scheduler.step(loss)
print('epoch {}, loss {}'.format(epoch, loss.item()))
if loss.item() < 0.01:
print("########## Solved! ##########")
torch.save(mod_network.state_dict(), './supervised_learning/run_{}.pth'.format(test))
break
# save every 500 episodes
if epoch % 100 == 0:
torch.save(gating_network.state_dict(), './run_{}.pth'.format(test))
Upvotes: 1
Views: 2200
Reputation: 871
2
. So intuitively, your network will always learn more about class 2
. That is why with the very small epoch of 500, the network classified all the classes as 2
, because it is a fast and easy way to lower the loss. However when the network can't get the loss much lower by applying the knowledge about label 2
, it will learn about 1
and 0
too. So it is possible to train the network, though not so efficient.ReduceLROnPlateau
also isn't efficient because at the time the network starts to learn about label 0
and 1
, the learning rate is already to small(intuitively speaking). Doesn't mean it is untrainable, but it will probably take a lot of time.CrossEntropyLoss
calculates LogSoftmax
internally, so having Sigmoid
at the end of the network means you have a Softmax
layer right after Sigmoid
layer, which is probably not what you want. I think the network isn't necessarily 'wrong', but it will be much harder to train.1.1
is being labeled 2
because you have <1.1
and >1.1
.
Get rid of sigmoid
and scheduler
.
I was able to get Solved!
somewhere around 15000 epoch (with learning rate and batch size as same as your code).
Upvotes: 3