Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings
/aoPublic

enable smoothquant for int8 static tensor#3468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
jcaip wants to merge37 commits intomain
base:main
Choose a base branch
Loading
fromjcaip/enable-smoothquant

Conversation

@jcaip
Copy link
Contributor

@jcaipjcaip commentedDec 8, 2025
edited
Loading

This PR hooks up the static quant workflow added in#3442 to the prototype smoothquant API.

You can use the new flow like follows:

fromtorchao.quantization.quant_apiimport (Int8StaticActivationInt8WeightConfig,)fromtorchao.prototype.smoothquantimport (SmoothQuantConfig)config=SmoothQuantConfig(base_config=Int8StaticActivationInt8Weight(granularity=PerRow()),step=SmoothQuantStep.PREPARE,alpha=0.5,        )quantize_(model,config)# Perform calibration with test datamodel(*x)config.step=SmoothQuantStep.CONVERTquantize_(model,config)# model will now be statically quantized with the inputs used in smoothquant observer.model(*x)

Summary:This PR creates a new Int8Tensor and updates the configs to use the newInt8Tensor flowTest Plan:To ensure BC:```pytest test/quantization/test_quant_api.py```To test new Int8Tensor:```pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py```Reviewers:Subscribers:Tasks:Tags:
@pytorch-bot
Copy link

pytorch-botbot commentedDec 8, 2025
edited
Loading

🔗 Helpful Links

🧪 See artifacts and rendered test results athud.pytorch.org/pr/pytorch/ao/3468

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Pending

As of commit0c23589 with merge basef99105a (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-clameta-clabot added the CLA SignedThis label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labelDec 8, 2025
@jcaipjcaip added the topic: improvementUse this tag if this PR is an improvement (doesn't fit into any of the other categories) labelDec 8, 2025
@jcaipjcaip changed the title[wip] enable smoothquant for int8 static tensorenable smoothquant for int8 static tensorDec 8, 2025
@jcaipjcaip marked this pull request as ready for reviewDecember 8, 2025 22:24
@jcaip
Copy link
ContributorAuthor

cc@Xia-Weiwen and@cyxlily fyi

Xia-Weiwen reacted with thumbs up emoji

qw=quant_mod.weight

# Add smoothing factor metadata
qw=to_weight_tensor_with_linear_activation_scale_metadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

we should not be using this, please check awq on how this should be implemented in the new stack:

assertisinstance(qw,SupportsActivationPreScaling), (
"weight must support activation scaling through implementing `SupportsActivationPreScaling`"
)
# since we want to do `act` * `act_pre_scale` during runtime for speed, we'll save the
# reciprocal of the `equalization_scale`
qw.act_pre_scale=1.0/equalization_scale

jcaip reacted with thumbs up emoji
"""

scale:torch.Tensor
scale:torch.Tensor=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

nit: Optional[torch.Tensor]

[
Int8DynamicActivationInt8WeightConfig(),
Int8DynamicActivationInt8WeightConfig(version=2),
# TODO: not sure if we should allow not passing scales as part of static config?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

yeah I think it's fine

side note: we may need a separate API/flow for plain static quant without Smoothquant if needed.

@jcaipjcaip changed the base branch fromjcaip/static-quant-rebased tomainDecember 9, 2025 04:34
@meta-codesync
Copy link

@jcaip hasimported this pull request. If you are a Meta employee, you can view this inD88784212.

@cyxlily
Copy link
Contributor

@jcaip Our customer needs activation quantization PerTensor and weight quantization PerRow. Will you implement it, or may I create a new PR to do it?

Xia-Weiwen reacted with thumbs up emoji

@jcaip
Copy link
ContributorAuthor

jcaip commentedDec 17, 2025
edited
Loading

@cyxlily feel free to open a new PR for activation per tensor x weight per row, it's not something im planning to do currently.

Thank you for your smoothquant pr btw, I used it to implement this.

Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@jerryzh168jerryzh168jerryzh168 left review comments

Assignees

No one assigned

Labels

CLA SignedThis label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.topic: improvementUse this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

4 participants

@jcaip@cyxlily@jerryzh168

[8]ページ先頭

©2009-2025 Movatter.jp