Fine-tune Llama 2 with LoRA: Customizing a large language model for question-answering
Contents
Fine-tune Llama 2 with LoRA: Customizing a large language model for question-answering#

1, Feb 2024 by .
In this blog, we show you how to fine-tune Llama 2 on an AMD GPU with ROCm. We use Low-RankAdaptation of Large Language Models (LoRA) to overcome memory and computing limitations andmake open-source large language models (LLMs) more accessible. We also show you how tofine-tune and upload models to Hugging Face.
Introduction#
In the dynamic realm of Generative AI (GenAI), fine-tuning LLMs (such as Llama 2) poses distinctivechallenges related to substantial computational and memory requirements. LoRA introduces acompelling solution, allowing rapid and cost-effective fine-tuning of state-of-the-art LLMs. Thisbreakthrough capability not only expedites the tuning process, but also lowers associated costs.
To explore the benefits of LoRA, we provide a comprehensive walkthrough of the fine-tuning processfor Llama 2 using LoRA specifically tailored for question-answering (QA) tasks on an AMD GPU.
Before jumping in, let’s take a moment to briefly review the three pivotal components that form thefoundation of our discussion:
Llama 2: Meta’s advanced language model with variants that scale up to 70 billion parameters.
Fine-tuning: A crucial process that refines LLMs for specialized tasks, optimizing its performance.
LoRA: The algorithm employed for fine-tuning Llama 2, ensuring effective adaptation to specializedtasks.
Llama 2#
Llama 2 is a collection of second-generation, open-source LLMsfrom Meta; it comes with a commercial license. Llama 2 is designed to handle a wide range of naturallanguage processing (NLP) tasks, with models ranging in scale from 7 billion to 70 billion parameters.
Llama 2 Chat, which is optimized for dialogue, has shown similar performance to popularclosed-source models like ChatGPT and PaLM. You can improve the performance of this model byfine-tuning it with a high-quality conversational data set. In this blog post, we delve into the process ofrefining a Llama 2 Chat model using a QA data set.
Fine-tuning a model#
Fine-tuning in machine learning is the process of adjusting the weights and parameters of apre-trained model using new data in order to improve its performance on a specific task. It involvesusing a new data set–one that is specific to the current task–to update the model’s weights. It’stypically not possible to fine-tune LLMs on consumer hardware due to inadequate memory andcomputing power. However, in this tutorial, we use LoRA to overcome these challenges.
LoRA#
LoRA is an innovative technique– developed by researchers atMicrosoft–designed to address the challenges of fine-tuning LLMs. This results in a significantreduction in the number of parameters (by a factor of up to 10,000) that need to be fine-tuned, whichsignificantly reduces GPU memory requirements. To learn more about the fundamental principles of LoRA, refer toUsing LoRA for efficient fine-tuning: Fundamental principles.
Step-by-step Llama 2 fine-tuning#
Standard (full-parameter) fine-tuning involves considering all parameters. It requires significantcomputational power to manage optimizer states and gradient check-pointing. The resulting memoryfootprint is typically about four times larger than the model itself. For example, loading a 7 billionparameter model (e.g. Llama 2) in FP32 (4 bytes per parameter) requires approximately 28 GB of GPUmemory, while fine-tuning demands around 28*4=112 GB of GPU memory. Note that the 112 GBfigure is derived empirically, and various factors like batch size, data precision, and gradientaccumulation contribute to overall memory usage.
To overcome this memory limitation, you can use a parameter-efficient fine-tuning (PEFT) technique,such as LoRA.
This example leverages two GCDs (Graphics Compute Dies) of a AMD MI250 GPU and each GCD are equipped with 64 GB of VRAM. Using this setup allows us to explore different settings for fine-tuning the Llama 2–7b weights with and without LoRA.
Our setup:
Hardware & OS: Seethis link for a list of supported hardware and OS with ROCm.
Software:
Libraries:
transformers,accelerate,peft,trl,bitsandbytes,scipy
In this blog, we conducted our experiment using a single MI250GPU with the Docker imagerocm/pytorch:rocm6.1.2_ubuntu22.04_py3.10_pytorch_release-2.1.2.
Step 1: Getting started#
First, let’s confirm the availability of the GPU.
!rocm-smi--showproductname
Your output should look like this:
=========================ROCmSystemManagementInterface============================================================ProductInfo===================================GPU[0]:Cardseries:AMDINSTINCTMI250(MCM)OAMACMBAGPU[0]:Cardmodel:0x0b0cGPU[0]:Cardvendor:AdvancedMicroDevices,Inc.[AMD/ATI]GPU[0]:CardSKU:D65209GPU[1]:Cardseries:AMDINSTINCTMI250(MCM)OAMACMBAGPU[1]:Cardmodel:0x0b0cGPU[1]:Cardvendor:AdvancedMicroDevices,Inc.[AMD/ATI]GPU[1]:CardSKU:D65209===================================================================================================================EndofROCmSMILog================================
Next, install the required libraries.
!pipinstall-qpandaspeft==0.9.0transformers==4.31.0trl==0.4.7acceleratescipy
Install bitsandbytes#
Install bitsandbytes using the following code.
gitclone--recursehttps://github.com/ROCm/bitsandbytescdbitsandbytesgitcheckoutrocm_enabledpipinstall-rrequirements-dev.txtcmake-DCOMPUTE_BACKEND=hip-S.#Use -DBNB_ROCM_ARCH="gfx90a;gfx942" to target specific gpu archmakepipinstall.
Check the bitsandbytes version.
At the time of writing this blog, the version is 0.43.0.
%%bashpiplist|grepbitsandbytes
Import the required packages#
importtorchfromdatasetsimportload_datasetfromtransformersimport(AutoModelForCausalLM,AutoTokenizer,BitsAndBytesConfig,TrainingArguments,pipeline)frompeftimportLoraConfigfromtrlimportSFTTrainer
Step 2: Configuring the model and data#
You can access Meta’s official Llama-2 model from Hugging Face after making a request, which cantake a couple of days. Instead of waiting, we’ll use NousResearch’s Llama-2-7b-chat-hf as our basemodel (it’s the same as the original, but quicker to access).
# Model and tokenizer namesbase_model_name="NousResearch/Llama-2-7b-chat-hf"new_model_name="llama-2-7b-enhanced"#You can give your own name for fine tuned model# Tokenizerllama_tokenizer=AutoTokenizer.from_pretrained(base_model_name,trust_remote_code=True)llama_tokenizer.pad_token=llama_tokenizer.eos_tokenllama_tokenizer.padding_side="right"# Modelbase_model=AutoModelForCausalLM.from_pretrained(base_model_name,device_map="auto")base_model.config.use_cache=Falsebase_model.config.pretraining_tp=1
After you have the base model, you can start fine-tuning. We fine-tune our base model for aquestion-and-answer task using a small data set calledmlabonne/guanaco-llama2-1k, whichis a subset (1,000 samples) of thetimdettmers/openassistant-guanaco data set.This data set is a human-generated, human-annotated, assistant-style conversation corpus thatcontains 161,443 messages in 35 different languages, annotated with 461,292 quality ratings. Thisresults in over 10,000 fully annotated conversation trees.
# Data setdata_name="mlabonne/guanaco-llama2-1k"training_data=load_dataset(data_name,split="train")# check the dataprint(training_data.shape)# #11 is a QA sample in Englishprint(training_data[11])
(1000,1){'text':'<s>[INST] write me a 1000 words essay about deez nuts. [/INST] The Deez Nuts meme first gained popularity in 2015 on the social media platform Vine. The video featured a young man named Rodney Bullard, who recorded himself asking people if they had heard of a particular rapper. When they responded that they had not, he would respond with the phrase "Deez Nuts" and film their reactions. The video quickly went viral, and the phrase became a popular meme.\n\nSince then, Deez Nuts has been used in a variety of contexts to interrupt conversations, derail discussions, or simply add humor to a situation. It has been used in internet memes, in popular music, and even in politics. In the 2016 US presidential election, a 15-year-old boy named Brady Olson registered as an independent candidate under the name Deez Nuts...</s>'}
## There is a dependency during training!pipinstalltensorboardX
Step 3: Start fine-tuning#
To set your training parameters, use the following code:
# Training Paramstrain_params=TrainingArguments(output_dir="./results_modified",num_train_epochs=1,per_device_train_batch_size=4,gradient_accumulation_steps=1,optim="paged_adamw_32bit",save_steps=50,logging_steps=50,learning_rate=4e-5,weight_decay=0.001,fp16=False,bf16=False,max_grad_norm=0.3,max_steps=-1,warmup_ratio=0.03,group_by_length=True,lr_scheduler_type="constant",report_to="tensorboard")
Training with LoRA configuration#
Now you can integrate LoRA into the base model and assess its additional parameters. LoRA essentiallyadds pairs of rank-decomposition weight matrices (called update matrices) to existing weights, andonly trains the newly added weights.
frompeftimportget_peft_model# LoRA Configpeft_parameters=LoraConfig(lora_alpha=8,lora_dropout=0.1,r=8,bias="none",task_type="CAUSAL_LM")model=get_peft_model(base_model,peft_parameters)model.print_trainable_parameters()
The output looks like this:
trainableparams:4,194,304||allparams:6,742,609,920||trainable%:0.06220594176090199
Note that there are only 0.062% parameters added by LoRA, which is a tiny portion of the originalmodel. This is the percentage we’ll update through fine-tuning, as follows.
# Trainer with LoRA configurationfine_tuning=SFTTrainer(model=base_model,train_dataset=training_data,peft_config=peft_parameters,dataset_text_field="text",tokenizer=llama_tokenizer,args=train_params)# Trainingfine_tuning.train()
The output looks like this:
[250/25007:59,Epoch1/1]\StepTrainingLoss \501.976400 \1001.613500\1501.409100\2001.391500\2501.377300TrainOutput(global_step=250,training_loss=1.5535581665039062,metrics={'train_runtime':484.7942,'train_samples_per_second':2.063,'train_steps_per_second':0.516,'total_flos':1.701064079130624e+16,'train_loss':1.5535581665039062,'epoch':1.0})
To save your model, run this code:
# Save Modelfine_tuning.model.save_pretrained(new_model_name)
Checking memory usage during training with LoRA#
During training, you can check the memory usage by running therocm-smi command in a terminal.This command produces the following output:
=======================ROCmSystemManagementInterface===================================================ConciseInfo================================GPUTemp(DieEdge)AvgPwrSCLKMCLKFanPerfPwrCapVRAM%GPU%052.0c179.0W1700Mhz1600Mhz0%auto300.0W65%100%152.0c171.0W1650Mhz1600Mhz0%auto300.0W66%100%=========================================================================================================EndofROCmSMILog============================
To facilitate a comparison between fine-tuning with and without LoRA, our subsequent phase involvesrunning a thorough fine-tuning process on the base model. This involves updating all parameterswithin the base model. We then analyze differences in memory usage, training speed, training loss, andother relevant metrics.
Training without LoRA configuration#
For this section, you must restart the kernel and skip the ‘Training with LoRA configuration’ section.
For a direct comparison between models using the same criteria, we maintain consistent settings(without any alterations) fortrain_params during the full-parameter fine-tuning process.
To check the trainable parameters in your base model, use the following code.
defprint_trainable_parameters(model):""" Prints the number of trainable parameters in the model. """trainable_params=0all_param=0for_,paraminmodel.named_parameters():all_param+=param.numel()ifparam.requires_grad:trainable_params+=param.numel()print(f"trainable params:{trainable_params} || all params:{all_param} || trainable%:{100*trainable_params/all_param:.2f}")print_trainable_parameters(base_model)
The output looks like this:
trainableparams:6738415616||allparams:6738415616||trainable%:100.00
Continue the process using the following code:
# Set a lower learning rate for fine-tuningtrain_params.learning_rate=4e-7print(train_params.learning_rate)
# Trainer without LoRA configurationfine_tuning_full=SFTTrainer(model=base_model,train_dataset=training_data,dataset_text_field="text",tokenizer=llama_tokenizer,args=train_params)# Trainingfine_tuning_full.train()
The output looks like this:
[250/2503:02:12,Epoch1/1]\StepTrainingLoss\501.712300\1001.487000\1501.363800\2001.371100\2501.368300TrainOutput(global_step=250,training_loss=1.4604909362792968,metrics={'train_runtime':10993.7995,'train_samples_per_second':0.091,'train_steps_per_second':0.023,'total_flos':1.6999849383985152e+16,'train_loss':1.4604909362792968,'epoch':1.0})
Checking memory usage during training without LoRA#
During training, you can check the memory usage by running therocm-smi command in a terminal.This command produces the following output:
=======================ROCmSystemManagementInterface===================================================ConciseInfo================================GPUTemp(DieEdge)AvgPwrSCLKMCLKFanPerfPwrCapVRAM%GPU%040.0c44.0W800Mhz1600Mhz0%auto300.0W100%89%139.0c50.0W1700Mhz1600Mhz0%auto300.0W100%85%=========================================================================================================EndofROCmSMILog============================
Step 4: Comparison between fine-tuning with LoRA and full-parameter fine-tuning#
Comparing the results from theTraining with LoRA configuration andTraining without LoRA configuration sections, note the following:
Memory usage:
In the case of full-parameter fine-tuning, there are6,738,415,616 trainable parameters, leadingto significant memory consumption during the training back propagation stage.
LoRA only introduces4,194,304 trainable parameters, accounting for0.062% of the totaltrainable parameters in full-parameter fine-tuning.
Monitoring memory usage during training with and without LoRA reveals that fine-tuning with LoRAuses only65% of the memory consumed by full-parameter fine-tuning. This presents anopportunity to increase batch size and max sequence length, and train on larger data sets usinglimited hardware resources.
Training speed:
The results demonstrate that full-parameter fine-tuning takeshours to complete, whilefine-tuning with LoRA finishes in less than9 minutes. Several factors contribute to thisacceleration:
Fewer trainable parameters in LoRA translate to fewer derivative calculations and less memoryrequired to store and update weights.
Full-parameter fine-tuning is more prone to being memory-bound, where data movementbecomes a bottleneck for training. This is reflected in lower GPU utilization. Although adjustingtraining settings can alleviate this, it may require more resources (additional GPUs) and a smallerbatch size.
Accuracy:
In both training sessions, a notable reduction in training loss was observed. We achieved a closelyaligned training loss for two both approaches:1.368 for full-parameter fine-tuning and1.377 for fine-tuning with LoRA. If you’re interested in understanding the impact of LoRA onfine-tuning performance, refer toLoRA: Low-Rank Adaptation of Large Language Models.
Step 5: Test the fine-tuned model with LoRA#
To test your model, run the following code:
# Reload model in FP16 and merge it with LoRA weightsbase_model=AutoModelForCausalLM.from_pretrained(base_model_name,low_cpu_mem_usage=True,return_dict=True,torch_dtype=torch.float16,device_map="auto")frompeftimportLoraConfig,PeftModelmodel=PeftModel.from_pretrained(base_model,new_model_name)model=model.merge_and_unload()# Reload tokenizer to save ittokenizer=AutoTokenizer.from_pretrained(base_model_name,trust_remote_code=True)tokenizer.pad_token=tokenizer.eos_tokentokenizer.padding_side="right"
The output looks like this:
Loadingcheckpointshards:100%|██████████|2/2[00:04<00:00,2.34s/it]
Uploading the model to Hugging Face let’s you conduct subsequent tests or share your model withothers (to proceed with this step, you’ll need an active Hugging Face account).
fromhuggingface_hubimportlogin# You need to use your Hugging Face Access Tokenslogin("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")# Push the model to Hugging Face. This takes minutes and time depends the model size and your# network speed.model.push_to_hub(new_model_name,use_temp_dir=False)tokenizer.push_to_hub(new_model_name,use_temp_dir=False)
Now you can test with the base model (original) and your fine-tuned model.
Base model:
# Generate text using base modelquery="What do you think is the most important part of building an AI chatbot?"text_gen=pipeline(task="text-generation",model=base_model_name,tokenizer=llama_tokenizer,max_length=200)output=text_gen(f"<s>[INST]{query} [/INST]")print(output[0]['generated_text'])
# Outputs:<s>[INST]WhatdoyouthinkisthemostimportantpartofbuildinganAIchatbot?[/INST]ThereareseveralimportantaspectstoconsiderwhenbuildinganAIchatbot,butherearesomeofthemostcriticalelements:1.NaturalLanguageProcessing(NLP):Achatbot's ability to understand and interpret human language is crucial for effective communication. NLP is the foundation of any chatbot, and it involves training the AI model to recognize patterns in language, interpret meaning, and generate responses.2.ConversationalFlow:Achatbot's conversational flow refers to the way it interacts with users. A well-designed conversational flow should be intuitive, easy to follow, and adaptable to different user scenarios. This involves creating a dialogue flowchart that guides the conversation and ensures the chatbot responds appropriately to user inputs.3.DomainKnowledge:Achat
Fine-tuned model:
# Generate text using fine-tuned modelquery="What do you think is the most important part of building an AI chatbot?"text_gen=pipeline(task="text-generation",model=new_model_name,tokenizer=llama_tokenizer,max_length=200)output=text_gen(f"<s>[INST]{query} [/INST]")print(output[0]['generated_text'])
# Outputs:<s>[INST]WhatdoyouthinkisthemostimportantpartofbuildinganAIchatbot?[/INST]ThemostimportantpartofbuildinganAIchatbotistoensurethatitisabletounderstandandrespondtouserinputinawaythatisbothaccurateandnatural-sounding.Thisrequiresacombinationofnaturallanguageprocessing(NLP)capabilitiesandawell-designedconversationalflow.HerearesomekeyfactorstoconsiderwhenbuildinganAIchatbot:1.NaturalLanguageProcessing(NLP):Thechatbotmustbeabletounderstandandinterpretuserinput,includingbothtextandvoicecommands.ThisrequiresarobustNLPenginethatcanhandleawiderangeoflanguageanddialects.2.ConversationalFlow:Thechatbotmustbeabletorespondtouserinputinawaythatisbothnaturalandintuitive.Thisrequiresawell-designedconversationalflowthatcanhandleawiderange
You can observe the outputs of the two models based on a given query. These outputs exhibit slightdifferences due to the fine-tuning process altering the model weights.