Note
Go to the endto download the full example code.
TorchRL objectives: Coding a DDPG loss#
Created On: Aug 14, 2023 | Last Updated: Mar 20, 2025 | Last Verified: Not Verified
Author:Vincent Moens
Overview#
TorchRL separates the training of RL algorithms in various pieces that will beassembled in your training script: the environment, the data collection andstorage, the model and finally the loss function.
TorchRL losses (or “objectives”) are stateful objects that contain thetrainable parameters (policy and value models).This tutorial will guide you through the steps to code a loss from the ground upusing TorchRL.
To this aim, we will be focusing on DDPG, which is a relatively straightforwardalgorithm to code.Deep Deterministic Policy Gradient (DDPG)is a simple continuous control algorithm. It consists in learning aparametric value function for an action-observation pair, andthen learning a policy that outputs actions that maximize this valuefunction given a certain observation.
What you will learn:
how to write a loss module and customize its value estimator;
how to build an environment in TorchRL, including transforms(for example, data normalization) and parallel execution;
how to design a policy and value network;
how to collect data from your environment efficiently and store themin a replay buffer;
how to store trajectories (and not transitions) in your replay buffer);
how to evaluate your model.
Prerequisites#
This tutorial assumes that you have completed thePPO tutorial which givesan overview of the TorchRL components and dependencies, such astensordict.TensorDict andtensordict.nn.TensorDictModules,although it should besufficiently transparent to be understood without a deep understanding ofthese classes.
Note
We do not aim at giving a SOTA implementation of the algorithm, but ratherto provide a high-level illustration of TorchRL’s loss implementationsand the library features that are to be used in the context ofthis algorithm.
Imports and setup#
%%bashpip3installtorchrlmujocoglfw
importtorchimporttqdm
We will execute the policy on CUDA if available
is_fork=multiprocessing.get_start_method()=="fork"device=(torch.device(0)iftorch.cuda.is_available()andnotis_forkelsetorch.device("cpu"))collector_device=torch.device("cpu")# Change the device to ``cuda`` to use CUDA
TorchRLLossModule#
TorchRL provides a series of losses to use in your training scripts.The aim is to have losses that are easily reusable/swappable and that havea simple signature.
The main characteristics of TorchRL losses are:
They are stateful objects: they contain a copy of the trainable parameterssuch that
loss_module.parameters()gives whatever is needed to train thealgorithm.They follow the
TensorDictconvention: thetorch.nn.Module.forward()method will receive a TensorDict as input that contains all the necessaryinformation to return a loss value.>>>data=replay_buffer.sample()>>>loss_dict=loss_module(data)
They output a
tensordict.TensorDictinstance with the loss valueswritten under a"loss_<smth>"wheresmthis a string describing theloss. Additional keys in theTensorDictmay be useful metrics to log duringtraining time.Note
The reason we return independent losses is to let the user use a differentoptimizer for different sets of parameters for instance. Summing the lossescan be simply done via
>>>loss_val=sum(lossforkey,lossinloss_dict.items()ifkey.startswith("loss_"))
The__init__ method#
The parent class of all losses isLossModule.As many other components of the library, itsforward() method expectsas input atensordict.TensorDict instance sampled from an experiencereplay buffer, or any similar data structure. Using this format makes itpossible to re-use the module acrossmodalities, or in complex settings where the model needs to read multipleentries for instance. In other words, it allows us to code a loss module thatis oblivious to the data type that is being given to is and that focuses onrunning the elementary steps of the loss function and only those.
To keep the tutorial as didactic as we can, we’ll be displaying each methodof the class independently and we’ll be populating the class at a laterstage.
Let us start with the__init__()method. DDPG aims at solving a control task with a simple strategy:training a policy to output actions that maximize the value predicted bya value network. Hence, our loss module needs to receive two networks in itsconstructor: an actor and a value networks. We expect both of these to beTensorDict-compatible objects, such astensordict.nn.TensorDictModule.Our loss function will need to compute a target value and fit the valuenetwork to this, and generate an action and fit the policy such that itsvalue estimate is maximized.
The crucial step of theLossModule.__init__() method is the call toconvert_to_functional(). This method will extractthe parameters from the module and convert it to a functional module.Strictly speaking, this is not necessary and one may perfectly code allthe losses without it. However, we encourage its usage for the followingreason.
The reason TorchRL does this is that RL algorithms often execute the samemodel with different sets of parameters, called “trainable” and “target”parameters.The “trainable” parameters are those that the optimizer needs to fit. The“target” parameters are usually a copy of the former’s with some time lag(absolute or diluted through a moving average).These target parameters are used to compute the value associated with thenext observation. One the advantages of using a set of target parametersfor the value model that do not match exactly the current configuration isthat they provide a pessimistic bound on the value function being computed.Pay attention to thecreate_target_params keyword argument below: thisargument tells theconvert_to_functional()method to create a set of target parameters in the loss module to be usedfor target value computation. If this is set toFalse (see the actor networkfor instance) thetarget_actor_network_params attribute will still beaccessible but this will just return adetached version of theactor parameters.
Later, we will see how the target parameters should be updated in TorchRL.
fromtensordict.nnimportTensorDictModule,TensorDictSequentialdef_init(self,actor_network:TensorDictModule,value_network:TensorDictModule,)->None:super(type(self),self).__init__()self.convert_to_functional(actor_network,"actor_network",create_target_params=True,)self.convert_to_functional(value_network,"value_network",create_target_params=True,compare_against=list(actor_network.parameters()),)self.actor_in_keys=actor_network.in_keys# Since the value we'll be using is based on the actor and value network,# we put them together in a single actor-critic container.actor_critic=ActorCriticWrapper(actor_network,value_network)self.actor_critic=actor_criticself.loss_function="l2"
The value estimator loss method#
In many RL algorithm, the value network (or Q-value network) is trained basedon an empirical value estimate. This can be bootstrapped (TD(0), lowvariance, high bias), meaningthat the target value is obtained using the next reward and nothing else, ora Monte-Carlo estimate can be obtained (TD(1)) in which case the wholesequence of upcoming rewards will be used (high variance, low bias). Anintermediate estimator (TD(\(\lambda\))) can also be used to compromisebias and variance.TorchRL makes it easy to use one or the other estimator via theValueEstimators Enum class, which containspointers to all the value estimators implemented. Let us define the defaultvalue function here. We will take the simplest version (TD(0)), and show lateron how this can be changed.
fromtorchrl.objectives.utilsimportValueEstimatorsdefault_value_estimator=ValueEstimators.TD0
We also need to give some instructions to DDPG on how to build the valueestimator, depending on the user query. Depending on the estimator provided,we will build the corresponding module to be used at train time:
fromtorchrl.objectives.utilsimportdefault_value_kwargsfromtorchrl.objectives.valueimportTD0Estimator,TD1Estimator,TDLambdaEstimatordefmake_value_estimator(self,value_type:ValueEstimators,**hyperparams):hp=dict(default_value_kwargs(value_type))ifhasattr(self,"gamma"):hp["gamma"]=self.gammahp.update(hyperparams)value_key="state_action_value"ifvalue_type==ValueEstimators.TD1:self._value_estimator=TD1Estimator(value_network=self.actor_critic,**hp)elifvalue_type==ValueEstimators.TD0:self._value_estimator=TD0Estimator(value_network=self.actor_critic,**hp)elifvalue_type==ValueEstimators.GAE:raiseNotImplementedError(f"Value type{value_type} it not implemented for loss{type(self)}.")elifvalue_type==ValueEstimators.TDLambda:self._value_estimator=TDLambdaEstimator(value_network=self.actor_critic,**hp)else:raiseNotImplementedError(f"Unknown value type{value_type}")self._value_estimator.set_keys(value=value_key)
Themake_value_estimator method can but does not need to be called: ifnot, theLossModule will query this method withits default estimator.
The actor loss method#
The central piece of an RL algorithm is the training loss for the actor.In the case of DDPG, this function is quite simple: we just need to computethe value associated with an action computed using the policy and optimizethe actor weights to maximize this value.
When computing this value, we must make sure to take the value parameters outof the graph, otherwise the actor and value loss will be mixed up.For this, thehold_out_params() functioncan be used.
def_loss_actor(self,tensordict,)->torch.Tensor:td_copy=tensordict.select(*self.actor_in_keys)# Get an action from the actor network: since we made it functional, we need to pass the paramswithself.actor_network_params.to_module(self.actor_network):td_copy=self.actor_network(td_copy)# get the value associated with that actionwithself.value_network_params.detach().to_module(self.value_network):td_copy=self.value_network(td_copy)return-td_copy.get("state_action_value")
The value loss method#
We now need to optimize our value network parameters.To do this, we will rely on the value estimator of our class:
fromtorchrl.objectives.utilsimportdistance_lossdef_loss_value(self,tensordict,):td_copy=tensordict.clone()# V(s, a)withself.value_network_params.to_module(self.value_network):self.value_network(td_copy)pred_val=td_copy.get("state_action_value").squeeze(-1)# we manually reconstruct the parameters of the actor-critic, where the first# set of parameters belongs to the actor and the second to the value function.target_params=TensorDict({"module":{"0":self.target_actor_network_params,"1":self.target_value_network_params,}},batch_size=self.target_actor_network_params.batch_size,device=self.target_actor_network_params.device,)withtarget_params.to_module(self.actor_critic):target_value=self.value_estimator.value_estimate(tensordict).squeeze(-1)# Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`loss_value=distance_loss(pred_val,target_value,loss_function=self.loss_function)td_error=(pred_val-target_value).pow(2)returnloss_value,td_error,pred_val,target_value
Putting things together in a forward call#
The only missing piece is the forward method, which will glue together thevalue and actor loss, collect the cost values and write them in aTensorDictdelivered to the user.
fromtensordictimportTensorDict,TensorDictBasedef_forward(self,input_tensordict:TensorDictBase)->TensorDict:loss_value,td_error,pred_val,target_value=self.loss_value(input_tensordict,)td_error=td_error.detach()td_error=td_error.unsqueeze(input_tensordict.ndimension())ifinput_tensordict.deviceisnotNone:td_error=td_error.to(input_tensordict.device)input_tensordict.set("td_error",td_error,inplace=True,)loss_actor=self.loss_actor(input_tensordict)returnTensorDict(source={"loss_actor":loss_actor.mean(),"loss_value":loss_value.mean(),"pred_value":pred_val.mean().detach(),"target_value":target_value.mean().detach(),"pred_value_max":pred_val.max().detach(),"target_value_max":target_value.max().detach(),},batch_size=[],)fromtorchrl.objectivesimportLossModuleclassDDPGLoss(LossModule):default_value_estimator=default_value_estimatormake_value_estimator=make_value_estimator__init__=_initforward=_forwardloss_value=_loss_valueloss_actor=_loss_actor
Now that we have our loss, we can use it to train a policy to solve acontrol task.
Environment#
In most algorithms, the first thing that needs to be taken care of is theconstruction of the environment as it conditions the remainder of thetraining script.
For this example, we will be using the"cheetah" task. The goal is to makea half-cheetah run as fast as possible.
In TorchRL, one can create such a task by relying ondm_control orgym:
env=GymEnv("HalfCheetah-v4")
or
env=DMControlEnv("cheetah","run")
By default, these environment disable rendering. Training from states isusually easier than training from images. To keep things simple, we focuson learning from states only. To pass the pixels to thetensordicts thatare collected byenv.step(), simply pass thefrom_pixels=Trueargument to the constructor:
env=GymEnv("HalfCheetah-v4",from_pixels=True,pixels_only=True)
We write amake_env() helper function that will create an environmentwith either one of the two backends considered above (dm-control orgym).
fromtorchrl.envs.libs.dm_controlimportDMControlEnvfromtorchrl.envs.libs.gymimportGymEnvenv_library=Noneenv_name=Nonedefmake_env(from_pixels=False):"""Create a base ``env``."""globalenv_libraryglobalenv_nameifbackend=="dm_control":env_name="cheetah"env_task="run"env_args=(env_name,env_task)env_library=DMControlEnvelifbackend=="gym":env_name="HalfCheetah-v4"env_args=(env_name,)env_library=GymEnvelse:raiseNotImplementedErrorenv_kwargs={"device":device,"from_pixels":from_pixels,"pixels_only":from_pixels,"frame_skip":2,}env=env_library(*env_args,**env_kwargs)returnenv
Transforms#
Now that we have a base environment, we may want to modify its representationto make it more policy-friendly. In TorchRL, transforms are appended to thebase environment in a specializedtorchr.envs.TransformedEnv class.
It is common in DDPG to rescale the reward using some heuristic value. Wewill multiply the reward by 5 in this example.
If we are using
dm_control, it is also important to build an interfacebetween the simulator which works with double precision numbers, and ourscript which presumably uses single precision ones. This transformation goesboth ways: when callingenv.step(), our actions will need to berepresented in double precision, and the output will need to be transformedto single precision.TheDoubleToFloattransform does exactly this: thein_keyslist refers to the keys that will need to be transformed fromdouble to float, while thein_keys_invrefers to those that need tobe transformed to double before being passed to the environment.We concatenate the state keys together using the
CatTensorstransform.Finally, we also leave the possibility of normalizing the states: we willtake care of computing the normalizing constants later on.
fromtorchrl.envsimport(CatTensors,DoubleToFloat,EnvCreator,InitTracker,ObservationNorm,ParallelEnv,RewardScaling,StepCounter,TransformedEnv,)defmake_transformed_env(env,):"""Apply transforms to the ``env`` (such as reward scaling and state normalization)."""env=TransformedEnv(env)# we append transforms one by one, although we might as well create the# transformed environment using the `env = TransformedEnv(base_env, transforms)`# syntax.env.append_transform(RewardScaling(loc=0.0,scale=reward_scaling))# We concatenate all states into a single "observation_vector"# even if there is a single tensor, it'll be renamed in "observation_vector".# This facilitates the downstream operations as we know the name of the# output tensor.# In some environments (not half-cheetah), there may be more than one# observation vector: in this case this code snippet will concatenate them# all.selected_keys=list(env.observation_spec.keys())out_key="observation_vector"env.append_transform(CatTensors(in_keys=selected_keys,out_key=out_key))# we normalize the states, but for now let's just instantiate a stateless# version of the transformenv.append_transform(ObservationNorm(in_keys=[out_key],standard_normal=True))env.append_transform(DoubleToFloat())env.append_transform(StepCounter(max_frames_per_traj))# We need a marker for the start of trajectories for our Ornstein-Uhlenbeck (OU)# exploration:env.append_transform(InitTracker())returnenv
Parallel execution#
The following helper function allows us to run environments in parallel.Running environments in parallel can significantly speed up the collectionthroughput. When using transformed environment, we need to choose whether wewant to execute the transform individually for each environment, orcentralize the data and transform it in batch. Both approaches are easy tocode:
env=ParallelEnv(lambda:TransformedEnv(GymEnv("HalfCheetah-v4"),transforms),num_workers=4)env=TransformedEnv(ParallelEnv(lambda:GymEnv("HalfCheetah-v4"),num_workers=4),transforms)
To leverage the vectorization capabilities of PyTorch, we adoptthe first method:
defparallel_env_constructor(env_per_collector,transform_state_dict,):ifenv_per_collector==1:defmake_t_env():env=make_transformed_env(make_env())env.transform[2].init_stats(3)env.transform[2].loc.copy_(transform_state_dict["loc"])env.transform[2].scale.copy_(transform_state_dict["scale"])returnenvenv_creator=EnvCreator(make_t_env)returnenv_creatorparallel_env=ParallelEnv(num_workers=env_per_collector,create_env_fn=EnvCreator(lambda:make_env()),create_env_kwargs=None,pin_memory=False,)env=make_transformed_env(parallel_env)# we call `init_stats` for a limited number of steps, just to instantiate# the lazy buffers.env.transform[2].init_stats(3,cat_dim=1,reduce_dim=[0,1])env.transform[2].load_state_dict(transform_state_dict)returnenv# The backend can be ``gym`` or ``dm_control``backend="gym"
Note
frame_skip batches multiple step together with a single actionIf > 1, the other frame counts (for example, frames_per_batch, total_frames)need to be adjusted to have a consistent total number of frames collectedacross experiments. This is important as raising the frame-skip but keeping thetotal number of frames unchanged may seem like cheating: all things compared,a dataset of 10M elements collected with a frame-skip of 2 and another witha frame-skip of 1 actually have a ratio of interactions with the environmentof 2:1! In a nutshell, one should be cautious about the frame-count of atraining script when dealing with frame skipping as this may lead tobiased comparisons between training strategies.
Scaling the reward helps us control the signal magnitude for a moreefficient learning.
reward_scaling=5.0
We also define when a trajectory will be truncated. A thousand steps (500 ifframe-skip = 2) is a good number to use for the cheetah task:
max_frames_per_traj=500
Normalization of the observations#
To compute the normalizing statistics, we run an arbitrary number of randomsteps in the environment and compute the mean and standard deviation of thecollected observations. TheObservationNorm.init_stats() method canbe used for this purpose. To get the summary statistics, we create a dummyenvironment and run it for a given number of steps, collect data over a givennumber of steps and compute its summary statistics.
defget_env_stats():"""Gets the stats of an environment."""proof_env=make_transformed_env(make_env())t=proof_env.transform[2]t.init_stats(init_env_steps)transform_state_dict=t.state_dict()proof_env.close()returntransform_state_dict
Normalization stats#
Number of random steps used as for stats computation usingObservationNorm
init_env_steps=5000transform_state_dict=get_env_stats()
Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.Users of this version of Gym should be able to simply replace 'import gym' with 'import gymnasium as gym' in the vast majority of cases.See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
Number of environments in each data collector
env_per_collector=4
We pass the stats computed earlier to normalize the output of ourenvironment:
parallel_env=parallel_env_constructor(env_per_collector=env_per_collector,transform_state_dict=transform_state_dict,)fromtorchrl.dataimportCompositeSpec
Building the model#
We now turn to the setup of the model. As we have seen, DDPG requires avalue network, trained to estimate the value of a state-action pair, and aparametric actor that learns how to select actions that maximize this value.
Recall that building a TorchRL module requires two steps:
writing the
torch.nn.Modulethat will be used as network,wrapping the network in a
tensordict.nn.TensorDictModulewhere thedata flow is handled by specifying the input and output keys.
In more complex scenarios,tensordict.nn.TensorDictSequential canalso be used.
The Q-Value network is wrapped in aValueOperatorthat automatically sets theout_keys to"state_action_value for q-valuenetworks andstate_value for other value networks.
TorchRL provides a built-in version of the DDPG networks as presented in theoriginal paper. These can be found underDdpgMlpActorandDdpgMlpQNet.
Since we use lazy modules, it is necessary to materialize the lazy modulesbefore being able to move the policy from device to device and achieve otheroperations. Hence, it is good practice to run the modules with a smallsample of data. For this purpose, we generate fake data from theenvironment specs.
fromtorchrl.modulesimport(ActorCriticWrapper,DdpgMlpActor,DdpgMlpQNet,OrnsteinUhlenbeckProcessModule,ProbabilisticActor,TanhDelta,ValueOperator,)defmake_ddpg_actor(transform_state_dict,device="cpu",):proof_environment=make_transformed_env(make_env())proof_environment.transform[2].init_stats(3)proof_environment.transform[2].load_state_dict(transform_state_dict)out_features=proof_environment.action_spec.shape[-1]actor_net=DdpgMlpActor(action_dim=out_features,)in_keys=["observation_vector"]out_keys=["param"]actor=TensorDictModule(actor_net,in_keys=in_keys,out_keys=out_keys,)actor=ProbabilisticActor(actor,distribution_class=TanhDelta,in_keys=["param"],spec=CompositeSpec(action=proof_environment.action_spec),).to(device)q_net=DdpgMlpQNet()in_keys=in_keys+["action"]qnet=ValueOperator(in_keys=in_keys,module=q_net,).to(device)# initialize lazy modulesqnet(actor(proof_environment.reset().to(device)))returnactor,qnetactor,qnet=make_ddpg_actor(transform_state_dict=transform_state_dict,device=device,)
/usr/local/lib/python3.10/dist-packages/torchrl/data/tensor_specs.py:7085: DeprecationWarning:The CompositeSpec has been deprecated and will be removed in v0.8. Please use Composite instead.
Exploration#
The policy is passed into aOrnsteinUhlenbeckProcessModuleexploration module, as suggested in the original paper.Let’s define the number of frames before OU noise reaches its minimum value
annealing_frames=1_000_000actor_model_explore=TensorDictSequential(actor,OrnsteinUhlenbeckProcessModule(spec=actor.spec.clone(),annealing_num_steps=annealing_frames,).to(device),)ifdevice==torch.device("cpu"):actor_model_explore.share_memory()
Data collector#
TorchRL provides specialized classes to help you collect data by executingthe policy in the environment. These “data collectors” iteratively computethe action to be executed at a given time, then execute a step in theenvironment and reset it when required.Data collectors are designed to help developers have a tight controlon the number of frames per batch of data, on the (a)sync nature of thiscollection and on the resources allocated to the data collection (for exampleGPU, number of workers, and so on).
Here we will useSyncDataCollector, a simple, single-processdata collector. TorchRL offers other collectors, such asMultiaSyncDataCollector, which executed therollouts in an asynchronous manner (for example, data will be collected whilethe policy is being optimized, thereby decoupling the training anddata collection).
The parameters to specify are:
an environment factory or an environment,
the policy,
the total number of frames before the collector is considered empty,
the maximum number of frames per trajectory (useful for non-terminatingenvironments, like
dm_controlones).Note
The
max_frames_per_trajpassed to the collector will have the effectof registering a newStepCountertransformwith the environment used for inference. We can achieve the same resultmanually, as we do in this script.
One should also pass:
the number of frames in each batch collected,
the number of random steps executed independently from the policy,
the devices used for policy execution
the devices used to store data before the data is passed to the mainprocess.
The total frames we will use during training should be around 1M.
total_frames=10_000# 1_000_000
The number of frames returned by the collector at each iteration of the outerloop is equal to the length of each sub-trajectories times the number ofenvironments run in parallel in each collector.
In other words, we expect batches from the collector to have a shape[env_per_collector,traj_len] wheretraj_len=frames_per_batch/env_per_collector:
traj_len=200frames_per_batch=env_per_collector*traj_leninit_random_frames=5000num_collectors=2fromtorchrl.collectorsimportSyncDataCollectorfromtorchrl.envsimportExplorationTypecollector=SyncDataCollector(parallel_env,policy=actor_model_explore,total_frames=total_frames,frames_per_batch=frames_per_batch,init_random_frames=init_random_frames,reset_at_each_iter=False,split_trajs=False,device=collector_device,exploration_type=ExplorationType.RANDOM,)
Evaluator: building your recorder object#
As the training data is obtained using some exploration strategy, the trueperformance of our algorithm needs to be assessed in deterministic mode. Wedo this using a dedicated class,Recorder, which executes the policy inthe environment at a given frequency and returns some statistics obtainedfrom these simulations.
The following helper function builds this object:
fromtorchrl.trainersimportRecorderdefmake_recorder(actor_model_explore,transform_state_dict,record_interval):base_env=make_env()environment=make_transformed_env(base_env)environment.transform[2].init_stats(3)# must be instantiated to load the state dictenvironment.transform[2].load_state_dict(transform_state_dict)recorder_obj=Recorder(record_frames=1000,policy_exploration=actor_model_explore,environment=environment,exploration_type=ExplorationType.DETERMINISTIC,record_interval=record_interval,)returnrecorder_obj
We will be recording the performance every 10 batch collected
record_interval=10recorder=make_recorder(actor_model_explore,transform_state_dict,record_interval=record_interval)fromtorchrl.data.replay_buffersimport(LazyMemmapStorage,PrioritizedSampler,RandomSampler,TensorDictReplayBuffer,)
Replay buffer#
Replay buffers come in two flavors: prioritized (where some error signalis used to give a higher likelihood of sampling to some items than others)and regular, circular experience replay.
TorchRL replay buffers are composable: one can pick up the storage, samplingand writing strategies. It is also possible tostore tensors on physical memory using a memory-mapped array. The followingfunction takes care of creating the replay buffer with the desiredhyperparameters:
fromtorchrl.envsimportRandomCropTensorDictdefmake_replay_buffer(buffer_size,batch_size,random_crop_len,prefetch=3,prb=False):ifprb:sampler=PrioritizedSampler(max_capacity=buffer_size,alpha=0.7,beta=0.5,)else:sampler=RandomSampler()replay_buffer=TensorDictReplayBuffer(storage=LazyMemmapStorage(buffer_size,scratch_dir=buffer_scratch_dir,),batch_size=batch_size,sampler=sampler,pin_memory=False,prefetch=prefetch,transform=RandomCropTensorDict(random_crop_len,sample_dim=1),)returnreplay_buffer
We’ll store the replay buffer in a temporary directory on disk
importtempfiletmpdir=tempfile.TemporaryDirectory()buffer_scratch_dir=tmpdir.name
Replay buffer storage and batch size#
TorchRL replay buffer counts the number of elements along the first dimension.Since we’ll be feeding trajectories to our buffer, we need to adapt the buffersize by dividing it by the length of the sub-trajectories yielded by ourdata collector.Regarding the batch-size, our sampling strategy will consist in samplingtrajectories of lengthtraj_len=200 before selecting sub-trajectoriesor lengthrandom_crop_len=25 on which the loss will be computed.This strategy balances the choice of storing whole trajectories of a certainlength with the need for providing samples with a sufficient heterogeneityto our loss. The following figure shows the dataflow from a collectorthat gets 8 frames in each batch with 2 environments run in parallel,feeds them to a replay buffer that contains 1000 trajectories andsamples sub-trajectories of 2 time steps each.

Let’s start with the number of frames stored in the buffer
defceil_div(x,y):return-x//(-y)buffer_size=1_000_000buffer_size=ceil_div(buffer_size,traj_len)
Prioritized replay buffer is disabled by default
prb=False
We also need to define how many updates we’ll be doing per batch of datacollected. This is known as the update-to-data orUTD ratio:
update_to_data=64
We’ll be feeding the loss with trajectories of length 25:
random_crop_len=25
In the original paper, the authors perform one update with a batch of 64elements for each frame collected. Here, we reproduce the same ratiobut while realizing several updates at each batch collection. Weadapt our batch-size to achieve the same number of update-per-frame ratio:
batch_size=ceil_div(64*frames_per_batch,update_to_data*random_crop_len)replay_buffer=make_replay_buffer(buffer_size=buffer_size,batch_size=batch_size,random_crop_len=random_crop_len,prefetch=3,prb=prb,)
Loss module construction#
We build our loss module with the actor andqnet we’ve just created.Because we have target parameters to update, we _must_ create a target networkupdater.
let’s use the TD(lambda) estimator!
loss_module.make_value_estimator(ValueEstimators.TDLambda,gamma=gamma,lmbda=lmbda,device=device)
Note
Off-policy usually dictates a TD(0) estimator. Here, we use a TD(\(\lambda\))estimator, which will introduce some bias as the trajectory that followsa certain state has been collected with an outdated policy.This trick, as the multi-step trick that can be used during data collection,are alternative versions of “hacks” that we usually find to work well inpractice despite the fact that they introduce some bias in the returnestimates.
Target network updater#
Target networks are a crucial part of off-policy RL algorithms.Updating the target network parameters is made easy thanks to theHardUpdate andSoftUpdateclasses. They’re built with the loss module as argument, and the update isachieved via a call toupdater.step() at the appropriate location in thetraining loop.
fromtorchrl.objectives.utilsimportSoftUpdatetarget_net_updater=SoftUpdate(loss_module,eps=1-tau)
Optimizer#
Finally, we will use the Adam optimizer for the policy and value network:
fromtorchimportoptimoptimizer_actor=optim.Adam(loss_module.actor_network_params.values(True,True),lr=1e-4,weight_decay=0.0)optimizer_value=optim.Adam(loss_module.value_network_params.values(True,True),lr=1e-3,weight_decay=1e-2)total_collection_steps=total_frames//frames_per_batch
Time to train the policy#
The training loop is pretty straightforward now that we have built all themodules we need.
rewards=[]rewards_eval=[]# Main loopcollected_frames=0pbar=tqdm.tqdm(total=total_frames)r0=Nonefori,tensordictinenumerate(collector):# update weights of the inference policycollector.update_policy_weights_()ifr0isNone:r0=tensordict["next","reward"].mean().item()pbar.update(tensordict.numel())# extend the replay buffer with the new datacurrent_frames=tensordict.numel()collected_frames+=current_framesreplay_buffer.extend(tensordict.cpu())# optimization stepsifcollected_frames>=init_random_frames:for_inrange(update_to_data):# sample from replay buffersampled_tensordict=replay_buffer.sample().to(device)# Compute lossloss_dict=loss_module(sampled_tensordict)# optimizeloss_dict["loss_actor"].backward()gn1=torch.nn.utils.clip_grad_norm_(loss_module.actor_network_params.values(True,True),10.0)optimizer_actor.step()optimizer_actor.zero_grad()loss_dict["loss_value"].backward()gn2=torch.nn.utils.clip_grad_norm_(loss_module.value_network_params.values(True,True),10.0)optimizer_value.step()optimizer_value.zero_grad()gn=(gn1**2+gn2**2)**0.5# update priorityifprb:replay_buffer.update_tensordict_priority(sampled_tensordict)# update target networktarget_net_updater.step()rewards.append((i,tensordict["next","reward"].mean().item(),))td_record=recorder(None)iftd_recordisnotNone:rewards_eval.append((i,td_record["r_evaluation"].item()))iflen(rewards_eval)andcollected_frames>=init_random_frames:target_value=loss_dict["target_value"].item()loss_value=loss_dict["loss_value"].item()loss_actor=loss_dict["loss_actor"].item()rn=sampled_tensordict["next","reward"].mean().item()rs=sampled_tensordict["next","reward"].std().item()pbar.set_description(f"reward:{rewards[-1][1]: 4.2f} (r0 ={r0: 4.2f}), "f"reward eval: reward:{rewards_eval[-1][1]: 4.2f}, "f"reward normalized={rn:4.2f}/{rs:4.2f}, "f"grad norm={gn: 4.2f}, "f"loss_value={loss_value: 4.2f}, "f"loss_actor={loss_actor: 4.2f}, "f"target value:{target_value: 4.2f}")# update the exploration strategyactor_model_explore[1].step(current_frames)collector.shutdown()delcollector
0%| | 0/10000 [00:00<?, ?it/s] 8%|▊ | 800/10000 [00:00<00:06, 1349.02it/s] 16%|█▌ | 1600/10000 [00:02<00:15, 538.16it/s] 24%|██▍ | 2400/10000 [00:03<00:09, 812.60it/s] 32%|███▏ | 3200/10000 [00:03<00:06, 1073.09it/s] 40%|████ | 4000/10000 [00:03<00:04, 1302.59it/s] 48%|████▊ | 4800/10000 [00:04<00:03, 1499.51it/s] 56%|█████▌ | 5600/10000 [00:04<00:02, 1640.71it/s]reward: -2.52 (r0 = -1.82), reward eval: reward: -0.01, reward normalized=-1.89/6.00, grad norm= 114.01, loss_value= 290.27, loss_actor= 12.75, target value: -10.96: 56%|█████▌ | 5600/10000 [00:06<00:02, 1640.71it/s]reward: -2.52 (r0 = -1.82), reward eval: reward: -0.01, reward normalized=-1.89/6.00, grad norm= 114.01, loss_value= 290.27, loss_actor= 12.75, target value: -10.96: 64%|██████▍ | 6400/10000 [00:07<00:05, 717.95it/s]reward: -1.72 (r0 = -1.82), reward eval: reward: -0.01, reward normalized=-2.20/5.53, grad norm= 43.75, loss_value= 145.73, loss_actor= 13.58, target value: -13.91: 64%|██████▍ | 6400/10000 [00:08<00:05, 717.95it/s]reward: -1.72 (r0 = -1.82), reward eval: reward: -0.01, reward normalized=-2.20/5.53, grad norm= 43.75, loss_value= 145.73, loss_actor= 13.58, target value: -13.91: 72%|███████▏ | 7200/10000 [00:09<00:05, 528.34it/s]reward: -5.18 (r0 = -1.82), reward eval: reward: -0.01, reward normalized=-2.77/5.28, grad norm= 171.44, loss_value= 276.55, loss_actor= 16.65, target value: -18.04: 72%|███████▏ | 7200/10000 [00:11<00:05, 528.34it/s]reward: -5.18 (r0 = -1.82), reward eval: reward: -0.01, reward normalized=-2.77/5.28, grad norm= 171.44, loss_value= 276.55, loss_actor= 16.65, target value: -18.04: 80%|████████ | 8000/10000 [00:11<00:04, 447.82it/s]reward: -4.87 (r0 = -1.82), reward eval: reward: -0.01, reward normalized=-3.02/4.59, grad norm= 51.04, loss_value= 204.47, loss_actor= 17.84, target value: -19.86: 80%|████████ | 8000/10000 [00:13<00:04, 447.82it/s]reward: -4.87 (r0 = -1.82), reward eval: reward: -0.01, reward normalized=-3.02/4.59, grad norm= 51.04, loss_value= 204.47, loss_actor= 17.84, target value: -19.86: 88%|████████▊ | 8800/10000 [00:14<00:02, 406.51it/s]reward: -22.11 (r0 = -1.82), reward eval: reward: -4.59, reward normalized=-5.24/7.70, grad norm= 256.75, loss_value= 603.01, loss_actor= 30.79, target value: -35.40: 88%|████████▊ | 8800/10000 [00:17<00:02, 406.51it/s]reward: -22.11 (r0 = -1.82), reward eval: reward: -4.59, reward normalized=-5.24/7.70, grad norm= 256.75, loss_value= 603.01, loss_actor= 30.79, target value: -35.40: 96%|█████████▌| 9600/10000 [00:18<00:01, 304.70it/s]reward: -4.26 (r0 = -1.82), reward eval: reward: -4.59, reward normalized=-5.23/8.50, grad norm= 255.76, loss_value= 587.21, loss_actor= 35.35, target value: -35.19: 96%|█████████▌| 9600/10000 [00:20<00:01, 304.70it/s]reward: -4.26 (r0 = -1.82), reward eval: reward: -4.59, reward normalized=-5.23/8.50, grad norm= 255.76, loss_value= 587.21, loss_actor= 35.35, target value: -35.19: : 10400it [00:21, 304.58it/s]reward: -3.02 (r0 = -1.82), reward eval: reward: -4.59, reward normalized=-3.04/5.97, grad norm= 158.36, loss_value= 250.89, loss_actor= 22.75, target value: -21.72: : 10400it [00:23, 304.58it/s]
Experiment results#
We make a simple plot of the average rewards during training. We can observethat our policy learned quite well to solve the task.
Note
As already mentioned above, to get a more reasonable performance,use a greater value fortotal_frames for example, 1M.
frommatplotlibimportpyplotaspltplt.figure()plt.plot(*zip(*rewards),label="training")plt.plot(*zip(*rewards_eval),label="eval")plt.legend()plt.xlabel("iter")plt.ylabel("reward")plt.tight_layout()

Conclusion#
In this tutorial, we have learned how to code a loss module in TorchRL giventhe concrete example of DDPG.
The key takeaways are:
How to use the
LossModuleclass to code up a newloss component;How to use (or not) a target network, and how to update its parameters;
How to create an optimizer associated with a loss module.
Next Steps#
To iterate further on this loss module we might consider:
Using@dispatch (see[Feature] Distpatch IQL loss module.)
Allowing flexible TensorDict keys.
Total running time of the script: (0 minutes 29.412 seconds)