Reputation: 1
I'm currently trying to fine-tune a LLM in AWS Sagemaker. Since it's too big to fit on a single GPU I'm trying to distribute the model weights over multiple GPUs in an AWS Sagemaker instance. In my training script, I use the HuggingFace Trainer. Since the HuggingFace Trainer (with the fsdp parameter), the PyTorch library (with torch.distributed) as well as AWS Sagemaker (with smdistributed) all have mechanisms to enable fsdp I'm entirely confused how I can (or should?) enable FSDP for my use case.
I'd be very glad if someone could help me out here by providing a minimal but working example on how to enable FSDP by utilizing the HuggingFace Trainer in an AWS Sagemaker Training Job.
Edit: I now tried to implement the suggestion, but run into the following error:
Traceback (most recent call last):
File "/app/train.py", line 178, in <module>
train_results = trainer.train()
File "/usr/local/lib/python3.9/site-packages/transformers/trainer.py", line 1624, in train
return inner_training_loop(
File "/usr/local/lib/python3.9/site-packages/transformers/trainer.py", line 1766, in _inner_training_loop
self.model = self.accelerator.prepare(self.model)
File "/usr/local/lib/python3.9/site-packages/accelerate/accelerator.py", line 1228, in prepare
result = tuple(
File "/usr/local/lib/python3.9/site-packages/accelerate/accelerator.py", line 1229, in <genexpr>
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
File "/usr/local/lib/python3.9/site-packages/accelerate/accelerator.py", line 1105, in _prepare_one
return self.prepare_model(obj, device_placement=device_placement)
File "/usr/local/lib/python3.9/site-packages/accelerate/accelerator.py", line 1328, in prepare_model
if torch.device(current_device_index) != self.device:
TypeError: device() received an invalid combination of arguments - got (NoneType), but expected one of:
* (torch.device device)
didn't match because some of the arguments have invalid types: (!NoneType!)
* (str type, int index)
This is my training script:
from torch.utils.data import Dataset
import torch
from peft import (
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
prepare_model_for_int8_training,
)
from datasets import load_from_disk
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import argparse
import sys
import os
import logging
import matplotlib.pyplot as plt
if __name__ == "__main__":
parser = argparse.ArgumentParser()
# hyperparameters sent by the client
parser.add_argument("--batch_size", type=int, default=128)
parser.add_argument("--per_device_train_batch_size", type=int, default=32)
parser.add_argument("--model_name", type=str, default="codellama/CodeLlama-7b-hf")
parser.add_argument("--learn_rate", type=str, default="3e-4")
parser.add_argument("--warmup_steps", type=int, default=400)
# Data, model and output directories
parser.add_argument("--output_data_dir", type=str, default="/opt/ml/output/data")
parser.add_argument("--model-dir", type=str, default="/opt/ml/model")
parser.add_argument("--n_gpus", type=str, default="4")
parser.add_argument("--training_dir", type=str, default="/opt/ml/input/data/train")
parser.add_argument("--test_dir", type=str, default="/opt/ml/input/data/test")
args, _ = parser.parse_known_args()
# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.getLevelName("INFO"),
handlers=[logging.StreamHandler(sys.stdout)],
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
load_in_8bit=True,
torch_dtype=torch.float16
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
# %%
tokenizer.add_eos_token = True
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"
# %%
class MappingDataset(Dataset):
def __init__(self, train=True):
if train:
self.dataset = load_from_disk(dataset_path=args.training_dir)
logger.info(f"loaded train dataset with a length of:{len(self.dataset)}")
else:
self.dataset = load_from_disk(dataset_path=args.test_dir)
logger.info(f"loaded test dataset with a length of:{len(self.dataset)}")
self.dataset = self.dataset.select(range(1000))
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
return self.dataset[idx]
# %%
train_dataset = MappingDataset(train=True)
val_dataset = MappingDataset(train=False)
# %%
model.train() # put model back into training mode
model = prepare_model_for_int8_training(model)
config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
# %% [markdown]
# The cell below keeps the Trainer from trying its own DataParallelism when more than 1 gpu is available
# %%
if torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
# %%
gradient_accumulation_steps = args.batch_size // args.per_device_train_batch_size
output_dir = "bis-mapping-code-llama"
training_args = TrainingArguments(
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=args.warmup_steps,
learning_rate=float(args.learn_rate),
fp16=True,
logging_steps=1,
optim="adamw_torch",
evaluation_strategy="steps", # if val_set_size > 0 else "no",
save_strategy="steps",
eval_steps=20,
save_steps=20,
output_dir=args.model_dir,
load_best_model_at_end=False,
group_by_length=True, # group sequences of roughly the same length together to speed up training
log_level='debug',
logging_dir=f"{args.output_data_dir}/logs"
)
def compute_perplexity(pred):
# Extract the predicted logits from the model output
logits = pred.predictions
# Flatten the logits and labels to compute cross-entropy loss
logits = logits.view(-1, logits.size(-1))
labels = pred.label_ids.view(-1)
# Compute cross-entropy loss
loss = torch.nn.functional.cross_entropy(logits, labels)
# Compute perplexity
perplexity = torch.exp(loss)
return {"perplexity": perplexity.item()}
trainer = Trainer(
model=model,
train_dataset=train_dataset,
eval_dataset=val_dataset,
args=training_args,
compute_metrics=compute_perplexity,
data_collator=DataCollatorForSeq2Seq(
tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
),
)
# %% [markdown]
# The cell below only serves for optimizing the model training
# %%
model.config.use_cache = False
old_state_dict = model.state_dict
model.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())).__get__(
model, type(model)
)
if torch.__version__ >= "2" and sys.platform != "win32":
print("compiling the model")
model = torch.compile(model)
# %%
train_results = trainer.train()
train_loss_values = train_results["train_loss"]
# Plot the loss values
plt.plot(train_loss_values, label="Training Loss")
plt.xlabel("Training Steps")
plt.ylabel("Loss")
plt.title("Training Loss over Steps")
plt.legend()
# Save the plot to disk
plt.savefig(f"{args.output_data_dir}/plots/training_loss_plot.png")
# %%
eval_results = trainer.evaluate()
print(f"Perplexity: {2**eval_results['eval_loss']}")
And this is my FSDP config:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
I'm starting the script with:
accelerate launch --use_fsdp --config_file=fsdp_config.yaml train.py
Can you help me out some more?
Upvotes: 0
Views: 1547
Reputation: 15
I don't know of a specific example, but here are my thoughts:
accelerate launch
(explained here)In general, using FSDP through HF's Trainer just abstracts having to deal with the actual training loop. As for SM DDP optimized collectives, those are now better integrated into pytorch; they just replace the default nccl backend by the smddp backend when you instantiate a process group with PyTorch. This will be turned on for you by default if you run HF's Trainer on a SageMaker Training Job, as you can see HF has integrated it here.
Upvotes: 0