Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[cuDNN][SDPA][Convolution] Expose cuDNN runtime version in CUDA hooks#167111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Closed
eqy wants to merge2 commits intopytorch:mainfromeqy:cudnnruntimeversion

Conversation

@eqy
Copy link
Collaborator

@eqyeqy commentedNov 5, 2025
edited
Loading

cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users toresolve#166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version.

cc@csarofeen@ptrblck@xwang233@jgong5@mingfeima@XiaobingSuper@sanchitintel@ashokei@jingxu10@jerryzh168@aditew01

@eqyeqy added the module: cudnnRelated to torch.backends.cudnn, and CuDNN support labelNov 5, 2025
@eqyeqy requested a review fromAidyn-A as acode ownerNovember 5, 2025 19:06
@eqyeqy added module: convolutionProblems related to convolutions (THNN, THCUNN, CuDNN) open source release notes: cudnn module: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiion labelsNov 5, 2025
@pytorch-bot
Copy link

pytorch-botbot commentedNov 5, 2025
edited
Loading

🔗 Helpful Links

🧪 See artifacts and rendered test results athud.pytorch.org/pr/167111

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commitcaa7a77 with merge base5c63946 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-botpytorch-botbot added the module: cpuCPU specific problem (e.g., perf, algorithm) labelNov 5, 2025
longversionCUDART()constoverride;
longversionCuDNN()constoverride;
longversionRuntimeCuDNN()constoverride;
longversionCuDNNFrontend()constoverride;
Copy link
Collaborator

@Skylion007Skylion007Nov 5, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Why does Runtime CUDNN frontend matter? It cannot be changed right? It's a compile time include header?

Copy link
CollaboratorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

I sidecar'd this change in as we'll need it in the near future for SDPA issues that require a cuDNN frontend version to be available for gating. In theorysdp_utils.cpp could be able to access this but I'm not sure I want to include that directly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Can the runtime version be different for cudNNFronteEnd or should it be constexpr?

staticboolhasCuDNN() {
returndetail::getCUDAHooks().hasCuDNN();
}
staticlongversionCuDNN() {
Copy link
Collaborator

@Skylion007Skylion007Nov 5, 2025
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

If this is really compile time? Why no constexpr? Would enable if constexpr logic that would simplify critical code paths in CUDNN dispatch.

Copy link
CollaboratorAuthor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

yes see


other uses ofCUDNN_VERSION in the file are macros, etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Yeah, if they are macros they should be propogated with constexpr then. :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Yeah, CUDNN_FRONTNED has it's equivalent function as constexpr

@eqy
Copy link
CollaboratorAuthor

eqy commentedNov 5, 2025

@Skylion007 are we building with C++20 only? not sure ifvirtual functions (as these are CUDAHooks) can beconstexpr

@eqy
Copy link
CollaboratorAuthor

eqy commentedNov 6, 2025

@pytorchmergebot merge

pytorch-bot[bot] reacted with thumbs up emoji

@pytorch-botpytorch-botbot added the ciflow/trunkTrigger trunk jobs on your pull request labelNov 6, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in thewiki.

Questions? Feedback? Please reach out to thePyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@Skylion007
Copy link
Collaborator

@Skylion007 are we building with C++20 only? not sure ifvirtual functions (as these are CUDAHooks) can beconstexpr

Ah, wasn't aware of that limitation. Not yet, no. :(

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are:trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 5, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu)

Details for Dev Infra teamRaised byworkflow job

@eqy
Copy link
CollaboratorAuthor

eqy commentedNov 7, 2025

@pytorchmergebot merge

pytorch-bot[bot] reacted with thumbs up emoji

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in thewiki.

Questions? Feedback? Please reach out to thePyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@eqy
Copy link
CollaboratorAuthor

eqy commentedNov 7, 2025

@pytorchbot cherry-pick --onto release/2.9 --fixes "cuDNN conv3d performance workaround" -c regression

pytorch-bot[bot] reacted with thumbs up emoji

pytorchbot pushed a commit that referenced this pull requestNov 7, 2025
…#167111)cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users toresolve#166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version.Pull Requestresolved:#167111Approved by:https://github.com/Skylion007(cherry picked from commite678450)
@pytorchbot
Copy link
Collaborator

Cherry picking#167111

The cherry pick PR is at#167327 and it is linked with issue cuDNN conv3d performance workaround. The following tracker issues are updated:

Details for Dev Infra teamRaised byworkflow job

atalman pushed a commit that referenced this pull requestNov 7, 2025
…#167327)[cuDNN][SDPA][Convolution] Expose cuDNN runtime version in CUDA hooks (#167111)cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users toresolve#166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version.Pull Requestresolved:#167111Approved by:https://github.com/Skylion007(cherry picked from commite678450)Co-authored-by: Eddie Yan <eddiey@nvidia.com>
jovan2009 referenced this pull request in comfyanonymous/ComfyUINov 14, 2025
Silv3S pushed a commit to Silv3S/pytorch that referenced this pull requestNov 18, 2025
…pytorch#167111)cuDNN dispatching heuristics rely on versions checks but currently only that compile-time version is exposed, if we want to allow users toresolvepytorch#166643 on their end by updating their cuDNN version locally we need to check the runtime version rather than compile-time version.Pull Requestresolved:pytorch#167111Approved by:https://github.com/Skylion007
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@Skylion007Skylion007Skylion007 approved these changes

@syed-ahmedsyed-ahmedAwaiting requested review from syed-ahmedsyed-ahmed is a code owner

@Aidyn-AAidyn-AAwaiting requested review from Aidyn-AAidyn-A is a code owner

Assignees

No one assigned

Labels

ciflow/trunkTrigger trunk jobs on your pull requestMergedmodule: convolutionProblems related to convolutions (THNN, THCUNN, CuDNN)module: cpuCPU specific problem (e.g., perf, algorithm)module: cudnnRelated to torch.backends.cudnn, and CuDNN supportmodule: sdpaAll things related to torch.nn.functional.scaled_dot_product_attentiionopen sourcerelease notes: cudnn

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

Significant Memory Regression inF.conv3d withbfloat16 Inputs inPyTorch 2.9.0

4 participants

@eqy@pytorchmergebot@Skylion007@pytorchbot

[8]ページ先頭

©2009-2025 Movatter.jp