Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit85e9a12

Browse files
authored
fix: preprocess tensor names in tensor types map (leejet#607)
Thank you for your contribution
1 parentfbd42b6 commit85e9a12

File tree

1 file changed

+40
-11
lines changed

1 file changed

+40
-11
lines changed

‎model.cpp‎

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,26 @@ std::string convert_tensor_name(std::string name) {
558558
return new_name;
559559
}
560560

561+
voidadd_preprocess_tensor_storage_types(std::map<std::string,enum ggml_type>& tensor_storages_types, std::string name,enum ggml_type type) {
562+
std::string new_name =convert_tensor_name(name);
563+
564+
if (new_name.find("cond_stage_model") != std::string::npos &&ends_with(new_name,"attn.in_proj_weight")) {
565+
size_t prefix_size = new_name.find("attn.in_proj_weight");
566+
std::string prefix = new_name.substr(0, prefix_size);
567+
tensor_storages_types[prefix +"self_attn.q_proj.weight"] = type;
568+
tensor_storages_types[prefix +"self_attn.k_proj.weight"] = type;
569+
tensor_storages_types[prefix +"self_attn.v_proj.weight"] = type;
570+
}elseif (new_name.find("cond_stage_model") != std::string::npos &&ends_with(new_name,"attn.in_proj_bias")) {
571+
size_t prefix_size = new_name.find("attn.in_proj_bias");
572+
std::string prefix = new_name.substr(0, prefix_size);
573+
tensor_storages_types[prefix +"self_attn.q_proj.bias"] = type;
574+
tensor_storages_types[prefix +"self_attn.k_proj.bias"] = type;
575+
tensor_storages_types[prefix +"self_attn.v_proj.bias"] = type;
576+
}else {
577+
tensor_storages_types[new_name] = type;
578+
}
579+
}
580+
561581
voidpreprocess_tensor(TensorStorage tensor_storage,
562582
std::vector<TensorStorage>& processed_tensor_storages) {
563583
std::vector<TensorStorage> result;
@@ -927,7 +947,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
927947
GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes());
928948

929949
tensor_storages.push_back(tensor_storage);
930-
tensor_storages_types[tensor_storage.name] =tensor_storage.type;
950+
add_preprocess_tensor_storage_types(tensor_storages_types,tensor_storage.name,tensor_storage.type);
931951
}
932952

933953
gguf_free(ctx_gguf_);
@@ -1072,7 +1092,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10721092
}
10731093

10741094
tensor_storages.push_back(tensor_storage);
1075-
tensor_storages_types[tensor_storage.name] =tensor_storage.type;
1095+
add_preprocess_tensor_storage_types(tensor_storages_types,tensor_storage.name,tensor_storage.type);
10761096

10771097
// LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
10781098
}
@@ -1403,7 +1423,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
14031423
// printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
14041424
reader.tensor_storage.name = prefix + reader.tensor_storage.name;
14051425
tensor_storages.push_back(reader.tensor_storage);
1406-
tensor_storages_types[reader.tensor_storage.name] =reader.tensor_storage.type;
1426+
add_preprocess_tensor_storage_types(tensor_storages_types,reader.tensor_storage.name,reader.tensor_storage.type);
14071427

14081428
// LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
14091429
// reset
@@ -1461,10 +1481,10 @@ SDVersion ModelLoader::get_sd_version() {
14611481
TensorStorage token_embedding_weight, input_block_weight;
14621482
bool input_block_checked =false;
14631483

1464-
bool has_multiple_encoders=false;
1465-
bool is_unet =false;
1484+
bool has_multiple_encoders =false;
1485+
bool is_unet=false;
14661486

1467-
bool is_xl =false;
1487+
bool is_xl=false;
14681488
bool is_flux =false;
14691489

14701490
#definefound_family (is_xl || is_flux)
@@ -1481,7 +1501,7 @@ SDVersion ModelLoader::get_sd_version() {
14811501
}
14821502
if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) {
14831503
is_unet =true;
1484-
if(has_multiple_encoders){
1504+
if(has_multiple_encoders){
14851505
is_xl =true;
14861506
if (input_block_checked) {
14871507
break;
@@ -1490,7 +1510,7 @@ SDVersion ModelLoader::get_sd_version() {
14901510
}
14911511
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
14921512
has_multiple_encoders =true;
1493-
if(is_unet){
1513+
if(is_unet){
14941514
is_xl =true;
14951515
if (input_block_checked) {
14961516
break;
@@ -1635,11 +1655,20 @@ ggml_type ModelLoader::get_vae_wtype() {
16351655
voidModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) {
16361656
for (auto& pair : tensor_storages_types) {
16371657
if (prefix.size() <1 || pair.first.substr(0, prefix.size()) == prefix) {
1658+
bool found =false;
16381659
for (auto& tensor_storage : tensor_storages) {
1639-
if (tensor_storage.name == pair.first) {
1640-
if (tensor_should_be_converted(tensor_storage, wtype)) {
1641-
pair.second = wtype;
1660+
std::map<std::string, ggml_type> temp;
1661+
add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type);
1662+
for (auto& preprocessed_name : temp) {
1663+
if (preprocessed_name.first == pair.first) {
1664+
if (tensor_should_be_converted(tensor_storage, wtype)) {
1665+
pair.second = wtype;
1666+
}
1667+
found =true;
1668+
break;
16421669
}
1670+
}
1671+
if (found) {
16431672
break;
16441673
}
16451674
}

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp