MLP running but not learning

I am currently trying to implement a MLP by hand, with ReLU activations for each hidden layer and a softmax activation for the output layer. The MLP is running but it is not learning from its training, but i cant seem to find the issue with my backpropagation.

The important functions are forward and backward which seem correct to me. I am pretty new to pytorch, so sorry in advance if there are blatant mistakes, i am still learning.

Here is the corresponding code

import torch
def inputs_tilde(x, axis=-1):
    # augments the inputs `x` with ones along `axis`
    # todo : implémenter code ici.
    pad = [0 for i in range(2 * len(x.size()))]
    pad[-(2 * axis + 1)] = 1
    x_tilde = torch.nn.functional.pad(x,pad, "constant", 1)
    return x_tilde

def softmax(x, axis=-1):
    # assurez vous que la fonction est numeriquement stable
    # e.g. softmax(np.array([1000, 10000, 100000], ndim=2))

    # todo : calcul des valeurs de softmax(x)

    scaled =  x - x.max()
    exp = torch.exp(scaled)
    values = exp / exp.sum()
    if not torch.allclose(values.sum(),torch.tensor(1.000)):
        raise BaseException(values.sum())
    return values

def cross_entropy(y, y_pred):
    # todo : calcul de la valeur d'entropie croisée.
    eps = 10**-6
    loss = (y * torch.log(y_pred + eps) + (1 - y) * torch.log(abs(1 - y_pred) + eps)).sum()
    return loss

def softmax_cross_entropy_backward(y, y_pred):
    # todo : calcul de la valeur du gradient de l'entropie croisée composée avec `softmax`
    values = (y_pred - y)
    # gradients were chosen as vertical for convention
    return values[:,None]

def relu_forward(x):
    # todo : calcul des valeurs de relu(x)
    values = x.clone()
    values[values < 0] = 0
    return values

def relu_backward(x):
    # todo : calcul des valeurs du gradient de la fonction `relu`
    values = x.clone()
    values[values >= 0] = 1
    values[values < 0] = 0
    return values

class MLPModel:
    def __init__(self, n_features, n_hidden_features, n_hidden_layers, n_classes):
        self.n_features        = n_features
        self.n_hidden_features = n_hidden_features
        self.n_hidden_layers   = n_hidden_layers
        self.n_classes         = n_classes

        # todo : initialiser la liste des paramètres Teta de l'estimateur.

        self.params = []
        self.params.append(torch.normal(0, 0.01, (n_hidden_features, n_features + 1)))

        for l in range(0, n_hidden_layers):
            self.params.append(torch.normal(0, 0.01, (n_hidden_features, n_hidden_features+1)))

        self.params.append(torch.normal(0, 0.01, (n_classes,n_hidden_features + 1)))
        print(f"Teta params={[p.shape for p in self.params]}")

        self.a = [] # liste contenant le resultat des multiplications matricielles
        self.h = [] # liste contenant le resultat des fonctions d'activations

    def forward(self, x):
        # todo : implémenter calcul des outputs en fonction des inputs `x`.

        outputs = torch.empty(x.size()[0],self.n_classes)

        for k in range(x.size()[0]):
            out = x[k][:,None]
            # h = torch.empty(self.n_hidden_layers,self.n_features+1,self.n_features)
            # a = torch.empty(self.n_hidden_layers,self.n_features+1,self.n_features)

            h = [0] * (self.n_hidden_layers + 2)
            a = [0] * (self.n_hidden_layers + 2)

            h[0] = out
            for i in range(self.n_hidden_layers+1):
                out = torch.matmul(self.params[i],inputs_tilde(out,0))
                a[i] = out
                out = relu_forward(out)
                h[i+1] = out

            out = torch.matmul(self.params[-1],inputs_tilde(out,0))
            a[-1] = out
            out = softmax(out)
            outputs[k] = out.squeeze(1)

        return outputs

    def backward(self, y, y_pred):
        # todo : implémenter calcul des gradients.

        b_size = y.size()[0]
        grads = []
        for i in range(self.n_hidden_layers):
            grads.append(torch.empty([b_size] + [ s for s in self.params[1].size()]))

        grads.append(torch.empty([b_size] + [ s for s in self.params[-1].size()]))

        smloss = softmax_cross_entropy_backward(y,y_pred) 

        for k in range(y.size()[0]):
            g = smloss[k]
            g = g.transpose(0,1)

            grads[-1][k] = torch.matmul(g,inputs_tilde(self.h[k][-1],0).transpose(0,1))

            for i in range(self.n_hidden_layers,0,-1):

                g = torch.matmul(self.params[i+1].transpose(0,1)[:-1],g  * relu_backward(self.a[k][i]))
                grads[i][k] = torch.matmul(g,inputs_tilde(self.h[k][i],0).transpose(0,1))
            g = torch.matmul(self.params[1].transpose(0,1)[:-1],g * relu_backward(self.a[k][0]))
            grads[0][k] = torch.matmul(g,inputs_tilde(self.h[k][0],0).transpose(0,1))

        return grads

    def sgd_update(self, lr, grads):
        # TODO : implémenter mise à jour des paramètres ici.
        avg_grads = [g.mean(dim=0) for g in grads]

        for i in range(len(self.params)):
            self.params[i] -= lr * avg_grads[i]

def train(model, lr=0.1, nb_epochs=10, sgd=True, data_loader_train=None, data_loader_val=None):
    best_model = None
    best_val_accuracy = 0
    logger = Logger()

    for epoch in range(nb_epochs+1):

        # at epoch 0 evaluate random initial model
        #   then for subsequent epochs, do optimize before evaluation.
        if epoch > 0:
            for x, y in data_loader_train:
                x, y = reshape_input(x, y)

                y_pred = model.forward(x)
                grads  = model.backward(y, y_pred)

                model.sgd_update(lr, grads)

        accuracy_train, loss_train = accuracy_and_loss_whole_dataset(data_loader_train, model)
        accuracy_val, loss_val = accuracy_and_loss_whole_dataset(data_loader_val, model)
        if accuracy_val > best_val_accuracy:
            # TODO : record the best model parameters and best validation accuracy
            best_model = model
            best_val_accuracy = accuracy_val

        logger.log(accuracy_train, loss_train, accuracy_val, loss_val)
        print(f"Epoch {epoch:2d}, \
                Train:loss={loss_train.item():.3f}, accuracy={accuracy_train.item()*100:.1f}%, \
                Valid: loss={loss_val.item():.3f}, accuracy={accuracy_val.item()*100:.1f}%", flush=True)

    return best_model, best_val_accuracy, logger

