Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Sep 24, 2023. It is now read-only.

A PyTorch Implementation of "Recurrent Models of Visual Attention"

License

NotificationsYou must be signed in to change notification settings

kevinzakka/recurrent-visual-attention

Repository files navigation

This is aPyTorch implementation ofRecurrent Models of Visual Attention byVolodymyr Mnih, Nicolas Heess, Alex Graves and Koray Kavukcuoglu.

Drawing

Drawing

TheRecurrent Attention Model (RAM) is a neural network that processes inputs sequentially, attending to different locations within the image one at a time, and incrementally combining information from these fixations to build up a dynamic internal representation of the image.

Model Description

In this paper, the attention problem is modeled as the sequential decision process of a goal-directed agent interacting with a visual environment. The agent is built around a recurrent neural network: at each time step, it processes the sensor data, integrates information over time, and chooses how to act and how to deploy its sensor at the next time step.

Drawing

  • glimpse sensor: a retina that extracts a foveated glimpsephi around locationl from an imagex. It encodes the region aroundl at a high-resolution but uses a progressively lower resolution for pixels further froml, resulting in a compressed representation of the original imagex.
  • glimpse network: a network that combines the "what" (phi) and the "where" (l) into a glimpse feature vector wg_t.
  • core network: an RNN that maintains an internal state that integrates information extracted from the history of past observations. It encodes the agent's knowledge of the environment through a state vectorh_t that gets updated at every time stept.
  • location network: uses the internal stateh_t of the core network to produce the location coordinatesl_t for the next time step.
  • action network: after a fixed number of time steps, uses the internal stateh_t of the core network to produce the final output classificationy.

Results

I decided to tackle the28x28 MNIST task with the RAM model containing 6 glimpses, of size8x8, with a scale factor of1.

ModelValidation ErrorTest Error
6 8x81.11.21

I haven't done random search on the policy standard deviation to tune it, so I expect the test error can be reduced to sub1% error. I'll be updating the table above with results for the60x60 Translated MNIST,60x60 Cluttered Translated MNIST and the new Fashion MNIST dataset when I get the time.

Finally, here's an animation showing the glimpses extracted by the network on a random batch at epoch 23.

Drawing

With the Adam optimizer, paper accuracy can be reached in ~160 epochs.

Usage

The easiest way to start training your RAM variant is to edit the parameters inconfig.py and run the following command:

python main.py

To resume training, run:

python main.py --resume=True

Finally, to test a checkpoint of your model that has achieved the best validation accuracy, run the following command:

python main.py --is_train=False

References

About

A PyTorch Implementation of "Recurrent Models of Visual Attention"

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors3

  •  
  •  
  •  

Languages


[8]ページ先頭

©2009-2025 Movatter.jp