- Notifications
You must be signed in to change notification settings - Fork386
enable smoothquant for int8 static tensor#3468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
base:main
Are you sure you want to change the base?
Conversation
Summary:This PR creates a new Int8Tensor and updates the configs to use the newInt8Tensor flowTest Plan:To ensure BC:```pytest test/quantization/test_quant_api.py```To test new Int8Tensor:```pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py```Reviewers:Subscribers:Tasks:Tags:
pytorch-botbot commentedDec 8, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
🔗 Helpful Links🧪 See artifacts and rendered test results athud.pytorch.org/pr/pytorch/ao/3468
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 PendingAs of commit0c23589 with merge basef99105a ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
jcaip commentedDec 8, 2025
cc@Xia-Weiwen and@cyxlily fyi |
| qw=quant_mod.weight | ||
| # Add smoothing factor metadata | ||
| qw=to_weight_tensor_with_linear_activation_scale_metadata( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
we should not be using this, please check awq on how this should be implemented in the new stack:
ao/torchao/prototype/awq/api.py
Lines 108 to 113 in08e5e20
| assertisinstance(qw,SupportsActivationPreScaling), ( | |
| "weight must support activation scaling through implementing `SupportsActivationPreScaling`" | |
| ) | |
| # since we want to do `act` * `act_pre_scale` during runtime for speed, we'll save the | |
| # reciprocal of the `equalization_scale` | |
| qw.act_pre_scale=1.0/equalization_scale |
| """ | ||
| scale:torch.Tensor | ||
| scale:torch.Tensor=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
nit: Optional[torch.Tensor]
| [ | ||
| Int8DynamicActivationInt8WeightConfig(), | ||
| Int8DynamicActivationInt8WeightConfig(version=2), | ||
| # TODO: not sure if we should allow not passing scales as part of static config? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
yeah I think it's fine
side note: we may need a separate API/flow for plain static quant without Smoothquant if needed.
cyxlily commentedDec 17, 2025
@jcaip Our customer needs activation quantization PerTensor and weight quantization PerRow. Will you implement it, or may I create a new PR to do it? |
jcaip commentedDec 17, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
@cyxlily feel free to open a new PR for activation per tensor x weight per row, it's not something im planning to do currently. Thank you for your smoothquant pr btw, I used it to implement this. |
Uh oh!
There was an error while loading.Please reload this page.
This PR hooks up the static quant workflow added in#3442 to the prototype smoothquant API.
You can use the new flow like follows: