Reputation: 4151
I'm trying to train a model with a context length of 50,000 using PyTorch's Fully Sharded Data Parallel (FSDP) but I keep running into an OutOfMemoryError regardless of how many GPUs are available.
Here is the llama-recipes command I am using to train the model:
torchrun --nnodes 1 --nproc_per_node 8 recipes/quickstart/finetuning/finetuning.py --context_length 50000 --enable_fsdp --model_name /path_of_model_folder/8B --use_peft --peft_method lora --output_dir Path/to/save/PEFT/model --use_fast_kernels
I've tried multiple GPUs (8 in this case) but still face out-of-memory errors when increasing the --context_length.
A single GPU can handle context_length=5000 but I cannot even train context_length=6000 with 8 GPUs
Does FSDP work for long context? (the default strategy is ShardingStrategy.FULL_SHARD)
Upvotes: 0
Views: 29