Note
Go to the endto download the full example code.
Recurrent DQN: Training recurrent policies#
Created On: Nov 08, 2023 | Last Updated: Jan 27, 2025 | Last Verified: Not Verified
Author:Vincent Moens
How to incorporating an RNN in an actor in TorchRL
How to use that memory-based policy with a replay buffer and a loss module
PyTorch v2.0.0
gym[mujoco]
tqdm
Overview#
Memory-based policies are crucial not only when the observations are partiallyobservable but also when the time dimension must be taken into account tomake informed decisions.
Recurrent neural network have long been a popular tool for memory-basedpolicies. The idea is to keep a recurrent state in memory between twoconsecutive steps, and use this as an input to the policy along with thecurrent observation.
This tutorial shows how to incorporate an RNN in a policy using TorchRL.
Key learnings:
Incorporating an RNN in an actor in TorchRL;
Using that memory-based policy with a replay buffer and a loss module.
The core idea of using RNNs in TorchRL is to use TensorDict as a data carrierfor the hidden states from one step to another. We’ll build a policy thatreads the previous recurrent state from the current TensorDict, and writes thecurrent recurrent states in the TensorDict of the next state:

As this figure shows, our environment populates the TensorDict with zeroed recurrentstates which are read by the policy together with the observation to produce anaction, and recurrent states that will be used for the next step.When thestep_mdp() function is called, the recurrent statesfrom the next state are brought to the current TensorDict. Let’s see how thisis implemented in practice.
If you are running this in Google Colab, make sure you install the following dependencies:
!pip3installtorchrl!pip3installgym[mujoco]!pip3installtqdm
Setup#
importtorchimporttqdmfromtensordict.nnimportTensorDictModuleasMod,TensorDictSequentialasSeqfromtorchimportnnfromtorchrl.collectorsimportSyncDataCollectorfromtorchrl.dataimportLazyMemmapStorage,TensorDictReplayBufferfromtorchrl.envsimport(Compose,ExplorationType,GrayScale,InitTracker,ObservationNorm,Resize,RewardScaling,set_exploration_type,StepCounter,ToTensorImage,TransformedEnv,)fromtorchrl.envs.libs.gymimportGymEnvfromtorchrl.modulesimportConvNet,EGreedyModule,LSTMModule,MLP,QValueModulefromtorchrl.objectivesimportDQNLoss,SoftUpdateis_fork=multiprocessing.get_start_method()=="fork"device=(torch.device(0)iftorch.cuda.is_available()andnotis_forkelsetorch.device("cpu"))
Environment#
As usual, the first step is to build our environment: it helps usdefine the problem and build the policy network accordingly. For this tutorial,we’ll be running a single pixel-based instance of the CartPole gymenvironment with some custom transforms: turning to grayscale, resizing to84x84, scaling down the rewards and normalizing the observations.
Note
TheStepCounter transform is accessory. Since the CartPoletask goal is to make trajectories as long as possible, counting the stepscan help us track the performance of our policy.
Two transforms are important for the purpose of this tutorial:
InitTrackerwill stamp thecalls toreset()by adding a"is_init"boolean mask in the TensorDict that will track which steps require a resetof the RNN hidden states.The
TensorDictPrimertransform is a bit moretechnical. It is not required to use RNN policies. However, itinstructs the environment (and subsequently the collector) that some extrakeys are to be expected. Once added, a call toenv.reset() will populatethe entries indicated in the primer with zeroed tensors. Knowing thatthese tensors are expected by the policy, the collector will pass them onduring collection. Eventually, we’ll be storing our hidden states in thereplay buffer, which will help us bootstrap the computation of theRNN operations in the loss module (which would otherwise be initiatedwith 0s). In summary: not including this transform will not impact hugelythe training of our policy, but it will make the recurrent keys disappearfrom the collected data and the replay buffer, which will in turn lead toa slightly less optimal training.Fortunately, theLSTMModulewe propose isequipped with a helper method to build just that transform for us, sowe can wait until we build it!
env=TransformedEnv(GymEnv("CartPole-v1",from_pixels=True,device=device),Compose(ToTensorImage(),GrayScale(),Resize(84,84),StepCounter(),InitTracker(),RewardScaling(loc=0.0,scale=0.1),ObservationNorm(standard_normal=True,in_keys=["pixels"]),),)
As always, we need to initialize manually our normalization constants:
env.transform[-1].init_stats(1000,reduce_dim=[0,1,2],cat_dim=0,keep_dims=[0])td=env.reset()
Policy#
Our policy will have 3 components: aConvNetbackbone, anLSTMModule memory layer and a shallowMLP block that will map the LSTM output onto theaction values.
Convolutional network#
We build a convolutional network flanked with atorch.nn.AdaptiveAvgPool2dthat will squash the output in a vector of size 64. TheConvNetcan assist us with this:
feature=Mod(ConvNet(num_cells=[32,32,64],squeeze_output=True,aggregator_class=nn.AdaptiveAvgPool2d,aggregator_kwargs={"output_size":(1,1)},device=device,),in_keys=["pixels"],out_keys=["embed"],)
we execute the first module on a batch of data to gather the size of theoutput vector:
n_cells=feature(env.reset())["embed"].shape[-1]
LSTM Module#
TorchRL provides a specializedLSTMModule classto incorporate LSTMs in your code-base. It is aTensorDictModuleBasesubclass: as such, it has a set ofin_keys andout_keys that indicatewhat values should be expected to be read and written/updated during theexecution of the module. The class comes with customizable predefinedvalues for these attributes to facilitate its construction.
Note
Usage limitations: The class supports almost all LSTM features such asdropout or multi-layered LSTMs.However, to respect TorchRL’s conventions, this LSTM must have thebatch_firstattribute set toTrue which isnot the default in PyTorch. However,ourLSTMModule changes this defaultbehavior, so we’re good with a native call.
Also, the LSTM cannot have abidirectional attribute set toTrue asthis wouldn’t be usable in online settings. In this case, the default valueis the correct one.
lstm=LSTMModule(input_size=n_cells,hidden_size=128,device=device,in_key="embed",out_key="embed",)
Let us look at the LSTM Module class, specifically its in and out_keys:
print("in_keys",lstm.in_keys)print("out_keys",lstm.out_keys)
We can see that these values contain the key we indicated as the in_key (and out_key)as well as recurrent key names. The out_keys are preceded by a “next” prefixthat indicates that they will need to be written in the “next” TensorDict.We use this convention (which can be overridden by passing the in_keys/out_keysarguments) to make sure that a call tostep_mdp() willmove the recurrent state to the root TensorDict, making it available to theRNN during the following call (see figure in the intro).
As mentioned earlier, we have one more optional transform to add to ourenvironment to make sure that the recurrent states are passed to the buffer.Themake_tensordict_primer() method doesexactly that:
env.append_transform(lstm.make_tensordict_primer())
and that’s it! We can print the environment to check that everything looks good nowthat we have added the primer:
print(env)
MLP#
We use a single-layer MLP to represent the action values we’ll be using forour policy.
mlp=MLP(out_features=2,num_cells=[64,],device=device,)
and fill the bias with zeros:
mlp[-1].bias.data.fill_(0.0)mlp=Mod(mlp,in_keys=["embed"],out_keys=["action_value"])
Using the Q-Values to select an action#
The last part of our policy is the Q-Value Module.The Q-Value moduleQValueModulewill read the"action_values" key that is produced by our MLP andfrom it, gather the action that has the maximum value.The only thing we need to do is to specify the action space, which can be doneeither by passing a string or an action-spec. This allows us to useCategorical (sometimes called “sparse”) encoding or the one-hot version of it.
qval=QValueModule(spec=env.action_spec)
Note
TorchRL also provides a wrapper classtorchrl.modules.QValueActor thatwraps a module in a Sequential together with aQValueModulelike we are doing explicitly here. There is little advantage to do thisand the process is less transparent, but the end results will be similar towhat we do here.
We can now put things together in aTensorDictSequential
stoch_policy=Seq(feature,lstm,mlp,qval)
DQN being a deterministic algorithm, exploration is a crucial part of it.We’ll be using an\(\epsilon\)-greedy policy with an epsilon of 0.2 decayingprogressively to 0.This decay is achieved via a call tostep()(see training loop below).
exploration_module=EGreedyModule(annealing_num_steps=1_000_000,spec=env.action_spec,eps_init=0.2)stoch_policy=Seq(stoch_policy,exploration_module,)
Using the model for the loss#
The model as we’ve built it is well equipped to be used in sequential settings.However, the classtorch.nn.LSTM can use a cuDNN-optimized backendto run the RNN sequence faster on GPU device. We would not want to misssuch an opportunity to speed up our training loop!To use it, we just need to tell the LSTM module to run on “recurrent-mode”when used by the loss.As we’ll usually want to have two copies of the LSTM module, we do this bycalling aset_recurrent_mode() method thatwill return a new instance of the LSTM (with shared weights) that willassume that the input data is sequential in nature.
policy=Seq(feature,lstm.set_recurrent_mode(True),mlp,qval)
Because we still have a couple of uninitialized parameters we shouldinitialize them before creating an optimizer and such.
policy(env.reset())
DQN Loss#
Out DQN loss requires us to pass the policy and, again, the action-space.While this may seem redundant, it is important as we want to make sure thattheDQNLoss and theQValueModuleclasses are compatible, but aren’t strongly dependent on each other.
To use the Double-DQN, we ask for adelay_value argument that willcreate a non-differentiable copy of the network parameters to be usedas a target network.
loss_fn=DQNLoss(policy,action_space=env.action_spec,delay_value=True)
Since we are using a double DQN, we need to update the target parameters.We’ll use aSoftUpdate instance to carry outthis work.
updater=SoftUpdate(loss_fn,eps=0.95)optim=torch.optim.Adam(policy.parameters(),lr=3e-4)
Collector and replay buffer#
We build the simplest data collector there is. We’ll try to train our algorithmwith a million frames, extending the buffer with 50 frames at a time. The bufferwill be designed to store 20 thousands trajectories of 50 steps each.At each optimization step (16 per data collection), we’ll collect 4 itemsfrom our buffer, for a total of 200 transitions.We’ll use aLazyMemmapStorage storage to keep the dataon disk.
Note
For the sake of efficiency, we’re only running a few thousands iterationshere. In a real setting, the total number of frames should be set to 1M.
collector=SyncDataCollector(env,stoch_policy,frames_per_batch=50,total_frames=200,device=device)rb=TensorDictReplayBuffer(storage=LazyMemmapStorage(20_000),batch_size=4,prefetch=10)
Training loop#
To keep track of the progress, we will run the policy in the environment onceevery 50 data collection, and plot the results after training.
utd=16pbar=tqdm.tqdm(total=1_000_000)longest=0traj_lens=[]fori,datainenumerate(collector):ifi==0:print("Let us print the first batch of data.\nPay attention to the key names ""which will reflect what can be found in this data structure, in particular: ""the output of the QValueModule (action_values, action and chosen_action_value),""the 'is_init' key that will tell us if a step is initial or not, and the ""recurrent_state keys.\n",data,)pbar.update(data.numel())# it is important to pass data that is not flattenedrb.extend(data.unsqueeze(0).to_tensordict().cpu())for_inrange(utd):s=rb.sample().to(device,non_blocking=True)loss_vals=loss_fn(s)loss_vals["loss"].backward()optim.step()optim.zero_grad()longest=max(longest,data["step_count"].max().item())pbar.set_description(f"steps:{longest}, loss_val:{loss_vals['loss'].item(): 4.4f}, action_spread:{data['action'].sum(0)}")exploration_module.step(data.numel())updater.step()withset_exploration_type(ExplorationType.DETERMINISTIC),torch.no_grad():rollout=env.rollout(10000,stoch_policy)traj_lens.append(rollout.get(("next","step_count")).max().item())
Let’s plot our results:
iftraj_lens:frommatplotlibimportpyplotaspltplt.plot(traj_lens)plt.xlabel("Test collection")plt.title("Test trajectory lengths")
Conclusion#
We have seen how an RNN can be incorporated in a policy in TorchRL.You should now be able:
Create an LSTM module that acts as a
TensorDictModuleIndicate to the LSTM module that a reset is needed via an
InitTrackertransformIncorporate this module in a policy and in a loss module
Make sure that the collector is made aware of the recurrent state entriessuch that they can be stored in the replay buffer along with the rest ofthe data
Further Reading#
The TorchRL documentation can be foundhere.