Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/aoPublic

enable smoothquant for int8 static tensor#3468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
jcaip wants to merge40 commits intomain
base:main
Choose a base branch
Loading
fromjcaip/enable-smoothquant
Open
Show file tree
Hide file tree
Changes from1 commit
Commits
Show all changes
40 commits
Select commitHold shift + click to select a range
48cdb61
Int8Tensor migration
jcaipDec 1, 2025
0b73aed
ruff fixes
jcaipDec 1, 2025
1e49945
add init
jcaipDec 1, 2025
669b6ee
fix ruff again
jcaipDec 1, 2025
9071526
update
jcaipDec 1, 2025
1539e0f
wip
jcaipDec 2, 2025
d9a2b1b
Merge branch 'main' into jcaip/int8-tensor
jcaipDec 3, 2025
673f228
undo update tests
jcaipDec 3, 2025
739fd64
fix ruff
jcaipDec 3, 2025
750db1a
fix varname
jcaipDec 3, 2025
9410488
fix typing
jcaipDec 3, 2025
45a3a76
add tests
jcaipDec 3, 2025
4e2f09c
fix dtype
jcaipDec 3, 2025
dd80cca
fix ci
jcaipDec 3, 2025
7f73062
address granularity cr
jcaipDec 4, 2025
ac6a2b6
update _choose_quant_func_and_quantize_tensor
jcaipDec 4, 2025
f28df4a
make block size required attribute
jcaipDec 4, 2025
328585e
made dtype required as well
jcaipDec 4, 2025
ce4d568
address nits
jcaipDec 4, 2025
a665d45
skip per tensor weight only test for now
jcaipDec 4, 2025
0338016
add static quant
jcaipDec 3, 2025
ee39691
add static quant
jcaipDec 4, 2025
9eb0aa9
update
jcaipDec 5, 2025
d4a1514
static quant working eager + compile
jcaipDec 6, 2025
3cdea56
remove file
jcaipDec 6, 2025
fa9022d
added asserts
jcaipDec 6, 2025
8ce5cde
undo smoothquant change
jcaipDec 6, 2025
6f64121
fix return
jcaipDec 6, 2025
8ae921d
Merge branch 'main' into jcaip/static-quant-rebased
jcaipDec 7, 2025
5b9e243
got smoothquant + int8 static working
jcaipDec 8, 2025
7a0e38f
generalized smoothquat code
jcaipDec 8, 2025
3d18edf
free tests
jcaipDec 8, 2025
9e07f8b
fix static scale check
jcaipDec 8, 2025
4274e02
update
jcaipDec 8, 2025
b5309eb
address cr feedback
jcaipDec 9, 2025
a732fee
Merge branch 'jcaip/static-quant-rebased' into jcaip/enable-smoothquant
jcaipDec 9, 2025
0c23589
Merge branch 'main' into jcaip/enable-smoothquant
jcaipDec 9, 2025
0872986
update
jcaipDec 17, 2025
049830f
fix ruff
jcaipDec 17, 2025
2586ab6
fix varname
jcaipDec 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
PrevPrevious commit
NextNext commit
add init
  • Loading branch information
@jcaip
jcaip committedDec 1, 2025
commit1e49945d5f3380f0a7ccd976fb20b64d401b7453

Some comments aren't visible on the classic Files Changed page.

View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -90,8 +90,8 @@ def test_int8_linear_variants(

quantize_(model_q, config)

self.assertEqual(model_q.linear2.weight.scale.shape, (K,))
self.assertEqual(model_q.linear2.weight.scale.ndim,1)
self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1))
self.assertEqual(model_q.linear2.weight.scale.ndim,2)

if compile:
model_q = torch.compile(model_q, fullgraph=True)
Expand Down
Empty file.
125 changes: 39 additions & 86 deletionstorchao/quantization/quantize_/workflows/int8/int8_tensor.py
View file
Open in desktop
Original file line numberDiff line numberDiff line change
Expand Up@@ -21,6 +21,7 @@
from torchao.quantization.quantize_.common import QuantizeTensorKwargs
from torchao.quantization.utils import get_block_size
from torchao.utils import TorchAOBaseTensor, fill_defaults
from torchao.float8.inference import _slice_scale_for_dimension

__all__ = ["Int8Tensor", "QuantizeTensorToInt8Kwargs"]

Expand DownExpand Up@@ -136,6 +137,11 @@ def from_hp(
output_dtype=torch.int8,
)

if isinstance(granularity, PerRow):
scale = scale.unsqueeze(1)
else:
scale = scale.unsqueeze(0).unsqueeze(1)

return cls(
int_data,
scale,
Expand All@@ -152,7 +158,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
return dequantize_affine(
input=self.qdata,
block_size=self.block_size,
scale=self.scale,
scale=self.scale.squeeze(),
zero_point=None,
input_dtype=torch.int8,
quant_min=-128,
Expand All@@ -164,65 +170,6 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
implements = Int8Tensor.implements
implements_torch_function = Int8Tensor.implements_torch_function


def _slice_scale(
scale: torch.Tensor,
data_shape: list[int],
dim: int,
start: int,
end: int,
step: int,
) -> torch.Tensor:
"""
Slice the scale tensor appropriately based on the data tensor slicing.
This function calculates how the scale should be sliced when the data tensor
is sliced along a given dimension, taking into account the block structure.

Example:
If data_shape is [256, 128] and scale shape is [1] (indicating per-tensor scaling),
slicing along any dimension should return the same scale tensor.

If data_shape is [256, 128] and scale shape is [256] (indicating per-row scaling),
and we slice data along dim=0 from 64 to 192, the corresponding scale
"""
aten = torch.ops.aten

# Case 1: Per-tensor quantization (scalar scale)
if scale.numel() <= 1:
return scale

# Case 2: Per-row quantization (1D scale)
# Scale is per-element along this dimension
if scale.ndim == 1:
if dim == 0:
return aten.slice.Tensor(scale, 0, start, end, step)
else:
return scale

# Case 3: Per-block quantization (2D scale)
block_sizes = tuple(
data_shape[i] // scale.shape[i] for i in range(len(scale.shape))
)

block_size_for_dim = block_sizes[dim]

if step > 1:
raise NotImplementedError(
"Slicing with step > 1 is not implemented for scale tensors."
)

# There is blocking in this dimension
# Calculate which scale elements correspond to the sliced data
scale_start = start // block_size_for_dim if start is not None else None
scale_end = (
(end + block_size_for_dim - 1) // block_size_for_dim
if end is not None
else None
)

return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1)


@implements(aten.linear.default)
@implements_torch_function(torch.nn.functional.linear)
def _(func, types, args, kwargs):
Expand All@@ -233,8 +180,7 @@ def _(func, types, args, kwargs):
args[2] if len(args) > 2 else None,
)

if not isinstance(weight_tensor, Int8Tensor):
raise TypeError(f"Expected weight to be Int8Tensor, got {type(weight_tensor)}")
assert isinstance(weight_tensor, Int8Tensor), f"Expected weight to be Int8Tensor, got {type(weight_tensor)}"

output_dtype = activation_tensor.dtype

Expand DownExpand Up@@ -266,7 +212,7 @@ def _(func, types, args, kwargs):
y_dot_scaled = int_scaled_matmul(
tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)
).to(output_dtype)
y = (y_dot_scaled * w_scales).reshape(
y = (y_dot_scaled * w_scales.flatten()).reshape(
*x_vals_int8.shape[:-1], y_dot_scaled.shape[-1]
)

Expand All@@ -277,7 +223,7 @@ def _(func, types, args, kwargs):
activation_tensor.reshape(-1, activation_tensor.shape[-1]),
w_vals_int8_t.to(output_dtype),
)
y = m * weight_tensor.scale.to(m.dtype)
y = m * weight_tensor.scale.to(m.dtype).flatten()
y = y.reshape(*activation_tensor.shape[:-1], weight_tensor.qdata.shape[0])

if bias is not None:
Expand DownExpand Up@@ -306,7 +252,19 @@ def _(func, types, args, kwargs):
end = self.shape[dim]

sliced_qdata = aten.slice.Tensor(self.qdata, dim, start, end, step)
sliced_scale = _slice_scale(self.scale, self.qdata.shape, dim, start, end, step)
if self.scale.numel() == 1:
# Per-tensor quantization - scale doesn't change
sliced_scale = self.scale
else:
# Block-wise quantization - need to slice the scale appropriately
sliced_scale = _slice_scale_for_dimension(
self.scale, self.qdata.shape, dim, start, end, step
)

# adjust block_size since the shape has changed, block_size[i] should not be greater than shape[i]
block_size = self.block_size.copy()
for i in range(len(self.block_size)):
block_size[i] = min(block_size[i], sliced_qdata.shape[i])

return return_and_correct_aliasing(
func,
Expand All@@ -315,7 +273,7 @@ def _(func, types, args, kwargs):
Int8Tensor(
sliced_qdata,
sliced_scale,
block_size=self.block_size[1:],
block_size=block_size,
act_quant_kwargs=self.act_quant_kwargs,
dtype=self.dtype,
),
Expand All@@ -325,27 +283,22 @@ def _(func, types, args, kwargs):
@implements(aten.select.int)
def _(func, types, args, kwargs):
"""Select operation for Int8Tensor"""
self, dim, index = args
if dim != 0:
raise NotImplementedError(f"Only dim=0 supported, got dim={dim}")

selected_qdata = self.qdata[index]
selected_scale = _slice_scale(
self.scale, self.qdata.shape, dim, index, index + 1, step=1
).squeeze(0)

return return_and_correct_aliasing(
func,
args,
kwargs,
Int8Tensor(
selected_qdata,
selected_scale,
block_size=self.block_size[1:],
act_quant_kwargs=self.act_quant_kwargs,
dtype=self.dtype,
),
old_int8_tensor, dim, index = args
assert dim == 0, f"Int8Tensor aten.select.int with {dim=} is not yet supported"
assert len(old_int8_tensor.qdata.shape) == len(old_int8_tensor.scale.shape), (
"unsupported"
)
assert len(old_int8_tensor.qdata.shape) == len(old_int8_tensor.block_size), (
"unsupported"
)
new_int8_tensor = old_int8_tensor.__class__(
old_int8_tensor.qdata[index],
old_int8_tensor.scale[index],
old_int8_tensor.block_size[1:],
old_int8_tensor.act_quant_kwargs,
old_int8_tensor.dtype,
)
return return_and_correct_aliasing(func, args, kwargs, new_int8_tensor)


Int8Tensor.__module__ = "torchao.quantization"
Expand Down
Loading

[8]ページ先頭

©2009-2025 Movatter.jp