A.A
A.A

Reputation: 4151

OutOfMemoryError when increasing context_length of LLMs

I'm trying to train a model with a context length of 50,000 using PyTorch's Fully Sharded Data Parallel (FSDP) but I keep running into an OutOfMemoryError regardless of how many GPUs are available.

Here is the llama-recipes command I am using to train the model:

torchrun --nnodes 1 --nproc_per_node 8  recipes/quickstart/finetuning/finetuning.py --context_length 50000 --enable_fsdp --model_name /path_of_model_folder/8B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels

I've tried multiple GPUs (8 in this case) but still face out-of-memory errors when increasing the --context_length.

A single GPU can handle context_length=5000 but I cannot even train context_length=6000 with 8 GPUs

Does FSDP work for long context? (the default strategy is ShardingStrategy.FULL_SHARD)

Upvotes: 0

Views: 29

Answers (0)

Related Questions