Union find
Union find

Reputation: 8160

How does TextCategorizer.predict work with spaCy?

I've been following the spaCy quick-start guide for text classification.

Let's say I have a very simple dataset.

TRAIN_DATA = [
    ("beef", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
    ("apple", {"cats": {"POSITIVE": 0, "NEGATIVE": 1}})
]

I'm training a pipe to classify text. It trains and has a low loss rate.

textcat = nlp.create_pipe("pytt_textcat", config={"exclusive_classes": True})
for label in ("POSITIVE", "NEGATIVE"):
    textcat.add_label(label)
nlp.add_pipe(textcat)

optimizer = nlp.resume_training()
for i in range(10):
    random.shuffle(TRAIN_DATA)
    losses = {}
    for batch in minibatch(TRAIN_DATA, size=8):
        texts, cats = zip(*batch)
        nlp.update(texts, cats, sgd=optimizer, losses=losses)
    print(i, losses)

Now, how do I predict whether a new string of text is "POSITIVE" or "NEGATIVE"?

This will work:

doc = nlp(u'Pork')
print(doc.cats)

It gives a score for each category we've trained to predict on.

But that seems at odds with the docs. It says I should use a predict method on the original subclass pipeline component.

That doesn't work though.

Trying textcat.predict('text') or textcat.predict(['text']) etc.. throws:

AttributeError          Traceback (most recent call last)
<ipython-input-29-39e0c6e34fd8> in <module>
----> 1 textcat.predict(['text'])

pipes.pyx in spacy.pipeline.pipes.TextCategorizer.predict()

AttributeError: 'str' object has no attribute 'tensor'

Upvotes: 2

Views: 1615

Answers (1)

Sofie VL
Sofie VL

Reputation: 3106

The predict methods of pipeline components actually expect a Doc as input, so you'll need to do something like textcat.predict(nlp(text)). The nlp used there does not necessarily have a textcat component. The result of that call then needs to be fed into a call to set_annotations() as shown here.

However, your first approach is just fine:

...
nlp.add_pipe(textcat)
...
doc = nlp(u'Pork')
print(doc.cats)
...

Internally, when calling nlp(text), first the Doc for the text will be generated, and then each pipeline component, one by one, will run its predict method on that Doc and keep adding information to it with set_annotations. Eventually the textcat component will define the cats variable of the Doc.

The API docs from which you're citing for the other approach, kind of give you a look "under the hood". So they're not really conflicting approaches ;-)

Upvotes: 1

Related Questions