Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit2abe945

Browse files
authored
fix: optimize the handling of CLIP embedding weight (leejet#840)
1 parentf3140ea commit2abe945

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

‎clip.hpp‎

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,12 +553,13 @@ class CLIPEmbeddings : public GGMLBlock {
553553
voidinit_params(structggml_context* ctx,const String2GGMLType& tensor_types = {},const std::string prefix ="") {
554554
enum ggml_type token_wtype = GGML_TYPE_F32;
555555
if (!force_clip_f32) {
556-
auto tensor_type = tensor_types.find(prefix +"token_embedding.weight");
557-
if (tensor_type != tensor_types.end())
556+
auto tensor_type = tensor_types.find(prefix +"token_embedding.weight");
557+
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
558+
if (tensor_type != tensor_types.end() && allow_types.find(tensor_type->second) != allow_types.end()) {
558559
token_wtype = tensor_type->second;
560+
}
559561
}
560-
enum ggml_type position_wtype = GGML_TYPE_F32;
561-
562+
enum ggml_type position_wtype = GGML_TYPE_F32;
562563
params["token_embedding.weight"] =ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size);
563564
params["position_embedding.weight"] =ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions);
564565
}

‎model.cpp‎

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2422,6 +2422,8 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage
24222422
// Pass, do not convert. For MMDiT
24232423
}elseif (contains(name,"time_embed.") ||contains(name,"label_emb.")) {
24242424
// Pass, do not convert. For Unet
2425+
}elseif (contains(name,"embedding")) {
2426+
// Pass, do not convert embedding
24252427
}else {
24262428
returntrue;
24272429
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp