Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

🚀 Efficient implementations of state-of-the-art linear attention models in Torch and Triton

License

NotificationsYou must be signed in to change notification settings

fla-org/flash-linear-attention

Repository files navigation

This repo aims at providing a collection of efficient Triton-based implementations for state-of-the-art linear attention models.Any pull requests are welcome!

image

News

  • $\texttt{[2025-06]}$: 🐍 Add Comba implementation tofla (paper).
  • $\texttt{[2025-05]}$: 🎉 Add Rodimus* implementation tofla (paper).
  • $\texttt{[2025-04]}$: 🎉 Add DeltaProduct implementation tofla (paper).
  • $\texttt{[2025-04]}$: 🎉 Add FoX implementation tofla (paper).
  • $\texttt{[2025-03]}$:We have changed the defaultinitializer_range to the magic 🐳 0.006 Theinitializer_range was rolled back to the default value of 0.02. For actual training, we recommend trying both.
  • $\texttt{[2025-02]}$: 🐳 Add NSA implementations tofla. See kernelshere.
  • $\texttt{[2025-01]}$: 🔥 We are migrating totorchtitan-based training framework. Check out theflame repo for more details.
  • $\texttt{[2025-01]}$: 🎉 Add RWKV7 implementations (both kernels and models) tofla.
  • $\texttt{[2024-12]}$: Integratedflash-bidirectional-attention tofla-org (repo)
  • $\texttt{[2024-12]}$: 🎉 Add Gated DeltaNet implementation tofla (paper).
  • $\texttt{[2024-12]}$: 🚀fla now officially supports kernels with variable-length inputs.
  • $\texttt{[2024-11]}$: The inputs are now switched from head-first to seq-first format.
  • $\texttt{[2024-11]}$: 💥fla now provides a flexible way for training hybrid models.
  • $\texttt{[2024-10]}$: 🔥 Announcingflame, a minimal and scalable framework for trainingfla models. Check out the detailshere.
  • $\texttt{[2024-09]}$:fla now includes a fused linear and cross-entropy layer, significantly reducing memory usage during training.
  • $\texttt{[2024-09]}$: 🎉 Add GSA implementation tofla (paper).
  • $\texttt{[2024-05]}$: 🎉 Add DeltaNet implementation tofla (paper).
  • $\texttt{[2024-05]}$: 💥fla v0.1: a variety of subquadratic kernels/layers/models integrated (RetNet/GLA/Mamba/HGRN/HGRN2/RWKV6, etc., seeModels).
  • $\texttt{[2023-12]}$: 💥 Launchedfla, offering a collection of implementations for state-of-the-art linear attention models.

Models

Roughly sorted according to the timeline supported infla. The recommended training mode ischunk when available.

YearVenueModelPaperCode
2023RetNetRetentive network: a successor to transformer for large language modelsofficialfla
2024ICMLGLAGated Linear Attention Transformers with Hardware-Efficient Trainingofficialfla
2024ICMLBasedSimple linear attention language models balance the recall-throughput tradeoffofficialfla
2024ACLRebasedLinear Transformers with Learnable Kernel Functions are Better In-Context Modelsofficialfla
2024NeurIPSDeltaNetParallelizing Linear Transformers with Delta Rule over Sequence Lengthofficialfla
2022ACLABCABC: Attention with Bounded-memory Controlfla
2023NeurIPSHGRNHierarchically Gated Recurrent Neural Network for Sequence Modelingofficialfla
2024COLMHGRN2HGRN2: Gated Linear RNNs with State Expansionofficialfla
2024COLMRWKV6Eagle and Finch: RWKV with Matrix-Valued States and Dynamic Recurrenceofficialfla
2024LightNetYou Only Scan Once: Efficient Multi-dimension Sequential Modeling with LightNetofficialfla
2025ICLRSambaSamba: Simple Hybrid State Space Models for Efficient Unlimited Context Language Modelingofficialfla
2024ICMLMamba2Transformers are SSMs: Generalized Models and Efficient Algorithms Through Structured State Space Dualityofficialfla
2024NeurIPSGSAGated Slot Attention for Efficient Linear-Time Sequence Modelingofficialfla
2025ICLRGated DeltaNetGated Delta Networks: Improving Mamba2 with Delta Ruleofficialfla
2025RWKV7RWKV-7 "Goose" with Expressive Dynamic State Evolutionofficialfla
2025NSANative Sparse Attention: Hardware-Aligned and Natively Trainable Sparse Attentionfla
2025FoXForgetting Transformer: Softmax Attention with a Forget Gateofficialfla
2025DeltaProductDeltaProduct: Improving State-Tracking in Linear RNNs via Householder Productsfla
2025ICLRRodimus*Rodimus*: Breaking the Accuracy-Efficiency Trade-Off with Efficient Attentionsofficialfla
2025MesaNetMesaNet: Sequence Modeling by Locally Optimal Test-Time Trainingfla
2025CombaComba: Improving Bilinear RNNs with Closed-loop Controlofficialfla

Installation

nvidia-4090-cinvidia-a100-cinvidia-h100-ciintel-a770-ci

The following requirements should be satisfied

You can installfla with pip:

pip install flash-linear-attention

Asfla is actively developed now, for the latest features and updates, an alternative way is to install the package from source

# uninstall `fla` first to ensure a successful upgradepip uninstall flash-linear-attention&& pip install -U git+https://github.com/fla-org/flash-linear-attention

or managefla with submodules

git submodule add https://github.com/fla-org/flash-linear-attention.git 3rdparty/flash-linear-attentionln -s 3rdparty/flash-linear-attention/fla fla

If you have installedtriton-nightly andtorch pre version, please use the following command:

pip install einops ninja datasets transformers numpypip uninstall flash-linear-attention&& pip install -U --no-use-pep517 git+https://github.com/fla-org/flash-linear-attention --no-deps

ARM (aarch64) Support for Triton

You need to choose a specific version to install, seeFAQs

Usage

Token Mixing

We provide ``token mixing'' linear attention layers infla.layers for you to use.You can replace the standard multihead attention layer in your model with other linear attention layers.Example usage is as follows:

>>>importtorch>>>fromfla.layersimportMultiScaleRetention>>>batch_size,num_heads,seq_len,hidden_size=32,4,2048,1024>>>device,dtype='cuda:0',torch.bfloat16>>>retnet=MultiScaleRetention(hidden_size=hidden_size,num_heads=num_heads).to(device=device,dtype=dtype)>>>retnetMultiScaleRetention(  (q_proj):Linear(in_features=1024,out_features=1024,bias=False)  (k_proj):Linear(in_features=1024,out_features=1024,bias=False)  (v_proj):Linear(in_features=1024,out_features=2048,bias=False)  (g_proj):Linear(in_features=1024,out_features=2048,bias=False)  (o_proj):Linear(in_features=2048,out_features=1024,bias=False)  (g_norm_swish_gate):FusedRMSNormGated(512,eps=1e-05,activation=swish)  (rotary):RotaryEmbedding(dim=256,base=10000.0,interleaved=False,pos_idx_in_fp32=True))>>>x=torch.randn(batch_size,seq_len,hidden_size).to(device=device,dtype=dtype)>>>y,*_=retnet(x)>>>y.shapetorch.Size([32,2048,1024])

We provide the implementations of models that are compatible with 🤗 Transformers library.Here's an example of how to initialize a GLA model from the default configs infla:

>>>fromfla.modelsimportGLAConfig>>>fromtransformersimportAutoModelForCausalLM>>>config=GLAConfig()>>>configGLAConfig {"attn":null,"attn_mode":"chunk","bos_token_id":1,"clamp_min":null,"conv_size":4,"elementwise_affine":true,"eos_token_id":2,"expand_k":0.5,"expand_v":1,"feature_map":null,"fuse_cross_entropy":true,"fuse_norm":true,"fuse_swiglu":true,"hidden_act":"swish","hidden_ratio":4,"hidden_size":2048,"initializer_range":0.006,"intermediate_size":null,"max_position_embeddings":2048,"model_type":"gla","norm_eps":1e-06,"num_heads":4,"num_hidden_layers":24,"num_kv_heads":null,"tie_word_embeddings":false,"transformers_version":"4.50.1","use_cache":true,"use_gk":true,"use_gv":false,"use_output_gate":true,"use_short_conv":false,"vocab_size":32000}>>>AutoModelForCausalLM.from_config(config)GLAForCausalLM(  (model):GLAModel(    (embeddings):Embedding(32000,2048)    (layers):ModuleList(      (0-23):24xGLABlock(        (attn_norm):RMSNorm(2048,eps=1e-06)        (attn):GatedLinearAttention(          (q_proj):Linear(in_features=2048,out_features=1024,bias=False)          (k_proj):Linear(in_features=2048,out_features=1024,bias=False)          (v_proj):Linear(in_features=2048,out_features=2048,bias=False)          (g_proj):Linear(in_features=2048,out_features=2048,bias=False)          (gk_proj):Sequential(            (0):Linear(in_features=2048,out_features=16,bias=False)            (1):Linear(in_features=16,out_features=1024,bias=True)          )          (o_proj):Linear(in_features=2048,out_features=2048,bias=False)          (g_norm_swish_gate):FusedRMSNormGated(512,eps=1e-06,activation=swish)        )        (mlp_norm):RMSNorm(2048,eps=1e-06)        (mlp):GatedMLP(          (gate_proj):Linear(in_features=2048,out_features=5632,bias=False)          (up_proj):Linear(in_features=2048,out_features=5632,bias=False)          (down_proj):Linear(in_features=5632,out_features=2048,bias=False)          (swiglu_linear):SwiGLULinear()        )      )    )    (norm):RMSNorm(2048,eps=1e-06)  )  (lm_head):Linear(in_features=2048,out_features=32000,bias=False))

Fused Modules

We offer a collection of fused modules infla.modules to facilitate faster training:

  • Rotary Embedding: rotary positional embeddings as adopted by the Llama architecture, a.k.a., Transformer++.
  • Norm Layers:
    • RMSNorm,LayerNorm andGroupNorm
    • RMSNormLinear,LayerNormLinear andGroupNormLinear to reduce memory usage of intermediate tensors for improved memory efficiency.
  • Norm Layers with Gating: combine norm layers with element-wise sigmoid or swish gating, as used by RetNet/GLA.
  • Cross Entropy: faster Triton implementation of cross entropy loss.
  • Linear Cross Entropy: fused linear layer and cross entropy loss to avoid the materialization of large logits tensors. Also refer to implementations bymgmalek andLiger-Kernel.
  • Linear KL Divergence: fused linear layer and KL divergence loss in a similar vein as CE loss.

Generation

Upon successfully pretraining a model, it becomes accessible for generating text using the 🤗 text generation APIs.In the following, we give a generation example:

>>>importfla>>>fromtransformersimportAutoModelForCausalLM,AutoTokenizer>>>name='fla-hub/gla-1.3B-100B'>>>tokenizer=AutoTokenizer.from_pretrained(name)>>>model=AutoModelForCausalLM.from_pretrained(name).cuda()>>>input_prompt="Power goes with permanence. Impermanence is impotence. And rotation is castration.">>>input_ids=tokenizer(input_prompt,return_tensors="pt").input_ids.cuda()>>>outputs=model.generate(input_ids,max_length=64)>>>tokenizer.batch_decode(outputs,skip_special_tokens=True)[0]

We also provide a simple scripthere for benchmarking the generation speed.Simply run it by:

$ python -m benchmarks.benchmark_generation \  --path'fla-hub/gla-1.3B-100B' \  --repetition_penalty 2. \  --prompt="Hello everyone, I'm Songlin Yang"Prompt:Hello everyone, I'm Songlin YangGenerated:Hello everyone, I'm Songlin Yang.I am a 20 year old girl from China who is currently studyingin the United States of Americafor my Master degree and also working as an English teacher at school here on campus since last summer (1st semester). My main goal to be abledo well with this course so that we can havePrompt length: 10, generation length: 64Total prompt processing + decoding time: 4593ms

All of the pretrained models currently available can be found infla-hub.

>>>fromhuggingface_hubimportlist_models>>>formodelinlist_models(author='fla-hub'):print(model.id)

Hybrid Models

fla provides a flexible method to incorporate standard attention layers into existing linear attention models.This is easily achieved by specifying theattn argument in the model configuration.

For example, to create a 2-layer Samba model with interleaved Mamba and local attention layers, using a sliding window size of 2048:

>>>fromfla.modelsimportSambaConfig>>>fromtransformersimportAutoModelForCausalLM>>>config=SambaConfig(num_hidden_layers=2)>>>config.attn= {'layers': [1],'num_heads':18,'num_kv_heads':18,'qkv_bias':False,'rope_theta':10000.,'window_size':2048}>>>configSambaConfig {"attn": {"layers": [1    ],"num_heads":18,"num_kv_heads":18,"qkv_bias":false,"rope_theta":10000.0,"window_size":2048  },"bos_token_id":1,"conv_kernel":4,"eos_token_id":2,"expand":2,"fuse_cross_entropy":true,"fuse_norm":true,"fuse_swiglu":true,"hidden_act":"swish","hidden_ratio":4,"hidden_size":2304,"initializer_range":0.02,"intermediate_size":4608,"max_position_embeddings":2048,"model_type":"samba","norm_eps":1e-05,"num_hidden_layers":2,"pad_token_id":0,"rescale_prenorm_residual":false,"residual_in_fp32":false,"state_size":16,"tie_word_embeddings":false,"time_step_floor":0.0001,"time_step_init_scheme":"random","time_step_max":0.1,"time_step_min":0.001,"time_step_rank":144,"time_step_scale":1.0,"transformers_version":"4.50.1","use_bias":false,"use_cache":true,"use_conv_bias":true,"vocab_size":32000}>>>AutoModelForCausalLM.from_config(config)SambaForCausalLM(  (backbone):SambaModel(    (embeddings):Embedding(32000,2304)    (layers):ModuleList(      (0):SambaBlock(        (mixer_norm):RMSNorm(2304,eps=1e-05)        (mixer):Mamba(          (conv1d):Conv1d(4608,4608,kernel_size=(4,),stride=(1,),padding=(3,),groups=4608)          (in_proj):Linear(in_features=2304,out_features=9216,bias=False)          (x_proj):Linear(in_features=4608,out_features=176,bias=False)          (dt_proj):Linear(in_features=144,out_features=4608,bias=True)          (out_proj):Linear(in_features=4608,out_features=2304,bias=False)        )        (mlp_norm):RMSNorm(2304,eps=1e-05)        (mlp):GatedMLP(          (gate_proj):Linear(in_features=2304,out_features=6144,bias=False)          (up_proj):Linear(in_features=2304,out_features=6144,bias=False)          (down_proj):Linear(in_features=6144,out_features=2304,bias=False)          (swiglu_linear):SwiGLULinear()        )      )      (1):SambaBlock(        (mixer_norm):RMSNorm(2304,eps=1e-05)        (mixer):Attention(          (q_proj):Linear(in_features=2304,out_features=2304,bias=False)          (k_proj):Linear(in_features=2304,out_features=2304,bias=False)          (v_proj):Linear(in_features=2304,out_features=2304,bias=False)          (o_proj):Linear(in_features=2304,out_features=2304,bias=False)          (rotary):RotaryEmbedding(dim=128,base=10000.0,interleaved=False,pos_idx_in_fp32=True)        )        (mlp_norm):RMSNorm(2304,eps=1e-05)        (mlp):GatedMLP(          (gate_proj):Linear(in_features=2304,out_features=6144,bias=False)          (up_proj):Linear(in_features=2304,out_features=6144,bias=False)          (down_proj):Linear(in_features=6144,out_features=2304,bias=False)          (swiglu_linear):SwiGLULinear()        )      )    )    (norm_f):RMSNorm(2304,eps=1e-05)  )  (lm_head):Linear(in_features=2304,out_features=32000,bias=False))

During inference, youDO NOT need to revise anything for generation!The model will produce output as-is, without any need for additional configurations or modifications.

Training

We provide a minimal framework called🔥flame built on top oftorchtitan, for efficient training offla models.

Checkoutthe GLA example for more details.

Evaluation

Thelm-evaluation-harness library allows you to easily perform (zero-shot) model evaluations.Follow the steps below to use this library:

  1. Installlm_eval followingtheir instructions.

  2. Run evaluation with:

$ MODEL='fla-hub/gla-1.3B-100B'$ python -m evals.harness --model hf \    --model_args pretrained=$MODEL,dtype=bfloat16 \    --tasks wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,boolq,sciq,copa,openbookqa \    --batch_size 64 \    --num_fewshot 0 \    --device cuda \    --show_config

We've madefla compatible with hf-style evaluations, you can callevals.harness to finish the evaluations.Running the command above will provide the task results reported in the GLA paper.

  1. Multi-GPU Evaluation with Hugging Face accelerate 🚀

To perform data-parallel evaluation (where each GPU loads a separate full copy of the model), we leverage the accelerate launcher as follows:

$ MODEL='fla-hub/gla-1.3B-100B'$ accelerate launch -m evals.harness --model hf  \    --model_args pretrained=$MODEL,dtype=bfloat16,trust_remote_code=True  \    --tasks wikitext,lambada_openai,piqa,hellaswag,winogrande,arc_easy,arc_challenge,boolq,sciq,copa,openbookqa \    --batch_size 64  \    --num_fewshot 0  \    --device cuda  \    --show_config  \    --trust_remote_code
  1. 📏 RULER Benchmark suite

The RULER benchmarks are commonly used for evaluating model performance on long-context tasks.You can evaluatefla models on RULER directly usinglm-evaluation-harness. RULER is only available in a relatively recent version oflm-evaluation-harness, so make sure you have the latest version installed.

git clone --depth 1 https://github.com/EleutherAI/lm-evaluation-harnesscd lm-evaluation-harnesspip install -e .

Then, install the necessary dependencies for RULER:

pip install lm_eval["ruler"]

and run evaluation by (e.g., 32k contexts):

$ accelerate launch -m evals.harness \    --output_path$OUTPUT \    --tasks niah_single_1,niah_single_2,niah_single_3,niah_multikey_1,niah_multikey_2,niah_multikey_3,niah_multiquery,niah_multivalue,ruler_vt,ruler_cwe,ruler_fwe,ruler_qa_hotpot,ruler_qa_squad \    --model_args pretrained=$MODEL,dtype=bfloat16,max_length=32768,trust_remote_code=True \    --metadata='{"max_seq_lengths":[4096,8192,16384,32768]}' \    --batch_size 2 \    --show_config  \    --trust_remote_code

If a GPU can't load a full copy of the model, please refer tothis link for FSDP settings.

Tip

If you are usinglm-evaluation-harness as an external library and can't find (almost) any tasks available, before callinglm_eval.evaluate() orlm_eval.simple_evaluate(), simply run the following to load the library's stock tasks!

>>>fromlm_eval.tasksimportTaskManager;TaskManager().initialize_tasks()

Benchmarks

We compared our Triton-based RetNet implementation with CUDA-based FlashAttention2, using a batch size of 8, 32 heads, and a head dimension of 128, across different sequence lengths.These tests were conducted on a single H100 80GB GPU, as illustrated in the following graph

# you might have to first install `fla` to enable its import via `pip install -e .`$pythonbenchmark_retention.pyPerformance:Tchunk_fwdparallel_fwdflash_fwdchunk_fwdbwdparallel_fwdbwdflash_fwdbwd0128.00.2640320.2435360.0834881.3018561.1667840.3207041256.00.2734720.2528480.0943041.3458721.3006080.8079362512.00.3036000.2788960.0981121.5031681.4331840.85721631024.00.3572480.3673600.1565281.7735522.3034241.16086442048.00.4546240.6056160.3409282.2837284.4833601.95593654096.00.6389601.3780161.0049923.37472012.2712154.81377668192.01.0123524.2013443.6250085.58180840.83361815.023697716384.01.74851214.48966413.71008010.191552153.09376554.336864
image

Citation

If you find this repository helpful, please cite our work:

@software{yang2024fla,title  ={FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism},author ={Yang, Songlin and Zhang, Yu},url    ={https://github.com/fla-org/flash-linear-attention},month  = jan,year   ={2024}}@inproceedings{yang2024gdn,title     ={Gated Delta Networks: Improving Mamba2 with Delta Rule},author    ={Songlin Yang and Jan Kautz and Ali Hatamizadeh},booktitle ={Proceedings of ICLR},year      ={2025}}@inproceedings{yang2024deltanet,title     ={Parallelizing Linear Transformers with the Delta Rule over Sequence Length},author    ={Yang, Songlin and Wang, Bailin and Zhang, Yu and Shen, Yikang and Kim, Yoon},booktitle ={Proceedings of NeurIPS},year      ={2024}}@inproceedings{zhang2024gsa,title     ={Gated Slot Attention for Efficient Linear-Time Sequence Modeling},author    ={Zhang, Yu and Yang, Songlin and Zhu, Ruijie and Zhang, Yue and Cui, Leyang and Wang, Yiqiao and Wang, Bolun and Shi, Freda and Wang, Bailin and Bi, Wei and Zhou, Peng and Fu, Guohong},booktitle ={Proceedings of NeurIPS},year      ={2024}}@inproceedings{qin2024hgrn2,title     ={HGRN2: Gated Linear RNNs with State Expansion},author    ={Qin, Zhen and Yang, Songlin and Sun, Weixuan and Shen, Xuyang and Li, Dong and Sun, Weigao and Zhong, Yiran},booktitle ={Proceedings of COLM},year      ={2024}}@inproceedings{yang2024gla,title     ={Gated Linear Attention Transformers with Hardware-Efficient Training},author    ={Yang, Songlin and Wang, Bailin and Shen, Yikang and Panda, Rameswar and Kim, Yoon},booktitle ={Proceedings of ICML},year      ={2024}}

Star History

Stargazers repo roster for @fla-org/flash-linear-attention

Star History Chart

Acknowledgments

We extend our gratitude toIntel Corporation andBitdeer for providing CI server resources that power our infrastructure.

About

🚀 Efficient implementations of state-of-the-art linear attention models in Torch and Triton

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

[8]ページ先頭

©2009-2025 Movatter.jp