File size: 3,241 Bytes
dd31960 9672b38 dd319609672b38dd31960 9672b38 dd319609672b38dd31960 9672b38dd319609672b38 dd31960 9672b38dd31960 9672b38 dd319609672b38dd31960 9672b38 dd31960 9672b38 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | import osimport jsonfrom argparseimport ArgumentParserfrom globimport globfrom tqdmimport tqdmimport torchfrom safetensors.torchimport load_file, save_filefrom kernelimport weight_dequantdefmain(fp8_path, bf16_path): torch.set_default_dtype(torch.bfloat16) os.makedirs(bf16_path, exist_ok=True) model_index_file = os.path.join(fp8_path,"model.safetensors.index.json")withopen(model_index_file,"r")as f: model_index = json.load(f) weight_map = model_index["weight_map"]# Cache for loaded safetensor files loaded_files = {} fp8_weight_names = []# Helper function to get tensor from the correct filedefget_tensor(tensor_name): file_name = weight_map[tensor_name]if file_namenotin loaded_files: file_path = os.path.join(fp8_path, file_name) loaded_files[file_name] = load_file(file_path, device="cuda")return loaded_files[file_name][tensor_name] safetensor_files =list(glob(os.path.join(fp8_path,"*.safetensors"))) safetensor_files.sort()for safetensor_filein tqdm(safetensor_files): file_name = os.path.basename(safetensor_file) current_state_dict = load_file(safetensor_file, device="cuda") loaded_files[file_name] = current_state_dict new_state_dict = {}for weight_name, weightin current_state_dict.items():if weight_name.endswith("_scale_inv"):continueelif weight.element_size() ==1:# FP8 weight scale_inv_name =f"{weight_name}_scale_inv"try:# Get scale_inv from the correct file scale_inv = get_tensor(scale_inv_name) fp8_weight_names.append(weight_name) new_state_dict[weight_name] = weight_dequant(weight, scale_inv)except KeyError:print(f"Warning: Missing scale_inv tensor for{weight_name}, skipping conversion") new_state_dict[weight_name] = weightelse: new_state_dict[weight_name] = weight new_safetensor_file = os.path.join(bf16_path, file_name) save_file(new_state_dict, new_safetensor_file)# Memory management: keep only the 2 most recently used filesiflen(loaded_files) >2: oldest_file =next(iter(loaded_files))del loaded_files[oldest_file] torch.cuda.empty_cache()# Update model index new_model_index_file = os.path.join(bf16_path,"model.safetensors.index.json")for weight_namein fp8_weight_names: scale_inv_name =f"{weight_name}_scale_inv"if scale_inv_namein weight_map: weight_map.pop(scale_inv_name)withopen(new_model_index_file,"w")as f: json.dump({"metadata": {},"weight_map": weight_map}, f, indent=2)if __name__ =="__main__": parser = ArgumentParser() parser.add_argument("--input-fp8-hf-path",type=str, required=True) parser.add_argument("--output-bf16-hf-path",type=str, required=True) args = parser.parse_args() main(args.input_fp8_hf_path, args.output_bf16_hf_path) |