J.Todd
J.Todd

Reputation: 827

How can I get this basic example working of TensorFlow.js training for a true / false output with tf.browser.fromPixels(image)?

I've Googled every version of the question I could think of, but for the life of me I cant find a single basic example of tensorflow.js training a tf.browser.fromPixels(image) to result in a yes or a no. All the examples out there I could find start with pre-trained nets.

I've built a database of 25x25 pixel images and have them all stored as canvases in a variable like:

let data = {
    t: [canvas1, canvas2, canvas3, ... canvas3000 ....],
    f: [canvas1, canvas2, ... and so on ...]
}

And I think it should be trivial to do something like:

data.t.forEach(canvas => {
    const xs = tf.browser.fromPixels(canvas);
    const ys = tf.tensor([1]); // output 1, since this canvas is from the `t` (true) dataset
    model.fit(xs, ys, {
      batchSize: 1,
      epochs: 1000
    });
});

data.f.forEach(canvas => {
    const xs = tf.browser.fromPixels(canvas);
    const ys = tf.tensor([0]); // output 0, since this canvas is from the `f` (false) dataset
    model.fit(xs, ys, {
      batchSize: 1,
      epochs: 1000
    });
});

model.predict(tf.browser.fromPixels(data.t[0])).print(); // -> [1]
model.predict(tf.browser.fromPixels(data.t[1])).print(); // -> [1]
model.predict(tf.browser.fromPixels(data.t[2])).print(); // -> [1]

model.predict(tf.browser.fromPixels(data.f[0])).print(); // -> [0]
model.predict(tf.browser.fromPixels(data.f[1])).print(); // -> [0]
model.predict(tf.browser.fromPixels(data.f[2])).print(); // -> [0]

But the specifics, like inputShape and various little details, being new to TF, make trying to accomplish this without being able to find a basic example pretty much a painful learning curve. What would a valid representation of this training func look like? Here's the code so far:

// Just imagine DataSet builds a large data set like described in my 
// question and calls a callpack function with the data variable as 
// its only argument, full of pre-categorized images. Since my database 
// of images is locally stored, I cant really produce an example here 
// that works fully, but this gets the idea across at least.
new DataSet(
  data => {
    
    const model = tf.sequential();
    
    model.add(
    
      // And yes, I realize I would want a convolutional layer, 
      // some max pooling, filtering, etc, but I'm trying to start simple
      
      tf.layers.dense({
        units: [1],
        inputShape: [25, 25, 3],
        dataFormat: "channelsLast",
        activation: "tanh"
      })
    );
    
    model.compile({optimizer: "sgd", loss: "binaryCrossentropy", lr: 0.1});
    
    data.t.forEach(canvas => {
        const xs = tf.browser.fromPixels(canvas);
        const ys = tf.tensor([1]); // output 1, since this canvas is 
        // from the `t` (true) dataset
        model.fit(xs, ys, {
          batchSize: 1,
          epochs: 1000
        });
    });
    
    data.f.forEach(canvas => {
        const xs = tf.browser.fromPixels(canvas);
        const ys = tf.tensor([0]); // output 0, since this canvas is 
        // from the `f` (false) dataset
        model.fit(xs, ys, {
          batchSize: 1,
          epochs: 1000
        });
    });
    
    model.predict(tf.browser.fromPixels(data.t[0])).print(); // -> [1]
    model.predict(tf.browser.fromPixels(data.t[1])).print(); // -> [1]
    model.predict(tf.browser.fromPixels(data.t[2])).print(); // -> [1]
    
    model.predict(tf.browser.fromPixels(data.f[0])).print(); // -> [0]
    model.predict(tf.browser.fromPixels(data.f[1])).print(); // -> [0]
    model.predict(tf.browser.fromPixels(data.f[2])).print(); // -> [0]
    
  },
  {canvas: true}
);
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/[email protected]/dist/tf.min.js"></script>

Upvotes: 1

Views: 440

Answers (1)

edkeveked
edkeveked

Reputation: 18371

You have only one layer for your model. You need more layers than that. There are lots of tutorial you can follow to build a classifier to distinguish between two or more class of images. Here is this tutorial on the official website of tensorflow using CNN.

Additionnaly, you can see how to use fully connected neural network using this snippet to build a classifier though the accuracy might not be as good as CNN models.

Upvotes: 1

Related Questions