@@ -26,7 +26,7 @@ class ModelArgs:
2626dim :int = 4096
2727intermediate_size :int = None
2828n_local_heads :int = - 1
29- head_dim :int = 64
29+ head_dim :int = None
3030rope_base :float = 10000
3131norm_eps :float = 1e-5
3232
@@ -37,7 +37,8 @@ def __post_init__(self):
3737hidden_dim = 4 * self .dim
3838n_hidden = int (2 * hidden_dim / 3 )
3939self .intermediate_size = find_multiple (n_hidden ,256 )
40- self .head_dim = self .dim // self .n_head
40+ if self .head_dim is None :
41+ self .head_dim = self .dim // self .n_head
4142
4243@classmethod
4344def from_name (cls ,name :str ):
@@ -51,6 +52,7 @@ def from_name(cls, name: str):
5152
5253transformer_configs = {
5354"gemma-2b" :dict (dim = 2048 ,vocab_size = 256000 ,n_layer = 18 ,n_head = 8 ,n_local_heads = 1 ,intermediate_size = 16384 ),
55+ "gemma-7b" :dict (dim = 3072 ,vocab_size = 256000 ,n_layer = 28 ,n_head = 16 ,n_local_heads = 16 ,intermediate_size = 24576 ,head_dim = 256 ),
5456"CodeLlama-7b-Python-hf" :dict (block_size = 16384 ,vocab_size = 32000 ,n_layer = 32 ,dim = 4096 ,rope_base = 1000000 ),
5557"7B" :dict (n_layer = 32 ,n_head = 32 ,dim = 4096 ),
5658"13B" :dict (n_layer = 40 ,n_head = 40 ,dim = 5120 ),
@@ -95,14 +97,13 @@ def __init__(self, config: ModelArgs) -> None:
9597def setup_caches (self ,max_batch_size ,max_seq_length ):
9698if self .max_seq_length >= max_seq_length and self .max_batch_size >= max_batch_size :
9799return
98- head_dim = self .config .dim // self .config .n_head
99100max_seq_length = find_multiple (max_seq_length ,8 )
100101self .max_seq_length = max_seq_length
101102self .max_batch_size = max_batch_size
102103for b in self .layers :
103- b .attention .kv_cache = KVCache (max_batch_size ,max_seq_length ,self .config .n_local_heads ,head_dim )
104+ b .attention .kv_cache = KVCache (max_batch_size ,max_seq_length ,self .config .n_local_heads ,self . config . head_dim )
104105
105- self .freqs_cis = precompute_freqs_cis (self .config .block_size ,self .config .dim // self . config . n_head ,self .config .rope_base )
106+ self .freqs_cis = precompute_freqs_cis (self .config .block_size ,self .config .head_dim ,self .config .rope_base )
106107self .causal_mask = torch .tril (torch .ones (self .max_seq_length ,self .max_seq_length ,dtype = torch .bool ))
107108
108109def forward (self ,idx :Tensor ,input_pos :Optional [Tensor ]= None )-> Tensor :
@@ -145,7 +146,7 @@ def __init__(self, config: ModelArgs):
145146total_head_dim = (config .n_head + 2 * config .n_local_heads )* config .head_dim
146147# key, query, value projections for all heads, but in a batch
147148self .wqkv = nn .Linear (config .dim ,total_head_dim ,bias = False )
148- self .wo = nn .Linear (config .dim ,config .dim ,bias = False )
149+ self .wo = nn .Linear (config .n_head * config . head_dim ,config .dim ,bias = False )
149150self .kv_cache = None
150151
151152self .n_head = config .n_head
@@ -165,7 +166,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
165166bsz ,seqlen ,_ = x .shape
166167
167168kv_size = self .n_local_heads * self .head_dim
168- q ,k ,v = self .wqkv (x ).split ([self .dim ,kv_size ,kv_size ],dim = - 1 )
169+ q ,k ,v = self .wqkv (x ).split ([self .n_head * self . head_dim ,kv_size ,kv_size ],dim = - 1 )
169170
170171q = q .view (bsz ,seqlen ,self .n_head ,self .head_dim )
171172k = k .view (bsz ,seqlen ,self .n_local_heads ,self .head_dim )
@@ -183,7 +184,7 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona
183184v = v .repeat_interleave (self .n_head // self .n_local_heads ,dim = 1 )
184185y = F .scaled_dot_product_attention (q ,k ,v ,attn_mask = mask ,dropout_p = 0.0 )
185186
186- y = y .transpose (1 ,2 ).contiguous ().view (bsz ,seqlen ,self .dim )
187+ y = y .transpose (1 ,2 ).contiguous ().view (bsz ,seqlen ,self .n_head * self . head_dim )
187188
188189y = self .wo (y )
189190return y
@@ -197,7 +198,7 @@ def __init__(self, config: ModelArgs) -> None:
197198self .w2 = nn .Linear (config .intermediate_size ,config .dim ,bias = False )
198199
199200def forward (self ,x :Tensor )-> Tensor :
200- return self .w2 (F .gelu (self .w1 (x ))* self .w3 (x ))
201+ return self .w2 (F .gelu (self .w1 (x ), approximate = "tanh" )* self .w3 (x ))
201202
202203
203204class RMSNorm (nn .Module ):