Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitdf51a6e

Browse files
committed
fix: plamo from PRggml-org#3557
1 parentb9fdfbd commitdf51a6e

File tree

2 files changed

+256
-13
lines changed

2 files changed

+256
-13
lines changed

‎convert-plamo-hf-to-gguf.py‎

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
importargparse
2+
importjson
3+
importsys
4+
importos
5+
importtorch
6+
importnumpyasnp
7+
frompathlibimportPath
8+
importgguf
9+
fromsentencepieceimportSentencePieceProcessor# type: ignore[import]
10+
11+
try:
12+
fromsafetensorsimportsafe_open
13+
exceptImportError:
14+
print("Please install `safetensors` python package")
15+
sys.exit(1)
16+
17+
18+
defcount_model_parts(dir_model:Path)->int:
19+
# get number of model parts
20+
num_parts=0
21+
forfilenameinos.listdir(dir_model):
22+
iffilename.startswith("model-00"):
23+
num_parts+=1
24+
25+
ifnum_parts>0:
26+
print("gguf: found "+str(num_parts)+" model parts")
27+
returnnum_parts
28+
29+
30+
defparse_args()->argparse.Namespace:
31+
parser=argparse.ArgumentParser(description="Convert a PLaMo model to a GGML compatible file")
32+
parser.add_argument(
33+
"--vocab-only",action="store_true",
34+
help="extract only the vocab",
35+
)
36+
parser.add_argument(
37+
"--outfile",type=Path,
38+
help="path to write to; default: based on input",
39+
)
40+
parser.add_argument(
41+
"model",type=Path,
42+
help="directory containing model file, or model file itself (*.bin)",
43+
)
44+
parser.add_argument(
45+
"ftype",type=int,choices=[0,1],default=1,nargs='?',
46+
help="output format - use 0 for float32, 1 for float16",
47+
)
48+
returnparser.parse_args()
49+
50+
51+
args=parse_args()
52+
53+
dir_model=args.model
54+
ftype=args.ftype
55+
ifnotdir_model.is_dir():
56+
print(f'Error:{args.model} is not a directory',file=sys.stderr)
57+
sys.exit(1)
58+
59+
60+
# possible tensor data types
61+
# ftype == 0 -> float32
62+
# ftype == 1 -> float16
63+
64+
# map from ftype to string
65+
ftype_str= ["f32","f16"]
66+
67+
ifargs.outfileisnotNone:
68+
fname_out=args.outfile
69+
else:
70+
# output in the same directory as the model by default
71+
fname_out=dir_model/f'ggml-model-{ftype_str[ftype]}.gguf'
72+
73+
print("gguf: loading model "+dir_model.name)
74+
75+
withopen(dir_model/"config.json","r",encoding="utf-8")asf:
76+
hparams=json.load(f)
77+
78+
ifhparams["architectures"][0]!="PlamoForCausalLM":
79+
print("Model architecture not supported: "+hparams["architectures"][0])
80+
81+
sys.exit(1)
82+
83+
# get number of model parts
84+
num_parts=count_model_parts(dir_model)
85+
86+
# from add PLaMo model #3557
87+
# https://github.com/ggerganov/llama.cpp/pull/3557/files
88+
89+
ARCH=gguf.MODEL_ARCH.PLAMO
90+
gguf_writer=gguf.GGUFWriter(fname_out,gguf.MODEL_ARCH_NAMES[ARCH])
91+
92+
print("gguf: get model metadata")
93+
94+
block_count=hparams["num_hidden_layers"]
95+
96+
gguf_writer.add_name("PLaMo")
97+
gguf_writer.add_context_length(4096)# not in config.json
98+
gguf_writer.add_embedding_length(hparams["hidden_size"])
99+
gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
100+
gguf_writer.add_block_count(block_count)
101+
gguf_writer.add_head_count(hparams["num_attention_heads"])
102+
gguf_writer.add_head_count_kv(hparams["num_attention_heads"]//hparams["n_shared_head"])
103+
gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
104+
gguf_writer.add_file_type(ftype)
105+
106+
107+
# TOKENIZATION
108+
109+
print("gguf: get tokenizer metadata")
110+
111+
tokens:list[bytes]= []
112+
scores:list[float]= []
113+
toktypes:list[int]= []
114+
115+
tokenizer_model_file=dir_model/'tokenizer.model'
116+
ifnottokenizer_model_file.is_file():
117+
print(f'Error: Missing{tokenizer_model_file}',file=sys.stderr)
118+
sys.exit(1)
119+
120+
# vocab type sentencepiece
121+
print("gguf: get sentencepiece tokenizer vocab, scores and token types")
122+
123+
tokenizer=SentencePieceProcessor(str(tokenizer_model_file))
124+
125+
foriinrange(tokenizer.vocab_size()):
126+
text:bytes
127+
score:float
128+
129+
piece=tokenizer.id_to_piece(i)
130+
text=piece.encode("utf-8")
131+
score=tokenizer.get_score(i)
132+
133+
toktype=1# defualt to normal token type
134+
iftokenizer.is_unknown(i):
135+
toktype=2
136+
iftokenizer.is_control(i):
137+
toktype=3
138+
139+
# toktype = 4 is user-defined = tokens from added_tokens.json
140+
141+
iftokenizer.is_unused(i):
142+
toktype=5
143+
iftokenizer.is_byte(i):
144+
toktype=6
145+
146+
tokens.append(text)
147+
scores.append(score)
148+
toktypes.append(toktype)
149+
150+
gguf_writer.add_tokenizer_model("llama")
151+
gguf_writer.add_token_list(tokens)
152+
gguf_writer.add_token_scores(scores)
153+
gguf_writer.add_token_types(toktypes)
154+
gguf_writer.add_sep_token_id(5)
155+
gguf_writer.add_pad_token_id(3)
156+
157+
special_vocab=gguf.SpecialVocab(dir_model)
158+
special_vocab.add_to_gguf(gguf_writer)
159+
160+
# TENSORS
161+
162+
tensor_map=gguf.get_tensor_name_map(ARCH,block_count)
163+
164+
# params for qkv transform
165+
n_head=hparams["num_attention_heads"]
166+
n_head_kv=hparams["num_key_value_heads"]
167+
168+
head_dim=hparams["hidden_size"]//n_head
169+
170+
# tensor info
171+
print("gguf: get tensor metadata")
172+
173+
ifnum_parts==0:
174+
part_names=iter(("model.safetensors",))
175+
else:
176+
part_names= (
177+
f"model-{n:05}-of-{num_parts:05}.safetensors"forninrange(1,num_parts+1)
178+
)
179+
180+
forpart_nameinpart_names:
181+
ifargs.vocab_only:
182+
break
183+
print("gguf: loading model part '"+part_name+"'")
184+
model_part=safe_open(dir_model/part_name,framework="pt")
185+
186+
fornameinmodel_part.keys():
187+
if"self_attn.rotary_emb.inv_freq"inname:
188+
continue
189+
data=model_part.get_tensor(name)
190+
191+
old_dtype=data.dtype
192+
193+
# convert any unsupported data types to float32
194+
ifdata.dtype!=torch.float16anddata.dtype!=torch.float32:
195+
data=data.to(torch.float32)
196+
197+
data=data.squeeze().numpy()
198+
199+
# map tensor names
200+
new_name=tensor_map.get_name(name,try_suffixes= (".weight",".bias"))
201+
ifnew_nameisNone:
202+
print("Can not map tensor '"+name+"'")
203+
sys.exit()
204+
205+
n_dims=len(data.shape)
206+
data_dtype=data.dtype
207+
208+
# if f32 desired, convert any float16 to float32
209+
ifftype==0anddata_dtype==np.float16:
210+
data=data.astype(np.float32)
211+
212+
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
213+
ifftype==1anddata_dtype==np.float16andn_dims==1:
214+
data=data.astype(np.float32)
215+
216+
# if f16 desired, convert any float32 2-dim weight tensors to float16
217+
ifftype==1anddata_dtype==np.float32andname.endswith(".weight")andn_dims==2:
218+
data=data.astype(np.float16)
219+
220+
print(new_name+", n_dims = "+str(n_dims)+", "+str(old_dtype)+" --> "+str(data.dtype))
221+
222+
gguf_writer.add_tensor(new_name,data)
223+
224+
225+
print("gguf: write header")
226+
gguf_writer.write_header_to_file()
227+
print("gguf: write metadata")
228+
gguf_writer.write_kv_data_to_file()
229+
ifnotargs.vocab_only:
230+
print("gguf: write tensors")
231+
gguf_writer.write_tensors_to_file()
232+
233+
gguf_writer.close()
234+
235+
print(f"gguf: model successfully exported to '{fname_out}'")
236+
print("")

‎llama.cpp‎

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7232,7 +7232,7 @@ static struct ggml_cgraph * llm_build_plamo(
72327232
ggml_element_size(kv_self.k)*n_embd_head,
72337233
ggml_element_size(kv_self.k)*n_embd_gqa,
72347234
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
7235-
K_shift, n_embd_head,0,0, freq_base, freq_scale);
7235+
K_shift, n_embd_head,2,0, freq_base, freq_scale);
72367236
offload_func_kq(tmp);
72377237
ggml_build_forward_expand(gf, tmp);
72387238
}
@@ -7274,11 +7274,11 @@ static struct ggml_cgraph * llm_build_plamo(
72747274
offload_func_kq(tmpq);
72757275
ggml_set_name(tmpq,"tmpq");
72767276

7277-
structggml_tensor * Kcur =ggml_rope_custom(ctx0,ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head,0,0, freq_base, freq_scale);
7277+
structggml_tensor * Kcur =ggml_rope_custom(ctx0,ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), KQ_pos, n_embd_head,2,0, freq_base, freq_scale);
72787278
offload_func_kq(Kcur);
72797279
ggml_set_name(Kcur,"Kcur");
72807280

7281-
structggml_tensor * Qcur =ggml_rope_custom(ctx0,ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head,0,0, freq_base, freq_scale);
7281+
structggml_tensor * Qcur =ggml_rope_custom(ctx0,ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), KQ_pos, n_embd_head,2,0, freq_base, freq_scale);
72827282
offload_func_kq(Qcur);
72837283
ggml_set_name(Qcur,"Qcur");
72847284

@@ -7322,8 +7322,17 @@ static struct ggml_cgraph * llm_build_plamo(
73227322
offload_func_kq(K);
73237323
ggml_set_name(K,"K");
73247324

7325+
// from this PR
7326+
// https://github.com/ggerganov/llama.cpp/pull/3557
7327+
73257328
// K * Q
7326-
structggml_tensor * KQ =ggml_mul_mat(ctx0, K, Q);
7329+
//struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
7330+
// we should avoid to repeat K but current ggml_mul_mat generates wrong values for grouped query att
7331+
structggml_tensor * K_repeated =ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, K->ne[0], K->ne[1], Q->ne[2]);
7332+
offload_func_kq(K_repeated);
7333+
ggml_set_name(K_repeated,"K_repeated");
7334+
7335+
structggml_tensor * KQ =ggml_mul_mat(ctx0,ggml_repeat(ctx0, K, K_repeated), Q);
73277336
offload_func_kq(KQ);
73287337
ggml_set_name(KQ,"KQ");
73297338

@@ -7353,17 +7362,15 @@ static struct ggml_cgraph * llm_build_plamo(
73537362
offload_func_v(V);
73547363
ggml_set_name(V,"V");
73557364

7356-
#if1
7357-
structggml_tensor * KQV =ggml_mul_mat(ctx0, V, KQ_soft_max);
7365+
//struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
7366+
// we should avoid to repeat V but current ggml_mul_mat generates wrong values for grouped query att
7367+
structggml_tensor * V_repeated =ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, V->ne[0], V->ne[1], Q->ne[2]);
7368+
offload_func_v(V_repeated);
7369+
ggml_set_name(V_repeated,"V_repeated");
7370+
7371+
structggml_tensor * KQV =ggml_mul_mat(ctx0,ggml_repeat(ctx0, V, V_repeated), KQ_soft_max);
73587372
offload_func_v(KQV);
73597373
ggml_set_name(KQV,"KQV");
7360-
#else
7361-
// make V contiguous in memory to speed up the matmul, however we waste time on the copy
7362-
// on M1 this is faster for the perplexity computation, but ~5% slower for the single-token generation
7363-
// is there a better way?
7364-
struct ggml_tensor * V_cont = ggml_cpy(ctx0, V, ggml_new_tensor_3d(ctx0, kv_self.v->type, n_ctx, n_embd_head, n_head));
7365-
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_cont, KQ_soft_max);
7366-
#endif
73677374

73687375
// KQV_merged = KQV.permute(0, 2, 1, 3)
73697376
structggml_tensor * KQV_merged =ggml_permute(ctx0, KQV,0,2,1,3);

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp