- Notifications
You must be signed in to change notification settings - Fork25
License
uber-research/deconstructing-lottery-tickets
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Hattie Zhou, Janice Lan, Rosanne Liu, Jason Yosinski
This codebase implements the experiments inDeconstructing Lottery Tickets: Zeros, Signs, and the Supermask. This paper performs various ablation studies to shine light into the Lottery Tickets (LT) phenomenon observed by Frankle & Carbin inThe Lottery Ticket Hypothesis: Finding Small, Trainable Neural Networks.
@inproceedings{zhou_2019_dlt title={Deconstructing Lottery Tickets: Zeros, Signs, and the Supermask}, author={Zhou, Hattie and Lan, Janice and Liu, Rosanne and Yosinski, Jason}, booktitle={Advances in Neural Information Processing Systems}, year={2019}}
For more on this project, see theUber Eng Blog post.
data/download_mnist.py
,data/download_cifar10.py
downloads MNIST/CIFAR10 data and splits it into train, val, and test, and saves them in thedata
folder ash5
filesget_weight_init.py
computes various mask criteriamasked_layers.py
defines new layer classes with masking optionsmasked_networks.py
defines new layers and networks used in training Supermasksnetwork_builders.py
defines the four network architecture evaluated in the paper (FC, Conv2, Conv4, Conv6)train.py
trains original unmasked networkstrain_lottery.py
reads in initial and final weights from a previously trained model, calculates the mask, and train a lottery style networktrain_supermask
trains a supermask directly using Bernoulli samplingget_init_loss_train_lottery.py
derives masks and calculates the initial accuracy of the masked network for various pruning percentages and mask criteria. Note that this uses a one-shot approach rather than an iterative approach.
This codebase uses theGitResultsManager
package to keep track of experiments. See:https://github.com/yosinski/GitResultsManager
The following commands provide examples for running experiments in Deconstructing Lottery Tickets.
- Train a FC network (300-100-10) on MNIST:
./print_train_command.sh iter fc test 0 t
- Perform iterative LT training for a FC network on MNIST using large final mask criterion:
./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask none t
Randomly reinitialize weights prior to each round of iterative retraining:
./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask random_reinit t
Randomly reshuffle the initial values of remaining weights prior to each round of iterative retraining:
./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask random_reshuffle t
Convert the initial values of weights to a signed constant before randomly reshuffle the initial values of remaining weights prior to each round of iterative retraining:
./print_train_lottery_iterative_command.sh fc test 0 large_final -1 mask rand_signed_constant t
For versions that maintain the same sign, see
signed_reinit
,signed_reshuffle
, andsigned_constant
.
Freeze pruned weights at initial values:
./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init none t
Freeze pruned weights that increased in magnitude at initial values:
./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init_zero_mask none t
Initialize weights that decreased in magnitude at 0, and freeze pruned weights at initial value:
./print_train_lottery_iterative_command.sh fc test 0 large_final -1 freeze_init_zero_all none t
Evaluate the initial test accuracy of all alternative mask criteria:
python get_init_loss_train_lottery.py --output_dir ./results/iter_lot_fc_orig/test_seed_0/ --train_h5 ./data/mnist_train.h5 --test_h5 ./data/mnist_test.h5 --arch fc_lot --seed 0 --opt adam --lr 0.0012 --exp none --layer_cutoff 4,6 --prune_base 0.8,0.9 --prune_power 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24
Train a Supermask directly:
python train_supermask.py --output_dir ./results/iter_lot_fc_orig/learned_supermasks/run1/ --train_h5 ./data/mnist_train.h5 --test_h5 ./data/mnist_test.h5 --arch fc_mask --opt sgd --lr 100 --num_epochs 2000 --print_every 220 --eval_every 220 --log_every 220 --save_weights --save_every 22000
About
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Releases
Packages0
Uh oh!
There was an error while loading.Please reload this page.