Florian Rudaj
Florian Rudaj

Reputation: 1

Minimal FSDP example utilizing the HuggingFace Trainer in AWS Sagemaker

I'm currently trying to fine-tune a LLM in AWS Sagemaker. Since it's too big to fit on a single GPU I'm trying to distribute the model weights over multiple GPUs in an AWS Sagemaker instance. In my training script, I use the HuggingFace Trainer. Since the HuggingFace Trainer (with the fsdp parameter), the PyTorch library (with torch.distributed) as well as AWS Sagemaker (with smdistributed) all have mechanisms to enable fsdp I'm entirely confused how I can (or should?) enable FSDP for my use case.

I'd be very glad if someone could help me out here by providing a minimal but working example on how to enable FSDP by utilizing the HuggingFace Trainer in an AWS Sagemaker Training Job.

Edit: I now tried to implement the suggestion, but run into the following error:

Traceback (most recent call last):
  File "/app/train.py", line 178, in <module>
    train_results = trainer.train()
  File "/usr/local/lib/python3.9/site-packages/transformers/trainer.py", line 1624, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.9/site-packages/transformers/trainer.py", line 1766, in _inner_training_loop
    self.model = self.accelerator.prepare(self.model)
  File "/usr/local/lib/python3.9/site-packages/accelerate/accelerator.py", line 1228, in prepare
    result = tuple(
  File "/usr/local/lib/python3.9/site-packages/accelerate/accelerator.py", line 1229, in <genexpr>
    self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  File "/usr/local/lib/python3.9/site-packages/accelerate/accelerator.py", line 1105, in _prepare_one
    return self.prepare_model(obj, device_placement=device_placement)
  File "/usr/local/lib/python3.9/site-packages/accelerate/accelerator.py", line 1328, in prepare_model
    if torch.device(current_device_index) != self.device:
TypeError: device() received an invalid combination of arguments - got (NoneType), but expected one of:
 * (torch.device device)
      didn't match because some of the arguments have invalid types: (!NoneType!)
 * (str type, int index)

This is my training script:

from torch.utils.data import Dataset
import torch
from peft import (
    LoraConfig,
    get_peft_model,
    get_peft_model_state_dict,
    prepare_model_for_int8_training,
)
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import argparse

import sys
import os
import logging
import matplotlib.pyplot as plt
if __name__ == "__main__":

    
    parser = argparse.ArgumentParser()
    # hyperparameters sent by the client
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--per_device_train_batch_size", type=int, default=32)
    parser.add_argument("--model_name", type=str, default="codellama/CodeLlama-7b-hf")
    parser.add_argument("--learn_rate", type=str, default="3e-4")
    parser.add_argument("--warmup_steps", type=int, default=400)
    # Data, model and output directories
    parser.add_argument("--output_data_dir", type=str, default="/opt/ml/output/data")
    parser.add_argument("--model-dir", type=str, default="/opt/ml/model")
    parser.add_argument("--n_gpus", type=str, default="4")
    parser.add_argument("--training_dir", type=str, default="/opt/ml/input/data/train")
    parser.add_argument("--test_dir", type=str, default="/opt/ml/input/data/test")

    args, _ = parser.parse_known_args()

    # Set up logging
    logger = logging.getLogger(__name__)

    logging.basicConfig(
        level=logging.getLevelName("INFO"),
        handlers=[logging.StreamHandler(sys.stdout)],
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
        )

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        load_in_8bit=True,
        torch_dtype=torch.float16
    )
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)


    # %%
    tokenizer.add_eos_token = True
    tokenizer.pad_token_id = 0
    tokenizer.padding_side = "left"

    # %%
    class MappingDataset(Dataset):
        def __init__(self, train=True):
            if train:
                self.dataset = load_from_disk(dataset_path=args.training_dir)
                logger.info(f"loaded train dataset with a length of:{len(self.dataset)}")
            else:
                self.dataset = load_from_disk(dataset_path=args.test_dir)
                logger.info(f"loaded test dataset with a length of:{len(self.dataset)}")

            self.dataset = self.dataset.select(range(1000))
        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, idx):
            return self.dataset[idx]

    # %%
    train_dataset = MappingDataset(train=True)
    val_dataset = MappingDataset(train=False)


    # %%
    model.train() # put model back into training mode
    model = prepare_model_for_int8_training(model)

    config = LoraConfig(
        r=16,
        lora_alpha=16,
        target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
    ],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, config)


    # %% [markdown]
    # The cell below keeps the Trainer from trying its own DataParallelism when more than 1 gpu is available

    # %%
    if torch.cuda.device_count() > 1:
        model.is_parallelizable = True
        model.model_parallel = True

    # %%
    gradient_accumulation_steps = args.batch_size // args.per_device_train_batch_size
    output_dir = "bis-mapping-code-llama"

    training_args = TrainingArguments(
            per_device_train_batch_size=args.per_device_train_batch_size,
            gradient_accumulation_steps=gradient_accumulation_steps,
            warmup_steps=args.warmup_steps,
            learning_rate=float(args.learn_rate),
            fp16=True,
            logging_steps=1,
            optim="adamw_torch",
            evaluation_strategy="steps", # if val_set_size > 0 else "no",
            save_strategy="steps",
            eval_steps=20,
            save_steps=20,
            output_dir=args.model_dir,
            load_best_model_at_end=False,
            group_by_length=True, # group sequences of roughly the same length together to speed up training
            log_level='debug',
            logging_dir=f"{args.output_data_dir}/logs"
        )
    
    def compute_perplexity(pred):
        # Extract the predicted logits from the model output
        logits = pred.predictions
        # Flatten the logits and labels to compute cross-entropy loss
        logits = logits.view(-1, logits.size(-1))
        labels = pred.label_ids.view(-1)
        # Compute cross-entropy loss
        loss = torch.nn.functional.cross_entropy(logits, labels)
        # Compute perplexity
        perplexity = torch.exp(loss)
        return {"perplexity": perplexity.item()}

    trainer = Trainer(
        model=model,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        args=training_args,
        compute_metrics=compute_perplexity,
        data_collator=DataCollatorForSeq2Seq(
        tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    ),
    )


    # %% [markdown]
    # The cell below only serves for optimizing the model training

    # %%
    model.config.use_cache = False

    old_state_dict = model.state_dict
    model.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())).__get__(
        model, type(model)
    )
    if torch.__version__ >= "2" and sys.platform != "win32":
        print("compiling the model")
        model = torch.compile(model)


    # %%
    train_results = trainer.train()

    train_loss_values = train_results["train_loss"]

    
    # Plot the loss values
    plt.plot(train_loss_values, label="Training Loss")
    plt.xlabel("Training Steps")
    plt.ylabel("Loss")
    plt.title("Training Loss over Steps")
    plt.legend()

    # Save the plot to disk
    plt.savefig(f"{args.output_data_dir}/plots/training_loss_plot.png")

    # %%
    eval_results = trainer.evaluate()
    print(f"Perplexity: {2**eval_results['eval_loss']}")

And this is my FSDP config:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

I'm starting the script with:

accelerate launch --use_fsdp --config_file=fsdp_config.yaml train.py

Can you help me out some more?

Upvotes: 0

Views: 1547

Answers (1)

Jo&#227;o Moura
Jo&#227;o Moura

Reputation: 15

I don't know of a specific example, but here are my thoughts:

  • HuggingFace FSDP Trainer integration is launched with accelerate launch (explained here)
  • SageMaker allows you to pass in whatever entrypoint script you want (explained here)
  • So in principle nothing should change in your HF Trainer based FSDP script; you just have to create the appropriate launch script. Here is an example of a launch script for HF accelerate.

In general, using FSDP through HF's Trainer just abstracts having to deal with the actual training loop. As for SM DDP optimized collectives, those are now better integrated into pytorch; they just replace the default nccl backend by the smddp backend when you instantiate a process group with PyTorch. This will be turned on for you by default if you run HF's Trainer on a SageMaker Training Job, as you can see HF has integrated it here.

Upvotes: 0

Related Questions