Kokodoko
Kokodoko

Reputation: 28128

Basic training of ML5 neural network not working

I am using ML5 to train a Neural Network. I am loading a CSV file with data from the Titanic. This works when I download the demo file from the ML5 GitHub.

But when I use a different CSV file, and I replace the column names in my code, it stops working. Am I missing something? Is it a problem that my CSV file contains numbers, while the demo file contains strings?

let neuralNetwork

function start() {    
    const nnOptions = {
        dataUrl: "data/titanic.csv",
        inputs: ["Pclass", "Sex", "Age", "SibSp"],     // labels from my CSV file
        outputs: ["Survived"],
        task: "classification",
        debug: true,
    };

    neuralNetwork = ml5.neuralNetwork(nnOptions, modelReady);
}

function modelReady() {
    neuralNetwork.normalizeData();
    neuralNetwork.train({ epochs: 50 }, whileTraining, finishedTraining);
}

// this doesn't get called at all
function whileTraining(epoch, logs) {
    console.log(`Epoch: ${epoch} - loss: ${logs.loss.toFixed(2)}`);
}

// this gets called immediately
function finishedTraining() {
    console.log("done!");
}

start()

The console immediately shows "done!", but the model is not trained. There is no error message. The strange thing is, when a label name is incorrect, then I do get an error. So the label names are actually recognised.

Original CSV file, working:

survived,fare_class,sex,age,fare
died,first,male,39,0
died,first,male,29,0

My CSV file, not working:

Survived,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
0,3,1,22.0,1,0,7.25,1
1,1,0,38.0,1,0,71.2833,2

Upvotes: 0

Views: 468

Answers (2)

Prewish Autar
Prewish Autar

Reputation: 1

I'm not sure if this would help but from what I learned is that all the inputs should be numbers. But if the demo works then it the code should work. When it comes to ML5 you have to pay attention to the CSV file. make sure the output that you are trying to predict is a string. if ur using a csv with only numbers and are trying to predict a number as output then you should change the task: classification to regression. you can also give this code a try if you would like to:

Upvotes: 0

Kokodoko
Kokodoko

Reputation: 28128

Just in case anyone runs into this issue: when you are classifying, the label always has to be a string....

Working CSV file:

Survived,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
yes,3,1,22.0,1,0,7.25,1
no,1,0,38.0,1,0,71.2833,2

Upvotes: 1

Related Questions