terence1022
terence1022

Reputation: 1

Unable to get the metrics plots of SentenceTransformer model

I'm finetuning SentenceTransformer('all-MiniLM-L6-v2') model, using a dictionary (json) called category_descriptions as dataset.

Below is the data structure of category_descriptions

{
    "CategoryA": {
        "CategorySearch": "Description for CategoryA",
        "SubCategory1": "Description for SubCategory1 of CategoryA",
        "SubCategory2": "Description for SubCategory2 of CategoryA",
        ...
    },
    "CategoryB": {
        "CategorySearch": "Description for CategoryB",
        "SubCategory1": "Description for SubCategory1 of CategoryB",
        "SubCategory2": "Description for SubCategory2 of CategoryB",
        ...
    },
    ...
}

I'm unable to get the training accuracy, training loss, validation accuracy, validation loss. I've tried plenty of ways, but the result usually end up with TypeError: FitMixin.smart_batching_collate() missing 1 required positional argument: 'batch'

Did I use the wrong collate_fn in data loader or define the wrong loss function?

Following are my finetuning processes:

  1. Read category descriptions and subcategory descriptions from json file.
  2. Convert data into InputExample objects.
  3. Convert text labels into numeric indices.
  4. Split dataset into 80% training data and 20% validation data.
  5. Use SentenceTransformer('all-MiniLM-L6-v2') as pre-trained model.
  6. Create data loader for training data and validation data.
  7. Use SoftmaxLoss as loss function.
  8. Finetune model.
  9. Plot results.

Following are my finetuning part of the code:

with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Embedding Dictionary.json'), 'r', encoding='utf-8') as f:
    category_descriptions = json.load(f)

dataset = []

for category, subcats in category_descriptions.items():
    if 'CategorySearch' in subcats:
        dataset.append(InputExample(texts=[subcats["CategorySearch"], ''], label=category))

    for subcat, description in subcats.items():
        if subcat != "CategorySearch":
            dataset.append(InputExample(texts=[description, ''], label=subcat))

all_labels = list(set(example.label for example in dataset))

label_to_index = {label: idx for idx, label in enumerate(all_labels)}
index_to_label = {idx: label for label, idx in label_to_index.items()}

for example in dataset:
    example.label = label_to_index[example.label]

train_data, valid_data = train_test_split(dataset, test_size=0.2, random_state=42)

org_model = SentenceTransformer('all-MiniLM-L6-v2')

train_dataloader = DataLoader(train_data, shuffle=True, batch_size=8, collate_fn=SentenceTransformer.smart_batching_collate)
valid_dataloader = DataLoader(valid_data, shuffle=False, batch_size=8, collate_fn=SentenceTransformer.smart_batching_collate)

loss_function = losses.SoftmaxLoss(model=org_model, num_labels=len(all_labels), sentence_embedding_dimension=org_model.get_sentence_embedding_dimension())

epochs = 5
warmup_steps = 20

train_losses, valid_losses = [], []
train_accuracies, valid_accuracies = [], []

for epoch in range(epochs):

    org_model.train()
    train_loss, train_correct, train_total = 0, 0, 0
    for batch in train_dataloader:
        loss_value = loss_function(batch)
        loss_value.backward()
        org_model.optimizer.step()
        org_model.optimizer.zero_grad()
        train_loss += loss_value.item()
        predictions = loss_function.get_prediction(batch)
        train_correct += (predictions == batch['labels']).sum().item()
        train_total += len(batch['labels'])
    
    train_losses.append(train_loss / len(train_dataloader))
    train_accuracies.append(train_correct / train_total)

    org_model.eval()
    valid_loss, valid_correct, valid_total = 0, 0, 0
    with torch.no_grad():
        for batch in valid_dataloader:
            loss_value = loss_function(batch)
            valid_loss += loss_value.item()
            predictions = loss_function.get_prediction(batch)
            valid_correct += (predictions == batch['labels']).sum().item()
            valid_total += len(batch['labels'])
    
    valid_losses.append(valid_loss / len(valid_dataloader))
    valid_accuracies.append(valid_correct / valid_total)
    
    print(f"Epoch {epoch + 1}/{epochs}")
    print(f"Training Loss: {train_losses[-1]:.4f}, Training Accuracy: {train_accuracies[-1]:.4f}")
    print(f"Validation Loss: {valid_losses[-1]:.4f}, Validation Accuracy: {valid_accuracies[-1]:.4f}")

plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.plot(range(1, epochs + 1), train_losses, label='Training Loss', marker='o')
plt.plot(range(1, epochs + 1), valid_losses, label='Validation Loss', marker='o')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, epochs + 1), train_accuracies, label='Training Accuracy', marker='o')
plt.plot(range(1, epochs + 1), valid_accuracies, label='Validation Accuracy', marker='o')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

Upvotes: 0

Views: 51

Answers (0)

Related Questions