Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit2b7f921

Browse files
committed
Added gemma-7b performance
1 parentef055fc commit2b7f921

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

‎model.py‎

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class ModelArgs:
2626
dim:int=4096
2727
intermediate_size:int=None
2828
n_local_heads:int=-1
29-
head_dim:int=64
29+
head_dim:int=None
3030
rope_base:float=10000
3131
norm_eps:float=1e-5
3232

@@ -37,7 +37,8 @@ def __post_init__(self):
3737
hidden_dim=4*self.dim
3838
n_hidden=int(2*hidden_dim/3)
3939
self.intermediate_size=find_multiple(n_hidden,256)
40-
self.head_dim=self.dim//self.n_head
40+
ifself.head_dimisNone:
41+
self.head_dim=self.dim//self.n_head
4142

4243
@classmethod
4344
deffrom_name(cls,name:str):
@@ -51,6 +52,7 @@ def from_name(cls, name: str):
5152

5253
transformer_configs= {
5354
"gemma-2b":dict(dim=2048,vocab_size=256000,n_layer=18,n_head=8,n_local_heads=1,intermediate_size=16384),
55+
"gemma-7b":dict(dim=3072,vocab_size=256000,n_layer=28,n_head=16,n_local_heads=16,intermediate_size=24576,head_dim=256),
5456
"CodeLlama-7b-Python-hf":dict(block_size=16384,vocab_size=32000,n_layer=32,dim=4096,rope_base=1000000),
5557
"7B":dict(n_layer=32,n_head=32,dim=4096),
5658
"13B":dict(n_layer=40,n_head=40,dim=5120),
@@ -95,14 +97,13 @@ def __init__(self, config: ModelArgs) -> None:
9597
defsetup_caches(self,max_batch_size,max_seq_length):
9698
ifself.max_seq_length>=max_seq_lengthandself.max_batch_size>=max_batch_size:
9799
return
98-
head_dim=self.config.dim//self.config.n_head
99100
max_seq_length=find_multiple(max_seq_length,8)
100101
self.max_seq_length=max_seq_length
101102
self.max_batch_size=max_batch_size
102103
forbinself.layers:
103-
b.attention.kv_cache=KVCache(max_batch_size,max_seq_length,self.config.n_local_heads,head_dim)
104+
b.attention.kv_cache=KVCache(max_batch_size,max_seq_length,self.config.n_local_heads,self.config.head_dim)
104105

105-
self.freqs_cis=precompute_freqs_cis(self.config.block_size,self.config.dim//self.config.n_head,self.config.rope_base)
106+
self.freqs_cis=precompute_freqs_cis(self.config.block_size,self.config.head_dim,self.config.rope_base)
106107
self.causal_mask=torch.tril(torch.ones(self.max_seq_length,self.max_seq_length,dtype=torch.bool))
107108

108109
defforward(self,idx:Tensor,input_pos:Optional[Tensor]=None)->Tensor:
@@ -145,7 +146,7 @@ def __init__(self, config: ModelArgs):
145146
total_head_dim= (config.n_head+2*config.n_local_heads)*config.head_dim
146147
# key, query, value projections for all heads, but in a batch
147148
self.wqkv=nn.Linear(config.dim,total_head_dim,bias=False)
148-
self.wo=nn.Linear(config.dim,config.dim,bias=False)
149+
self.wo=nn.Linear(config.n_head*config.head_dim,config.dim,bias=False)
149150
self.kv_cache=None
150151

151152
self.n_head=config.n_head
@@ -165,7 +166,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
165166
bsz,seqlen,_=x.shape
166167

167168
kv_size=self.n_local_heads*self.head_dim
168-
q,k,v=self.wqkv(x).split([self.dim,kv_size,kv_size],dim=-1)
169+
q,k,v=self.wqkv(x).split([self.n_head*self.head_dim,kv_size,kv_size],dim=-1)
169170

170171
q=q.view(bsz,seqlen,self.n_head,self.head_dim)
171172
k=k.view(bsz,seqlen,self.n_local_heads,self.head_dim)
@@ -183,7 +184,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
183184
v=v.repeat_interleave(self.n_head//self.n_local_heads,dim=1)
184185
y=F.scaled_dot_product_attention(q,k,v,attn_mask=mask,dropout_p=0.0)
185186

186-
y=y.transpose(1,2).contiguous().view(bsz,seqlen,self.dim)
187+
y=y.transpose(1,2).contiguous().view(bsz,seqlen,self.n_head*self.head_dim)
187188

188189
y=self.wo(y)
189190
returny

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp