Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[Draft] [5526696] Add kv cache quantization support for onnx quantization#486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Open
zhanghaoc wants to merge5 commits intomain
base:main
Choose a base branch
Loading
fromhaoxiz/kv_cache

Conversation

@zhanghaoc
Copy link

What does this PR do?

Add kv cache quantization. Currently support int8/fp8 minMax calibration method.

Overview:

  • add new file kv_cache.py. Include a function to save calibration data, a function to read data and do scale calculations and finally add attributes and new inputs to onnx model.
  • other files' change only pass new parameters

Usage

python-mmodelopt.onnx.quantization--onnx_path="C:\repos\models\Llama-3.2-3B-Instruct-ONNX\cuda\cuda-fp16\model.onnx"--quantize_mode=int4--calibration_method=rtn_dq--kv_quant_mode=PER_TENSOR--output_path="C:\repos\models\Llama-3.2-3B-Instruct-ONNX\cuda\cuda-fp16\model.int4.rtn_dq.kv_cache.onnx"--log_level=DEBUG

Testing

Test not done, still waiting for feedback.

Before your PR is "Ready for review"

  • Make sure you read and followContributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: No
  • Did you add or update any necessary documentation?: Yes
  • Did you updateChangelog?: No

Additional Information

Signed-off-by: zhanghaoc <2272055687@qq.com>
Signed-off-by: zhanghaoc <2272055687@qq.com>
Signed-off-by: zhanghaoc <2272055687@qq.com>
Signed-off-by: zhanghaoc <2272055687@qq.com>
@copy-pr-bot
Copy link

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilitieshere.

Contributors can view more details about this messagehere.

@codecov
Copy link

codecovbot commentedOct 31, 2025
edited
Loading

Codecov Report

❌ Patch coverage is26.27737% with101 lines in your changes missing coverage. Please review.
✅ Project coverage is 73.02%. Comparing base (9e64f81) to head (a877d02).
⚠️ Report is 38 commits behind head on main.

Files with missing linesPatch %Lines
modelopt/onnx/quantization/kv_cache.py15.17%95 Missing⚠️
modelopt/onnx/quantization/int4.py88.88%2 Missing⚠️
modelopt/onnx/quantization/ort_patching.py33.33%2 Missing⚠️
modelopt/onnx/quantization/quantize.py50.00%2 Missing⚠️
Additional details and impacted files
@@            Coverage Diff             @@##             main     #486      +/-   ##==========================================- Coverage   73.38%   73.02%   -0.36%==========================================  Files         180      181       +1       Lines       17934    18260     +326     ==========================================+ Hits        13160    13334     +174- Misses       4774     4926     +152

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report?Share it here.

🚀 New features to boost your workflow:
  • ❄️Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: zhanghaoc <2272055687@qq.com>
@vishalpandya1990vishalpandya1990 removed their assignmentOct 31, 2025
# call to_dict and save to json
withopen(calib_data_path,"wb")asf:
pickle.dump(kv_tensor_data,f)
intermediate_generated_files.append(calib_data_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

What is the memory impact (or other issues) if we keep the KV cache related data in variable instead of writing them to disk file?

f"Unsupported kv_cache_type{kv_cache_type} for kv cache quantization"
)

kv_tensor_names_list.sort()
Copy link
Contributor

@vishalpandya1990vishalpandya1990Nov 5, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Should we add an assert/exception/suitable-safe-early-return here if the input model is not GenAI based i.e. it doesn't have expected IO binding / names? (e.g. if this list is empty or if there are no GQA nodes seen etc.?)

I think we currently support 8-bit KV Cache with GenAI Builder exported ONNX LLMs only, right?

ifcalibration_methodin ["rtn","rtn_dq","rtn_trt","rtn_trt_dq"]:
# Save kv-cache calibration data if kv_quant_mode is not NONE
ifkv_quant_mode!="NONE":
save_kv_cache_calib_data_rtn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

For INT4 AWQ/RTN + 8-bit KV Cache quantization, can we avoid 2 session runs by preparing KV tensor names before creating augmented model, augmenting model for these KV tensors too, and post-processing for save-kv-cache-calib-data after AWQ/RTN loop?

Just checking if we can avoid 2 session runs, and thereby speedup the combined quantization of matmul and kv-cache.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Yes, that's possible, but this change won't apply to int8/fp8 path and for awq_lite, awq_clip, rtn, they need to be implemented separately which means not much code can be reused. If you fell it's worthy, I can definitely implement in this way.

foroutputinonnx_model.graph.output:
if"present"inoutput.name:
kv_tensor_names_list.append(output.name)
ifkv_cache_type=="fp8":
Copy link
Contributor

@vishalpandya1990vishalpandya1990Nov 5, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I think we can simplify this a bit by creating a map, with assert/valueError for unsupported type. Something like below:

output.type.tensor_type.elemt_type = output_type_map[kv_cache_type]

where output_type_map = {"int8": , "fp8": }

Possibly we can create a util for validating dtype, model, input model - whether it is currently supported or not.

zhanghaoc reacted with thumbs up emoji
# With ActivationSymmetric as True, MinMax calibration is equivalent to max calibration
elseCalibrationMethod.MinMax
),
intermediate_generated_files=intermediate_generated_files,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I didn't get how KV-Cache quantization meta-data is used with int8/fp8 quantization. Can you please elaborate the flow?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Both int8/fp8 quantization callquantize_static fromort_patching. The change happens inort_pathching.py. Ifkv_quant_mode is not None, it will save additional calibration data on disk.

node.input.append("")
node.input.append(k_scale.name)
node.input.append(v_scale.name)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I think if kv-quant-type is per-channel then things wont go well, since we are not supporting it but not checking / flagging it as well. Is it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Add the check

@vishalpandya1990
Copy link
Contributor

vishalpandya1990 commentedNov 13, 2025
edited
Loading

Is int4-rtn remaining to be updated for KV-cache quantization?

I see now that we are calling save_***_rtn() separately for INT4-RTN (before quantize_rtn).

calibrate_method=CalibrationMethod.MinMax,
extra_options=None,
intermediate_generated_files:list[str]= [],
kv_quant_mode:str="NONE",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Does this method require matching function signature with the corresponding method from onnxruntime? I guess not so but would like to be clear.

@gcunhasegcunhase requested review fromajrasane andi-riyad and removed request forgcunhaseNovember 13, 2025 15:52
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@vishalpandya1990vishalpandya1990vishalpandya1990 left review comments

@i-riyadi-riyadAwaiting requested review from i-riyad

@ajrasaneajrasaneAwaiting requested review from ajrasane

At least 1 approving review is required to merge this pull request.

Assignees

@zhanghaoczhanghaoc

Labels

None yet

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

3 participants

@zhanghaoc@vishalpandya1990

[8]ページ先頭

©2009-2025 Movatter.jp