user2212461
user2212461

Reputation: 3253

Classify in tensorflowjs

Following this tutorial I want to load and use a model in tensorflowjs, and then use the classify method to classify an input.

I load and execute the model like this:

const model = await window.tf.loadGraphModel(MODEL_URL);

const threshold = 0.9;
const labelsToInclude = ["test1"];

model.load(threshold, labelsToInclude).then(model2 => {
    model2.classify(["test sentence"])
      .then(predictions => {
    console.log('prediction: ' + predictions);
    return true;
  })
});

But I am getting the error:

TypeError: model2.classify is not a function at App.js:23

How can I use the classify method in tensorflowjs correctly?

Upvotes: 1

Views: 459

Answers (1)

Thomas Dondorf
Thomas Dondorf

Reputation: 25230

The tutorial uses a specific model (toxicity). Its load and classify functions are not a feature of Tensorflow.js model itself but rather implemented by this specific model.

Check out the API to see the supported functions for models in general. If you load a GraphModel, you want to use the model.predict (or execute) function to execute the model.

Therefore, your code should look like this:

const model = await window.tf.loadGraphModel(MODEL_URL);
const input = tf.tensor(/* ... */); // whatever a valid tensor looks like for your model
const predictions = model.predict([input]);
console.log('prediction: ' + predictions);

Upvotes: 1

Related Questions