- Notifications
You must be signed in to change notification settings - Fork386
A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
License
pytorch/rl
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Documentation |TensorDict |Features |Examples, tutorials and demos |Citation |Installation |Asking a question |Contributing
TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch.
TorchRL now includes a comprehensiveLLM API for post-training and fine-tuning of language models! This new framework provides everything you need for RLHF, supervised fine-tuning, and tool-augmented training:
- 🤖Unified LLM Wrappers: Seamless integration with Hugging Face models and vLLM inference engines - more to come!
- 💬Conversation Management: Advanced
History
class for multi-turn dialogue with automatic chat template detection - 🛠️Tool Integration:Built-in support for Python code execution, function calling, and custom tool transforms
- 🎯Specialized Objectives:GRPO (Group Relative Policy Optimization) andSFT loss functions optimized for language models
- ⚡High-Performance Collectors:Async data collection with distributed training support
- 🔄Flexible Environments: Transform-based architecture for reward computation, data loading, and conversation augmentation
The LLM API follows TorchRL's modular design principles, allowing you to mix and match components for your specific use case. Check out thecomplete documentation andGRPO implementation example to get started!
Quick LLM API Example
fromtorchrl.envs.llmimportChatEnvfromtorchrl.modules.llmimportTransformersWrapperfromtorchrl.objectives.llmimportGRPOLossfromtorchrl.collectors.llmimportLLMCollector# Create environment with Python tool executionenv=ChatEnv(tokenizer=tokenizer,system_prompt="You are an assistant that can execute Python code.",batch_size=[1]).append_transform(PythonInterpreter())# Wrap your language modelllm=TransformersWrapper(model=model,tokenizer=tokenizer,input_mode="history")# Set up GRPO trainingloss_fn=GRPOLoss(llm,critic,gamma=0.99)collector=LLMCollector(env,llm,frames_per_batch=100)# Training loopfordataincollector:loss=loss_fn(data)loss.backward()optimizer.step()
- 🐍Python-first: Designed with Python as the primary language for ease of use and flexibility
- ⏱️Efficient: Optimized for performance to support demanding RL research applications
- 🧮Modular, customizable, extensible: Highly modular architecture allows for easy swapping, transformation, or creation of new components
- 📚Documented: Thorough documentation ensures that users can quickly understand and utilize the library
- ✅Tested: Rigorously tested to ensure reliability and stability
- ⚙️Reusable functionals: Provides a set of highly reusable functions for cost functions, returns, and data processing
- 🔥Aligns with PyTorch ecosystem: Follows the structure and conventions of popular PyTorch libraries(e.g., dataset pillar, transforms, models, data utilities)
- ➖ Minimal dependencies: Only requires Python standard library, NumPy, and PyTorch; optional dependencies forcommon environment libraries (e.g., OpenAI Gym) and datasets (D4RL, OpenX...)
Read thefull paper for a more curated description of the library.
Check ourGetting Started tutorials for quickly ramp up with the basicfeatures of the library!
The TorchRL documentation can be foundhere.It contains tutorials and the API reference.
TorchRL also provides a RL knowledge base to help you debug your code, or simplylearn the basics of RL. Check it outhere.
We have some introductory videos for you to get to know the library better, check them out:
TorchRL being domain-agnostic, you can use it across many different fields. Here are a few examples:
- ACEGEN: Reinforcement Learning of Generative Chemical Agentsfor Drug Discovery
- BenchMARL: Benchmarking Multi-Agent Reinforcement Learning
- BricksRL: A Platform for Democratizing Robotics and Reinforcement LearningResearch and Education with LEGO
- OmniDrones: An Efficient and Flexible Platform for Reinforcement Learning in Drone Control
- RL4CO: an Extensive Reinforcement Learning for Combinatorial Optimization Benchmark
- Robohive: A unified framework for robot learning
RL algorithms are very heterogeneous, and it can be hard to recycle a codebaseacross settings (e.g. from online to offline, from state-based to pixel-basedlearning).TorchRL solves this problem throughTensorDict
,a convenient data structure(1) that can be used to streamline one'sRL codebase.With this tool, one can write acomplete PPO training script in less than 100lines of code!
Code
importtorchfromtensordict.nnimportTensorDictModulefromtensordict.nn.distributionsimportNormalParamExtractorfromtorchimportnnfromtorchrl.collectorsimportSyncDataCollectorfromtorchrl.data.replay_buffersimportTensorDictReplayBuffer, \LazyTensorStorage,SamplerWithoutReplacementfromtorchrl.envs.libs.gymimportGymEnvfromtorchrl.modulesimportProbabilisticActor,ValueOperator,TanhNormalfromtorchrl.objectivesimportClipPPOLossfromtorchrl.objectives.valueimportGAEenv=GymEnv("Pendulum-v1")model=TensorDictModule(nn.Sequential(nn.Linear(3,128),nn.Tanh(),nn.Linear(128,128),nn.Tanh(),nn.Linear(128,128),nn.Tanh(),nn.Linear(128,2),NormalParamExtractor() ),in_keys=["observation"],out_keys=["loc","scale"])critic=ValueOperator(nn.Sequential(nn.Linear(3,128),nn.Tanh(),nn.Linear(128,128),nn.Tanh(),nn.Linear(128,128),nn.Tanh(),nn.Linear(128,1), ),in_keys=["observation"],)actor=ProbabilisticActor(model,in_keys=["loc","scale"],distribution_class=TanhNormal,distribution_kwargs={"low":-1.0,"high":1.0},return_log_prob=True )buffer=TensorDictReplayBuffer(storage=LazyTensorStorage(1000),sampler=SamplerWithoutReplacement(),batch_size=50, )collector=SyncDataCollector(env,actor,frames_per_batch=1000,total_frames=1_000_000,)loss_fn=ClipPPOLoss(actor,critic)adv_fn=GAE(value_network=critic,average_gae=True,gamma=0.99,lmbda=0.95)optim=torch.optim.Adam(loss_fn.parameters(),lr=2e-4)fordataincollector:# collect dataforepochinrange(10):adv_fn(data)# compute advantagebuffer.extend(data)forsampleinbuffer:# consume dataloss_vals=loss_fn(sample)loss_val=sum(valueforkey,valueinloss_vals.items()ifkey.startswith("loss") )loss_val.backward()optim.step()optim.zero_grad()print(f"avg reward:{data['next','reward'].mean().item(): 4.4f}")
Here is an example of how theenvironment APIrelies on tensordict to carry data from one function to another during a rolloutexecution:
TensorDict
makes it easy to re-use pieces of code across environments, models andalgorithms.
Code
For instance, here's how to code a rollout in TorchRL:
- obs, done = env.reset()+ tensordict = env.reset()policy = SafeModule( model, in_keys=["observation_pixels", "observation_vector"], out_keys=["action"],)out = []for i in range(n_steps):- action, log_prob = policy(obs)- next_obs, reward, done, info = env.step(action)- out.append((obs, next_obs, action, log_prob, reward, done))- obs = next_obs+ tensordict = policy(tensordict)+ tensordict = env.step(tensordict)+ out.append(tensordict)+ tensordict = step_mdp(tensordict) # renames next_observation_* keys to observation_*- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]+ out = torch.stack(out, 0) # TensorDict supports multiple tensor operations
Using this, TorchRL abstracts away the input / output signatures of the modules, env,collectors, replay buffers and losses of the library, allowing all primitivesto be easily recycled across settings.
Code
Here's another example of an off-policy training loop in TorchRL (assumingthat a data collector, a replay buffer, a loss and an optimizer have been instantiated):
- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):+ for i, tensordict in enumerate(collector):- replay_buffer.add((obs, next_obs, action, log_prob, reward, done))+ replay_buffer.add(tensordict) for j in range(num_optim_steps):- obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)- loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)+ tensordict = replay_buffer.sample(batch_size)+ loss = loss_fn(tensordict) loss.backward() optim.step() optim.zero_grad()
This training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.
TensorDict supports multiple tensor operations on its device and shape(the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):
Code
# stack and cattensordict=torch.stack(list_of_tensordicts,0)tensordict=torch.cat(list_of_tensordicts,0)# reshapetensordict=tensordict.view(-1)tensordict=tensordict.permute(0,2,1)tensordict=tensordict.unsqueeze(-1)tensordict=tensordict.squeeze(-1)# indexingtensordict=tensordict[:2]tensordict[:,2]=sub_tensordict# device and memory locationtensordict.cuda()tensordict.to("cuda:1")tensordict.share_memory_()
TensorDict comes with a dedicatedtensordict.nn
module that contains everything you might need to write your model with it.And it isfunctorch
andtorch.compile
compatible!
Code
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])src = torch.rand((10, 32, 512))tgt = torch.rand((20, 32, 512))+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])- out = transformer_model(src, tgt)+ td_module(tensordict)+ out = tensordict["out"]
TheTensorDictSequential
class allows to branch sequences ofnn.Module
instances in a highly modular way.For instance, here is an implementation of a transformer using the encoder and decoder blocks:
encoder_module=TransformerEncoder(...)encoder=TensorDictSequential(encoder_module,in_keys=["src","src_mask"],out_keys=["memory"])decoder_module=TransformerDecoder(...)decoder=TensorDictModule(decoder_module,in_keys=["tgt","memory"],out_keys=["output"])transformer=TensorDictSequential(encoder,decoder)asserttransformer.in_keys== ["src","src_mask","tgt"]asserttransformer.out_keys== ["memory","output"]
TensorDictSequential
allows to isolate subgraphs by querying a set of desired input / output keys:
transformer.select_subsequence(out_keys=["memory"])# returns the encodertransformer.select_subsequence(in_keys=["tgt","memory"])# returns the decoder
CheckTensorDict tutorials tolearn more!
A commoninterface for environmentswhich supports common libraries (OpenAI gym, deepmind control lab, etc.)(1) and state-less execution(e.g. Model-based environments).Thebatched environments containers allow parallel execution(2).A common PyTorch-first class oftensor-specification class is also provided.TorchRL's environments API is simple but stringent and specific. Check thedocumentationandtutorial to learn more!
Code
env_make=lambda:GymEnv("Pendulum-v1",from_pixels=True)env_parallel=ParallelEnv(4,env_make)# creates 4 envs in paralleltensordict=env_parallel.rollout(max_steps=20,policy=None)# random rollout (no policy given)asserttensordict.shape== [4,20]# 4 envs, 20 steps rolloutenv_parallel.action_spec.is_in(tensordict["action"])# spec check returns True
multiprocess and distributeddata collectors(2)that work synchronously or asynchronously.Through the use of TensorDict, TorchRL's training loops are made very similarto regular training loops in supervisedlearning (although the "dataloader" -- read data collector -- is modified on-the-fly):
Code
env_make=lambda:GymEnv("Pendulum-v1",from_pixels=True)collector=MultiaSyncDataCollector( [env_make,env_make],policy=policy,devices=["cuda:0","cuda:0"],total_frames=10000,frames_per_batch=50, ...)fori,tensordict_datainenumerate(collector):loss=loss_module(tensordict_data)loss.backward()optim.step()optim.zero_grad()collector.update_policy_weights_()
Check ourdistributed collector examples tolearn more about ultra-fast data collection with TorchRL.
efficient(2) and generic(1)replay buffers with modularized storage:
Code
storage=LazyMemmapStorage(# memory-mapped (physical) storagecfg.buffer_size,scratch_dir="/tmp/")buffer=TensorDictPrioritizedReplayBuffer(alpha=0.7,beta=0.5,collate_fn=lambdax:x,pin_memory=device!=torch.device("cpu"),prefetch=10,# multi-threaded samplingstorage=storage)
Replay buffers are also offered as wrappers around common datasets foroffline RL:
Code
fromtorchrl.data.replay_buffersimportSamplerWithoutReplacementfromtorchrl.data.datasets.d4rlimportD4RLExperienceReplaydata=D4RLExperienceReplay("maze2d-open-v0",split_trajs=True,batch_size=128,sampler=SamplerWithoutReplacement(drop_last=True),)forsampleindata:# or alternatively sample = data.sample()fun(sample)
cross-libraryenvironment transforms(1),executed on device and in a vectorized fashion(2),which process and prepare the data coming out of the environments to be used by the agent:
Code
env_make=lambda:GymEnv("Pendulum-v1",from_pixels=True)env_base=ParallelEnv(4,env_make,device="cuda:0")# creates 4 envs in parallelenv=TransformedEnv(env_base,Compose(ToTensorImage(),ObservationNorm(loc=0.5,scale=1.0)),# executes the transforms once and on device)tensordict=env.reset()asserttensordict.device==torch.device("cuda:0")
Other transforms include: reward scaling (
RewardScaling
), shape operations (concatenation of tensors, unsqueezing etc.), concatenation ofsuccessive operations (CatFrames
), resizing (Resize
) and many more.Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes iteasy to add and remove them at will:
env.insert_transform(0,NoopResetEnv())# inserts the NoopResetEnv transform at the index 0
Nevertheless, transforms can access and execute operations on the parent environment:
transform=env.transform[1]# gathers the second transform of the listparent_env=transform.parent# returns the base environment of the second transform, i.e. the base env + the first transform
various tools for distributed learning (e.g.memory mapped tensors)(2);
variousarchitectures and models (e.g.actor-critic)(1):
Code
# create an nn.Modulecommon_module=ConvNet(bias_last_layer=True,depth=None,num_cells=[32,64,64],kernel_sizes=[8,4,3],strides=[4,2,1],)# Wrap it in a SafeModule, indicating what key to read in and where to# write out the outputcommon_module=SafeModule(common_module,in_keys=["pixels"],out_keys=["hidden"],)# Wrap the policy module in NormalParamsWrapper, such that the output# tensor is split in loc and scale, and scale is mapped onto a positive spacepolicy_module=SafeModule(NormalParamsWrapper(MLP(num_cells=[64,64],out_features=32,activation=nn.ELU) ),in_keys=["hidden"],out_keys=["loc","scale"],)# Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a# SafeProbabilisticModule, indicating how to build the# torch.distribution.Distribution object and what to do with itpolicy_module=SafeProbabilisticTensorDictSequential(# stochastic policypolicy_module,SafeProbabilisticModule(in_keys=["loc","scale"],out_keys="action",distribution_class=TanhNormal, ),)value_module=MLP(num_cells=[64,64],out_features=1,activation=nn.ELU,)# Wrap the policy and value funciton in a common moduleactor_value=ActorValueOperator(common_module,policy_module,value_module)# standalone policy from thisstandalone_policy=actor_value.get_policy_operator()
explorationwrappers andmodules to easily swap between exploration and exploitation(1):
Code
policy_explore=EGreedyWrapper(policy)withset_exploration_type(ExplorationType.RANDOM):tensordict=policy_explore(tensordict)# will use eps-greedywithset_exploration_type(ExplorationType.DETERMINISTIC):tensordict=policy_explore(tensordict)# will not use eps-greedy
A series of efficientloss modulesand highly vectorizedfunctional return and advantagecomputation.
Code
fromtorchrl.objectivesimportDQNLossloss_module=DQNLoss(value_network=value_network,gamma=0.99)tensordict=replay_buffer.sample(batch_size)loss=loss_module(tensordict)
fromtorchrl.objectives.value.functionalimportvec_td_lambda_return_estimateadvantage=vec_td_lambda_return_estimate(gamma,lmbda,next_state_value,reward,done,terminated)
a generictrainer class(1) thatexecutes the aforementioned training loop. Through a hooking mechanism,it also supports any logging or data transformation operation at any giventime.
variousrecipes to build models thatcorrespond to the environment being deployed.
LLM API: Complete framework for language model fine-tuning with unified wrappers for Hugging Face and vLLM backends,conversation management with automatic chat template detection, tool integration (Python execution, function calling),specialized objectives (GRPO, SFT), and high-performance async collectors. Perfect for RLHF, supervised fine-tuning,and tool-augmented training scenarios.
Code
fromtorchrl.envs.llmimportChatEnvfromtorchrl.modules.llmimportTransformersWrapperfromtorchrl.envs.llm.transformsimportPythonInterpreter# Create environment with tool executionenv=ChatEnv(tokenizer=tokenizer,system_prompt="You can execute Python code.",batch_size=[1]).append_transform(PythonInterpreter())# Wrap language model for trainingllm=TransformersWrapper(model=model,tokenizer=tokenizer,input_mode="history")# Multi-turn conversation with tool useobs=env.reset(TensorDict({"query":"Calculate 2+2"},batch_size=[1]))llm_output=llm(obs)# Generates responseobs=env.step(llm_output)# Environment processes response
If you feel a feature is missing from the library, please submit an issue!If you would like to contribute to new features, check ourcall for contributions and ourcontribution page.
A series ofState-of-the-Art implementations are provided with an illustrative purpose:
Algorithm | Compile Support** | Tensordict-free API | Modular Losses | Continuous and Discrete |
DQN | 1.9x | + | NA | + (throughActionDiscretizer transform) |
DDPG | 1.87x | + | + | - (continuous only) |
IQL | 3.22x | + | + | + |
CQL | 2.68x | + | + | + |
TD3 | 2.27x | + | + | - (continuous only) |
TD3+BC | untested | + | + | - (continuous only) |
A2C | 2.67x | + | - | + |
PPO | 2.42x | + | - | + |
SAC | 2.62x | + | - | + |
REDQ | 2.28x | + | - | - (continuous only) |
Dreamer v1 | untested | + | + (different classes) | - (continuous only) |
Decision Transformers | untested | + | NA | - (continuous only) |
CrossQ | untested | + | + | - (continuous only) |
Gail | untested | + | NA | + |
Impala | untested | + | - | + |
IQL (MARL) | untested | + | + | + |
DDPG (MARL) | untested | + | + | - (continuous only) |
PPO (MARL) | untested | + | - | + |
QMIX-VDN (MARL) | untested | + | NA | + |
SAC (MARL) | untested | + | - | + |
RLHF | NA | + | NA | NA |
LLM API (GRPO) | NA | + | + | NA |
** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending onarchitecture and device.
and many more to come!
Code examples displaying toy code snippets and training scripts are also available
- LLM API & GRPO - Complete language model fine-tuning pipeline
- RLHF
- Memory-mapped replay buffers
Check theexamples directory for more detailsabout handling the various configuration settings.
We also providetutorials and demos that give a sense ofwhat the library can do.
If you're using TorchRL, please refer to this BibTeX entry to cite this work:
@misc{bou2023torchrl, title={TorchRL: A data-driven decision-making library for PyTorch}, author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens}, year={2023}, eprint={2306.00577}, archivePrefix={arXiv}, primaryClass={cs.LG}}
python -m venv torchrlsource torchrl/bin/activate# On Windows use: venv\Scripts\activate
Or create a conda environment where the packages will be installed.
conda create --name torchrl python=3.9conda activate torchrl
Depending on the use of torchrl that you want to make, you may want toinstall the latest (nightly) PyTorch release or the latest stable version of PyTorch.Seehere for a detailed list of commands,includingpip3
or other special installation instructions.
TorchRL offers a few pre-defined dependencies such as"torchrl[tests]"
,"torchrl[atari]"
etc.
You can install thelatest stable release by using
pip3 install torchrl
This should work on linux (including AArch64 machines), Windows 10 and OsX (Metal chips only).On certain Windows machines (Windows 11), one should build the library locally.This can be done in two ways:
# Install and build locally v0.8.1 of the library without cloningpip3 install git+https://github.com/pytorch/rl@v0.8.1# Clone the library and build it locallygit clone https://github.com/pytorch/tensordictgit clone https://github.com/pytorch/rlpip install -e tensordictpip install -e rl
Note that tensordict local build requirescmake
to be installed viahomebrew (MacOS) or another package managersuch asapt
,apt-get
,conda
oryum
but NOTpip
, as well aspip install "pybind11[global]"
.
One can also build the wheels to distribute to co-workers using
python setup.py bdist_wheel
Your wheels will be stored there./dist/torchrl<name>.whl
and installable via
pip install torchrl<name>.whl
Thenightly build can be installed via
pip3 install tensordict-nightly torchrl-nightly
which we currently only ship for Linux machines.Importantly, the nightly builds require the nightly builds of PyTorch too.Also, a local build of torchrl with the nightly build of tensordict may fail - install both nightlies or both local builds but do not mix them.
Disclaimer: As of today, TorchRL is roughly compatible with any pytorch version >= 2.1 and installing it will notdirectly require a newer version of pytorch to be installed. Indirectly though, tensordict still requires the latestPyTorch to be installed and we are working hard to loosen that requirement.The C++ binaries of TorchRL (mainly for prioritized replay buffers) will only work with PyTorch 2.7.0 and above.Some features (e.g., working with nested jagged tensors) may alsobe limited with older versions of pytorch. It is recommended to use the latest TorchRL with the latest PyTorch versionunless there is a strong reason not to do so.
Optional dependencies
The following libraries can be installed depending on the usage one wants tomake of torchrl:
# diversepip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher# renderingpip3 install "moviepy<2.0.0"# deepmind control suitepip3 install dm_control# gym, atari gamespip3 install "gym[atari]" "gym[accept-rom-license]" pygame# testspip3 install pytest pyyaml pytest-instafail# tensorboardpip3 install tensorboard# wandbpip3 install wandb
Versioning issues can cause error message of the typeundefined symbol
and such. For these, refer to theversioning issues documentfor a complete explanation and proposed workarounds.
If you spot a bug in the library, please raise an issue in this repo.
If you have a more generic question regarding RL in PyTorch, post it onthePyTorch forum.
Internal collaborations to torchrl are welcome! Feel free to fork, submit issues and PRs.You can checkout the detailed contribution guidehere.As mentioned above, a list of open contributions can be found inhere.
Contributors are recommended to installpre-commit hooks (usingpre-commit install
). pre-commit will check for linting related issues when the code is committed locally. You can disable th check by appending-n
to your commit command:git commit -m <commit message> -n
This library is released as a PyTorch beta feature.BC-breaking changes are likely to happen but they will be introduced with a deprecationwarranty after a few release cycles.
TorchRL is licensed under the MIT License. SeeLICENSE for details.
About
A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
Topics
Resources
License
Code of conduct
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.