- Notifications
You must be signed in to change notification settings - Fork17
PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models(NeurIPS 2024 Spotlight)
GraphPKU/PiSSA
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
We introduce a parameter-efficient fine-tuning (PEFT) method,PrincipalSingular values andSingular vectorsAdaptation (PiSSA), which optimizes the essential singular values and vectors while freezing the "noisy" parts. In comparison, LoRA freezes the original matrix and updates the "noise". This distinction enables PiSSA to convergence much faster than LoRA and also achieve better performance in the end. On five common benchmarks, PiSSA outperforms LoRA on all of them using exactly the same setups except for a different initialization. On GSM8K, Mistral-7B fine-tuned with PiSSA achieves an accuracy of 72.86%, outperforming LoRA's 67.7% by 5.16%.Due to the same architecture, PiSSA inherits many of LoRA's advantages, such as parameter efficiency and compatibility with quantization.Furthermore, PiSSA reduces the 4-bit quantization error in LLaMA 2-7B by 18.97%, resulting in a substantial improvement in fine-tuning performance. On the GSM8K benchmark, PiSSA achieves an accuracy of 49.13%, surpassing the performances of QLoRA at 39.8% and LoftQ at 40.71%.Leveraging a fast SVD technique, the initialization of PiSSA takes only a few seconds, inducing negligible cost of switching LoRA to PiSSA.
- [2025.01.09] ProvideDocument and中文文档 to help you better use PiSSA for training and testing.
- [2024.07.17] PiSSA now support Conv2d and Embedding,here is an example for using PiSSA on SDXL.
- [2024.07.16] PiSSA now support deepspeed.
- [2024.05.16] PiSSA has been merged into themain branch of peft as an optional initialization method for LoRA.
Install PiSSA via pip:
git clone https://github.com/GraphPKU/PiSSA.gitcd PiSSA/# export HF_ENDPOINT=https://hf-mirror.compip install -U huggingface_hubhuggingface-cli download --repo-type dataset --resume-download fxmeng/pissa-dataset --local-dir pissa-datasetconda create -n pissa python=3.10conda activate pissaconda install nvidia/label/cuda-12.1.0::cuda-toolkitconda install pytorch==2.4.0 torchvision=0.19.0 pytorch-cuda=12.1 -c pytorch -c nvidiapip install -r requirements.txtpip install flash-attn --no-build-isolation
All the datasets we used are publicly available atDataset.
The PiSSA-initialized models are shared onModels for easy reuse. They retain the same input and output as the original models but are split into residual models and PiSSA adapters for more effective fine-tuning.
PiSSA | QPiSSA | |
---|---|---|
LLaMA-2-7B | r128 | r16,32,64,128 |
LLaMA-3-8B | r16,32,64,128 | r64,128 |
LLaMA-3-8B-Instruct | r16,32,64,128 | -- |
LLaMA-3-70B | -- | r64,128 |
LLaMA-3-70B-Instruct | -- | r128 |
Qwen2-7B | r128 | r128 |
Qwen2-7B-Instruct | r128 | r128 |
Qwen2-72B | -- | r64,128 |
Qwen2-72B-Instruct | -- | r64,128 |
Running the following script will automatically download the model, then start training:
sh scripts/*/run_full_finetune.shsh scripts/*/lora.shsh scripts/*/pissa.shsh scripts/*/loftq.shsh scripts/*/qlora.shsh scripts/*/qpissa.sh
To evaluate the performance of your fine-tuned model, please follow the instructions infxmeng/pissa-dataset.
We recommend downloading decomposed models directly from theHugging Face Collections instead of performing SVD every time.If the existing models do not meet your needs, apply PiSSA initialization to a pre-trained model and store the decomposed model locally:
importtorchimportosfrompeftimportLoraConfig,get_peft_modelfromtransformersimportAutoTokenizer,AutoModelForCausalLMMODEL_ID="meta-llama/Llama-2-7b-hf"model=AutoModelForCausalLM.from_pretrained(MODEL_ID,torch_dtype=torch.bfloat16,device_map="auto")tokenizer=AutoTokenizer.from_pretrained(MODEL_ID)tokenizer.pad_token_id=tokenizer.eos_token_idlora_config=LoraConfig(# init_lora_weights="pissa", # Configure the initialization method to "pissa", which may take several minutes to execute SVD on the pre-trained model.init_lora_weights="pissa_niter_4",# Initialize the PiSSA with fast SVD, which completes in just a few seconds.r=128,lora_alpha=128,lora_dropout=0,# Since the component of the PiSSA adapter are the principal singular values and vectors, dropout should be set to 0 to avoid random discarding.target_modules=["q_proj","o_proj","k_proj","v_proj","gate_proj","up_proj","down_proj"],task_type="CAUSAL_LM",)peft_model=get_peft_model(model,lora_config)peft_model.print_trainable_parameters()OUTPUT_DIR="PiSSA-Llama-2-7b-hf-r128"# Save PiSSA modules:peft_model.peft_config["default"].init_lora_weights=True# Importantpeft_model.save_pretrained(os.path.join(OUTPUT_DIR,"pissa_init"))# Save residual model:peft_model=peft_model.unload()peft_model.save_pretrained(OUTPUT_DIR)# Save the tokenizer:tokenizer.save_pretrained(OUTPUT_DIR)
Load a pre-processed model and finetune it on IMDB dataset:
fromtrlimportSFTTrainerfromdatasetsimportload_datasetfromtransformersimportAutoTokenizer,AutoModelForCausalLMfrompeftimportPeftModelMODEL_ID="PiSSA-Llama-2-7b-hf-r128"residual_model=AutoModelForCausalLM.from_pretrained(MODEL_ID,device_map="auto")model=PeftModel.from_pretrained(residual_model,MODEL_ID,subfolder="pissa_init",is_trainable=True)tokenizer=AutoTokenizer.from_pretrained(MODEL_ID)dataset=load_dataset("imdb",split="train[:1%]")# Only use 1% of the datasettrainer=SFTTrainer(model=peft_model,train_dataset=dataset,dataset_text_field="text",max_seq_length=128,tokenizer=tokenizer,)trainer.train()peft_model.save_pretrained("pissa-llama-2-7b-ft")
When usingpeft_model.save_pretrained
, ifpath_initial_model_for_weight_conversion=None
, the fine-tuned matricespath_initial_model_for_weight_conversion="pissa_init_dir"
, the saving function converts PiSSA to LoRA by
importtorchfrompeftimportPeftModelfromtransformersimportAutoModelForCausalLMmodel=AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf",torch_dtype=torch.bfloat16,device_map="auto")# No SVD is performed during this step, and the base model remains unaltered.peft_model=PeftModel.from_pretrained(model,"pissa-llama-2-7b-lora")
Utilizing the converted LoRA does not require modifying the parameters of the base model. When multiple converted LoRAs are needed simultaneously, each adapter operates independently without interference, allowing for the adapters to be freely deleted or added.
@article{meng2024pissa, title={Pissa: Principal singular values and singular vectors adaptation of large language models}, author={Meng, Fanxu and Wang, Zhaohui and Zhang, Muhan}, journal={arXiv preprint arXiv:2404.02948}, year={2024}}
2024, May 27,LoRA-XS: Low-Rank Adaptation with Extremely Small Number of Parameters performs basis adaption for principal singular values and singular vectors.
2024, May 30,SVFT: Parameter-Efficient Fine-Tuning with Singular Vectors freeze the singular vectors while fintune the singular values in a sparse manner.
2024, Jun 3,OLoRA: Orthonormal Low-Rank Adaptation of Large Language Models, leverages orthonormal matrix initialization through QR decomposition.
2024, Jun 7,CorDA: Context-Oriented Decomposition Adaptation of Large Language Models, leverages knowledge-preserved adaptation and the instruction-previewed adaptation through Context-oriented Decomposition.
2024, Jun 7,MiLoRA: Harnessing Minor Singular Components for Parameter-Efficient LLM Finetuning, Minor Singular Components Adaption.
2024, Jun 18,LaMDA: Large Model Fine-Tuning via Spectrally Decomposed Low-Dimensional Adaptation performs basis adaption for principal singular values and singular vectors.
2024, Jul 6,LoRA-GA: Low-Rank Adaptation with Gradient Approximation aligns the gradients of low-rank matrix product with those of full fine-tuning at the first step.
2024, Jul 25,LoRA-Pro: Are Low-Rank Adapters Properly Optimized? strategically adjusts the gradients of adapters, enabling the low-rank gradients to more accurately approximate the full fine-tuning gradients.
2024, Oct 9,One Initialization to Rule them All: Fine-tuning via Explained Variance Adaptation initialize adapter in a data-driven manner by computing singular value decomposition on minibatches of activation vectors.
2024, Nov 7,SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models consolidate the outliers by shifting them from activations to weights, then employ a high-precision low-rank branch to take in the weight outliers with SVD.