Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
/rlPublic

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.

License

NotificationsYou must be signed in to change notification settings

pytorch/rl

Repository files navigation

Unit-testsDocumentationBenchmarkscodecovTwitter FollowPython versionGitHub licensepypi versionpypi nightly versionDownloadsDownloadsDiscord Shield

TorchRL

Documentation |TensorDict |Features |Examples, tutorials and demos |Citation |Installation |Asking a question |Contributing

TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch.

Key features

  • 🐍Python-first: Designed with Python as the primary language for ease of use and flexibility
  • ⏱️Efficient: Optimized for performance to support demanding RL research applications
  • 🧮Modular, customizable, extensible: Highly modular architecture allows for easy swapping, transformation, or creation of new components
  • 📚Documented: Thorough documentation ensures that users can quickly understand and utilize the library
  • Tested: Rigorously tested to ensure reliability and stability
  • ⚙️Reusable functionals: Provides a set of highly reusable functions for cost functions, returns, and data processing

Design Principles

  • 🔥Aligns with PyTorch ecosystem: Follows the structure and conventions of popular PyTorch libraries(e.g., dataset pillar, transforms, models, data utilities)
  • ➖ Minimal dependencies: Only requires Python standard library, NumPy, and PyTorch; optional dependencies forcommon environment libraries (e.g., OpenAI Gym) and datasets (D4RL, OpenX...)

Read thefull paper for a more curated description of the library.

Getting started

Check ourGetting Started tutorials for quickly ramp up with the basicfeatures of the library!

Documentation and knowledge base

The TorchRL documentation can be foundhere.It contains tutorials and the API reference.

TorchRL also provides a RL knowledge base to help you debug your code, or simplylearn the basics of RL. Check it outhere.

We have some introductory videos for you to get to know the library better, check them out:

Spotlight publications

TorchRL being domain-agnostic, you can use it across many different fields. Here are a few examples:

  • ACEGEN: Reinforcement Learning of Generative Chemical Agentsfor Drug Discovery
  • BenchMARL: Benchmarking Multi-Agent Reinforcement Learning
  • BricksRL: A Platform for Democratizing Robotics and Reinforcement LearningResearch and Education with LEGO
  • OmniDrones: An Efficient and Flexible Platform for Reinforcement Learning in Drone Control
  • RL4CO: an Extensive Reinforcement Learning for Combinatorial Optimization Benchmark
  • Robohive: A unified framework for robot learning

Writing simplified and portable RL codebase withTensorDict

RL algorithms are very heterogeneous, and it can be hard to recycle a codebaseacross settings (e.g. from online to offline, from state-based to pixel-basedlearning).TorchRL solves this problem throughTensorDict,a convenient data structure(1) that can be used to streamline one'sRL codebase.With this tool, one can write acomplete PPO training script in less than 100lines of code!

Code
importtorchfromtensordict.nnimportTensorDictModulefromtensordict.nn.distributionsimportNormalParamExtractorfromtorchimportnnfromtorchrl.collectorsimportSyncDataCollectorfromtorchrl.data.replay_buffersimportTensorDictReplayBuffer, \LazyTensorStorage,SamplerWithoutReplacementfromtorchrl.envs.libs.gymimportGymEnvfromtorchrl.modulesimportProbabilisticActor,ValueOperator,TanhNormalfromtorchrl.objectivesimportClipPPOLossfromtorchrl.objectives.valueimportGAEenv=GymEnv("Pendulum-v1")model=TensorDictModule(nn.Sequential(nn.Linear(3,128),nn.Tanh(),nn.Linear(128,128),nn.Tanh(),nn.Linear(128,128),nn.Tanh(),nn.Linear(128,2),NormalParamExtractor()  ),in_keys=["observation"],out_keys=["loc","scale"])critic=ValueOperator(nn.Sequential(nn.Linear(3,128),nn.Tanh(),nn.Linear(128,128),nn.Tanh(),nn.Linear(128,128),nn.Tanh(),nn.Linear(128,1),  ),in_keys=["observation"],)actor=ProbabilisticActor(model,in_keys=["loc","scale"],distribution_class=TanhNormal,distribution_kwargs={"low":-1.0,"high":1.0},return_log_prob=True  )buffer=TensorDictReplayBuffer(storage=LazyTensorStorage(1000),sampler=SamplerWithoutReplacement(),batch_size=50,  )collector=SyncDataCollector(env,actor,frames_per_batch=1000,total_frames=1_000_000,)loss_fn=ClipPPOLoss(actor,critic)adv_fn=GAE(value_network=critic,average_gae=True,gamma=0.99,lmbda=0.95)optim=torch.optim.Adam(loss_fn.parameters(),lr=2e-4)fordataincollector:# collect dataforepochinrange(10):adv_fn(data)# compute advantagebuffer.extend(data)forsampleinbuffer:# consume dataloss_vals=loss_fn(sample)loss_val=sum(valueforkey,valueinloss_vals.items()ifkey.startswith("loss")              )loss_val.backward()optim.step()optim.zero_grad()print(f"avg reward:{data['next','reward'].mean().item(): 4.4f}")

Here is an example of how theenvironment APIrelies on tensordict to carry data from one function to another during a rolloutexecution:Alt Text

TensorDict makes it easy to re-use pieces of code across environments, models andalgorithms.

Code

For instance, here's how to code a rollout in TorchRL:

- obs, done = env.reset()+ tensordict = env.reset()policy = SafeModule(    model,    in_keys=["observation_pixels", "observation_vector"],    out_keys=["action"],)out = []for i in range(n_steps):-     action, log_prob = policy(obs)-     next_obs, reward, done, info = env.step(action)-     out.append((obs, next_obs, action, log_prob, reward, done))-     obs = next_obs+     tensordict = policy(tensordict)+     tensordict = env.step(tensordict)+     out.append(tensordict)+     tensordict = step_mdp(tensordict)  # renames next_observation_* keys to observation_*- obs, next_obs, action, log_prob, reward, done = [torch.stack(vals, 0) for vals in zip(*out)]+ out = torch.stack(out, 0)  # TensorDict supports multiple tensor operations

Using this, TorchRL abstracts away the input / output signatures of the modules, env,collectors, replay buffers and losses of the library, allowing all primitivesto be easily recycled across settings.

Code

Here's another example of an off-policy training loop in TorchRL (assumingthat a data collector, a replay buffer, a loss and an optimizer have been instantiated):

- for i, (obs, next_obs, action, hidden_state, reward, done) in enumerate(collector):+ for i, tensordict in enumerate(collector):-     replay_buffer.add((obs, next_obs, action, log_prob, reward, done))+     replay_buffer.add(tensordict)    for j in range(num_optim_steps):-         obs, next_obs, action, hidden_state, reward, done = replay_buffer.sample(batch_size)-         loss = loss_fn(obs, next_obs, action, hidden_state, reward, done)+         tensordict = replay_buffer.sample(batch_size)+         loss = loss_fn(tensordict)        loss.backward()        optim.step()        optim.zero_grad()

This training loop can be re-used across algorithms as it makes a minimal number of assumptions about the structure of the data.

TensorDict supports multiple tensor operations on its device and shape(the shape of TensorDict, or its batch size, is the common arbitrary N first dimensions of all its contained tensors):

Code
# stack and cattensordict=torch.stack(list_of_tensordicts,0)tensordict=torch.cat(list_of_tensordicts,0)# reshapetensordict=tensordict.view(-1)tensordict=tensordict.permute(0,2,1)tensordict=tensordict.unsqueeze(-1)tensordict=tensordict.squeeze(-1)# indexingtensordict=tensordict[:2]tensordict[:,2]=sub_tensordict# device and memory locationtensordict.cuda()tensordict.to("cuda:1")tensordict.share_memory_()

TensorDict comes with a dedicatedtensordict.nnmodule that contains everything you might need to write your model with it.And it isfunctorch andtorch.compile compatible!

Code
transformer_model = nn.Transformer(nhead=16, num_encoder_layers=12)+ td_module = SafeModule(transformer_model, in_keys=["src", "tgt"], out_keys=["out"])src = torch.rand((10, 32, 512))tgt = torch.rand((20, 32, 512))+ tensordict = TensorDict({"src": src, "tgt": tgt}, batch_size=[20, 32])- out = transformer_model(src, tgt)+ td_module(tensordict)+ out = tensordict["out"]

TheTensorDictSequential class allows to branch sequences ofnn.Module instances in a highly modular way.For instance, here is an implementation of a transformer using the encoder and decoder blocks:

encoder_module=TransformerEncoder(...)encoder=TensorDictSequential(encoder_module,in_keys=["src","src_mask"],out_keys=["memory"])decoder_module=TransformerDecoder(...)decoder=TensorDictModule(decoder_module,in_keys=["tgt","memory"],out_keys=["output"])transformer=TensorDictSequential(encoder,decoder)asserttransformer.in_keys== ["src","src_mask","tgt"]asserttransformer.out_keys== ["memory","output"]

TensorDictSequential allows to isolate subgraphs by querying a set of desired input / output keys:

transformer.select_subsequence(out_keys=["memory"])# returns the encodertransformer.select_subsequence(in_keys=["tgt","memory"])# returns the decoder

CheckTensorDict tutorials tolearn more!

Features

  • A commoninterface for environmentswhich supports common libraries (OpenAI gym, deepmind control lab, etc.)(1) and state-less execution(e.g. Model-based environments).Thebatched environments containers allow parallel execution(2).A common PyTorch-first class oftensor-specification class is also provided.TorchRL's environments API is simple but stringent and specific. Check thedocumentationandtutorial to learn more!

    Code
    env_make=lambda:GymEnv("Pendulum-v1",from_pixels=True)env_parallel=ParallelEnv(4,env_make)# creates 4 envs in paralleltensordict=env_parallel.rollout(max_steps=20,policy=None)# random rollout (no policy given)asserttensordict.shape== [4,20]# 4 envs, 20 steps rolloutenv_parallel.action_spec.is_in(tensordict["action"])# spec check returns True
  • multiprocess and distributeddata collectors(2)that work synchronously or asynchronously.Through the use of TensorDict, TorchRL's training loops are made very similarto regular training loops in supervisedlearning (although the "dataloader" -- read data collector -- is modified on-the-fly):

    Code
    env_make=lambda:GymEnv("Pendulum-v1",from_pixels=True)collector=MultiaSyncDataCollector(    [env_make,env_make],policy=policy,devices=["cuda:0","cuda:0"],total_frames=10000,frames_per_batch=50,    ...)fori,tensordict_datainenumerate(collector):loss=loss_module(tensordict_data)loss.backward()optim.step()optim.zero_grad()collector.update_policy_weights_()

    Check ourdistributed collector examples tolearn more about ultra-fast data collection with TorchRL.

  • efficient(2) and generic(1)replay buffers with modularized storage:

    Code
    storage=LazyMemmapStorage(# memory-mapped (physical) storagecfg.buffer_size,scratch_dir="/tmp/")buffer=TensorDictPrioritizedReplayBuffer(alpha=0.7,beta=0.5,collate_fn=lambdax:x,pin_memory=device!=torch.device("cpu"),prefetch=10,# multi-threaded samplingstorage=storage)

    Replay buffers are also offered as wrappers around common datasets foroffline RL:

    Code
    fromtorchrl.data.replay_buffersimportSamplerWithoutReplacementfromtorchrl.data.datasets.d4rlimportD4RLExperienceReplaydata=D4RLExperienceReplay("maze2d-open-v0",split_trajs=True,batch_size=128,sampler=SamplerWithoutReplacement(drop_last=True),)forsampleindata:# or alternatively sample = data.sample()fun(sample)
  • cross-libraryenvironment transforms(1),executed on device and in a vectorized fashion(2),which process and prepare the data coming out of the environments to be used by the agent:

    Code
    env_make=lambda:GymEnv("Pendulum-v1",from_pixels=True)env_base=ParallelEnv(4,env_make,device="cuda:0")# creates 4 envs in parallelenv=TransformedEnv(env_base,Compose(ToTensorImage(),ObservationNorm(loc=0.5,scale=1.0)),# executes the transforms once and on device)tensordict=env.reset()asserttensordict.device==torch.device("cuda:0")

    Other transforms include: reward scaling (RewardScaling), shape operations (concatenation of tensors, unsqueezing etc.), concatenation ofsuccessive operations (CatFrames), resizing (Resize) and many more.

    Unlike other libraries, the transforms are stacked as a list (and not wrapped in each other), which makes iteasy to add and remove them at will:

    env.insert_transform(0,NoopResetEnv())# inserts the NoopResetEnv transform at the index 0

    Nevertheless, transforms can access and execute operations on the parent environment:

    transform=env.transform[1]# gathers the second transform of the listparent_env=transform.parent# returns the base environment of the second transform, i.e. the base env + the first transform
  • various tools for distributed learning (e.g.memory mapped tensors)(2);

  • variousarchitectures and models (e.g.actor-critic)(1):

    Code
    # create an nn.Modulecommon_module=ConvNet(bias_last_layer=True,depth=None,num_cells=[32,64,64],kernel_sizes=[8,4,3],strides=[4,2,1],)# Wrap it in a SafeModule, indicating what key to read in and where to# write out the outputcommon_module=SafeModule(common_module,in_keys=["pixels"],out_keys=["hidden"],)# Wrap the policy module in NormalParamsWrapper, such that the output# tensor is split in loc and scale, and scale is mapped onto a positive spacepolicy_module=SafeModule(NormalParamsWrapper(MLP(num_cells=[64,64],out_features=32,activation=nn.ELU)    ),in_keys=["hidden"],out_keys=["loc","scale"],)# Use a SafeProbabilisticTensorDictSequential to combine the SafeModule with a# SafeProbabilisticModule, indicating how to build the# torch.distribution.Distribution object and what to do with itpolicy_module=SafeProbabilisticTensorDictSequential(# stochastic policypolicy_module,SafeProbabilisticModule(in_keys=["loc","scale"],out_keys="action",distribution_class=TanhNormal,    ),)value_module=MLP(num_cells=[64,64],out_features=1,activation=nn.ELU,)# Wrap the policy and value funciton in a common moduleactor_value=ActorValueOperator(common_module,policy_module,value_module)# standalone policy from thisstandalone_policy=actor_value.get_policy_operator()
  • explorationwrappers andmodules to easily swap between exploration and exploitation(1):

    Code
    policy_explore=EGreedyWrapper(policy)withset_exploration_type(ExplorationType.RANDOM):tensordict=policy_explore(tensordict)# will use eps-greedywithset_exploration_type(ExplorationType.DETERMINISTIC):tensordict=policy_explore(tensordict)# will not use eps-greedy
  • A series of efficientloss modulesand highly vectorizedfunctional return and advantagecomputation.

    Code

    Loss modules

    fromtorchrl.objectivesimportDQNLossloss_module=DQNLoss(value_network=value_network,gamma=0.99)tensordict=replay_buffer.sample(batch_size)loss=loss_module(tensordict)

    Advantage computation

    fromtorchrl.objectives.value.functionalimportvec_td_lambda_return_estimateadvantage=vec_td_lambda_return_estimate(gamma,lmbda,next_state_value,reward,done,terminated)
  • a generictrainer class(1) thatexecutes the aforementioned training loop. Through a hooking mechanism,it also supports any logging or data transformation operation at any giventime.

  • variousrecipes to build models thatcorrespond to the environment being deployed.

If you feel a feature is missing from the library, please submit an issue!If you would like to contribute to new features, check ourcall for contributions and ourcontribution page.

Examples, tutorials and demos

A series ofState-of-the-Art implementations are provided with an illustrative purpose:

AlgorithmCompile Support**Tensordict-free APIModular LossesContinuous and Discrete
DQN 1.9x + NA + (throughActionDiscretizer transform)
DDPG 1.87x + + - (continuous only)
IQL 3.22x + + +
CQL 2.68x + + +
TD3 2.27x + + - (continuous only)
TD3+BC untested + + - (continuous only)
A2C 2.67x + - +
PPO 2.42x + - +
SAC 2.62x + - +
REDQ 2.28x + - - (continuous only)
Dreamer v1 untested + + (different classes) - (continuous only)
Decision Transformers untested + NA - (continuous only)
CrossQ untested + + - (continuous only)
Gail untested + NA +
Impala untested + - +
IQL (MARL) untested + + +
DDPG (MARL) untested + + - (continuous only)
PPO (MARL) untested + - +
QMIX-VDN (MARL) untested + NA +
SAC (MARL) untested + - +
RLHF NA + NA NA

** The number indicates expected speed-up compared to eager mode when executed on CPU. Numbers may vary depending onarchitecture and device.

and many more to come!

Code examples displaying toy code snippets and training scripts are also available

Check theexamples directory for more detailsabout handling the various configuration settings.

We also providetutorials and demos that give a sense ofwhat the library can do.

Citation

If you're using TorchRL, please refer to this BibTeX entry to cite this work:

@misc{bou2023torchrl,      title={TorchRL: A data-driven decision-making library for PyTorch},       author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},      year={2023},      eprint={2306.00577},      archivePrefix={arXiv},      primaryClass={cs.LG}}

Installation

Create a conda environment where the packages will be installed.

conda create --name torch_rl python=3.9conda activate torch_rl

PyTorch

Depending on the use of functorch that you want to make, you may want toinstall the latest (nightly) PyTorch release or the latest stable version of PyTorch.Seehere for a detailed list of commands,includingpip3 or other special installation instructions.

Torchrl

You can install thelatest stable release by using

pip3 install torchrl

This should work on linux, Windows 10 and OsX (Intel or Silicon chips).On certain Windows machines (Windows 11), one should install the library locally (see below).

Thenightly build can be installed via

pip3 install torchrl-nightly

which we currently only ship for Linux and OsX (Intel) machines.Importantly, the nightly builds require the nightly builds of PyTorch too.

To install extra dependencies, call

pip3 install"torchrl[atari,dm_control,gym_continuous,rendering,tests,utils,marl,open_spiel,checkpointing]"

or a subset of these.

One may also desire to install the library locally. Three main reasons can motivate this:

  • the nightly/stable release isn't available for one's platform (eg, Windows 11, nightlies for Apple Silicon etc.);
  • contributing to the code;
  • install torchrl with a previous version of PyTorch (any version >= 2.0) (note that this should also be doable via a regular install followedby a downgrade to a previous pytorch version -- but the C++ binaries will not be available so some feature will not work,
    such as prioritized replay buffers and the like.)

To install the library locally, start by cloning the repo:

git clone https://github.com/pytorch/rl

and don't forget to check out the branch or tag you want to use for the build:

git checkout v0.4.0

Go to the directory where you have cloned the torchrl repo and install it (afterinstallingninja)

cd /path/to/torchrl/pip3 install ninja -Upython setup.py develop

One can also build the wheels to distribute to co-workers using

python setup.py bdist_wheel

Your wheels will be stored there./dist/torchrl<name>.whl and installable via

pip install torchrl<name>.whl

Warning: Unfortunately,pip3 install -e . does not currently work. Contributions to help fix this are welcome!

On M1 machines, this should work out-of-the-box with the nightly build of PyTorch.If the generation of this artifact in MacOs M1 doesn't work correctly or in the execution the message(mach-o file, but is an incompatible architecture (have 'x86_64', need 'arm64e')) appears, then try

ARCHFLAGS="-arch arm64" python setup.py develop

To run a quick sanity check, leave that directory (e.g. by executingcd ~/)and try to import the library.

python -c "import torchrl"

This should not return any warning or error.

Optional dependencies

The following libraries can be installed depending on the usage one wants tomake of torchrl:

# diversepip3 install tqdm tensorboard "hydra-core>=1.1" hydra-submitit-launcher# renderingpip3 install "moviepy<2.0.0"# deepmind control suitepip3 install dm_control# gym, atari gamespip3 install "gym[atari]" "gym[accept-rom-license]" pygame# testspip3 install pytest pyyaml pytest-instafail# tensorboardpip3 install tensorboard# wandbpip3 install wandb

Troubleshooting

If aModuleNotFoundError: No module named ‘torchrl._torchrl errors occurs (ora warning indicating that the C++ binaries could not be loaded),it means that the C++ extensions were not installed or not found.

  • One common reason might be that you are trying to import torchrl from within thegit repo location. The following code snippet should return an error iftorchrl has not been installed indevelop mode:
    cd ~/path/to/rl/repopython -c 'from torchrl.envs.libs.gym import GymEnv'
    If this is the case, consider executing torchrl from another location.
  • If you're not importing torchrl from within its repo location, it could becaused by a problem during the local installation. Check the log after thepython setup.py develop. One common cause is a g++/C++ version discrepancyand/or a problem with theninja library.
  • If the problem persists, feel free to open an issue on the topic in the repo,we'll make our best to help!
  • OnMacOs, we recommend installing XCode first.With Apple Silicon M1 chips, make sure you are using the arm64-built python(e.g.here).Running the following lines of code
    wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.pypython collect_env.py
    should display
    OS: macOS *** (arm64)
    and not
    OS: macOS **** (x86_64)

Versioning issues can cause error message of the typeundefined symboland such. For these, refer to theversioning issues documentfor a complete explanation and proposed workarounds.

Asking a question

If you spot a bug in the library, please raise an issue in this repo.

If you have a more generic question regarding RL in PyTorch, post it onthePyTorch forum.

Contributing

Internal collaborations to torchrl are welcome! Feel free to fork, submit issues and PRs.You can checkout the detailed contribution guidehere.As mentioned above, a list of open contributions can be found inhere.

Contributors are recommended to installpre-commit hooks (usingpre-commit install). pre-commit will check for linting related issues when the code is committed locally. You can disable th check by appending-n to your commit command:git commit -m <commit message> -n

Disclaimer

This library is released as a PyTorch beta feature.BC-breaking changes are likely to happen but they will be introduced with a deprecationwarranty after a few release cycles.

License

TorchRL is licensed under the MIT License. SeeLICENSE for details.


[8]ページ先頭

©2009-2025 Movatter.jp