Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

License

NotificationsYou must be signed in to change notification settings

lucidrains/memory-efficient-attention-pytorch

Repository files navigation

Implementation of a memory efficient multi-head attention as proposed in the paper,Self-attention Does Not Need O(n²) Memory. In addition, the module will take care of masking, causal masking, as well as cross attention.

This repository also contains anaive non-CUDA implementation of the improvements made byTri Dao with hisFlash Attention 2 paper, for educational purposes. It is a game changer for attention and building long-context transformers.

Update: from now on, you should just be using theF.scaled_dot_product_attention function in Pytorch 2.0 for built-in Flash Attention v1 support - or use Flash Attention v2 at theofficial repository

Install

$ pip install memory-efficient-attention-pytorch

Usage

For autoregressive language model

importtorchfrommemory_efficient_attention_pytorchimportAttentionattn=Attention(dim=512,dim_head=64,# dimension per headheads=8,# number of attention headscausal=True,# autoregressive or notmemory_efficient=True,# whether to use memory efficient attention (can be turned off to test against normal attention)q_bucket_size=1024,# bucket size along queries dimensionk_bucket_size=2048# bucket size along key / values dimension).cuda()x=torch.randn(1,65536,512).cuda()out=attn(x)# (1, 65536, 512)

Cross attention

importtorchfrommemory_efficient_attention_pytorchimportAttentioncross_attn=Attention(dim=512,dim_head=64,heads=8,memory_efficient=True,q_bucket_size=1024,k_bucket_size=2048).cuda()x=torch.randn(1,65536,512).cuda()context=torch.randn(1,65536,512).cuda()mask=torch.ones(1,65536).bool().cuda()out=cross_attn(x,context=context,mask=mask)# (1, 65536, 512)

Citations

@misc{rabe2021selfattention,title   ={Self-attention Does Not Need $O(n^2)$ Memory},author  ={Markus N. Rabe and Charles Staats},year    ={2021},eprint  ={2112.05682},archivePrefix ={arXiv},primaryClass ={cs.LG}}
@misc{liu2021swin,title   ={Swin Transformer V2: Scaling Up Capacity and Resolution},author  ={Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},year    ={2021},eprint  ={2111.09883},archivePrefix ={arXiv},primaryClass ={cs.CV}}
@article{Dao2022FlashAttentionFA,title   ={FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness},author  ={Tri Dao and Daniel Y. Fu and Stefano Ermon and Atri Rudra and Christopher R'e},journal ={ArXiv},year    ={2022},volume  ={abs/2205.14135}}
@article{dao2023flashattention2,title     ={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning,  author    = {Dao, Tri},  year      = {2023}}

About

Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors3

  •  
  •  
  •  

Languages


[8]ページ先頭

©2009-2025 Movatter.jp