Reputation: 1
I'm finetuning SentenceTransformer('all-MiniLM-L6-v2') model, using a dictionary (json) called category_descriptions
as dataset.
Below is the data structure of category_descriptions
:
{
"CategoryA": {
"CategorySearch": "Description for CategoryA",
"SubCategory1": "Description for SubCategory1 of CategoryA",
"SubCategory2": "Description for SubCategory2 of CategoryA",
...
},
"CategoryB": {
"CategorySearch": "Description for CategoryB",
"SubCategory1": "Description for SubCategory1 of CategoryB",
"SubCategory2": "Description for SubCategory2 of CategoryB",
...
},
...
}
I'm unable to get the training accuracy, training loss, validation accuracy, validation loss. I've tried plenty of ways, but the result usually end up with TypeError: FitMixin.smart_batching_collate() missing 1 required positional argument: 'batch'
Did I use the wrong collate_fn
in data loader or define the wrong loss function?
Following are my finetuning processes:
InputExample
objects.SentenceTransformer('all-MiniLM-L6-v2')
as pre-trained model.SoftmaxLoss
as loss function.Following are my finetuning part of the code:
with open(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Embedding Dictionary.json'), 'r', encoding='utf-8') as f:
category_descriptions = json.load(f)
dataset = []
for category, subcats in category_descriptions.items():
if 'CategorySearch' in subcats:
dataset.append(InputExample(texts=[subcats["CategorySearch"], ''], label=category))
for subcat, description in subcats.items():
if subcat != "CategorySearch":
dataset.append(InputExample(texts=[description, ''], label=subcat))
all_labels = list(set(example.label for example in dataset))
label_to_index = {label: idx for idx, label in enumerate(all_labels)}
index_to_label = {idx: label for label, idx in label_to_index.items()}
for example in dataset:
example.label = label_to_index[example.label]
train_data, valid_data = train_test_split(dataset, test_size=0.2, random_state=42)
org_model = SentenceTransformer('all-MiniLM-L6-v2')
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=8, collate_fn=SentenceTransformer.smart_batching_collate)
valid_dataloader = DataLoader(valid_data, shuffle=False, batch_size=8, collate_fn=SentenceTransformer.smart_batching_collate)
loss_function = losses.SoftmaxLoss(model=org_model, num_labels=len(all_labels), sentence_embedding_dimension=org_model.get_sentence_embedding_dimension())
epochs = 5
warmup_steps = 20
train_losses, valid_losses = [], []
train_accuracies, valid_accuracies = [], []
for epoch in range(epochs):
org_model.train()
train_loss, train_correct, train_total = 0, 0, 0
for batch in train_dataloader:
loss_value = loss_function(batch)
loss_value.backward()
org_model.optimizer.step()
org_model.optimizer.zero_grad()
train_loss += loss_value.item()
predictions = loss_function.get_prediction(batch)
train_correct += (predictions == batch['labels']).sum().item()
train_total += len(batch['labels'])
train_losses.append(train_loss / len(train_dataloader))
train_accuracies.append(train_correct / train_total)
org_model.eval()
valid_loss, valid_correct, valid_total = 0, 0, 0
with torch.no_grad():
for batch in valid_dataloader:
loss_value = loss_function(batch)
valid_loss += loss_value.item()
predictions = loss_function.get_prediction(batch)
valid_correct += (predictions == batch['labels']).sum().item()
valid_total += len(batch['labels'])
valid_losses.append(valid_loss / len(valid_dataloader))
valid_accuracies.append(valid_correct / valid_total)
print(f"Epoch {epoch + 1}/{epochs}")
print(f"Training Loss: {train_losses[-1]:.4f}, Training Accuracy: {train_accuracies[-1]:.4f}")
print(f"Validation Loss: {valid_losses[-1]:.4f}, Validation Accuracy: {valid_accuracies[-1]:.4f}")
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(range(1, epochs + 1), train_losses, label='Training Loss', marker='o')
plt.plot(range(1, epochs + 1), valid_losses, label='Validation Loss', marker='o')
plt.title('Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(range(1, epochs + 1), train_accuracies, label='Training Accuracy', marker='o')
plt.plot(range(1, epochs + 1), valid_accuracies, label='Validation Accuracy', marker='o')
plt.title('Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.show()
Upvotes: 0
Views: 51