- Notifications
You must be signed in to change notification settings - Fork35
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"
License
lucidrains/memory-efficient-attention-pytorch
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Implementation of a memory efficient multi-head attention as proposed in the paper,Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.
This repository also contains anaive non-CUDA implementation of the improvements made byTri Dao with hisFlash Attention 2 paper, for educational purposes. It is a game changer for attention and building long-context transformers.
Update: from now on, you should just be using theF.scaled_dot_product_attention function in Pytorch 2.0 for built-in Flash Attention v1 support - or use Flash Attention v2 at theofficial repository
$ pip install memory-efficient-attention-pytorch
For autoregressive language model
importtorchfrommemory_efficient_attention_pytorchimportAttentionattn=Attention(dim=512,dim_head=64,# dimension per headheads=8,# number of attention headscausal=True,# autoregressive or notmemory_efficient=True,# whether to use memory efficient attention (can be turned off to test against normal attention)q_bucket_size=1024,# bucket size along queries dimensionk_bucket_size=2048# bucket size along key / values dimension).cuda()x=torch.randn(1,65536,512).cuda()out=attn(x)# (1, 65536, 512)
Cross attention
importtorchfrommemory_efficient_attention_pytorchimportAttentioncross_attn=Attention(dim=512,dim_head=64,heads=8,memory_efficient=True,q_bucket_size=1024,k_bucket_size=2048).cuda()x=torch.randn(1,65536,512).cuda()context=torch.randn(1,65536,512).cuda()mask=torch.ones(1,65536).bool().cuda()out=cross_attn(x,context=context,mask=mask)# (1, 65536, 512)
@misc{rabe2021selfattention,title ={Self-attention Does Not Need $O(n^2)$ Memory},author ={Markus N. Rabe and Charles Staats},year ={2021},eprint ={2112.05682},archivePrefix ={arXiv},primaryClass ={cs.LG}}
@misc{liu2021swin,title ={Swin Transformer V2: Scaling Up Capacity and Resolution},author ={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},year ={2021},eprint ={2111.09883},archivePrefix ={arXiv},primaryClass ={cs.CV}}
@article{Dao2022FlashAttentionFA,title ={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},author ={Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},journal ={ArXiv},year ={2022},volume ={abs/2205.14135}}
@article{dao2023flashattention2,title ={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning, author = {Dao, Tri}, year = {2023}}
About
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"
Topics
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Contributors3
Uh oh!
There was an error while loading.Please reload this page.