Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

This is an implementation of n-step SARSA and a gridworld to be solved by it.

NotificationsYou must be signed in to change notification settings

egrund/SARSA-gridworld

Repository files navigation

This is an implementation of n-step SARSA and a gridworld to be solved by it.
It was created for the courseDeep reinforcement learning (at UOS)
Author: Eosandra Grund
Date last modified: 20.07.2022
Sample execution code inMain_SARSA.py, execute in shell.

The Gridworld

Visualization of the gridworld

double lines are the end of the gridworld, behind that (first row and last line) are x and y values
A = Agent
X = barrier
number = reward at the field

The class Gridworld is implemented in the FileGrid.py. The constructor gets a dictionary with the layout.
Structure of the dictionary:

  • x_dim (int>0) : x dimension of gridworld
  • y_dim (int>0) : y dimension of gridworld
  • epsilon (0<float<1) : for epsilon-greedy state transition function
  • start [x,y] : starting state of agent for each episode
  • terminal [x,y] : terminal state with a positive reward
  • neg_rewards [[x,y,reward],[x,y,reward],...] : list of fields with negative rewards
  • barrier [[x,y],[x,y],...] : list of fields that are barriers

There are some hardcoded gridworld dictionaries in theGridworlds.py file(access via class variableGridworlds.GRIDWORLD[index]), but you can also create your own ones.
A Gridworld has a starting state, a terminal state (with a positive reward of 10), some other negative rewards and barriers.Possible actions areup,down,left andright.

State transition function: In the environment you take the given action with the probabiliy 1-epsilon and a random action with probabilityepsilon.

Reward function:

  • 10 for the terminal state
  • other rewards as user inputs
  • -0.5 for invalid moves (against barriers or outside of the gridworld)
  • -0.1 for every move (if no other reward)

The gridworld will be visualized via stdout, and because old prints have to be removed so the gridworld stays in the same place, it is best toexecute everything in a shell.

The Agent

The agent is in theSARSAn.py file. It is an implementation of the reinforcement-learning algorithmn-step SARSA and can also do 1-step SARSA and Monte Carlo.

It uses an epsilon-greedy policy with the possibility of decreasing the exploration over time (setdecreasing_epsilon = True).

visualization of the policy

If you setvisualize_policy = True, the q-values will be visualized after each episode as a matplotlib heatmap showing all state-action values.

Start the learning process with the start method. As parameters it gets the amount ofepisodes you want to do and if you want anevaluation.
list of returns

list of the average return, total return and steps per episode
plot of returns

plot of the total return and steps per episode (The plot does only work ifvisualize_policy = False)

How to execute

First you have to clone the repository.You can use or modify Main_SARSA.py and execute it in the terminal.
Imports:

importmatplotlib.pyplotaspltimportnumpyasnpimportSARSAnimportGridimportGridworlds

Creation of the Gridworld:

which_gridworld=0world=Grid.Gridworld(Gridworlds.Gridworlds.GRIDWORLD[which_gridworld])world.visualize()

Decide which default world by changingwhich_gridworld to any value between 0 and 4. The gridworld on the picture above is gridworld 0.

Creation of player and learning start:

player=SARSAn.SARSAn(gridworld=world,n=10,epsilon=0.5,decreasing_epsilon=True,gamma=0.99,alpha=0.3,visualize_policy=False,visualize_grid=True)player.start(episodes=50,evaluation=True)

That means it is an 10-step SARSA solving Gridworld 0. You can change all of the parameters and see what happens. But changing them can cause the algorithm to be inefficient or not learning.
Create an MonteCarlo approach by executing this instead of the last cell.

player=SARSAn.SARSAn(gridworld=world,n=np.inf,epsilon=0.3,alpha=1)player.start(episodes=50,evaluation=True)

You can export your pyplot plots by executing the following lines after the learning is done (Only the plots you will see during learning will be in the picture).

plt.savefig("Figure_SARSA_policy_returns.png")

About

This is an implementation of n-step SARSA and a gridworld to be solved by it.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages


[8]ページ先頭

©2009-2025 Movatter.jp