Zaba
Zaba

Reputation: 27

Linear layers for LORA

I have been trying to do DPO on the Llava models (llava-hf/llava-v1.6-mistral-7b-hf) and came across the training script Llava folks provided and realized that all the multimodal linear layers are ignored when selecting LORA targets. Can someone please explain why?

https://github.com/LLaVA-VL/LLaVA-NeXT/blob/09e5840d5589ad2d6a8656c0a60f21ae134b3309/llava/train/train_dpo.py#L226

Here is the function they have for selecting the layers:

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split(".")
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if "lm_head" in lora_module_names:  # needed for 16-bit
        lora_module_names.remove("lm_head")
    return list(lora_module_names)

I expected the vision_tower layers(linear) to also be included primarily fc1 and fc2 but the LLaVA training script ignores them. Trying to understand why.

Upvotes: 0

Views: 19

Answers (0)

Related Questions