user29424767
user29424767

Reputation: 1

Relation Extraction Model returns only one entity instead of entity pairs

I'm working on a relation extraction model task using a transformer-based model. the `pipeline is expected to extract entity pairs along with their labelled relation labels. When I run the evaluation after training the model on labelled data, it works fine. However, when I run the model pipeline, the model sometimes returns only one entity instead of a complete entity pair, along with relation label. I don't quite understand how the model is making a prediction for relation label just by considering one entity.

def entity_extract(text):
    doc_=MODEL_NER(text)
    for name, proc in MODEL_REL.pipeline:
            doc_ = proc(doc_)
    return doc_
def relation_extraction(doc):
    relation_extraction_output = {}
    unique_entities = set()

    for span, rel_dict in doc._.rel.items():
        # Extract the relation with the highest confidence score
        most_probable_relation = max(rel_dict, key=rel_dict.get)
        score = rel_dict[most_probable_relation]
        
        # Skip if the score is below the cutoff (0.5 50%  in your case)
        if score <= 0.5:
            continue

        start, end = span
        relation_span = doc[start:end]
        # Extract entities involved in the relation
        entities = [ent for ent in relation_span.ents if ent.text not in unique_entities]
        # Store the information if entities exist
        if entities:
            unique_entities.update(ent.text for ent in entities)  # Efficient update
            # Create a key for the entity pair
            entity_pair = tuple(sorted([ent.text for ent in entities]))
            if entity_pair not in relation_extraction_output:
                relation_extraction_output[entity_pair] = {
                    'relations': [],
                    'scores': []
                }
            # Append the relation and score
            relation_extraction_output[entity_pair]['relations'].append(most_probable_relation)
            relation_extraction_output[entity_pair]['scores'].append(score)
    # Convert the output to a list of dictionaries
    final_output = [
        {
            'entities': key,
            'relations': value['relations'],
            'scores': value['scores']
        }
        for key, value in relation_extraction_output.items()
    ]
    return final_output
def extract_relation(file):
    filename = os.path.splitext(os.path.basename(file))[0]
    print("Starting to process RelationExtract part for file:", filename)
    chunk_size=10000
    # Define a function to process each row in a chunk
    def process_row(row):
        try:
            text = entity_extract(row["sent"])
            return relation_extraction(text)
        except Exception as e:
            print(f"Error processing row: {e}")
            return None

Expected output: entity pair: [entity1, entity2], relation: [relation_label] Actual output: [entity1, ], relation: [relation_label]

Upvotes: 0

Views: 18

Answers (0)

Related Questions