Reputation: 1
I'm working on a relation extraction model task using a transformer-based model. the `pipeline is expected to extract entity pairs along with their labelled relation labels. When I run the evaluation after training the model on labelled data, it works fine. However, when I run the model pipeline, the model sometimes returns only one entity instead of a complete entity pair, along with relation label. I don't quite understand how the model is making a prediction for relation label just by considering one entity.
def entity_extract(text):
doc_=MODEL_NER(text)
for name, proc in MODEL_REL.pipeline:
doc_ = proc(doc_)
return doc_
def relation_extraction(doc):
relation_extraction_output = {}
unique_entities = set()
for span, rel_dict in doc._.rel.items():
# Extract the relation with the highest confidence score
most_probable_relation = max(rel_dict, key=rel_dict.get)
score = rel_dict[most_probable_relation]
# Skip if the score is below the cutoff (0.5 50% in your case)
if score <= 0.5:
continue
start, end = span
relation_span = doc[start:end]
# Extract entities involved in the relation
entities = [ent for ent in relation_span.ents if ent.text not in unique_entities]
# Store the information if entities exist
if entities:
unique_entities.update(ent.text for ent in entities) # Efficient update
# Create a key for the entity pair
entity_pair = tuple(sorted([ent.text for ent in entities]))
if entity_pair not in relation_extraction_output:
relation_extraction_output[entity_pair] = {
'relations': [],
'scores': []
}
# Append the relation and score
relation_extraction_output[entity_pair]['relations'].append(most_probable_relation)
relation_extraction_output[entity_pair]['scores'].append(score)
# Convert the output to a list of dictionaries
final_output = [
{
'entities': key,
'relations': value['relations'],
'scores': value['scores']
}
for key, value in relation_extraction_output.items()
]
return final_output
def extract_relation(file):
filename = os.path.splitext(os.path.basename(file))[0]
print("Starting to process RelationExtract part for file:", filename)
chunk_size=10000
# Define a function to process each row in a chunk
def process_row(row):
try:
text = entity_extract(row["sent"])
return relation_extraction(text)
except Exception as e:
print(f"Error processing row: {e}")
return None
Expected output: entity pair: [entity1, entity2], relation: [relation_label] Actual output: [entity1, ], relation: [relation_label]
Upvotes: 0
Views: 18