Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc0e1c18

Browse files
authored
[bugfix] fix mtp pp (#6812)
1 parent450679a commitc0e1c18

File tree

7 files changed

+50
-50
lines changed

7 files changed

+50
-50
lines changed

‎docs/source/BestPractices/Qwen3-Best-Practice.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ swift rlhf \
328328

329329
Qwen3-235B-A22B-Instruct-250718 单机8卡H20 LoRA训练的最佳实践参考:[https://github.com/modelscope/ms-swift/pull/5033](https://github.com/modelscope/ms-swift/pull/5033)
330330

331-
ms-swift 引入了 Megatron 并行技术以加速大模型的CPT/SFT/DPO/KTO/RM。支持的模型可以在[支持的模型文档](../Instruction/Supported-models-and-datasets.md)中找到。
331+
ms-swift 引入了 Megatron 并行技术以加速大模型的CPT/SFT/DPO/GRPO。支持的模型可以在[支持的模型文档](../Instruction/Supported-models-and-datasets.md)中找到。
332332

333333
关于环境准备以及 HF 和 MCore 模型权重的转换,可以参考[Megatron-SWIFT训练文档](../Megatron-SWIFT/Quick-start.md)
334334

‎docs/source/Megatron-SWIFT/Quick-start.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
#快速开始
33

4-
ms-swift引入了Megatron的并行技术来加速大模型的训练,包括数据并行、张量并行、流水线并行、序列并行,上下文并行,专家并行。支持Qwen3、[Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/mcore_bridge/full/moe.sh)、Qwen2.5、Llama3、Deepseek-R1、GLM4.5等模型的CPT/SFT/DPO。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/Supported-models-and-datasets.md)。推荐在MoE训练时使用Megatron-SWIFT,这通常可以获得10倍的训练速度提升。
4+
ms-swift引入了Megatron的并行技术来加速大模型的训练,包括数据并行、张量并行、流水线并行、序列并行,上下文并行,专家并行。支持Qwen3、[Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/mcore_bridge/full/moe.sh)、Qwen2.5、Llama3、Deepseek-R1、GLM4.5等模型的CPT/SFT/DPO/GRPO。完整支持的模型可以参考[支持的模型与数据集文档](../Instruction/Supported-models-and-datasets.md)。推荐在MoE训练时使用Megatron-SWIFT,这通常可以获得10倍的训练速度提升。
55

66

77
| 方法| 全参数| LoRA| MoE| 多模态| FP8|

‎docs/source_en/BestPractices/Qwen3-Best-Practice.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ swift rlhf \
332332

333333
Best practice reference for single-node 8xH20 LoRA training with Qwen3-235B-A22B-Instruct-250718:https://github.com/modelscope/ms-swift/pull/5033.
334334

335-
ms-swift introduces Megatron parallelism techniques to accelerate CPT/SFT/DPO/KTO/RM for large models. Supported models can be found in the[Supported Models and Datasets Document](../Instruction/Supported-models-and-datasets.md).
335+
ms-swift introduces Megatron parallelism techniques to accelerate CPT/SFT/DPO/GRPO for large models. Supported models can be found in the[Supported Models and Datasets Document](../Instruction/Supported-models-and-datasets.md).
336336

337337
For environment setup and conversion between HF and MCore model weights, refer to the[Megatron-SWIFT Training Documentation](../Megatron-SWIFT/Quick-start.md).
338338

‎docs/source_en/Megatron-SWIFT/Quick-start.md‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#Quick Start
22

3-
ms-swift incorporates Megatron's parallelization techniques to accelerate the training of large models, including data parallelism, tensor parallelism, pipeline parallelism, sequence parallelism, context parallelism, and expert parallelism. It supports CPT/SFT/DPO for models such as Qwen3,[Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/mcore_bridge/full/moe.sh), Qwen2.5, Llama3, Deepseek-R1 and GLM4.5 series. For a complete list of supported models, please refer to the[Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). We recommend using Megatron-SWIFT for MoE training; it can typically achieve a 10x speedup in training.
3+
ms-swift incorporates Megatron's parallelization techniques to accelerate the training of large models, including data parallelism, tensor parallelism, pipeline parallelism, sequence parallelism, context parallelism, and expert parallelism. It supports CPT/SFT/DPO/GRPO for models such as Qwen3,[Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/megatron/mcore_bridge/full/moe.sh), Qwen2.5, Llama3, Deepseek-R1 and GLM4.5 series. For a complete list of supported models, please refer to the[Supported Models and Datasets documentation](../Instruction/Supported-models-and-datasets.md). We recommend using Megatron-SWIFT for MoE training; it can typically achieve a 10x speedup in training.
44

55

66
| Method| Full-Parameter| LoRA| MoE| Multimodal| FP8|

‎swift/megatron/argument/megatron_args.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ class MegatronTunerMixin:
267267
use_rslora:bool=False
268268

269269
def__post_init__(self):
270-
ifself.freeze_parameters_ratio>0andself.pipeline_model_parallel_size>1:
270+
if0<self.freeze_parameters_ratio<1andself.pipeline_model_parallel_size>1:
271271
raiseValueError('`freeze_parameters_ratio` is not supported when `pipeline_model_parallel_size` > 1')
272272
ifself.target_regex:
273273
self.target_modules=self.target_regex

‎swift/megatron/model/gpt_bridge.py‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,9 @@ def _convert(self, mg_models, hf_state_dict, hf_prefix: str, to_mcore: bool, tqd
12061206

12071207
if (notto_mcoreoris_pp_last_stage)andself.args.mtp_num_layers:
12081208
lm_model=getattr(mg_model,'language_model')ifself.args.is_multimodalelsemg_model
1209+
ifto_mcoreandself.pp_rank>0:
1210+
self._set_state_dict(lm_model,'embedding.word_embeddings.weight',hf_state_dict,self.hf_embed_key,
1211+
to_mcore)
12091212
layer_idx=0
12101213
whilelayer_idx<self.args.mtp_num_layers:
12111214
res=self._convert_mtp_layer(lm_model,hf_state_dict,f'{self.hf_mtp_prefix}.',layer_idx,to_mcore)

‎swift/megatron/model/gpt_model.py‎

Lines changed: 42 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,6 @@ def forward(
318318
rotary_pos_emb=rotary_pos_emb,
319319
rotary_pos_cos=rotary_pos_cos,
320320
rotary_pos_sin=rotary_pos_sin,
321-
mtp_in_postprocess=self.mtp_process,
322321
loss_mask=loss_mask,
323322
decoder_input=decoder_input,
324323
attention_mask=attention_mask,
@@ -339,7 +338,6 @@ def _postprocess(
339338
rotary_pos_emb,
340339
rotary_pos_cos,
341340
rotary_pos_sin,
342-
mtp_in_postprocess=None,
343341
loss_mask=None,
344342
decoder_input=None,
345343
attention_mask=None,
@@ -355,6 +353,8 @@ def _postprocess(
355353
Applies Multi-Token Prediction if enabled, generates output logits through
356354
the output layer, and computes language model loss when labels are provided.
357355
"""
356+
ifnotself.post_process:
357+
returnhidden_states
358358
args=get_args()
359359
labels=labelsifargs.task_type=='causal_lm'elseNone
360360
in_inference_mode=inference_contextisnotNoneandnotself.training
@@ -366,7 +366,7 @@ def _postprocess(
366366
ifself.share_embeddings_and_output_weights:
367367
output_weight=self.shared_embedding_or_output_weight()
368368

369-
ifmtp_in_postprocess:
369+
ifself.mtp_process:
370370
hidden_states=self.mtp(
371371
input_ids=input_ids,
372372
position_ids=position_ids,
@@ -381,51 +381,48 @@ def _postprocess(
381381
embedding=self.embedding,
382382
**(extra_block_kwargsor {}),
383383
)
384-
385-
ifnotself.post_process:
386-
returnhidden_states
387-
388-
ifself.mtp_process:
389-
mtp_labels=labels.clone()
390-
from ..trainers.utilsimportsplit_cp_inputs
391384
hidden_states_list=torch.chunk(hidden_states,1+self.config.mtp_num_layers,dim=0)
392385
hidden_states=hidden_states_list[0]
393-
ifloss_maskisNone:
394-
# if loss_mask is not provided, use all ones as loss_mask
395-
loss_mask=mtp_labels.new_ones((1,packed_seq_params.cu_seqlens_q[-1]))
396-
cu_seqlens=packed_seq_params.cu_seqlens_q
397-
formtp_layer_numberinrange(self.config.mtp_num_layers):
398-
# output
399-
mtp_logits,_=self.output_layer(
400-
hidden_states_list[mtp_layer_number+1],
401-
weight=output_weight,
402-
runtime_gather_output=runtime_gather_output,
403-
)
404-
# Calc loss for the current Multi-Token Prediction (MTP) layers.
405-
mtp_labels,_=roll_tensor(mtp_labels,shifts=-1,dims=-1,cp_group=self.cp_group)
406-
loss_mask[:,cu_seqlens[:-1]]=0
407-
loss_mask,_=roll_tensor(loss_mask,shifts=-1,dims=-1)
408-
ifargs.context_parallel_size>1:
409-
loss_mask_=split_cp_inputs(loss_mask,cu_seqlens,dim=1)
410-
else:
411-
loss_mask_=loss_mask.clone()
412-
mtp_loss=self.compute_language_model_loss(mtp_labels,mtp_logits)
413-
mtp_loss=loss_mask_*mtp_loss
414-
num_tokens=loss_mask_.sum()
415-
ifself.training:
416-
# TODO(shifangx): remove the use of parallel_state here
417-
# after moving loss logging to loss_func in pretrain_gpt.py
418-
MTPLossLoggingHelper.save_loss_to_tracker(
419-
torch.sum(mtp_loss)/num_tokens,
420-
mtp_layer_number,
421-
self.config.mtp_num_layers,
422-
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
386+
387+
iflabelsisnotNone:
388+
from ..trainers.utilsimportsplit_cp_inputs
389+
mtp_labels=labels.clone()
390+
ifloss_maskisNone:
391+
# if loss_mask is not provided, use all ones as loss_mask
392+
loss_mask=mtp_labels.new_ones((1,packed_seq_params.cu_seqlens_q[-1]))
393+
cu_seqlens=packed_seq_params.cu_seqlens_q
394+
formtp_layer_numberinrange(self.config.mtp_num_layers):
395+
# output
396+
mtp_logits,_=self.output_layer(
397+
hidden_states_list[mtp_layer_number+1],
398+
weight=output_weight,
399+
runtime_gather_output=runtime_gather_output,
423400
)
424-
mtp_loss_scale=self.config.mtp_loss_scaling_factor/self.config.mtp_num_layers
425-
ifself.config.calculate_per_token_loss:
426-
hidden_states=MTPLossAutoScaler.apply(hidden_states,mtp_loss_scale*mtp_loss)
427-
else:
428-
hidden_states=MTPLossAutoScaler.apply(hidden_states,mtp_loss_scale*mtp_loss/num_tokens)
401+
# Calc loss for the current Multi-Token Prediction (MTP) layers.
402+
mtp_labels,_=roll_tensor(mtp_labels,shifts=-1,dims=-1,cp_group=self.cp_group)
403+
loss_mask[:,cu_seqlens[:-1]]=0
404+
loss_mask,_=roll_tensor(loss_mask,shifts=-1,dims=-1)
405+
ifargs.context_parallel_size>1:
406+
loss_mask_=split_cp_inputs(loss_mask,cu_seqlens,dim=1)
407+
else:
408+
loss_mask_=loss_mask.clone()
409+
mtp_loss=self.compute_language_model_loss(mtp_labels,mtp_logits)
410+
mtp_loss=loss_mask_*mtp_loss
411+
num_tokens=loss_mask_.sum()
412+
ifself.training:
413+
# TODO(shifangx): remove the use of parallel_state here
414+
# after moving loss logging to loss_func in pretrain_gpt.py
415+
MTPLossLoggingHelper.save_loss_to_tracker(
416+
torch.sum(mtp_loss)/num_tokens,
417+
mtp_layer_number,
418+
self.config.mtp_num_layers,
419+
avg_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
420+
)
421+
mtp_loss_scale=self.config.mtp_loss_scaling_factor/self.config.mtp_num_layers
422+
ifself.config.calculate_per_token_loss:
423+
hidden_states=MTPLossAutoScaler.apply(hidden_states,mtp_loss_scale*mtp_loss)
424+
else:
425+
hidden_states=MTPLossAutoScaler.apply(hidden_states,mtp_loss_scale*mtp_loss/num_tokens)
429426
sequence_parallel_override=False
430427
ifin_inference_modeandinference_context.materialize_only_last_token_logits:
431428
ifinference_context.is_static_batching():

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp