Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

Recurrent and multi-process PyTorch implementation of deep reinforcement Actor-Critic algorithms A2C and PPO

License

NotificationsYou must be signed in to change notification settings

lcswillems/torch-ac

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Thetorch_ac package contains the PyTorch implementation of two Actor-Critic deep reinforcement learning algorithms:

Note: An example of use of this package is given in therl-starter-files repository. More details below.

Features

  • Recurrent policies
  • Reward shaping
  • Handle observation spaces that are tensors ordict of tensors
  • Handlediscrete action spaces
  • Observation preprocessing
  • Multiprocessing
  • CUDA

Installation

pip3 install torch-ac

Note: If you want to modifytorch-ac algorithms, you will need to rather install a cloned version, i.e.:

git clone https://github.com/lcswillems/torch-ac.gitcd torch-acpip3 install -e .

Package components overview

A brief overview of the components of the package:

  • torch_ac.A2CAlgo andtorch_ac.PPOAlgo classes for A2C and PPO algorithms
  • torch_ac.ACModel andtorch_ac.RecurrentACModel abstract classes for non-recurrent and recurrent actor-critic models
  • torch_ac.DictList class for making dictionnaries of lists list-indexable and hence batch-friendly

Package components details

Here are detailled the most important components of the package.

torch_ac.A2CAlgo andtorch_ac.PPOAlgo have 2 methods:

  • __init__ that may take, among the other parameters:
    • anacmodel actor-critic model, i.e. an instance of a class inheriting from eithertorch_ac.ACModel ortorch_ac.RecurrentACModel.
    • apreprocess_obss function that transforms a list of observations into a list-indexable objectX (e.g. a PyTorch tensor). The defaultpreprocess_obss function converts observations into a PyTorch tensor.
    • areshape_reward function that takes into parameter an observationobs, the actionaction taken, the rewardreward received and the terminal statusdone and returns a new reward. By default, the reward is not reshaped.
    • arecurrence number to specify over how many timesteps gradient is backpropagated. This number is only taken into account if a recurrent model is used andmust divide thenum_frames_per_agent parameter and, for PPO, thebatch_size parameter.
  • update_parameters that first collects experiences, then update the parameters and finally returns logs.

torch_ac.ACModel has 2 abstract methods:

  • __init__ that takes into parameter anobservation_space and anaction_space.
  • forward that takes into parameter N preprocessed observationsobs and returns a PyTorch distributiondist and a tensor of valuesvalue. The tensor of valuesmust be of size N, not N x 1.

torch_ac.RecurrentACModel has 3 abstract methods:

  • __init__ that takes into parameter the same parameters thantorch_ac.ACModel.
  • forward that takes into parameter the same parameters thantorch_ac.ACModel along with a tensor of N memoriesmemory of size N x M where M is the size of a memory. It returns the same thing thantorch_ac.ACModel plus a tensor of N memoriesmemory.
  • memory_size that returns the size M of a memory.

Note: Thepreprocess_obss function must return a list-indexable object (e.g. a PyTorch tensor). If your observations are dictionnaries, yourpreprocess_obss function may first convert a list of dictionnaries into a dictionnary of lists and then make it list-indexable using thetorch_ac.DictList class as follow:

>>>d=DictList({"a": [[1,2], [3,4]],"b": [[5], [6]]})>>>d.a[[1,2], [3,4]]>>>d[0]DictList({"a": [1,2],"b": [5]})

Note: if you use a RNN, you will need to setbatch_first toTrue.

Examples

Examples of use of the package components are given in therl-starter-scripts repository.

Example of use oftorch_ac.A2CAlgo andtorch_ac.PPOAlgo

...algo=torch_ac.PPOAlgo(envs,acmodel,args.frames_per_proc,args.discount,args.lr,args.gae_lambda,args.entropy_coef,args.value_loss_coef,args.max_grad_norm,args.recurrence,args.optim_eps,args.clip_eps,args.epochs,args.batch_size,preprocess_obss)...exps,logs1=algo.collect_experiences()logs2=algo.update_parameters(exps)

More detailshere.

Example of use oftorch_ac.DictList

torch_ac.DictList({"image":preprocess_images([obs["image"]forobsinobss],device=device),"text":preprocess_texts([obs["mission"]forobsinobss],vocab,device=device)})

More detailshere.

Example of implementation oftorch_ac.RecurrentACModel

classACModel(nn.Module,torch_ac.RecurrentACModel):    ...defforward(self,obs,memory):        ...returndist,value,memory

More detailshere.

Examples ofpreprocess_obss functions

More detailshere.


[8]ページ先頭

©2009-2025 Movatter.jp