Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Train transformer language models with reinforcement learning.

License

NotificationsYou must be signed in to change notification settings

huggingface/trl

Repository files navigation

TRL Banner


A comprehensive library to post-train foundation models

LicenseDocumentationGitHub releaseHugging Face Hub

Overview

TRL is a cutting-edge library designed for post-training foundation models using advanced techniques like Supervised Fine-Tuning (SFT), Proximal Policy Optimization (PPO), and Direct Preference Optimization (DPO). Built on top of the🤗 Transformers ecosystem, TRL supports a variety of model architectures and modalities, and can be scaled-up across various hardware setups.

Highlights

  • Trainers: Various fine-tuning methods are easily accessible via trainers likeSFTTrainer,GRPOTrainer,DPOTrainer,RewardTrainer and more.

  • Efficient and scalable:

    • Leverages🤗 Accelerate to scale from single GPU to multi-node clusters using methods likeDDP andDeepSpeed.
    • Full integration with🤗 PEFT enables training on large models with modest hardware via quantization and LoRA/QLoRA.
    • Integrates🦥 Unsloth for accelerating training using optimized kernels.
  • Command Line Interface (CLI): A simple interface lets you fine-tune with models without needing to write code.

Installation

Python Package

Install the library usingpip:

pip install trl

From source

If you want to use the latest features before an official release, you can install TRL from source:

pip install git+https://github.com/huggingface/trl.git

Repository

If you want to use the examples you can clone the repository with the following command:

git clone https://github.com/huggingface/trl.git

Quick Start

For more flexibility and control over training, TRL provides dedicated trainer classes to post-train language models or PEFT adapters on a custom dataset. Each trainer in TRL is a light wrapper around the 🤗 Transformers trainer and natively supports distributed training methods like DDP, DeepSpeed ZeRO, and FSDP.

SFTTrainer

Here is a basic example of how to use theSFTTrainer:

fromtrlimportSFTTrainerfromdatasetsimportload_datasetdataset=load_dataset("trl-lib/Capybara",split="train")trainer=SFTTrainer(model="Qwen/Qwen2.5-0.5B",train_dataset=dataset,)trainer.train()

GRPOTrainer

GRPOTrainer implements theGroup Relative Policy Optimization (GRPO) algorithm that is more memory-efficient than PPO and was used to trainDeepseek AI's R1.

fromdatasetsimportload_datasetfromtrlimportGRPOTrainerdataset=load_dataset("trl-lib/tldr",split="train")# Dummy reward function: count the number of unique characters in the completionsdefreward_num_unique_chars(completions,**kwargs):return [len(set(c))forcincompletions]trainer=GRPOTrainer(model="Qwen/Qwen2-0.5B-Instruct",reward_funcs=reward_num_unique_chars,train_dataset=dataset,)trainer.train()

DPOTrainer

DPOTrainer implements the popularDirect Preference Optimization (DPO) algorithm that was used to post-trainLlama 3 and many other models. Here is a basic example of how to use theDPOTrainer:

fromdatasetsimportload_datasetfromtransformersimportAutoModelForCausalLM,AutoTokenizerfromtrlimportDPOConfig,DPOTrainermodel=AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")tokenizer=AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")dataset=load_dataset("trl-lib/ultrafeedback_binarized",split="train")training_args=DPOConfig(output_dir="Qwen2.5-0.5B-DPO")trainer=DPOTrainer(model=model,args=training_args,train_dataset=dataset,processing_class=tokenizer)trainer.train()

RewardTrainer

Here is a basic example of how to use theRewardTrainer:

fromtrlimportRewardConfig,RewardTrainerfromdatasetsimportload_datasetfromtransformersimportAutoModelForSequenceClassification,AutoTokenizertokenizer=AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")model=AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct",num_labels=1)model.config.pad_token_id=tokenizer.pad_token_iddataset=load_dataset("trl-lib/ultrafeedback_binarized",split="train")training_args=RewardConfig(output_dir="Qwen2.5-0.5B-Reward",per_device_train_batch_size=2)trainer=RewardTrainer(args=training_args,model=model,processing_class=tokenizer,train_dataset=dataset,)trainer.train()

Command Line Interface (CLI)

You can use the TRL Command Line Interface (CLI) to quickly get started with post-training methods like Supervised Fine-Tuning (SFT) or Direct Preference Optimization (DPO):

SFT:

trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \    --dataset_name trl-lib/Capybara \    --output_dir Qwen2.5-0.5B-SFT

DPO:

trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \    --dataset_name argilla/Capybara-Preferences \    --output_dir Qwen2.5-0.5B-DPO

Read more about CLI in therelevant documentation section or use--help for more details.

Development

If you want to contribute totrl or customize it to your needs make sure to read thecontribution guide and make sure you make a dev install:

git clone https://github.com/huggingface/trl.gitcd trl/pip install -e .[dev]

Citation

@misc{vonwerra2022trl,author ={Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},title ={TRL: Transformer Reinforcement Learning},year ={2020},publisher ={GitHub},journal ={GitHub repository},howpublished ={\url{https://github.com/huggingface/trl}}}

License

This repository's source code is available under theApache-2.0 License.

About

Train transformer language models with reinforcement learning.

Resources

License

Code of conduct

Stars

Watchers

Forks

Packages

No packages published

Languages


[8]ページ先頭

©2009-2025 Movatter.jp