- Notifications
You must be signed in to change notification settings - Fork65
Recurrent and multi-process PyTorch implementation of deep reinforcement Actor-Critic algorithms A2C and PPO
License
lcswillems/torch-ac
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Thetorch_ac
package contains the PyTorch implementation of two Actor-Critic deep reinforcement learning algorithms:
Note: An example of use of this package is given in therl-starter-files
repository. More details below.
- Recurrent policies
- Reward shaping
- Handle observation spaces that are tensors ordict of tensors
- Handlediscrete action spaces
- Observation preprocessing
- Multiprocessing
- CUDA
pip3 install torch-ac
Note: If you want to modifytorch-ac
algorithms, you will need to rather install a cloned version, i.e.:
git clone https://github.com/lcswillems/torch-ac.gitcd torch-acpip3 install -e .
A brief overview of the components of the package:
torch_ac.A2CAlgo
andtorch_ac.PPOAlgo
classes for A2C and PPO algorithmstorch_ac.ACModel
andtorch_ac.RecurrentACModel
abstract classes for non-recurrent and recurrent actor-critic modelstorch_ac.DictList
class for making dictionnaries of lists list-indexable and hence batch-friendly
Here are detailled the most important components of the package.
torch_ac.A2CAlgo
andtorch_ac.PPOAlgo
have 2 methods:
__init__
that may take, among the other parameters:- an
acmodel
actor-critic model, i.e. an instance of a class inheriting from eithertorch_ac.ACModel
ortorch_ac.RecurrentACModel
. - a
preprocess_obss
function that transforms a list of observations into a list-indexable objectX
(e.g. a PyTorch tensor). The defaultpreprocess_obss
function converts observations into a PyTorch tensor. - a
reshape_reward
function that takes into parameter an observationobs
, the actionaction
taken, the rewardreward
received and the terminal statusdone
and returns a new reward. By default, the reward is not reshaped. - a
recurrence
number to specify over how many timesteps gradient is backpropagated. This number is only taken into account if a recurrent model is used andmust divide thenum_frames_per_agent
parameter and, for PPO, thebatch_size
parameter.
- an
update_parameters
that first collects experiences, then update the parameters and finally returns logs.
torch_ac.ACModel
has 2 abstract methods:
__init__
that takes into parameter anobservation_space
and anaction_space
.forward
that takes into parameter N preprocessed observationsobs
and returns a PyTorch distributiondist
and a tensor of valuesvalue
. The tensor of valuesmust be of size N, not N x 1.
torch_ac.RecurrentACModel
has 3 abstract methods:
__init__
that takes into parameter the same parameters thantorch_ac.ACModel
.forward
that takes into parameter the same parameters thantorch_ac.ACModel
along with a tensor of N memoriesmemory
of size N x M where M is the size of a memory. It returns the same thing thantorch_ac.ACModel
plus a tensor of N memoriesmemory
.memory_size
that returns the size M of a memory.
Note: Thepreprocess_obss
function must return a list-indexable object (e.g. a PyTorch tensor). If your observations are dictionnaries, yourpreprocess_obss
function may first convert a list of dictionnaries into a dictionnary of lists and then make it list-indexable using thetorch_ac.DictList
class as follow:
>>>d=DictList({"a": [[1,2], [3,4]],"b": [[5], [6]]})>>>d.a[[1,2], [3,4]]>>>d[0]DictList({"a": [1,2],"b": [5]})
Note: if you use a RNN, you will need to setbatch_first
toTrue
.
Examples of use of the package components are given in therl-starter-scripts
repository.
...algo=torch_ac.PPOAlgo(envs,acmodel,args.frames_per_proc,args.discount,args.lr,args.gae_lambda,args.entropy_coef,args.value_loss_coef,args.max_grad_norm,args.recurrence,args.optim_eps,args.clip_eps,args.epochs,args.batch_size,preprocess_obss)...exps,logs1=algo.collect_experiences()logs2=algo.update_parameters(exps)
More detailshere.
torch_ac.DictList({"image":preprocess_images([obs["image"]forobsinobss],device=device),"text":preprocess_texts([obs["mission"]forobsinobss],vocab,device=device)})
More detailshere.
classACModel(nn.Module,torch_ac.RecurrentACModel): ...defforward(self,obs,memory): ...returndist,value,memory
More detailshere.
More detailshere.