Rate this Page

Note

Go to the endto download the full example code.

Reinforcement Learning (PPO) with TorchRL Tutorial#

Created On: Mar 15, 2023 | Last Updated: Sep 17, 2025 | Last Verified: Nov 05, 2024

Author:Vincent Moens

This tutorial demonstrates how to use PyTorch andtorchrl to train a parametric policynetwork to solve the Inverted Pendulum task from theOpenAI-Gym/Farama-Gymnasiumcontrol library.

Inverted pendulum

Inverted pendulum#

Key learnings:

  • How to create an environment in TorchRL, transform its outputs, and collect data from this environment;

  • How to make your classes talk to each other usingTensorDict;

  • The basics of building your training loop with TorchRL:

    • How to compute the advantage signal for policy gradient methods;

    • How to create a stochastic policy using a probabilistic neural network;

    • How to create a dynamic replay buffer and sample from it without repetition.

We will cover six crucial components of TorchRL:

If you are running this in Google Colab, make sure you install the following dependencies:

!pip3installtorchrl!pip3installgym[mujoco]!pip3installtqdm

Proximal Policy Optimization (PPO) is a policy-gradient algorithm where abatch of data is being collected and directly consumed to train the policy to maximisethe expected return given some proximality constraints. You can think of itas a sophisticated version ofREINFORCE,the foundational policy-optimization algorithm. For more information, see theProximal Policy Optimization Algorithms paper.

PPO is usually regarded as a fast and efficient method for online, on-policyreinforcement algorithm. TorchRL provides a loss-module that does all the workfor you, so that you can rely on this implementation and focus on solving yourproblem rather than re-inventing the wheel every time you want to train a policy.

For completeness, here is a brief overview of what the loss computes, even thoughthis is taken care of by ourClipPPOLoss module—the algorithm works as follows:1. we will sample a batch of data by playing thepolicy in the environment for a given number of steps.2. Then, we will perform a given number of optimization steps with random sub-samples of this batch usinga clipped version of the REINFORCE loss.3. The clipping will put a pessimistic bound on our loss: lower return estimates willbe favored compared to higher ones.The precise formula of the loss is:

\[L(s,a,\theta_k,\theta) = \min\left(\frac{\pi_{\theta}(a|s)}{\pi_{\theta_k}(a|s)} A^{\pi_{\theta_k}}(s,a), \;\;g(\epsilon, A^{\pi_{\theta_k}}(s,a))\right),\]

There are two components in that loss: in the first part of the minimum operator,we simply compute an importance-weighted version of the REINFORCE loss (for example, aREINFORCE loss that we have corrected for the fact that the current policyconfiguration lags the one that was used for the data collection).The second part of that minimum operator is a similar loss where we have clippedthe ratios when they exceeded or were below a given pair of thresholds.

This loss ensures that whether the advantage is positive or negative, policyupdates that would produce significant shifts from the previous configurationare being discouraged.

This tutorial is structured as follows:

  1. First, we will define a set of hyperparameters we will be using for training.

  2. Next, we will focus on creating our environment, or simulator, using TorchRL’swrappers and transforms.

  3. Next, we will design the policy network and the value model,which is indispensable to the loss function. These modules will be usedto configure our loss module.

  4. Next, we will create the replay buffer and data loader.

  5. Finally, we will run our training loop and analyze the results.

Throughout this tutorial, we’ll be using thetensordict library.TensorDict is the lingua franca of TorchRL: it helps us abstractwhat a module reads and writes and care less about the specific datadescription and more about the algorithm itself.

importwarningswarnings.filterwarnings("ignore")fromtorchimportmultiprocessingfromcollectionsimportdefaultdictimportmatplotlib.pyplotaspltimporttorchfromtensordict.nnimportTensorDictModulefromtensordict.nn.distributionsimportNormalParamExtractorfromtorchimportnnfromtorchrl.collectorsimportSyncDataCollectorfromtorchrl.data.replay_buffersimportReplayBufferfromtorchrl.data.replay_buffers.samplersimportSamplerWithoutReplacementfromtorchrl.data.replay_buffers.storagesimportLazyTensorStoragefromtorchrl.envsimport(Compose,DoubleToFloat,ObservationNorm,StepCounter,TransformedEnv)fromtorchrl.envs.libs.gymimportGymEnvfromtorchrl.envs.utilsimportcheck_env_specs,ExplorationType,set_exploration_typefromtorchrl.modulesimportProbabilisticActor,TanhNormal,ValueOperatorfromtorchrl.objectivesimportClipPPOLossfromtorchrl.objectives.valueimportGAEfromtqdmimporttqdm

Define Hyperparameters#

We set the hyperparameters for our algorithm. Depending on the resourcesavailable, one may choose to execute the policy on GPU or on anotherdevice.Theframe_skip will control how for how many frames is a singleaction being executed. The rest of the arguments that count framesmust be corrected for this value (since one environment step willactually returnframe_skip frames).

is_fork=multiprocessing.get_start_method()=="fork"device=(torch.device(0)iftorch.cuda.is_available()andnotis_forkelsetorch.device("cpu"))num_cells=256# number of cells in each layer i.e. output dim.lr=3e-4max_grad_norm=1.0

Data collection parameters#

When collecting data, we will be able to choose how big each batch will beby defining aframes_per_batch parameter. We will also define how manyframes (such as the number of interactions with the simulator) we will allow ourselves touse. In general, the goal of an RL algorithm is to learn to solve the taskas fast as it can in terms of environment interactions: the lower thetotal_framesthe better.

frames_per_batch=1000# For a complete training, bring the number of frames up to 1Mtotal_frames=50_000

PPO parameters#

At each data collection (or batch collection) we will run the optimizationover a certain number ofepochs, each time consuming the entire data we justacquired in a nested training loop. Here, thesub_batch_size is different from theframes_per_batch here above: recall that we are working with a “batch of data”coming from our collector, which size is defined byframes_per_batch, and thatwe will further split in smaller sub-batches during the inner training loop.The size of these sub-batches is controlled bysub_batch_size.

sub_batch_size=64# cardinality of the sub-samples gathered from the current data in the inner loopnum_epochs=10# optimization steps per batch of data collectedclip_epsilon=(0.2# clip value for PPO loss: see the equation in the intro for more context.)gamma=0.99lmbda=0.95entropy_eps=1e-4

Define an environment#

In RL, anenvironment is usually the way we refer to a simulator or acontrol system. Various libraries provide simulation environments for reinforcementlearning, including Gymnasium (previously OpenAI Gym), DeepMind control suite, andmany others.As a general library, TorchRL’s goal is to provide an interchangeable interfaceto a large panel of RL simulators, allowing you to easily swap one environmentwith another. For example, creating a wrapped gym environment can be achieved with few characters:

base_env=GymEnv("InvertedDoublePendulum-v4",device=device)
Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.

There are a few things to notice in this code: first, we createdthe environment by calling theGymEnv wrapper. If extra keyword argumentsare passed, they will be transmitted to thegym.make method, hence coveringthe most common environment construction commands.Alternatively, one could also directly create a gym environment usinggym.make(env_name,**kwargs)and wrap it in aGymWrapper class.

Also thedevice argument: for gym, this only controls the device whereinput action and observed states will be stored, but the execution will alwaysbe done on CPU. The reason for this is simply that gym does not support on-deviceexecution, unless specified otherwise. For other libraries, we have control overthe execution device and, as much as we can, we try to stay consistent in terms ofstoring and execution backends.

Transforms#

We will append some transforms to our environments to prepare the data forthe policy. In Gym, this is usually achieved via wrappers. TorchRL takes a differentapproach, more similar to other pytorch domain libraries, through the use of transforms.To add transforms to an environment, one should simply wrap it in aTransformedEnvinstance and append the sequence of transforms to it. The transformed environment will inheritthe device and meta-data of the wrapped environment, and transform these depending on the sequenceof transforms it contains.

Normalization#

The first to encode is a normalization transform.As a rule of thumbs, it is preferable to have data that looselymatch a unit Gaussian distribution: to obtain this, we willrun a certain number of random steps in the environment and computethe summary statistics of these observations.

We’ll append two other transforms: theDoubleToFloat transform willconvert double entries to single-precision numbers, ready to be read by thepolicy. TheStepCounter transform will be used to count the steps beforethe environment is terminated. We will use this measure as a supplementary measureof performance.

As we will see later, many of the TorchRL’s classes rely onTensorDictto communicate. You could think of it as a python dictionary with some extratensor features. In practice, this means that many modules we will be workingwith need to be told what key to read (in_keys) and what key to write(out_keys) in thetensordict they will receive. Usually, ifout_keysis omitted, it is assumed that thein_keys entries will be updatedin-place. For our transforms, the only entry we are interested in is referredto as"observation" and our transform layers will be told to modify thisentry and this entry only:

env=TransformedEnv(base_env,Compose(# normalize observationsObservationNorm(in_keys=["observation"]),DoubleToFloat(),StepCounter(),),)

As you may have noticed, we have created a normalization layer but we did notset its normalization parameters. To do this,ObservationNorm canautomatically gather the summary statistics of our environment:

env.transform[0].init_stats(num_iter=1000,reduce_dim=0,cat_dim=0)

TheObservationNorm transform has now been populated with alocation and a scale that will be used to normalize the data.

Let us do a little sanity check for the shape of our summary stats:

print("normalization constant shape:",env.transform[0].loc.shape)
normalization constant shape: torch.Size([11])

An environment is not only defined by its simulator and transforms, but alsoby a series of metadata that describe what can be expected during itsexecution.For efficiency purposes, TorchRL is quite stringent when it comes toenvironment specs, but you can easily check that your environment specs areadequate.In our example, theGymWrapper andGymEnv that inheritsfrom it already take care of setting the proper specs for your environment soyou should not have to care about this.

Nevertheless, let’s see a concrete example using our transformedenvironment by looking at its specs.There are three specs to look at:observation_spec which defines whatis to be expected when executing an action in the environment,reward_spec which indicates the reward domain and finally theinput_spec (which contains theaction_spec) and which representseverything an environment requires to execute a single step.

print("observation_spec:",env.observation_spec)print("reward_spec:",env.reward_spec)print("input_spec:",env.input_spec)print("action_spec (as defined by input_spec):",env.action_spec)
observation_spec: Composite(    observation: UnboundedContinuous(        shape=torch.Size([11]),        space=ContinuousBox(            low=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True),            high=Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, contiguous=True)),        device=cpu,        dtype=torch.float32,        domain=continuous),    step_count: BoundedDiscrete(        shape=torch.Size([1]),        space=ContinuousBox(            low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),            high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),        device=cpu,        dtype=torch.int64,        domain=discrete),    device=cpu,    shape=torch.Size([]),    data_cls=None)reward_spec: UnboundedContinuous(    shape=torch.Size([1]),    space=ContinuousBox(        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),    device=cpu,    dtype=torch.float32,    domain=continuous)input_spec: Composite(    full_state_spec: Composite(        step_count: BoundedDiscrete(            shape=torch.Size([1]),            space=ContinuousBox(                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True),                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, contiguous=True)),            device=cpu,            dtype=torch.int64,            domain=discrete),        device=cpu,        shape=torch.Size([]),        data_cls=None),    full_action_spec: Composite(        action: BoundedContinuous(            shape=torch.Size([1]),            space=ContinuousBox(                low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),                high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),            device=cpu,            dtype=torch.float32,            domain=continuous),        device=cpu,        shape=torch.Size([]),        data_cls=None),    device=cpu,    shape=torch.Size([]),    data_cls=None)action_spec (as defined by input_spec): BoundedContinuous(    shape=torch.Size([1]),    space=ContinuousBox(        low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True),        high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)),    device=cpu,    dtype=torch.float32,    domain=continuous)

thecheck_env_specs() function runs a small rollout and compares its output against the environmentspecs. If no error is raised, we can be confident that the specs are properly defined:

2026-02-19 16:53:50,744 [torchrl][INFO]    check_env_specs succeeded! [END]

For fun, let’s see what a simple random rollout looks like. You cancallenv.rollout(n_steps) and get an overview of what the environment inputsand outputs look like. Actions will automatically be drawn from the action specdomain, so you don’t need to care about designing a random sampler.

Typically, at each step, an RL environment receives anaction as input, and outputs an observation, a reward and a done state. Theobservation may be composite, meaning that it could be composed of more than onetensor. This is not a problem for TorchRL, since the whole set of observationsis automatically packed in the outputTensorDict. After executing a rollout(for example, a sequence of environment steps and random action generations) over a givennumber of steps, we will retrieve aTensorDict instance with a shapethat matches this trajectory length:

rollout=env.rollout(3)print("rollout of three steps:",rollout)print("Shape of the rollout TensorDict:",rollout.batch_size)
rollout of three steps: TensorDict(    fields={        action: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),        done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),        next: TensorDict(            fields={                done: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),                observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),                reward: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),                step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),                terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),                truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},            batch_size=torch.Size([3]),            device=cpu,            is_shared=False),        observation: Tensor(shape=torch.Size([3, 11]), device=cpu, dtype=torch.float32, is_shared=False),        step_count: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.int64, is_shared=False),        terminated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False),        truncated: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.bool, is_shared=False)},    batch_size=torch.Size([3]),    device=cpu,    is_shared=False)Shape of the rollout TensorDict: torch.Size([3])

Our rollout data has a shape oftorch.Size([3]), which matches the number of stepswe ran it for. The"next" entry points to the data coming after the current step.In most cases, the"next" data at timet matches the data att+1, but thismay not be the case if we are using some specific transformations (for example, multi-step).

Policy#

PPO utilizes a stochastic policy to handle exploration. This means that ourneural network will have to output the parameters of a distribution, ratherthan a single value corresponding to the action taken.

As the data is continuous, we use a Tanh-Normal distribution to respect theaction space boundaries. TorchRL provides such distribution, and the onlything we need to care about is to build a neural network that outputs theright number of parameters for the policy to work with (a location, or mean,and a scale):

\[f_{\theta}(\text{observation}) = \mu_{\theta}(\text{observation}), \sigma^{+}_{\theta}(\text{observation})\]

The only extra-difficulty that is brought up here is to split our output in twoequal parts and map the second to a strictly positive space.

We design the policy in three steps:

  1. Define a neural networkD_obs ->2*D_action. Indeed, ourloc (mu) andscale (sigma) both have dimensionD_action.

  2. Append aNormalParamExtractor to extract a location and a scale (for example, splits the input in two equal parts and applies a positive transformation to the scale parameter).

  3. Create a probabilisticTensorDictModule that can generate this distribution and sample from it.

To enable the policy to “talk” with the environment through thetensordictdata carrier, we wrap thenn.Module in aTensorDictModule. Thisclass will simply ready thein_keys it is provided with and write theoutputs in-place at the registeredout_keys.

policy_module=TensorDictModule(actor_net,in_keys=["observation"],out_keys=["loc","scale"])

We now need to build a distribution out of the location and scale of ournormal distribution. To do so, we instruct theProbabilisticActorclass to build aTanhNormal out of the location and scaleparameters. We also provide the minimum and maximum values of thisdistribution, which we gather from the environment specs.

The name of thein_keys (and hence the name of theout_keys fromtheTensorDictModule above) cannot be set to any value one maylike, as theTanhNormal distribution constructor will expect theloc andscale keyword arguments. That being said,ProbabilisticActor also acceptsDict[str,str] typedin_keys where the key-value pair indicateswhatin_key string should be used for every keyword argument that is to be used.

policy_module=ProbabilisticActor(module=policy_module,spec=env.action_spec,in_keys=["loc","scale"],distribution_class=TanhNormal,distribution_kwargs={"low":env.action_spec.space.low,"high":env.action_spec.space.high,},return_log_prob=True,# we'll need the log-prob for the numerator of the importance weights)

Value network#

The value network is a crucial component of the PPO algorithm, even though itwon’t be used at inference time. This module will read the observations andreturn an estimation of the discounted return for the following trajectory.This allows us to amortize learning by relying on the some utility estimationthat is learned on-the-fly during training. Our value network share the samestructure as the policy, but for simplicity we assign it its own set ofparameters.

let’s try our policy and value modules. As we said earlier, the usage ofTensorDictModule makes it possible to directly read the outputof the environment to run these modules, as they know what information to readand where to write it:

print("Running policy:",policy_module(env.reset()))print("Running value:",value_module(env.reset()))
Running policy: TensorDict(    fields={        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),        action_log_prob: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),        loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),        observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),        scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},    batch_size=torch.Size([]),    device=cpu,    is_shared=False)Running value: TensorDict(    fields={        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),        observation: Tensor(shape=torch.Size([11]), device=cpu, dtype=torch.float32, is_shared=False),        state_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),        step_count: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},    batch_size=torch.Size([]),    device=cpu,    is_shared=False)

Data collector#

TorchRL provides a set ofDataCollector classes.Briefly, these classes execute three operations: reset an environment,compute an action given the latest observation, execute a step in the environment,and repeat the last two steps until the environment signals a stop (or reachesa done state).

They allow you to control how many frames to collect at each iteration(through theframes_per_batch parameter),when to reset the environment (through themax_frames_per_traj argument),on whichdevice the policy should be executed, etc. They are alsodesigned to work efficiently with batched and multiprocessed environments.

The simplest data collector is theSyncDataCollector:it is an iterator that you can use to get batches of data of a given length, andthat will stop once a total number of frames (total_frames) have beencollected.Other data collectors (MultiSyncDataCollector andMultiaSyncDataCollector) will executethe same operations in synchronous and asynchronous manner over aset of multiprocessed workers.

As for the policy and environment before, the data collector will returnTensorDict instances with a total number of elements that willmatchframes_per_batch. UsingTensorDict to pass data to thetraining loop allows you to write data loading pipelinesthat are 100% oblivious to the actual specificities of the rollout content.

collector=SyncDataCollector(env,policy_module,frames_per_batch=frames_per_batch,total_frames=total_frames,split_trajs=False,device=device,)

Replay buffer#

Replay buffers are a common building piece of off-policy RL algorithms.In on-policy contexts, a replay buffer is refilled every time a batch ofdata is collected, and its data is repeatedly consumed for a certain numberof epochs.

TorchRL’s replay buffers are built using a common containerReplayBuffer which takes as argument the componentsof the buffer: a storage, a writer, a sampler and possibly some transforms.Only the storage (which indicates the replay buffer capacity) is mandatory.We also specify a sampler without repetition to avoid sampling multiple timesthe same item in one epoch.Using a replay buffer for PPO is not mandatory and we could simplysample the sub-batches from the collected batch, but using these classesmake it easy for us to build the inner training loop in a reproducible way.

replay_buffer=ReplayBuffer(storage=LazyTensorStorage(max_size=frames_per_batch),sampler=SamplerWithoutReplacement(),)

Loss function#

The PPO loss can be directly imported from TorchRL for convenience using theClipPPOLoss class. This is the easiest way of utilizing PPO:it hides away the mathematical operations of PPO and the control flow thatgoes with it.

PPO requires some “advantage estimation” to be computed. In short, an advantageis a value that reflects an expectancy over the return value while dealing withthe bias / variance tradeoff.To compute the advantage, one just needs to (1) build the advantage module, whichutilizes our value operator, and (2) pass each batch of data through it before eachepoch.The GAE module will update the inputtensordict with new"advantage" and"value_target" entries.The"value_target" is a gradient-free tensor that represents the empiricalvalue that the value network should represent with the input observation.Both of these will be used byClipPPOLoss toreturn the policy and value losses.

advantage_module=GAE(gamma=gamma,lmbda=lmbda,value_network=value_module,average_gae=True,device=device,)loss_module=ClipPPOLoss(actor_network=policy_module,critic_network=value_module,clip_epsilon=clip_epsilon,entropy_bonus=bool(entropy_eps),entropy_coef=entropy_eps,# these keys match by default but we set this for completenesscritic_coef=1.0,loss_critic_type="smooth_l1",)optim=torch.optim.Adam(loss_module.parameters(),lr)scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optim,total_frames//frames_per_batch,0.0)
/usr/local/lib/python3.10/dist-packages/torchrl/objectives/ppo.py:445: DeprecationWarning:'critic_coef' is deprecated and will be removed in torchrl v0.11. Please use 'critic_coeff' instead./usr/local/lib/python3.10/dist-packages/torchrl/objectives/ppo.py:511: DeprecationWarning:'entropy_coef' is deprecated and will be removed in torchrl v0.11. Please use 'entropy_coeff' instead.

Training loop#

We now have all the pieces needed to code our training loop.The steps include:

  • Collect data

    • Compute advantage

      • Loop over the collected to compute loss values

      • Back propagate

      • Optimize

      • Repeat

    • Repeat

  • Repeat

logs=defaultdict(list)pbar=tqdm(total=total_frames)eval_str=""# We iterate over the collector until it reaches the total number of frames it was# designed to collect:fori,tensordict_datainenumerate(collector):# we now have a batch of data to work with. Let's learn something from it.for_inrange(num_epochs):# We'll need an "advantage" signal to make PPO work.# We re-compute it at each epoch as its value depends on the value# network which is updated in the inner loop.advantage_module(tensordict_data)data_view=tensordict_data.reshape(-1)replay_buffer.extend(data_view.cpu())for_inrange(frames_per_batch//sub_batch_size):subdata=replay_buffer.sample(sub_batch_size)loss_vals=loss_module(subdata.to(device))loss_value=(loss_vals["loss_objective"]+loss_vals["loss_critic"]+loss_vals["loss_entropy"])# Optimization: backward, grad clipping and optimization steploss_value.backward()# this is not strictly mandatory but it's good practice to keep# your gradient norm boundedtorch.nn.utils.clip_grad_norm_(loss_module.parameters(),max_grad_norm)optim.step()optim.zero_grad()logs["reward"].append(tensordict_data["next","reward"].mean().item())pbar.update(tensordict_data.numel())cum_reward_str=(f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})")logs["step_count"].append(tensordict_data["step_count"].max().item())stepcount_str=f"step count (max):{logs['step_count'][-1]}"logs["lr"].append(optim.param_groups[0]["lr"])lr_str=f"lr policy:{logs['lr'][-1]: 4.4f}"ifi%10==0:# We evaluate the policy once every 10 batches of data.# Evaluation is rather simple: execute the policy without exploration# (take the expected value of the action distribution) for a given# number of steps (1000, which is our ``env`` horizon).# The ``rollout`` method of the ``env`` can take a policy as argument:# it will then execute this policy at each step.withset_exploration_type(ExplorationType.DETERMINISTIC),torch.no_grad():# execute a rollout with the trained policyeval_rollout=env.rollout(1000,policy_module)logs["eval reward"].append(eval_rollout["next","reward"].mean().item())logs["eval reward (sum)"].append(eval_rollout["next","reward"].sum().item())logs["eval step_count"].append(eval_rollout["step_count"].max().item())eval_str=(f"eval cumulative reward:{logs['eval reward (sum)'][-1]: 4.4f} "f"(init:{logs['eval reward (sum)'][0]: 4.4f}), "f"eval step-count:{logs['eval step_count'][-1]}")deleval_rolloutpbar.set_description(", ".join([eval_str,cum_reward_str,stepcount_str,lr_str]))# We're also using a learning rate scheduler. Like the gradient clipping,# this is a nice-to-have but nothing necessary for PPO to work.scheduler.step()
  0%|          | 0/50000 [00:00<?, ?it/s]  2%|▏         | 1000/50000 [00:03<02:40, 306.09it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.0894 (init= 9.0894), step count (max): 10, lr policy:  0.0003:   2%|▏         | 1000/50000 [00:03<02:40, 306.09it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.0894 (init= 9.0894), step count (max): 10, lr policy:  0.0003:   4%|▍         | 2000/50000 [00:06<02:27, 325.73it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.1228 (init= 9.0894), step count (max): 19, lr policy:  0.0003:   4%|▍         | 2000/50000 [00:06<02:27, 325.73it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.1228 (init= 9.0894), step count (max): 19, lr policy:  0.0003:   6%|▌         | 3000/50000 [00:09<02:20, 334.75it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.1514 (init= 9.0894), step count (max): 19, lr policy:  0.0003:   6%|▌         | 3000/50000 [00:09<02:20, 334.75it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.1514 (init= 9.0894), step count (max): 19, lr policy:  0.0003:   8%|▊         | 4000/50000 [00:11<02:15, 340.12it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.1861 (init= 9.0894), step count (max): 23, lr policy:  0.0003:   8%|▊         | 4000/50000 [00:11<02:15, 340.12it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.1861 (init= 9.0894), step count (max): 23, lr policy:  0.0003:  10%|█         | 5000/50000 [00:14<02:10, 344.12it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2224 (init= 9.0894), step count (max): 33, lr policy:  0.0003:  10%|█         | 5000/50000 [00:14<02:10, 344.12it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2224 (init= 9.0894), step count (max): 33, lr policy:  0.0003:  12%|█▏        | 6000/50000 [00:17<02:06, 347.04it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2349 (init= 9.0894), step count (max): 25, lr policy:  0.0003:  12%|█▏        | 6000/50000 [00:17<02:06, 347.04it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2349 (init= 9.0894), step count (max): 25, lr policy:  0.0003:  14%|█▍        | 7000/50000 [00:20<02:03, 349.22it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2478 (init= 9.0894), step count (max): 36, lr policy:  0.0003:  14%|█▍        | 7000/50000 [00:20<02:03, 349.22it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2478 (init= 9.0894), step count (max): 36, lr policy:  0.0003:  16%|█▌        | 8000/50000 [00:23<02:01, 345.49it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2430 (init= 9.0894), step count (max): 35, lr policy:  0.0003:  16%|█▌        | 8000/50000 [00:23<02:01, 345.49it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2430 (init= 9.0894), step count (max): 35, lr policy:  0.0003:  18%|█▊        | 9000/50000 [00:26<01:57, 348.24it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2585 (init= 9.0894), step count (max): 48, lr policy:  0.0003:  18%|█▊        | 9000/50000 [00:26<01:57, 348.24it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2585 (init= 9.0894), step count (max): 48, lr policy:  0.0003:  20%|██        | 10000/50000 [00:29<01:54, 350.70it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2540 (init= 9.0894), step count (max): 44, lr policy:  0.0003:  20%|██        | 10000/50000 [00:29<01:54, 350.70it/s]eval cumulative reward:  110.5815 (init:  110.5815), eval step-count: 11, average reward= 9.2540 (init= 9.0894), step count (max): 44, lr policy:  0.0003:  22%|██▏       | 11000/50000 [00:31<01:50, 352.70it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2550 (init= 9.0894), step count (max): 59, lr policy:  0.0003:  22%|██▏       | 11000/50000 [00:31<01:50, 352.70it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2550 (init= 9.0894), step count (max): 59, lr policy:  0.0003:  24%|██▍       | 12000/50000 [00:34<01:47, 352.89it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2690 (init= 9.0894), step count (max): 55, lr policy:  0.0003:  24%|██▍       | 12000/50000 [00:34<01:47, 352.89it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2690 (init= 9.0894), step count (max): 55, lr policy:  0.0003:  26%|██▌       | 13000/50000 [00:37<01:44, 354.18it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2669 (init= 9.0894), step count (max): 49, lr policy:  0.0003:  26%|██▌       | 13000/50000 [00:37<01:44, 354.18it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2669 (init= 9.0894), step count (max): 49, lr policy:  0.0003:  28%|██▊       | 14000/50000 [00:40<01:41, 355.12it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2655 (init= 9.0894), step count (max): 48, lr policy:  0.0003:  28%|██▊       | 14000/50000 [00:40<01:41, 355.12it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2655 (init= 9.0894), step count (max): 48, lr policy:  0.0003:  30%|███       | 15000/50000 [00:43<01:38, 356.10it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2706 (init= 9.0894), step count (max): 44, lr policy:  0.0002:  30%|███       | 15000/50000 [00:43<01:38, 356.10it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2706 (init= 9.0894), step count (max): 44, lr policy:  0.0002:  32%|███▏      | 16000/50000 [00:45<01:36, 351.61it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2770 (init= 9.0894), step count (max): 71, lr policy:  0.0002:  32%|███▏      | 16000/50000 [00:45<01:36, 351.61it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2770 (init= 9.0894), step count (max): 71, lr policy:  0.0002:  34%|███▍      | 17000/50000 [00:48<01:33, 353.56it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2776 (init= 9.0894), step count (max): 83, lr policy:  0.0002:  34%|███▍      | 17000/50000 [00:48<01:33, 353.56it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2776 (init= 9.0894), step count (max): 83, lr policy:  0.0002:  36%|███▌      | 18000/50000 [00:51<01:30, 355.10it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2869 (init= 9.0894), step count (max): 70, lr policy:  0.0002:  36%|███▌      | 18000/50000 [00:51<01:30, 355.10it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2869 (init= 9.0894), step count (max): 70, lr policy:  0.0002:  38%|███▊      | 19000/50000 [00:54<01:27, 356.11it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2822 (init= 9.0894), step count (max): 64, lr policy:  0.0002:  38%|███▊      | 19000/50000 [00:54<01:27, 356.11it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2822 (init= 9.0894), step count (max): 64, lr policy:  0.0002:  40%|████      | 20000/50000 [00:57<01:24, 356.71it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2756 (init= 9.0894), step count (max): 50, lr policy:  0.0002:  40%|████      | 20000/50000 [00:57<01:24, 356.71it/s]eval cumulative reward:  222.6437 (init:  110.5815), eval step-count: 23, average reward= 9.2756 (init= 9.0894), step count (max): 50, lr policy:  0.0002:  42%|████▏     | 21000/50000 [00:59<01:21, 356.84it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2839 (init= 9.0894), step count (max): 68, lr policy:  0.0002:  42%|████▏     | 21000/50000 [01:00<01:21, 356.84it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2839 (init= 9.0894), step count (max): 68, lr policy:  0.0002:  44%|████▍     | 22000/50000 [01:02<01:19, 354.36it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2882 (init= 9.0894), step count (max): 83, lr policy:  0.0002:  44%|████▍     | 22000/50000 [01:02<01:19, 354.36it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2882 (init= 9.0894), step count (max): 83, lr policy:  0.0002:  46%|████▌     | 23000/50000 [01:05<01:17, 350.08it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2808 (init= 9.0894), step count (max): 48, lr policy:  0.0002:  46%|████▌     | 23000/50000 [01:05<01:17, 350.08it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2808 (init= 9.0894), step count (max): 48, lr policy:  0.0002:  48%|████▊     | 24000/50000 [01:08<01:13, 352.44it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2819 (init= 9.0894), step count (max): 49, lr policy:  0.0002:  48%|████▊     | 24000/50000 [01:08<01:13, 352.44it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2819 (init= 9.0894), step count (max): 49, lr policy:  0.0002:  50%|█████     | 25000/50000 [01:11<01:10, 353.81it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2738 (init= 9.0894), step count (max): 67, lr policy:  0.0002:  50%|█████     | 25000/50000 [01:11<01:10, 353.81it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2738 (init= 9.0894), step count (max): 67, lr policy:  0.0002:  52%|█████▏    | 26000/50000 [01:14<01:07, 355.24it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2739 (init= 9.0894), step count (max): 53, lr policy:  0.0001:  52%|█████▏    | 26000/50000 [01:14<01:07, 355.24it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2739 (init= 9.0894), step count (max): 53, lr policy:  0.0001:  54%|█████▍    | 27000/50000 [01:16<01:04, 356.31it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2821 (init= 9.0894), step count (max): 57, lr policy:  0.0001:  54%|█████▍    | 27000/50000 [01:16<01:04, 356.31it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2821 (init= 9.0894), step count (max): 57, lr policy:  0.0001:  56%|█████▌    | 28000/50000 [01:19<01:01, 357.23it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2988 (init= 9.0894), step count (max): 106, lr policy:  0.0001:  56%|█████▌    | 28000/50000 [01:19<01:01, 357.23it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2988 (init= 9.0894), step count (max): 106, lr policy:  0.0001:  58%|█████▊    | 29000/50000 [01:22<00:58, 357.72it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.3041 (init= 9.0894), step count (max): 79, lr policy:  0.0001:  58%|█████▊    | 29000/50000 [01:22<00:58, 357.72it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.3041 (init= 9.0894), step count (max): 79, lr policy:  0.0001:  60%|██████    | 30000/50000 [01:25<00:56, 352.68it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2961 (init= 9.0894), step count (max): 70, lr policy:  0.0001:  60%|██████    | 30000/50000 [01:25<00:56, 352.68it/s]eval cumulative reward:  390.5333 (init:  110.5815), eval step-count: 41, average reward= 9.2961 (init= 9.0894), step count (max): 70, lr policy:  0.0001:  62%|██████▏   | 31000/50000 [01:28<00:53, 354.69it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3000 (init= 9.0894), step count (max): 72, lr policy:  0.0001:  62%|██████▏   | 31000/50000 [01:28<00:53, 354.69it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3000 (init= 9.0894), step count (max): 72, lr policy:  0.0001:  64%|██████▍   | 32000/50000 [01:31<00:50, 354.20it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3085 (init= 9.0894), step count (max): 113, lr policy:  0.0001:  64%|██████▍   | 32000/50000 [01:31<00:50, 354.20it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3085 (init= 9.0894), step count (max): 113, lr policy:  0.0001:  66%|██████▌   | 33000/50000 [01:33<00:47, 355.66it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3189 (init= 9.0894), step count (max): 132, lr policy:  0.0001:  66%|██████▌   | 33000/50000 [01:33<00:47, 355.66it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3189 (init= 9.0894), step count (max): 132, lr policy:  0.0001:  68%|██████▊   | 34000/50000 [01:36<00:44, 356.57it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.2997 (init= 9.0894), step count (max): 108, lr policy:  0.0001:  68%|██████▊   | 34000/50000 [01:36<00:44, 356.57it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.2997 (init= 9.0894), step count (max): 108, lr policy:  0.0001:  70%|███████   | 35000/50000 [01:39<00:42, 356.90it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.2963 (init= 9.0894), step count (max): 86, lr policy:  0.0001:  70%|███████   | 35000/50000 [01:39<00:42, 356.90it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.2963 (init= 9.0894), step count (max): 86, lr policy:  0.0001:  72%|███████▏  | 36000/50000 [01:42<00:39, 357.87it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3058 (init= 9.0894), step count (max): 86, lr policy:  0.0001:  72%|███████▏  | 36000/50000 [01:42<00:39, 357.87it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3058 (init= 9.0894), step count (max): 86, lr policy:  0.0001:  74%|███████▍  | 37000/50000 [01:45<00:36, 352.77it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.2977 (init= 9.0894), step count (max): 66, lr policy:  0.0001:  74%|███████▍  | 37000/50000 [01:45<00:36, 352.77it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.2977 (init= 9.0894), step count (max): 66, lr policy:  0.0001:  76%|███████▌  | 38000/50000 [01:47<00:33, 354.77it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3025 (init= 9.0894), step count (max): 78, lr policy:  0.0000:  76%|███████▌  | 38000/50000 [01:47<00:33, 354.77it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3025 (init= 9.0894), step count (max): 78, lr policy:  0.0000:  78%|███████▊  | 39000/50000 [01:50<00:30, 354.93it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3023 (init= 9.0894), step count (max): 68, lr policy:  0.0000:  78%|███████▊  | 39000/50000 [01:50<00:30, 354.93it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3023 (init= 9.0894), step count (max): 68, lr policy:  0.0000:  80%|████████  | 40000/50000 [01:53<00:28, 356.27it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3067 (init= 9.0894), step count (max): 116, lr policy:  0.0000:  80%|████████  | 40000/50000 [01:53<00:28, 356.27it/s]eval cumulative reward:  287.5383 (init:  110.5815), eval step-count: 30, average reward= 9.3067 (init= 9.0894), step count (max): 116, lr policy:  0.0000:  82%|████████▏ | 41000/50000 [01:56<00:25, 357.44it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3070 (init= 9.0894), step count (max): 109, lr policy:  0.0000:  82%|████████▏ | 41000/50000 [01:56<00:25, 357.44it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3070 (init= 9.0894), step count (max): 109, lr policy:  0.0000:  84%|████████▍ | 42000/50000 [01:59<00:22, 353.72it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3119 (init= 9.0894), step count (max): 84, lr policy:  0.0000:  84%|████████▍ | 42000/50000 [01:59<00:22, 353.72it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3119 (init= 9.0894), step count (max): 84, lr policy:  0.0000:  86%|████████▌ | 43000/50000 [02:01<00:19, 353.91it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3047 (init= 9.0894), step count (max): 121, lr policy:  0.0000:  86%|████████▌ | 43000/50000 [02:01<00:19, 353.91it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3047 (init= 9.0894), step count (max): 121, lr policy:  0.0000:  88%|████████▊ | 44000/50000 [02:04<00:17, 350.25it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3091 (init= 9.0894), step count (max): 100, lr policy:  0.0000:  88%|████████▊ | 44000/50000 [02:04<00:17, 350.25it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3091 (init= 9.0894), step count (max): 100, lr policy:  0.0000:  90%|█████████ | 45000/50000 [02:07<00:14, 352.84it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3052 (init= 9.0894), step count (max): 106, lr policy:  0.0000:  90%|█████████ | 45000/50000 [02:07<00:14, 352.84it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3052 (init= 9.0894), step count (max): 106, lr policy:  0.0000:  92%|█████████▏| 46000/50000 [02:10<00:11, 354.80it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3169 (init= 9.0894), step count (max): 196, lr policy:  0.0000:  92%|█████████▏| 46000/50000 [02:10<00:11, 354.80it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3169 (init= 9.0894), step count (max): 196, lr policy:  0.0000:  94%|█████████▍| 47000/50000 [02:13<00:08, 356.31it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3181 (init= 9.0894), step count (max): 102, lr policy:  0.0000:  94%|█████████▍| 47000/50000 [02:13<00:08, 356.31it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3181 (init= 9.0894), step count (max): 102, lr policy:  0.0000:  96%|█████████▌| 48000/50000 [02:16<00:05, 357.21it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3172 (init= 9.0894), step count (max): 189, lr policy:  0.0000:  96%|█████████▌| 48000/50000 [02:16<00:05, 357.21it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3172 (init= 9.0894), step count (max): 189, lr policy:  0.0000:  98%|█████████▊| 49000/50000 [02:18<00:02, 357.67it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3229 (init= 9.0894), step count (max): 195, lr policy:  0.0000:  98%|█████████▊| 49000/50000 [02:18<00:02, 357.67it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3229 (init= 9.0894), step count (max): 195, lr policy:  0.0000: 100%|██████████| 50000/50000 [02:21<00:00, 352.64it/s]eval cumulative reward:  596.4940 (init:  110.5815), eval step-count: 63, average reward= 9.3117 (init= 9.0894), step count (max): 143, lr policy:  0.0000: 100%|██████████| 50000/50000 [02:21<00:00, 352.64it/s]

Results#

Before the 1M step cap is reached, the algorithm should have reached a maxstep count of 1000 steps, which is the maximum number of steps before thetrajectory is truncated.

plt.figure(figsize=(10,10))plt.subplot(2,2,1)plt.plot(logs["reward"])plt.title("training rewards (average)")plt.subplot(2,2,2)plt.plot(logs["step_count"])plt.title("Max step count (training)")plt.subplot(2,2,3)plt.plot(logs["eval reward (sum)"])plt.title("Return (test)")plt.subplot(2,2,4)plt.plot(logs["eval step_count"])plt.title("Max step count (test)")plt.show()
training rewards (average), Max step count (training), Return (test), Max step count (test)

Conclusion and next steps#

In this tutorial, we have learned:

  1. How to create and customize an environment withtorchrl;

  2. How to write a model and a loss function;

  3. How to set up a typical training loop.

If you want to experiment with this tutorial a bit more, you can apply the following modifications:

  • From an efficiency perspective,we could run several simulations in parallel to speed up data collection.CheckParallelEnv for further information.

  • From a logging perspective, one could add atorchrl.record.VideoRecorder transform tothe environment after asking for rendering to get a visual rendering of theinverted pendulum in action. Checktorchrl.record toknow more.

Total running time of the script: (2 minutes 23.584 seconds)