Luiscri
Luiscri

Reputation: 1057

spaCy - Most efficient way to sort entities by label

I'm using spaCy pipeline to extract all entities from articles. I need to save these entities on a variable depending on the label they have been tagged with. For now I have this solution, but I think this is not the most suitable one as I need to iterate over all the entities for each label:

nlp = spacy.load("es_core_news_md")
text = # I upload my text here
doc = nlp(text)

personEntities = list(set([e.text for e in doc.ents if e.label_ == "PER"]))
locationEntities = list(set([e.text for e in doc.ents if e.label_ == "LOC"]))
organizationEntities = list(set([e.text for e in doc.ents if e.label_ == "ORG"]))

Is there a direct method in spaCy in order to get all the entities for each label or would I need to do for ent in ents: if... elif... elif... to achieve that?

Upvotes: 3

Views: 2665

Answers (1)

Wiktor Stribiżew
Wiktor Stribiżew

Reputation: 627327

I suggest using the groupby method from itertools:

from itertools import *
#...
entities = {key: list(g) for key, g in groupby(sorted(doc.ents, key=lambda x: x.label_), lambda x: x.label_)}

Or, if you need to only extract unique values:

entities = {key: list(set(map(lambda x: str(x), g))) for key, g in groupby(sorted(doc.ents, key=lambda x: x.label_), lambda x: x.label_)}

Then, you may print known entities using

print(entities['ORG'])

If you need to get unique occurrences of the entity objects, not just strings, you may use

import spacy
from itertools import *

nlp = spacy.load("en_core_web_sm")
s = "Hello, Mr. Wood! We are in New York. Mrs. Winston is not coming, John hasn't sent her any invite. They will meet in California next time. General Motors and Toyota are companies."
doc = nlp(s * 2)

entities = dict()
for key, g in groupby(sorted(doc.ents, key=lambda x: x.label_), lambda x: x.label_):
    seen = set()
    l = []
    for ent in list(g):
      if ent.text not in seen:
        seen.add(ent.text)
        l.append(ent)
    entities[key] = l

Output for print(entities['GPE'][0].text) is New York here.

Upvotes: 5

Related Questions