- Notifications
You must be signed in to change notification settings - Fork123
A PyTorch Implementation of "Recurrent Models of Visual Attention"
License
kevinzakka/recurrent-visual-attention
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This is aPyTorch implementation ofRecurrent Models of Visual Attention byVolodymyr Mnih, Nicolas Heess, Alex Graves and Koray Kavukcuoglu.
TheRecurrent Attention Model (RAM) is a neural network that processes inputs sequentially, attending to different locations within the image one at a time, and incrementally combining information from these fixations to build up a dynamic internal representation of the image.
In this paper, the attention problem is modeled as the sequential decision process of a goal-directed agent interacting with a visual environment. The agent is built around a recurrent neural network: at each time step, it processes the sensor data, integrates information over time, and chooses how to act and how to deploy its sensor at the next time step.
- glimpse sensor: a retina that extracts a foveated glimpse
phi
around locationl
from an imagex
. It encodes the region aroundl
at a high-resolution but uses a progressively lower resolution for pixels further froml
, resulting in a compressed representation of the original imagex
. - glimpse network: a network that combines the "what" (
phi
) and the "where" (l
) into a glimpse feature vector wg_t
. - core network: an RNN that maintains an internal state that integrates information extracted from the history of past observations. It encodes the agent's knowledge of the environment through a state vector
h_t
that gets updated at every time stept
. - location network: uses the internal state
h_t
of the core network to produce the location coordinatesl_t
for the next time step. - action network: after a fixed number of time steps, uses the internal state
h_t
of the core network to produce the final output classificationy
.
I decided to tackle the28x28
MNIST task with the RAM model containing 6 glimpses, of size8x8
, with a scale factor of1
.
Model | Validation Error | Test Error |
---|---|---|
6 8x8 | 1.1 | 1.21 |
I haven't done random search on the policy standard deviation to tune it, so I expect the test error can be reduced to sub1%
error. I'll be updating the table above with results for the60x60
Translated MNIST,60x60
Cluttered Translated MNIST and the new Fashion MNIST dataset when I get the time.
Finally, here's an animation showing the glimpses extracted by the network on a random batch at epoch 23.
With the Adam optimizer, paper accuracy can be reached in ~160 epochs.
The easiest way to start training your RAM variant is to edit the parameters inconfig.py
and run the following command:
python main.py
To resume training, run:
python main.py --resume=True
Finally, to test a checkpoint of your model that has achieved the best validation accuracy, run the following command:
python main.py --is_train=False
About
A PyTorch Implementation of "Recurrent Models of Visual Attention"
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Contributors3
Uh oh!
There was an error while loading.Please reload this page.