Reputation: 27
I have been trying to do DPO on the Llava models (llava-hf/llava-v1.6-mistral-7b-hf) and came across the training script Llava folks provided and realized that all the multimodal linear layers are ignored when selecting LORA targets. Can someone please explain why?
Here is the function they have for selecting the layers:
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"]
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if "lm_head" in lora_module_names: # needed for 16-bit
lora_module_names.remove("lm_head")
return list(lora_module_names)
I expected the vision_tower layers(linear) to also be included primarily fc1 and fc2 but the LLaVA training script ignores them. Trying to understand why.
Upvotes: 0
Views: 19