Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/aoPublic

Commit0c23589

Browse files
authored
Merge branch 'main' into jcaip/enable-smoothquant
2 parentsa732fee +f99105a commit0c23589

File tree

21 files changed

+421
-93
lines changed

21 files changed

+421
-93
lines changed

‎.github/workflows/xpu_test.yml‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ jobs:
2121
test:
2222
# Don't run on forked repos or empty test matrix
2323
# if: github.repository_owner == 'pytorch' && toJSON(fromJSON(inputs.test-matrix).include) != '[]'
24-
timeout-minutes:60
24+
timeout-minutes:120
2525
runs-on:linux.idc.xpu
2626
env:
2727
DOCKER_IMAGE:ci-image:pytorch-linux-noble-xpu-n-py3
@@ -166,7 +166,7 @@ jobs:
166166
GITHUB_RUN_NUMBER:${{ github.run_number }}
167167
GITHUB_RUN_ATTEMPT:${{ github.run_attempt }}
168168
SHA1:${{ github.event.pull_request.head.sha || github.sha }}
169-
timeout-minutes:60
169+
timeout-minutes:120
170170
run:|
171171
set -x
172172
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
importargparse
8+
importsubprocess
9+
10+
importtorch
11+
fromtransformersimportAutoModelForCausalLM,AutoTokenizer,TorchAoConfig
12+
13+
fromtorchao.quantizationimport (
14+
Float8DynamicActivationFloat8WeightConfig,
15+
Float8DynamicActivationInt4WeightConfig,
16+
Int4WeightOnlyConfig,
17+
Int8DynamicActivationInt8WeightConfig,
18+
Int8WeightOnlyConfig,
19+
PerRow,
20+
)
21+
22+
23+
defstring_to_config(s):
24+
ifsisNone:
25+
returnNone
26+
elifs=="float8_rowwise":
27+
returnFloat8DynamicActivationFloat8WeightConfig(granularity=PerRow())
28+
elifs=="int4_groupwise_weight_float8_rowwise_activation":
29+
returnFloat8DynamicActivationInt4WeightConfig()
30+
elifs=="int4_groupwise_hqq_weight_only":
31+
returnInt4WeightOnlyConfig(
32+
group_size=32,
33+
int4_packing_format="tile_packed_to_4d",
34+
int4_choose_qparams_algorithm="hqq",
35+
)
36+
elifs=="int8_rowwise_weight_only":
37+
returnInt8WeightOnlyConfig()
38+
elifs=="int8_rowwise":
39+
returnInt8DynamicActivationInt8WeightConfig()
40+
else:
41+
raiseAssertionError(f"unsupported{s}")
42+
43+
44+
defquantize_model_and_save(model_id,quant_config,output_dir="results"):
45+
"""Quantize the model and save it to the output directory."""
46+
print("Quantizing model with config: ",quant_config)
47+
ifquant_configisNone:
48+
quantization_config=None
49+
else:
50+
quantization_config=TorchAoConfig(quant_type=quant_config)
51+
quantized_model=AutoModelForCausalLM.from_pretrained(
52+
model_id,
53+
device_map="auto",
54+
dtype=torch.bfloat16,
55+
quantization_config=quantization_config,
56+
)
57+
tokenizer=AutoTokenizer.from_pretrained(model_id)
58+
quantized_model.save_pretrained(output_dir,safe_serialization=False)
59+
tokenizer.save_pretrained(output_dir,safe_serialization=False)
60+
returnquantized_model,tokenizer
61+
62+
63+
defrun_lm_eval(model_dir,tasks_list=["hellaswag"],device="cuda:0",batch_size=8):
64+
"""Run the lm_eval command using subprocess."""
65+
tasks_str=",".join(tasks_list)
66+
command= [
67+
"lm_eval",
68+
"--model",
69+
"hf",
70+
"--model_args",
71+
f"pretrained={model_dir}",
72+
"--tasks",
73+
f"{tasks_str}",
74+
"--device",
75+
f"{device}",
76+
"--batch_size",
77+
f"{batch_size}",
78+
"--output_path",
79+
f"{model_dir}/lm_eval_outputs/",
80+
]
81+
subprocess.run(command,check=True)
82+
83+
84+
defget_size_of_dir(model_output_dir):
85+
# get dir size from shell, to skip complexity of dealing with tensor
86+
# subclasses
87+
result=subprocess.run(
88+
["du","-sb",model_output_dir],capture_output=True,text=True
89+
)
90+
size=int(result.stdout.split()[0])
91+
returnsize
92+
93+
94+
defrun(
95+
model_id:str,
96+
quant_recipe_name:str|None,
97+
tasks,
98+
device,
99+
batch_size,
100+
model_output_dir,
101+
):
102+
print(f"\nRunning{model_id=} with{quant_recipe_name=}\n")
103+
model_name=model_id.split("/")[-1]
104+
model_output_dir= (
105+
f"benchmarks/data/quantized_model/{model_name}-{quant_recipe_name}"
106+
)
107+
quant_config=string_to_config(quant_recipe_name)
108+
quantized_model,tokenizer=quantize_model_and_save(
109+
model_id,quant_config=quant_config,output_dir=model_output_dir
110+
)
111+
print(quantized_model)
112+
113+
model_size=get_size_of_dir(model_output_dir)/1e9
114+
print(f"checkpoint size:{model_size} GB")
115+
116+
run_lm_eval(
117+
model_output_dir,tasks_list=tasks,device=device,batch_size=batch_size
118+
)
119+
print("done\n")
120+
121+
122+
if__name__=="__main__":
123+
try:
124+
importlm_eval# noqa: F401
125+
except:
126+
print(
127+
"lm_eval is required to run this script. Please install it using pip install lm-eval."
128+
)
129+
exit(0)
130+
131+
# Set up argument parser
132+
parser=argparse.ArgumentParser(
133+
description="Quantize a model and evaluate its throughput."
134+
)
135+
parser.add_argument(
136+
"--model_id",
137+
type=str,
138+
default="meta-llama/Llama-3.1-8B",
139+
help="The model ID to use.",
140+
)
141+
parser.add_argument(
142+
"--quant_recipe_name",
143+
type=str,
144+
default=None,
145+
help="The quantization recipe to use.",
146+
)
147+
parser.add_argument(
148+
"--tasks",
149+
nargs="+",
150+
type=str,
151+
default=["wikitext"],
152+
help="List of lm-eluther tasks to evaluate usage: --tasks task1 task2",
153+
)
154+
parser.add_argument(
155+
"--device",type=str,default="cuda:0",help="Device to run the model on."
156+
)
157+
parser.add_argument(
158+
"--batch_size",type=str,default="auto",help="Batch size for lm_eval."
159+
)
160+
parser.add_argument(
161+
"--output_dir",
162+
type=str,
163+
default="quantized_models",
164+
help="Output directory for quantized model.",
165+
)
166+
args=parser.parse_args()
167+
168+
# Use parsed arguments
169+
run(
170+
model_id=args.model_id,
171+
quant_recipe_name=args.quant_recipe_name,
172+
tasks=args.tasks,
173+
device=args.device,
174+
batch_size=args.batch_size,
175+
model_output_dir=args.output_dir,
176+
)
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#!/bin/bash
2+
3+
set -e
4+
5+
# Get model_id as positional argument (optional)
6+
MODEL_ID="${1:-meta-llama/Llama-3.1-8B}"
7+
8+
# Get log file as first positional argument (optional)
9+
LOG_FILE="${2:-benchmarks/data/eval_accuracy_for_readme_log.txt}"
10+
11+
# Build the base command arguments
12+
BASE_ARGS="--tasks wikitext winogrande"
13+
if [[-n"$MODEL_ID" ]];then
14+
BASE_ARGS="--model_id$MODEL_ID$BASE_ARGS"
15+
fi
16+
17+
# baseline
18+
# note: the -u flag is to prevent python from buffering stdout and stderr
19+
# and make the output log file be in chronological order
20+
time python -u benchmarks/quantization/eval_accuracy_for_readme.py$BASE_ARGS2>&1| tee"$LOG_FILE"
21+
22+
# quantized recipes
23+
# note:
24+
# * `int4_groupwise_hqq_weight_float8_rowwise_activation` doesn't work with dtype_map auto: https://gist.github.com/vkuzo/6b128681b628744d445c553cdeac8a85
25+
# * `int4_groupwise_hqq_weight_only` only works on A100
26+
forquant_recipein float8_rowwise int4_groupwise_weight_float8_rowwise_activation int4_groupwise_hqq_weight_only int8_rowwise_weight_only int8_rowwise;do
27+
time python -u benchmarks/quantization/eval_accuracy_for_readme.py$BASE_ARGS --quant_recipe_name$quant_recipe2>&1| tee -a"$LOG_FILE"
28+
done
29+
30+
# TODO(future PR): script to parse the log file instead of manual copy-paste

‎test/dtypes/test_bitpacking.py‎

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
fromtorch.utils._tritonimporthas_triton
99

1010
fromtorchao.dtypes.uintx.bitpackingimportpack,pack_cpu,unpack,unpack_cpu
11+
fromtorchao.utilsimportget_current_accelerator_device
1112

1213
bit_widths= (1,2,3,4,5,6,7)
1314
dimensions= (0,-1,1)
15+
_DEVICE=get_current_accelerator_device()
1416

1517

1618
@pytest.fixture(autouse=True)
@@ -30,40 +32,46 @@ def test_CPU(bit_width, dim):
3032
assertunpacked.allclose(test_tensor)
3133

3234

33-
@pytest.mark.skipif(nottorch.cuda.is_available(),reason="CUDA not available")
35+
@pytest.mark.skipif(nottorch.accelerator.is_available(),reason="GPU not available")
3436
@pytest.mark.parametrize("bit_width",bit_widths)
3537
@pytest.mark.parametrize("dim",dimensions)
3638
deftest_GPU(bit_width,dim):
37-
test_tensor=torch.randint(0,2**bit_width, (32,32,32),dtype=torch.uint8).cuda()
39+
test_tensor=torch.randint(0,2**bit_width, (32,32,32),dtype=torch.uint8).to(
40+
_DEVICE
41+
)
3842
packed=pack(test_tensor,bit_width,dim=dim)
3943
unpacked=unpack(packed,bit_width,dim=dim)
4044
assertunpacked.allclose(test_tensor)
4145

4246

43-
@pytest.mark.skipif(nottorch.cuda.is_available(),reason="CUDA not available")
47+
@pytest.mark.skipif(nottorch.accelerator.is_available(),reason="GPU not available")
4448
@pytest.mark.skipif(nothas_triton(),reason="unsupported without triton")
4549
@pytest.mark.parametrize("bit_width",bit_widths)
4650
@pytest.mark.parametrize("dim",dimensions)
4751
deftest_compile(bit_width,dim):
4852
torch._dynamo.config.specialize_int=True
4953
torch.compile(pack,fullgraph=True)
5054
torch.compile(unpack,fullgraph=True)
51-
test_tensor=torch.randint(0,2**bit_width, (32,32,32),dtype=torch.uint8).cuda()
55+
test_tensor=torch.randint(0,2**bit_width, (32,32,32),dtype=torch.uint8).to(
56+
_DEVICE
57+
)
5258
packed=pack(test_tensor,bit_width,dim=dim)
5359
unpacked=unpack(packed,bit_width,dim=dim)
5460
assertunpacked.allclose(test_tensor)
5561

5662

5763
# these test cases are for the example pack walk through in the bitpacking.py file
58-
@pytest.mark.skipif(nottorch.cuda.is_available(),reason="CUDA not available")
64+
@pytest.mark.skipif(nottorch.accelerator.is_available(),reason="GPU not available")
5965
deftest_pack_example():
6066
test_tensor=torch.tensor(
6167
[0x30,0x29,0x17,0x5,0x20,0x16,0x9,0x22],dtype=torch.uint8
62-
).cuda()
68+
).to(_DEVICE)
6369
shard_4,shard_2=pack(test_tensor,6)
6470
print(shard_4,shard_2)
65-
asserttorch.tensor([0,105,151,37],dtype=torch.uint8).cuda().allclose(shard_4)
66-
asserttorch.tensor([39,146],dtype=torch.uint8).cuda().allclose(shard_2)
71+
assert (
72+
torch.tensor([0,105,151,37],dtype=torch.uint8).to(_DEVICE).allclose(shard_4)
73+
)
74+
asserttorch.tensor([39,146],dtype=torch.uint8).to(_DEVICE).allclose(shard_2)
6775
unpacked=unpack([shard_4,shard_2],6)
6876
assertunpacked.allclose(test_tensor)
6977

‎test/dtypes/test_floatx.py‎

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@
3333
quantize_,
3434
)
3535
fromtorchao.testing.utilsimportskip_if_rocm
36-
fromtorchao.utilsimportis_fbcode
36+
fromtorchao.utilsimportget_current_accelerator_device,is_fbcode
3737

38-
_DEVICES= ["cpu"]+ (["cuda"]iftorch.cuda.is_available()else [])
3938
_Floatx_DTYPES= [(3,2), (2,2)]
39+
_DEVICE=get_current_accelerator_device()
40+
_DEVICES= ["cpu"]+ ([_DEVICE]iftorch.accelerator.is_available()else [])
4041

4142

4243
classTestFloatxTensorCoreAQTTensorImpl(TestCase):
@@ -87,7 +88,7 @@ def test_from_scaled_tc_floatx_compile(self, ebits, mbits, device):
8788
)
8889
torch.testing.assert_close(actual,expected)
8990

90-
@unittest.skipIf(nottorch.cuda.is_available(),reason="CUDA not available")
91+
@unittest.skipIf(nottorch.accelerator.is_available(),reason="GPU not available")
9192
@parametrize("ebits,mbits",_Floatx_DTYPES)
9293
deftest_to_copy_device(self,ebits,mbits):
9394
fromtorchao.quantization.quant_primitivesimport (
@@ -101,8 +102,8 @@ def test_to_copy_device(self, ebits, mbits):
101102
_layout=FloatxTensorCoreLayout(ebits,mbits)
102103
floatx_tensor_impl=FloatxTensorCoreAQTTensorImpl.from_plain(
103104
x,scale,None,_layout
104-
).cuda()
105-
assertfloatx_tensor_impl.device.type=="cuda"
105+
).to(_DEVICE)
106+
assertfloatx_tensor_impl.device.type==_DEVICE.type
106107
floatx_tensor_impl=floatx_tensor_impl.cpu()
107108
assertfloatx_tensor_impl.device.type=="cpu"
108109

@@ -114,7 +115,7 @@ def test_to_copy_device(self, ebits, mbits):
114115
@skip_if_rocm("ROCm enablement in progress")
115116
deftest_fpx_weight_only(self,ebits,mbits,bias,dtype):
116117
N,OC,IC=4,256,64
117-
device="cuda"
118+
device=_DEVICE
118119

119120
linear=torch.nn.Linear(IC,OC,bias=bias,device=device,dtype=dtype)
120121
fpx_linear=copy.deepcopy(linear)

‎test/prototype/moe_training/test_everything.sh‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ IS_ROCM=$(rocm-smi --version || true)
1212
# These tests do not work on ROCm yet
1313
if [-z"$IS_ROCM" ]
1414
then
15-
pytest test/prototype/moe_training/test_kernels.py -s
16-
pytest test/prototype/moe_training/test_training.py -s
15+
pytest test/prototype/moe_training/test_kernels.py -s -v
16+
pytest test/prototype/moe_training/test_scaled_grouped_mm.py -s -v
17+
pytest test/prototype/moe_training/test_training.py -s -v
1718
./test/prototype/moe_training/test_fsdp.sh
1819
./test/prototype/moe_training/test_tp.sh
1920
./test/prototype/moe_training/test_fsdp_tp.sh
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py -s
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py -s -v
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torchrun --nproc_per_node=4 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp_tp.py -s
1+
torchrun --nproc_per_node=4 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp_tp.py -s -v
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_tp.py -s
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_tp.py -s -v

‎test/prototype/mx_formats/test_kernels.py‎

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
fromtorchao.prototype.mx_formats.mx_tensorimportScaleCalculationMode,to_dtype,to_mx
4444
fromtorchao.prototype.mx_formats.utilsimportto_blocked
4545
fromtorchao.utilsimport (
46+
is_cuda_version_at_least,
4647
is_sm_at_least_89,
4748
is_sm_at_least_100,
4849
torch_version_at_least,
@@ -529,6 +530,10 @@ def test_rearrange(shape):
529530
notis_sm_at_least_100(),
530531
reason="MXFP8 requires CUDA capability 10.0 or greater",
531532
)
533+
@pytest.mark.skipif(
534+
notis_cuda_version_at_least(12,8),
535+
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
536+
)
532537
@pytest.mark.parametrize("M", (32,256))
533538
@pytest.mark.parametrize("K", (32,256))
534539
@pytest.mark.parametrize("input_dtype", (torch.float32,torch.bfloat16))
@@ -577,6 +582,10 @@ def test_cuda_mx_dim1_numerics(M, K, input_dtype, scaling_mode):
577582
notis_sm_at_least_100(),
578583
reason="MXFP8 requires CUDA capability 10.0 or greater",
579584
)
585+
@pytest.mark.skipif(
586+
notis_cuda_version_at_least(12,8),
587+
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
588+
)
580589
deftest_cuda_mx_dim0_not_supported():
581590
fromtorchao.prototypeimportmxfp8_cuda
582591

@@ -601,6 +610,10 @@ def test_cuda_mx_dim0_not_supported():
601610
notis_sm_at_least_100(),
602611
reason="MXFP8 requires CUDA capability 10.0 or greater",
603612
)
613+
@pytest.mark.skipif(
614+
notis_cuda_version_at_least(12,8),
615+
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
616+
)
604617
deftest_cuda_mx_dim1_invalid_block_size():
605618
fromtorchao.prototypeimportmxfp8_cuda
606619

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp