Movatterモバイル変換


[0]ホーム

URL:


Hugging Face's logoHugging Face

File size: 3,241 Bytes
dd31960                  9672b38           dd319609672b38dd31960 9672b38  dd319609672b38dd31960 9672b38dd319609672b38       dd31960 9672b38dd31960 9672b38     dd319609672b38dd31960  9672b38 dd31960         9672b38
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import osimport jsonfrom argparseimport ArgumentParserfrom globimport globfrom tqdmimport tqdmimport torchfrom safetensors.torchimport load_file, save_filefrom kernelimport weight_dequantdefmain(fp8_path, bf16_path):    torch.set_default_dtype(torch.bfloat16)    os.makedirs(bf16_path, exist_ok=True)    model_index_file = os.path.join(fp8_path,"model.safetensors.index.json")withopen(model_index_file,"r")as f:        model_index = json.load(f)    weight_map = model_index["weight_map"]# Cache for loaded safetensor files    loaded_files = {}    fp8_weight_names = []# Helper function to get tensor from the correct filedefget_tensor(tensor_name):        file_name = weight_map[tensor_name]if file_namenotin loaded_files:            file_path = os.path.join(fp8_path, file_name)            loaded_files[file_name] = load_file(file_path, device="cuda")return loaded_files[file_name][tensor_name]    safetensor_files =list(glob(os.path.join(fp8_path,"*.safetensors")))    safetensor_files.sort()for safetensor_filein tqdm(safetensor_files):        file_name = os.path.basename(safetensor_file)        current_state_dict = load_file(safetensor_file, device="cuda")        loaded_files[file_name] = current_state_dict                new_state_dict = {}for weight_name, weightin current_state_dict.items():if weight_name.endswith("_scale_inv"):continueelif weight.element_size() ==1:# FP8 weight                scale_inv_name =f"{weight_name}_scale_inv"try:# Get scale_inv from the correct file                    scale_inv = get_tensor(scale_inv_name)                    fp8_weight_names.append(weight_name)                    new_state_dict[weight_name] = weight_dequant(weight, scale_inv)except KeyError:print(f"Warning: Missing scale_inv tensor for{weight_name}, skipping conversion")                    new_state_dict[weight_name] = weightelse:                new_state_dict[weight_name] = weight                        new_safetensor_file = os.path.join(bf16_path, file_name)        save_file(new_state_dict, new_safetensor_file)# Memory management: keep only the 2 most recently used filesiflen(loaded_files) >2:            oldest_file =next(iter(loaded_files))del loaded_files[oldest_file]            torch.cuda.empty_cache()# Update model index    new_model_index_file = os.path.join(bf16_path,"model.safetensors.index.json")for weight_namein fp8_weight_names:        scale_inv_name =f"{weight_name}_scale_inv"if scale_inv_namein weight_map:            weight_map.pop(scale_inv_name)withopen(new_model_index_file,"w")as f:        json.dump({"metadata": {},"weight_map": weight_map}, f, indent=2)if __name__ =="__main__":    parser = ArgumentParser()    parser.add_argument("--input-fp8-hf-path",type=str, required=True)    parser.add_argument("--output-bf16-hf-path",type=str, required=True)    args = parser.parse_args()    main(args.input_fp8_hf_path, args.output_bf16_hf_path)

[8]ページ先頭

©2009-2025 Movatter.jp