- Notifications
You must be signed in to change notification settings - Fork18
[KDD'22] Learned Token Pruning for Transformers
License
kssteven418/LTP
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Check ourpaper for more details.
We follow the same installation procedure as the originalHuggingface transformer repo.
pip install sklearn scipy datasets torchpip install -e . # in the top directory
LTP is implemented on top of Huggingface transformer'sI-BERT implementation.Therefore, we first need to generate a checkpoint file of ibert finetuned on the target downstream task.While you can do this on the original Huggingface repository,we also support ourbase branchltp/base
where you can run the following code to finetuneibert
on the GLUE tasks.
git checkout ltp/basecd examples/text-classificationpython run_glue.py --model_name_or_path kssteven/ibert-roberta-base --output_dir {CKPT} --task {TASK} --do_train --do_eval {--some_more_arguments}
{TASK}
: RTE, MRPC, STSB, SST2, QNLI, QQP, MNLI- Please refer to theHuggingface tutorial and theofficial documentation for more details in arguments and hyperparameters.
- Note that as default ibert behaves the same as roberta (see thistutorial),hence the resulting model will be the same as
roberta-base
finetuned on the target GLUE task.
The final model will be checkpointed in{CKPT}
.
- Remove
{CKPT}/trainer_state.json
. - In the configuration file
{CKPT}/config.json
, change (1)"architectures"
to["LTPForSequenceClassification"]
and (2)"model_type"
to"ltp"
.
Add the following lines in the configuration file{CKPT}/config.json
.
"prune_mode": "absolute_threshold","final_token_threshold": 0.01,
final_token_threshold
determines the token threshold of the last layer, and the thresholds of the remaining layers will be linearly scaled.For instance, the thresholds for the 3rd, 6th, and 9th layers will be 0.0025, 0.005, and 0.0075, respectively, when setting thefinal_token_threshold
, i.e., the threshold for the last (12th) layer, to 0.01.This number is a hyperparameter, and we found that 0.01 works well in many cases.
The learnable mode consists of 2 stages: soft threshold and hard threshold.Please refer to ourpaper for more details.
We first train the model using the soft threshold mode.This trains the thresholds as well as the model parameters to search for the best threshold configuration.
Run the following command:
python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr 2e-5 --temperature {T}\ --lambda 0.1 --weight_decay 0 --bs 64 --masking_mode soft --epoch {epoch} --save_step 100 --no_load
{TASK}
: RTE, MRPC, STSB, SST2, QNLI, QQP, MNLI- You can assign different learning rate for
lr
, but 2e-5 worked fine. - We set
{epoch}
to be 10 for smaller datasets (e.g., RTE, MRPC) and 1 for larger datasets (e.g., SST2, QNLI, MRPC). --no_load
flag will not load the best model at the end of the training (i.e., the final checkpoint will be the one at the end of training).lambda
is an important hyperparameter than controls the pruning level: the higher the value, the more we prune tokens. 0.01 ~ 0.2 worked well in many cases, but we recommend the user to empirically search for the best number for it.temperature
is another hyperparameter, and 1e-3 ~ 1e-5 worked well. In the paper, we searched over {1e−4, 2e−4, 5e−4, 1e−3, 2e−3}.
The final model will be checkpointed in{CKPT_soft} = checkpoints/base/{TASK}/absolute_threshold/rate_{final_token_threshold}/temperature_{T}/lambda_{lambda}/lr_{lr}
.Removetrainer_state.json
from the checkpoint file in{CKPT_soft}
.
Once we learn the thresholds, we fix those values, turn back to the hard threshold mode, and finetune the model parameters only.
Run the following command:
python run.py --arch ltp-base --task {TASK} --restore {CKPT_soft} --lr {LR} --bs 64 --masking_mode hard --epoch 5
- We used
{LR}
{0.5, 1, 2}e-5 in the paper. - You can additionally set
--save_step 500
for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.
The final model will be checkpointed in{CKPT_soft}/hard/lr_{LR}
.
We additionally provide code to reproduce the baseline methods used in our paper (i.e., top-k and manual threshold).
Add the following lines in{CKPT}/config.json
.
"prune_mode": "topk","token_keep_rate": 0.2,
The token keep rates of the first three layers and the last layer are 1 andtoken_keep_rate
, respectively.The keep rates of the remaining layers are scaled linearly.The smallertoken_keep_rate
is, the more aggressive we prune tokens.You can also assign negative number fortoken_keep_rate
and, in that case, the keep rate of each layer will be assigned asmax(0, keep_rate)
.
Run the following command:
python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr {LR} --bs 64 --masking_mode hard --epoch 5
- We used
{LR}
{0.5, 1, 2}e-5 in the paper. - You can additionally set
--save_step 500
for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch.
The final model will be checkpointed in{CKPT}/topk/lr_{LR}
.
Add the following lines in{CKPT}/config.json
.
"prune_mode": "absolute_threshold","final_token_threshold": 0.01,
Run the following command:
python run.py --arch ltp-base --task {TASK} --restore {CKPT} --lr {LR} --bs 64 --masking_mode hard --epoch 5 --save_step 500
- We used
{LR}
{0.5, 1, 2}e-5 in the paper. - You can additionally set
--save_step 500
for more frequent evaluation/logging. The default setting will evaluate for every 1 epoch. - Note that the only difference from the learned token pruning mode is that we run the hard threshold mode from the beginning.
The final model will be checkpointed in{CKPT}/hard/lr_{LR}
.
THIS SOFTWARE AND/OR DATA WAS DEPOSITED IN THE BAIR OPEN RESEARCH COMMONS REPOSITORY ON 02/07/23.