alvas
alvas

Reputation: 122260

How to fine-tune a Mistral-7B model for machine translation?

There's a lot of tutorials online that uses raw text affix with arcane syntax to indicate document boundary and accessed through Huggingface datasets.Dataset object through the text key. E.g.

from datasets import load_dataset

dataset_name = "mlabonne/guanaco-llama2-1k"

dataset = load_dataset(dataset_name, split="train")
dataset["text"][42]

[out]:

<s>[INST] ¿Cuáles son los actuales presidentes de la región de Sur América? Enumérelos en una lista con su respectivo país. [/INST] A fecha del 13 de febrero de 2023, estos son los presidentes de los países de Sudamérica, según Wikipedia:
-Argentina: Alberto Fernández
-Bolivia: Luis Arce
-Brasil: Luiz Inácio Lula da Silva
-Chile: Gabriel Boric
-Colombia: Gustavo Petro
-Ecuador: Guillermo Lasso
-Paraguay: Mario Abdo Benítez
-Perú: Dina Boluarte
-Uruguay: Luis Lacalle Pou
-Venezuela: Nicolás Maduro
-Guyana: Irfaan Ali
-Surinam: Chan Santokhi
-Trinidad y Tobago: Paula-Mae Weekes </s>

But machine translation datasets are usually structured in 2 parts, source and target text with sentence_eng_Latn and sentence_deu_Latn keys, e.g.


valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, 
                          split="dev")
valid_data[42]

[out]:

{'id': 43,
 'URL': 'https://en.wikinews.org/wiki/Hurricane_Fred_churns_the_Atlantic',
 'domain': 'wikinews',
 'topic': 'disaster',
 'has_image': 0,
 'has_hyperlink': 0,
 'sentence_eng_Latn': 'The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say.',
 'sentence_deu_Latn': 'Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht.'}

How to fine-tune a Mistral-7b model for the machine translation task?

Upvotes: 4

Views: 895

Answers (1)

alvas
alvas

Reputation: 122260

The key is to re-format the data from a traditional machine translation dataset that splits the source and target text and piece them up together in a format that the model expects.

For the Mistral 7B model specifically, it usually expects:

  • each row of data would be encapsulated between <s> and where
    • the input source sentence would be embedded between the [INST] ... [/INST]
    • the output target sentence would be after the [/INST] symbol
  • any pre-data prompts before the [INST] ... [/INST]

E.g. if we want to use a translation prompt as such "Translate English to German:",

valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, split="dev")

def preprocess_func(row):
  return {'text': "Translate from English to German: <s>[INST] " + row['sentence_eng_Latn'] + " [INST] " + row['sentence_deu_Latn'] + " </s>"}

valid_dataset = valid_data.map(preprocess_func)

valid_dataset[42]

[out]:

{'id': 43,
 'URL': 'https://en.wikinews.org/wiki/Hurricane_Fred_churns_the_Atlantic',
 'domain': 'wikinews',
 'topic': 'disaster',
 'has_image': 0,
 'has_hyperlink': 0,
 'sentence_eng_Latn': 'The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say.',
 'sentence_deu_Latn': 'Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht.',
 'text': 'Translate from English to German: <s>[INST] The storm, situated about 645 miles (1040 km) west of the Cape Verde islands, is likely to dissipate before threatening any land areas, forecasters say. [INST] Prognostiker sagen, dass sich der Sturm, der etwa 645 Meilen (1040 km) westlich der Kapverdischen Inseln befindet, wahrscheinlich auflösen wird, bevor er Landflächen bedroht. </s>'}

Then the normal fine-tuning Mistral-7b scripts could just read the text key in the dataset, e.g.

Requires

!pip install -U transformers sentencepiece datasets
!pip install -q -U transformers
!pip install -q -U accelerate
!pip install -q -U bitsandbytes

!pip install -U peft
!pip install -U trl

And if you are in a Jupyter environment, you'll need to reset the kernel after installing accelerate, so:

import os
os._exit(00)

Then:

from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,HfArgumentParser,TrainingArguments,pipeline, logging
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training, get_peft_model
import os,torch
from datasets import load_dataset
from trl import SFTTrainer


base_model = "mistralai/Mistral-7B-Instruct-v0.2"
new_model = "mistral_7b_flores_dev_en_de"


bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)
model = AutoModelForCausalLM.from_pretrained(
        base_model,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
)
model.config.use_cache = False 
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()



tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
tokenizer.add_bos_token, tokenizer.add_eos_token



valid_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, 
                          split="dev")

test_data = load_dataset("facebook/flores", "eng_Latn-deu_Latn", streaming=False, 
                          split="devtest")



def preprocess_func(row):
  return {'text': "Translate from English to German: <s>[INST] " + row['sentence_eng_Latn'] + " [INST] " + row['sentence_deu_Latn'] + " </s>"}


valid_dataset = valid_data.map(preprocess_func)
test_dataset = test_data.map(preprocess_func)



model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj","gate_proj"]
)
model = get_peft_model(model, peft_config)


training_arguments = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    optim="paged_adamw_32bit",
    save_steps=25,
    logging_steps=25,
    learning_rate=2e-4,
    weight_decay=0.001,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type="constant",
    report_to=None
)


trainer = SFTTrainer(
    model=model,
    train_dataset=valid_dataset,
    peft_config=peft_config,
    max_seq_length= None,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)

trainer.train()

Upvotes: 5

Related Questions