Reputation: 1176
I am working on fine-tuning BLIP-2 on the RSICD dataset using LoRA. I am working on colab, using an A100. I am strangely finding that when I set the learning rate in the code below, it has no effect. I can set it to 10^55, or I can set it to 10^(-55), and the loss still jumps around at roughly the same sized intervals.
To show what I'm talking about, I created a public colab notebook here. In this notebook I create a ExponentialLRScheduler
with gamma=.1, updated on every individual step i.e. the loss should be changing by an order of magnitude less in each step, just to demonstrate my point. When I print the loss and the learning rate, the learning rate indeed decreases by a factor of 10 in each step, and yet the loss continues to jump around in the same jump sizes, indicating that the loss rate being printed is not actually taking effect.
Does anyone know what might be causing this? Why is the loss still jumping around so much even with a learning rate that is essentially 0?
My code is also copied below.
from accelerate import dispatch_model, infer_auto_device_map
from accelerate.utils import get_balanced_memory
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-opt-2.7b",
device_map="auto",
torch_dtype=torch.bfloat16
)
processor.num_query_tokens = model.config.num_query_tokens
image_token = AddedToken("<image>", normalized=False, special=True)
processor.tokenizer.add_tokens([image_token], special_tokens=True)
model.resize_token_embeddings(len(processor.tokenizer), pad_to_multiple_of=64)
model.config.image_token_index = len(processor.tokenizer) - 1
device = "cuda" if torch.cuda.is_available() else "cpu"
n_epochs = 10
learning_rate = 2e-5
batch_size = 16
gradient_accumulation_steps = 1
weight_decay = 0.01
logging_steps = 5
max_grad_norm = 0
seed = 42
evaluation_strategy = "steps"
lr_scheduler_type = "constant"
lora_alpha = 32
lora_dropout = 0.05
lora_dim = 8
targetData=torch.load("/content/drive/Shareddrives/TEMFOM/target_data1.pt")
print("learning rate: " + str(learning_rate))
config = LoraConfig(
r=lora_dim,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
bias="none",
target_modules=["q_proj", "k_proj"]
)
model = get_peft_model(model, config)
model.print_trainable_parameters()
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1, last_epoch=-1, verbose=True)
def fine_tune(model, train_dataloader, optimizer, n_epochs, model_name="fine-tuned"):
for epoch in range(0, n_epochs):
for idx, batch in enumerate(train_dataloader):
input_ids = batch.pop("input_ids").to(device)
pixel_values = batch.pop("pixel_values").to(device, torch.float16)
outputs = model(
input_ids=input_ids,
pixel_values=pixel_values,
labels=input_ids
)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
print("Loss:", loss.item())
scheduler.step()
model.save_pretrained(direct + model_name)
return model
train_dataset = ImageCaptioningDataset(targetData, processor)
train_dataloader = DataLoader(
train_dataset,
shuffle=True,
batch_size=batch_size,
collate_fn=collate_fn
)
fine_tuned_model = fine_tune(model, train_dataloader, optimizer, n_epochs, model_name="fine-tuned")
An example of the output (which you can also see from the colab notebook) is below. As you can see, there are still very large jumps in the loss rate even as the learning rate decreases below 10^-50.
Epoch: 0
Loss: 5.318110466003418
learning rate before: 2e-10
learning rate after: 2.0000000000000002e-11
Loss: 4.520220756530762
learning rate before: 2.0000000000000002e-11
learning rate after: 2.0000000000000004e-12
Loss: 4.288538455963135
learning rate before: 2.0000000000000004e-12
learning rate after: 2.0000000000000006e-13
Loss: 4.518784999847412
learning rate before: 2.0000000000000006e-13
learning rate after: 2.0000000000000006e-14
Loss: 4.479536533355713
learning rate before: 2.0000000000000006e-14
learning rate after: 2.0000000000000005e-15
Loss: 4.55037260055542
learning rate before: 2.0000000000000005e-15
learning rate after: 2.0000000000000007e-16
Loss: 4.3671770095825195
learning rate before: 2.0000000000000007e-16
learning rate after: 2.0000000000000008e-17
Loss: 4.405301570892334
learning rate before: 2.0000000000000008e-17
learning rate after: 2.000000000000001e-18
Loss: 4.015200138092041
learning rate before: 2.000000000000001e-18
learning rate after: 2.000000000000001e-19
Loss: 4.679333209991455
learning rate before: 2.000000000000001e-19
learning rate after: 2.000000000000001e-20
Loss: 3.9693050384521484
learning rate before: 2.000000000000001e-20
learning rate after: 2.0000000000000013e-21
Loss: 4.0665154457092285
learning rate before: 2.0000000000000013e-21
learning rate after: 2.0000000000000015e-22
Loss: 4.336864471435547
learning rate before: 2.0000000000000015e-22
learning rate after: 2.0000000000000017e-23
Loss: 4.552571773529053
learning rate before: 2.0000000000000017e-23
learning rate after: 2.0000000000000017e-24
Loss: 4.480626106262207
learning rate before: 2.0000000000000017e-24
learning rate after: 2.0000000000000017e-25
Loss: 4.519588947296143
learning rate before: 2.0000000000000017e-25
learning rate after: 2.0000000000000018e-26
Loss: 4.422896862030029
learning rate before: 2.0000000000000018e-26
learning rate after: 2.000000000000002e-27
Loss: 3.851675033569336
learning rate before: 2.000000000000002e-27
learning rate after: 2.000000000000002e-28
Loss: 3.561893939971924
learning rate before: 2.000000000000002e-28
learning rate after: 2.000000000000002e-29
Loss: 4.885611534118652
learning rate before: 2.000000000000002e-29
learning rate after: 2.0000000000000023e-30
Loss: 4.571497440338135
learning rate before: 2.0000000000000023e-30
learning rate after: 2.0000000000000024e-31
Loss: 4.3077521324157715
learning rate before: 2.0000000000000024e-31
learning rate after: 2.0000000000000026e-32
Loss: 3.834765911102295
learning rate before: 2.0000000000000026e-32
learning rate after: 2.000000000000003e-33
Loss: 4.235876560211182
learning rate before: 2.000000000000003e-33
learning rate after: 2.000000000000003e-34
Loss: 4.281957626342773
learning rate before: 2.000000000000003e-34
learning rate after: 2.000000000000003e-35
Loss: 4.0060648918151855
learning rate before: 2.000000000000003e-35
learning rate after: 2.0000000000000032e-36
Loss: 4.274528503417969
learning rate before: 2.0000000000000032e-36
learning rate after: 2.0000000000000035e-37
Loss: 4.298925876617432
learning rate before: 2.0000000000000035e-37
learning rate after: 2.0000000000000036e-38
Loss: 4.506286144256592
learning rate before: 2.0000000000000036e-38
learning rate after: 2.0000000000000038e-39
Loss: 4.11824369430542
learning rate before: 2.0000000000000038e-39
learning rate after: 2.000000000000004e-40
Loss: 4.141360759735107
learning rate before: 2.000000000000004e-40
learning rate after: 2.000000000000004e-41
Loss: 4.402781963348389
learning rate before: 2.000000000000004e-41
learning rate after: 2.0000000000000042e-42
Loss: 4.450037002563477
learning rate before: 2.0000000000000042e-42
learning rate after: 2.000000000000004e-43
Loss: 4.273048400878906
learning rate before: 2.000000000000004e-43
learning rate after: 2.0000000000000041e-44
Loss: 4.774006366729736
learning rate before: 2.0000000000000041e-44
learning rate after: 2.0000000000000043e-45
Loss: 3.908968687057495
learning rate before: 2.0000000000000043e-45
learning rate after: 2.0000000000000043e-46
Loss: 3.9161949157714844
learning rate before: 2.0000000000000043e-46
learning rate after: 2.0000000000000043e-47
Loss: 4.0039896965026855
learning rate before: 2.0000000000000043e-47
learning rate after: 2.0000000000000045e-48
Loss: 3.8200762271881104
learning rate before: 2.0000000000000045e-48
learning rate after: 2.0000000000000044e-49
Loss: 4.692992687225342
learning rate before: 2.0000000000000044e-49
learning rate after: 2.0000000000000045e-50
Loss: 4.407190799713135
learning rate before: 2.0000000000000045e-50
learning rate after: 2.0000000000000048e-51
Loss: 4.065435886383057
learning rate before: 2.0000000000000048e-51
learning rate after: 2.000000000000005e-52
Loss: 3.7482211589813232
learning rate before: 2.000000000000005e-52
learning rate after: 2.000000000000005e-53
Loss: 4.571844100952148
learning rate before: 2.000000000000005e-53
learning rate after: 2.000000000000005e-54
Loss: 4.8389458656311035
learning rate before: 2.000000000000005e-54
learning rate after: 2.000000000000005e-55
Loss: 3.70975923538208
learning rate before: 2.000000000000005e-55
learning rate after: 2.000000000000005e-56
Loss: 3.7369227409362793
Upvotes: 1
Views: 483
Reputation: 51
If the learning rate in torch.optim.AdamW seems to have no effect, here are a few more things you can check:
1. Check Parameter Groups: In PyTorch optimizers like AdamW, there are different parameter groups, each potentially with its own learning rate. It’s possible that the learning rate you're setting is not being applied to the correct group.
You can inspect the learning rates in the optimizer’s parameter groups with this snippet:
for param_group in optimizer.param_groups:
print(param_group['lr'])
This will let you verify whether the learning rate is being set correctly for all parameter groups.
2. Learning Rate Scheduler: If you are using a learning rate scheduler, it could be overriding the learning rate set in the optimizer. Check if there’s a scheduler active in your code, and if so, ensure it is being properly configured.
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
In this case, the scheduler will adjust the learning rate every step, possibly leading to unexpected behavior if it is not configured as intended.
3. Frozen Layers in LoRA: Since you're fine-tuning with LoRA, many layers might be frozen, and the optimizer might not actually be updating any of the trainable layers. You can check which layers are being updated by inspecting the requires_grad attribute:
for name, param in model.named_parameters():
if param.requires_grad:
print(name)
Ensure that the layers you expect to fine-tune actually have requires_grad=True.
4. Manual Gradients or Clipping: If you are manually updating gradients or using gradient clipping, verify that the updates are applied correctly after backpropagation. For example, gradient clipping could dampen the effect of the learning rate if set too aggressively.
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
5. Small Gradient Updates: In some cases, gradients might be too small to have a noticeable impact, especially if they're being scaled down by other factors in your model or data processing pipeline. You could try printing out gradient magnitudes to ensure they are not vanishing.
6. Weight Decay: AdamW uses weight decay to prevent weights from becoming too large. If weight decay is high, it can counteract the learning rate, especially for small values. Check your weight decay settings:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
Consider reducing the weight decay or checking if it is set too aggressively.
If none of these steps help, feel free to share more details about the training setup, and I can help you troubleshoot further.
Upvotes: 0
Reputation: 1176
I just realized that this is most likely because the loss is being calculated on the batch level, and I'm using a small batch. I think it makes it a rather bad question after all, but I'll keep it here in case its useful to anyone.
Upvotes: 0