Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
This repository was archived by the owner on Jan 21, 2025. It is now read-only.
/meshPublic archive

Commitd419b67

Browse files
author
Mesh TensorFlow Team
committed
Minor changes to make Experts Attention work.
PiperOrigin-RevId: 388312437
1 parent7e78cf8 commitd419b67

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

‎mesh_tensorflow/transformer/attention.py‎

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,8 @@ def __init__(self,
663663
ifmtf.layers.unit_scaling_convention():
664664
raiseNotImplementedError
665665

666+
# TODO(barretzoph): Make this work for model parallelism by not outputing
667+
# a tensor with `heads` dim.
666668
moe_output_dims=self.q_shape[-1]
667669
tf.logging.info("ExpertsAttention moe_hidden_size: {}".format(
668670
experts_hparams.hidden_size))
@@ -680,10 +682,12 @@ def __init__(self,
680682
switch_dropout=experts_hparams.switch_dropout,
681683
switch_temperature=experts_hparams.switch_temperature,
682684
switch_jitter=experts_hparams.switch_jitter,
683-
switch_top_k=experts_hparams.switch_top_k,
685+
ntlb_top_k=experts_hparams.ntlb_top_k,
684686
hidden_size=experts_hparams.hidden_size,
685687
output_dim=moe_output_dims,
686-
use_experts_attention=experts_hparams.use_experts_attention)
688+
use_experts_attention=experts_hparams.use_experts_attention,
689+
activation=experts_hparams.activation,
690+
z_loss=experts_hparams.z_loss)
687691

688692
def_compute_merge_qkv(self,antecedent):
689693
"""Computes qkv all in one call using MoE layer."""

‎mesh_tensorflow/transformer/transformer_layers.py‎

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -402,9 +402,11 @@ def __init__(self,
402402
switch_dropout=0.0,
403403
switch_temperature=1.0,
404404
switch_jitter=1e-2,
405-
switch_top_k=4,
405+
ntlb_top_k=4,
406406
hidden_size=3072,
407407
use_experts_attention=True,
408+
activation="relu",
409+
z_loss=None,
408410
**kwargs):
409411
super(ExpertsSelfAttention,self).__init__(**kwargs)
410412
self._hparams=mtf.transformer.moe.HParams(
@@ -420,9 +422,11 @@ def __init__(self,
420422
switch_dropout=switch_dropout,
421423
switch_temperature=switch_temperature,
422424
switch_jitter=switch_jitter,
423-
switch_top_k=switch_top_k,
425+
ntlb_top_k=ntlb_top_k,
424426
hidden_size=hidden_size,
425-
use_experts_attention=use_experts_attention)
427+
use_experts_attention=use_experts_attention,
428+
activation=activation,
429+
z_loss=z_loss)
426430

427431
defmake_params(self,context):
428432
num_heads=self.num_heads

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp