Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

[ROCm] NLLLoss (torch.nll_loss) Performance Tuning by Dynamically Selecting # of GPU threads#149548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Closed
apakbin wants to merge3 commits intopytorch:mainfromapakbin:nll_loss_tune

Conversation

@apakbin
Copy link
Contributor

@apakbinapakbin commentedMar 19, 2025
edited by pytorch-botbot
Loading

Instead of fixing the number of GPU threads to 32 regardless of input size, this PR dynamically selects the number of threads based on the formula: clamp(2^round(log2(dim0/16)), min = 32, max = 1024). The experiments below were done on an MI300 machine for data type float32:

nll_loss_threads_bests
nll_loss_heauristic

cc@jeffdaily@sunway513@jithunnair-amd@pruthvistony@ROCmSupport@dllehr-amd@jataylo@hongxiayang@naromero77amd

pruthvistony reacted with thumbs up emoji
@pytorch-botpytorch-botbot added the release notes: cudarelease notes category labelMar 19, 2025
@pytorch-bot
Copy link

pytorch-botbot commentedMar 19, 2025
edited
Loading

🔗 Helpful Links

🧪 See artifacts and rendered test results athud.pytorch.org/pr/149548

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit0827d28 with merge baseffa0853 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@apakbinapakbin changed the titleNLLLoss (torch.nll_loss) Performance Tuning by Dynamically Selecting # of GPU threads[ROCm] NLLLoss (torch.nll_loss) Performance Tuning by Dynamically Selecting # of GPU threadsMar 19, 2025
@pytorch-botpytorch-botbot added the module: rocmAMD GPU support for Pytorch labelMar 19, 2025
@apakbinapakbin marked this pull request as draftMarch 19, 2025 22:47
@jithunnair-amdjithunnair-amd added the ciflow/rocmTrigger "default" config CI on ROCm labelMar 19, 2025
@pytorch-botpytorch-botbot removed the ciflow/rocmTrigger "default" config CI on ROCm labelMar 19, 2025
@jeffdailyjeffdaily added the ciflow/rocmTrigger "default" config CI on ROCm labelMar 20, 2025
Copy link
Collaborator

@jeffdailyjeffdaily left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

approved, pending clean CI

@apakbin
Copy link
ContributorAuthor

The benchmark we ran:

repro.txt

@apakbin
Copy link
ContributorAuthor

@pytorchbot rebase

pytorch-bot[bot] reacted with thumbs up emoji

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job ontorefs/remotes/origin/viable/strict. Check the current statushere

@pytorchmergebot
Copy link
Collaborator

Successfully rebasednll_loss_tune ontorefs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, viagit checkout nll_loss_tune && git pull --rebase)

@pytorch-botpytorch-botbot removed the ciflow/rocmTrigger "default" config CI on ROCm labelMar 20, 2025
@pruthvistonypruthvistony added ciflow/periodicTrigger jobs ran periodically on master (periodic.yml) on the PR ciflow/rocmTrigger "default" config CI on ROCm ciflow/inductor-rocmTrigger "inductor" config CI on ROCm ciflow/rocm-mi300Trigger "default" config CI on ROCm MI300 labelsMar 20, 2025
@pytorch-bot
Copy link

To add the ciflow labelciflow/periodic please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link

To add the ciflow labelciflow/inductor-rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link

To add the ciflow labelciflow/rocm please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link

To add the ciflow labelciflow/rocm-mi300 please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@jeffdailyjeffdaily marked this pull request as ready for reviewMarch 21, 2025 00:01
@jeffdaily
Copy link
Collaborator

@pytorchbot rebase

pytorch-bot[bot] reacted with thumbs up emoji

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job ontorefs/remotes/origin/viable/strict. Check the current statushere

@pytorchmergebot
Copy link
Collaborator

Successfully rebasednll_loss_tune ontorefs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, viagit checkout nll_loss_tune && git pull --rebase)

@pytorch-botpytorch-botbot removed ciflow/periodicTrigger jobs ran periodically on master (periodic.yml) on the PR ciflow/rocmTrigger "default" config CI on ROCm ciflow/inductor-rocmTrigger "inductor" config CI on ROCm ciflow/rocm-mi300Trigger "default" config CI on ROCm MI300 labelsMar 21, 2025
@jeffdailyjeffdaily added ciflow/rocmTrigger "default" config CI on ROCm ciflow/rocm-mi300Trigger "default" config CI on ROCm MI300 labelsMar 21, 2025
@jeffdaily
Copy link
Collaborator

@pytorchbot merge

pytorch-bot[bot] reacted with thumbs up emoji

@pytorch-botpytorch-botbot added the ciflow/trunkTrigger trunk jobs on your pull request labelMar 21, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in thewiki.

Questions? Feedback? Please reach out to thePyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

apakbin added a commit to ROCm/pytorch that referenced this pull requestMar 21, 2025
jerrymannil pushed a commit to ROCm/pytorch that referenced this pull requestMar 21, 2025
svekars pushed a commit that referenced this pull requestMar 21, 2025
…ecting # of GPU threads (#149548)Instead of fixing the number of GPU threads to 32 regardless of input size, this PR dynamically selects the number of threads based on the formula: clamp(2^round(log2(dim0/16)), min = 32, max = 1024). The experiments below were done on an MI300 machine for data type float32:![nll_loss_threads_bests](https://github.com/user-attachments/assets/3be3d465-e3db-44ed-991a-fdfcab03baae)![nll_loss_heauristic](https://github.com/user-attachments/assets/e82b9788-9b4d-4862-a180-8df7ad298182)Pull Requestresolved:#149548Approved by:https://github.com/jeffdaily,https://github.com/pruthvistony
pytorchmergebot pushed a commit that referenced this pull requestMar 29, 2025
…n. (#149779)#149548 Fixed the arbitrarily missing parallelism for NLL, but they also added an arbritrary #ifdef ROCM guard around this fix to prevent its use on CUDA gpus. There is also a problem with the way the kernel does the reduction from the intermediate shared memory, using only thread 0 walking linearly. This has been changed to a simple parallel reduction algorithm.Tested changes with `python3 test/test_nn.py````Ran 3551 tests in 200.554sOK (skipped=998, expected failures=4)```Performance before and after with the script below with an RTX 3090, batch size x axis, time (sec) y axis. This GPU is also used for display graphics and such, so the measurements are pretty noisy, even with 100 samples.## Before![before_nll](https://github.com/user-attachments/assets/c19044aa-7bc2-4223-b560-9be7acedef35)## After ifdef removal![after_nll](https://github.com/user-attachments/assets/4672f5ca-93b0-4c34-a257-81b2ab364995)## After Parallel SMEM reduction![after_reduction](https://github.com/user-attachments/assets/9607b68c-7d9d-4ee0-9f99-8989d134e4fd)```pythonimport torchfrom matplotlib import pyplot as pltfrom torch.nn import functional as Ftiming = []batches=  list(range(32, 4096, 32))for batch in [32] + batches:    samples = []    for _ in range(100):        probs = torch.rand(batch, 10).cuda()        labels = torch.randint(0, 10, (batch,)).cuda()        start = torch.cuda.Event(enable_timing=True)        end = torch.cuda.Event(enable_timing=True)        start.record()        F.nll_loss(probs, labels)        end.record()        torch.cuda.synchronize()        elapsed = start.elapsed_time(end)        samples.append(elapsed)    timing.append(sum(samples) / len(samples))timing = timing[1:]plt.plot(batches, timing)plt.show()```Pull Requestresolved:#149779Approved by:https://github.com/jeffdaily
amathewc pushed a commit to amathewc/pytorch that referenced this pull requestApr 17, 2025
…ecting # of GPU threads (pytorch#149548)Instead of fixing the number of GPU threads to 32 regardless of input size, this PR dynamically selects the number of threads based on the formula: clamp(2^round(log2(dim0/16)), min = 32, max = 1024). The experiments below were done on an MI300 machine for data type float32:![nll_loss_threads_bests](https://github.com/user-attachments/assets/3be3d465-e3db-44ed-991a-fdfcab03baae)![nll_loss_heauristic](https://github.com/user-attachments/assets/e82b9788-9b4d-4862-a180-8df7ad298182)Pull Requestresolved:pytorch#149548Approved by:https://github.com/jeffdaily,https://github.com/pruthvistony
amathewc pushed a commit to amathewc/pytorch that referenced this pull requestApr 17, 2025
…n. (pytorch#149779)pytorch#149548 Fixed the arbitrarily missing parallelism for NLL, but they also added an arbritrary #ifdef ROCM guard around this fix to prevent its use on CUDA gpus. There is also a problem with the way the kernel does the reduction from the intermediate shared memory, using only thread 0 walking linearly. This has been changed to a simple parallel reduction algorithm.Tested changes with `python3 test/test_nn.py````Ran 3551 tests in 200.554sOK (skipped=998, expected failures=4)```Performance before and after with the script below with an RTX 3090, batch size x axis, time (sec) y axis. This GPU is also used for display graphics and such, so the measurements are pretty noisy, even with 100 samples.## Before![before_nll](https://github.com/user-attachments/assets/c19044aa-7bc2-4223-b560-9be7acedef35)## After ifdef removal![after_nll](https://github.com/user-attachments/assets/4672f5ca-93b0-4c34-a257-81b2ab364995)## After Parallel SMEM reduction![after_reduction](https://github.com/user-attachments/assets/9607b68c-7d9d-4ee0-9f99-8989d134e4fd)```pythonimport torchfrom matplotlib import pyplot as pltfrom torch.nn import functional as Ftiming = []batches=  list(range(32, 4096, 32))for batch in [32] + batches:    samples = []    for _ in range(100):        probs = torch.rand(batch, 10).cuda()        labels = torch.randint(0, 10, (batch,)).cuda()        start = torch.cuda.Event(enable_timing=True)        end = torch.cuda.Event(enable_timing=True)        start.record()        F.nll_loss(probs, labels)        end.record()        torch.cuda.synchronize()        elapsed = start.elapsed_time(end)        samples.append(elapsed)    timing.append(sum(samples) / len(samples))timing = timing[1:]plt.plot(batches, timing)plt.show()```Pull Requestresolved:pytorch#149779Approved by:https://github.com/jeffdaily
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@jerrymanniljerrymanniljerrymannil left review comments

@jeffdailyjeffdailyjeffdaily approved these changes

@pruthvistonypruthvistonypruthvistony approved these changes

@eqyeqyAwaiting requested review from eqyeqy is a code owner

@syed-ahmedsyed-ahmedAwaiting requested review from syed-ahmedsyed-ahmed is a code owner

Assignees

No one assigned

Labels

ciflow/rocmTrigger "default" config CI on ROCmciflow/rocm-mi300Trigger "default" config CI on ROCm MI300ciflow/trunkTrigger trunk jobs on your pull requestMergedmodule: rocmAMD GPU support for Pytorchopen sourcerelease notes: cudarelease notes category

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

7 participants

@apakbin@pytorchmergebot@jeffdaily@pruthvistony@jerrymannil@pytorchbot@jithunnair-amd

[8]ページ先頭

©2009-2025 Movatter.jp