Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitcac0a28

Browse files
Add afmoe model (#42168)
* Add AFMoE model support* Address review feedback for AFMoE implementation* Add flex attention support to AFMoE model* Fix expert_bias routing in AFMoE* Remove test-results directory* Address PR review feedback for AFMoE model* fix(afmoe): ensure RMSNorm output dtype matches input dtype)* properly return attn weights* fix most tests* cleanupRemove shared expert if else as defaults to 2Remove `route_norm` as it default to `True`.Make test smaller faster* fix input embeds api* update rope API, smaller test and should be good to go* oups wront place to skip unittest* quality* update* rope parameter docstring fill---------Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>Co-authored-by: Arthur <arthur.zucker@gmail.com>
1 parent2a61590 commitcac0a28

File tree

13 files changed

+1703
-0
lines changed

13 files changed

+1703
-0
lines changed

‎.gitignore‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,4 @@ tags
175175

176176
# Cursor IDE files
177177
.cursor/
178+
test-results/

‎docs/source/en/_toctree.yml‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,8 @@
384384
title:Main Classes
385385
-sections:
386386
-sections:
387+
-local:model_doc/afmoe
388+
title:AFMoE
387389
-local:model_doc/albert
388390
title:ALBERT
389391
-local:model_doc/apertus

‎docs/source/en/model_doc/afmoe.md‎

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
<!--Copyright 2025 Arcee AI and The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
*This model was released on {release_date} and added to Hugging Face Transformers on 2025-11-18.*
17+
18+
<divstyle="float:right;">
19+
<div class="flex flex-wrap space-x-1">
20+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
21+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
22+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
23+
</div>
24+
</div>
25+
26+
#AFMoE
27+
28+
AFMoE (Arcee Foundational Mixture of Experts) is a decoder-only transformer model that extends the Llama architecture with a sparse Mixture of Experts (MoE) approach. The model combines token-choice routing with shared experts and employs several architectural innovations for efficient inference and improved performance.
29+
30+
##Key Architecture Features
31+
32+
AFMoE introduces several key modifications to the standard transformer architecture:
33+
34+
-**Mixture of Experts with Shared Experts**: Combines routed experts (activated per-token via learned routing) with always-active shared experts for stable base computation
35+
-**Token-Choice Routing**: Uses sigmoid or softmax-based routing with normalization and scaling for expert selection
36+
-**Q/K Normalization and Gating**: Applies RMSNorm to query and key projections and uses sigmoid gating on attention outputs for improved stability
37+
-**Hybrid Attention Patterns**: Alternates between sliding window attention and full attention across layers for efficiency with long contexts
38+
-**Dual Normalization**: Uses pre- and post-normalization around both attention and MLP blocks for training stability
39+
-**Configurable Dense Layers**: Allows initial layers to use dense MLPs before transitioning to sparse MoE layers
40+
41+
The model supports extended context lengths with RoPE embeddings and includes all standard Transformers features including Flash Attention 2, SDPA, gradient checkpointing, and quantization support.
42+
43+
>[!TIP]
44+
>AFMoE is particularly well-suited for scenarios requiring efficient scaling through sparsity while maintaining strong performance. The shared experts provide a stable computation baseline while routed experts enable model capacity scaling.
45+
46+
The example below demonstrates how to generate text with AFMoE using[`Pipeline`] or the[`AutoModel`].
47+
48+
<hfoptionsid="usage">
49+
<hfoptionid="Pipeline">
50+
51+
```py
52+
import torch
53+
from transformersimport pipeline
54+
55+
pipeline= pipeline(
56+
task="text-generation",
57+
model="arcee-ai/Trinity-Mini",
58+
torch_dtype=torch.bfloat16,
59+
device=0
60+
)
61+
62+
output= pipeline("The key innovation in mixture of experts is")
63+
print(output[0]["generated_text"])
64+
```
65+
66+
</hfoption>
67+
<hfoptionid="AutoModel">
68+
69+
```py
70+
import torch
71+
from transformersimport AutoTokenizer, AfmoeForCausalLM
72+
73+
tokenizer= AutoTokenizer.from_pretrained("arcee-ai/Trinity-Mini")
74+
model= AfmoeForCausalLM.from_pretrained(
75+
"arcee-ai/Trinity-Mini",
76+
torch_dtype=torch.bfloat16,
77+
device_map="auto"
78+
)
79+
80+
inputs= tokenizer("The key innovation in mixture of experts is",return_tensors="pt")
81+
with torch.no_grad():
82+
outputs= model.generate(**inputs,max_new_tokens=50)
83+
84+
print(tokenizer.decode(outputs[0],skip_special_tokens=True))
85+
```
86+
87+
</hfoption>
88+
</hfoptions>
89+
90+
##Model Architecture Details
91+
92+
###Expert Routing
93+
94+
AFMoE uses token-choice routing where each token independently selects top-k experts based on router logits. The routing mechanism includes:
95+
96+
- Configurable scoring function (sigmoid or softmax)
97+
- Optional route normalization for balanced expert utilization
98+
- Route scaling to control expert contribution strength
99+
- Bias correction for expert selection
100+
101+
###Shared Experts
102+
103+
Unlike standard MoE models, AFMoE includes shared experts that are always activated for every token, providing:
104+
105+
- A stable computation baseline across all tokens
106+
- Reduced variance in model outputs
107+
- Better handling of out-of-distribution inputs
108+
109+
###Attention Mechanism
110+
111+
The hybrid attention pattern alternates between:
112+
113+
-**Sliding Window Attention**: For efficiency on long sequences, with configurable window size
114+
-**Full Attention**: Applied every N layers (configurable via`global_attn_every_n_layers`) for global context
115+
116+
All attention layers include Q/K normalization and output gating for improved training dynamics.
117+
118+
##AfmoeConfig
119+
120+
[[autodoc]] AfmoeConfig
121+
122+
##AfmoeModel
123+
124+
[[autodoc]] AfmoeModel
125+
- forward
126+
127+
##AfmoeForCausalLM
128+
129+
[[autodoc]] AfmoeForCausalLM
130+
- forward

‎splitted_tests.txt‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tests/models/afmoe/test_modeling_afmoe.py
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
fromtypingimportTYPE_CHECKING
15+
16+
from ...utilsimport_LazyModule
17+
from ...utils.import_utilsimportdefine_import_structure
18+
19+
20+
ifTYPE_CHECKING:
21+
from .configuration_afmoeimport*
22+
from .modeling_afmoeimport*
23+
else:
24+
importsys
25+
26+
_file=globals()["__file__"]
27+
sys.modules[__name__]=_LazyModule(__name__,_file,define_import_structure(_file),module_spec=__spec__)
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# coding=utf-8
2+
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""AFMoE model configuration"""
16+
17+
fromtypingimportOptional
18+
19+
from ...configuration_utilsimportPreTrainedConfig,layer_type_validation
20+
from ...modeling_rope_utilsimportRopeParameters
21+
from ...utilsimportlogging
22+
23+
24+
logger=logging.get_logger(__name__)
25+
26+
27+
classAfmoeConfig(PreTrainedConfig):
28+
r"""
29+
This is the configuration class to store the configuration of a [`AfmoeModel`]. It is used to instantiate an
30+
AFMoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
31+
with the defaults will yield a similar configuration to that of [arcee-ai/Trinity-Mini](https://huggingface.co/arcee-ai/Trinity-Mini).
32+
33+
AFMoE is an Adaptive Feedforward MoE (Mixture of Experts) model with token-choice routing, shared experts, and a
34+
hybrid attention mechanism combining sliding window and full attention patterns.
35+
36+
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
37+
documentation from [`PreTrainedConfig`] for more information.
38+
39+
Args:
40+
vocab_size (`int`, *optional*, defaults to 200192):
41+
Vocabulary size of the AFMoE model. Defines the number of different tokens that can be represented by the
42+
`inputs_ids` passed when calling [`AfmoeModel`].
43+
hidden_size (`int`, *optional*, defaults to 2048):
44+
Dimension of the hidden representations.
45+
intermediate_size (`int`, *optional*, defaults to 6144):
46+
Dimension of the dense MLP representations.
47+
moe_intermediate_size (`int`, *optional*, defaults to 1408):
48+
Intermediate size of the routed expert MLPs.
49+
num_hidden_layers (`int`, *optional*, defaults to 32):
50+
Number of hidden layers in the Transformer decoder.
51+
num_dense_layers (`int`, *optional*, defaults to 1):
52+
Number of initial dense layers before MoE layers begin. Layers with index < num_dense_layers will use
53+
standard dense MLPs instead of MoE.
54+
num_attention_heads (`int`, *optional*, defaults to 16):
55+
Number of attention heads for each attention layer in the Transformer decoder.
56+
num_key_value_heads (`int`, *optional*):
57+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
58+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
59+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
60+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
61+
by meanpooling all the original heads within that group. For more details, check out [this
62+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
63+
`num_attention_heads`.
64+
head_dim (`int`, *optional*, defaults to 128):
65+
The dimension of each attention head.
66+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
67+
The non-linear activation function (function or string) in the MLP blocks.
68+
max_position_embeddings (`int`, *optional*, defaults to 16384):
69+
The maximum sequence length that this model might ever be used with.
70+
initializer_range (`float`, *optional*, defaults to 0.02):
71+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
73+
The epsilon used by the RMS normalization layers.
74+
use_cache (`bool`, *optional*, defaults to `True`):
75+
Whether or not the model should return the last key/values attentions (not used by all models). Only
76+
relevant if `config.is_decoder=True`.
77+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
78+
Whether the model's input and output word embeddings should be tied.
79+
rope_theta (`float`, *optional*, defaults to 10000.0):
80+
The base period of the RoPE embeddings.
81+
rope_parameters (`RopeParameters`, *optional*):
82+
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
83+
a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE
84+
with longer `max_position_embeddings`.
85+
num_experts (`int`, *optional*, defaults to 64):
86+
Number of routed experts in MoE layers.
87+
num_experts_per_tok (`int`, *optional*, defaults to 6):
88+
Number of experts to route each token to. This is the top-k value for the token-choice routing.
89+
num_shared_experts (`int`, *optional*, defaults to 2):
90+
Number of shared experts that are always activated for all tokens.
91+
route_scale (`float`, *optional*, defaults to 1.0):
92+
Scaling factor applied to routing weights.
93+
global_attn_every_n_layers (`int`, *optional*, defaults to 4):
94+
The frequency of full attention layers. Every Nth layer will use full attention, while others use sliding
95+
window attention.
96+
sliding_window (`int`, *optional*, defaults to 1024):
97+
Sliding window size for local attention layers.
98+
layer_types (`list[str]`, *optional*):
99+
A list that explicitly maps each layer index with its attention type. Each element should be either
100+
"sliding_attention" or "full_attention". If not provided, it will be automatically generated based on
101+
`global_attn_every_n_layers`.
102+
attention_dropout (`float`, *optional*, defaults to 0.0):
103+
The dropout ratio for the attention probabilities.
104+
mup_enabled (`bool`, *optional*, defaults to `False`):
105+
Whether to enable muP (Maximal Update Parametrization) input scaling. When enabled, input embeddings
106+
are scaled by `sqrt(hidden_size)`.
107+
108+
Example:
109+
```python
110+
>>> from transformers import AfmoeModel, AfmoeConfig
111+
112+
>>> # Initializing an AFMoE configuration
113+
>>> configuration = AfmoeConfig()
114+
115+
>>> # Initializing a model from the afmoe-small-sft-v1 style configuration
116+
>>> model = AfmoeModel(configuration)
117+
118+
>>> # Accessing the model configuration
119+
>>> configuration = model.config
120+
```
121+
"""
122+
123+
model_type="afmoe"
124+
keys_to_ignore_at_inference= ["past_key_values"]
125+
126+
# Default pipeline parallel plan for base model
127+
base_model_pp_plan= {
128+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
129+
"layers": (["hidden_states","attention_mask"], ["hidden_states"]),
130+
"norm": (["hidden_states"], ["hidden_states"]),
131+
}
132+
133+
def__init__(
134+
self,
135+
vocab_size:Optional[int]=200192,
136+
hidden_size:Optional[int]=2048,
137+
intermediate_size:Optional[int]=6144,
138+
moe_intermediate_size:Optional[int]=1408,
139+
num_hidden_layers:Optional[int]=32,
140+
num_dense_layers:Optional[int]=1,
141+
num_attention_heads:Optional[int]=16,
142+
num_key_value_heads:Optional[int]=None,
143+
head_dim:Optional[int]=128,
144+
hidden_act:Optional[str]="silu",
145+
max_position_embeddings:Optional[int]=16384,
146+
initializer_range:Optional[float]=0.02,
147+
rms_norm_eps:Optional[float]=1e-5,
148+
use_cache:Optional[bool]=True,
149+
tie_word_embeddings:Optional[bool]=False,
150+
rope_theta:Optional[float]=10000.0,
151+
rope_parameters:Optional[RopeParameters|dict[str,RopeParameters]]=None,
152+
num_experts:Optional[int]=64,
153+
num_experts_per_tok:Optional[int]=6,
154+
num_shared_experts:Optional[int]=2,
155+
route_scale:Optional[float]=1.0,
156+
global_attn_every_n_layers:Optional[int]=4,
157+
sliding_window:Optional[int]=1024,
158+
layer_types:Optional[list]=None,
159+
attention_dropout:Optional[float]=0.0,
160+
mup_enabled:Optional[bool]=False,
161+
**kwargs,
162+
):
163+
self.vocab_size=vocab_size
164+
self.max_position_embeddings=max_position_embeddings
165+
self.hidden_size=hidden_size
166+
self.intermediate_size=intermediate_size
167+
self.num_hidden_layers=num_hidden_layers
168+
self.num_dense_layers=num_dense_layers
169+
self.num_attention_heads=num_attention_heads
170+
self.head_dim=head_dim
171+
self.hidden_act=hidden_act
172+
self.initializer_range=initializer_range
173+
self.rms_norm_eps=rms_norm_eps
174+
self.use_cache=use_cache
175+
self.rope_theta=rope_theta
176+
self.rope_parameters=rope_parameters
177+
178+
# MoE specific
179+
self.moe_intermediate_size=moe_intermediate_size
180+
self.num_experts_per_tok=num_experts_per_tok
181+
self.num_experts=num_experts
182+
self.num_shared_experts=num_shared_experts
183+
self.route_scale=route_scale
184+
self.attention_bias=False
185+
186+
# Attention specific
187+
self.attention_dropout=attention_dropout
188+
self.global_attn_every_n_layers=global_attn_every_n_layers
189+
self.sliding_window=sliding_window
190+
self.mup_enabled=mup_enabled
191+
self.layer_types=layer_types
192+
ifself.layer_typesisNone:
193+
self.layer_types= [
194+
"sliding_attention"ifbool((i+1)%global_attn_every_n_layers)else"full_attention"
195+
foriinrange(self.num_hidden_layers)
196+
]
197+
layer_type_validation(self.layer_types)
198+
199+
ifnum_key_value_headsisNone:
200+
num_key_value_heads=num_attention_heads
201+
202+
self.num_key_value_heads=num_key_value_heads
203+
204+
super().__init__(
205+
tie_word_embeddings=tie_word_embeddings,
206+
**kwargs,
207+
)
208+
209+
210+
__all__= ["AfmoeConfig"]

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp