- Notifications
You must be signed in to change notification settings - Fork26.3k
[ROCm] NLLLoss (torch.nll_loss) Performance Tuning by Dynamically Selecting # of GPU threads#149548
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
pytorch-botbot commentedMar 19, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
🔗 Helpful Links🧪 See artifacts and rendered test results athud.pytorch.org/pr/149548
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit0827d28 with merge baseffa0853 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Uh oh!
There was an error while loading.Please reload this page.
Uh oh!
There was an error while loading.Please reload this page.
jeffdaily left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
approved, pending clean CI
apakbin commentedMar 20, 2025
The benchmark we ran: |
apakbin commentedMar 20, 2025
@pytorchbot rebase |
pytorchmergebot commentedMar 20, 2025
@pytorchbot started a rebase job ontorefs/remotes/origin/viable/strict. Check the current statushere |
pytorchmergebot commentedMar 20, 2025
Successfully rebased |
6e8dc92 tobd13475CompareTo add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
jeffdaily commentedMar 21, 2025
@pytorchbot rebase |
pytorchmergebot commentedMar 21, 2025
@pytorchbot started a rebase job ontorefs/remotes/origin/viable/strict. Check the current statushere |
…amp(2^round(log2(dim0/16)), min = 32, max = 1024)
pytorchmergebot commentedMar 21, 2025
Successfully rebased |
bd13475 to0827d28Comparejeffdaily commentedMar 21, 2025
@pytorchbot merge |
pytorchmergebot commentedMar 21, 2025
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in thewiki. Questions? Feedback? Please reach out to thePyTorch DevX Team |
…ecting # of GPU threads (#149548)Instead of fixing the number of GPU threads to 32 regardless of input size, this PR dynamically selects the number of threads based on the formula: clamp(2^round(log2(dim0/16)), min = 32, max = 1024). The experiments below were done on an MI300 machine for data type float32:Pull Requestresolved:#149548Approved by:https://github.com/jeffdaily,https://github.com/pruthvistony
…n. (#149779)#149548 Fixed the arbitrarily missing parallelism for NLL, but they also added an arbritrary #ifdef ROCM guard around this fix to prevent its use on CUDA gpus. There is also a problem with the way the kernel does the reduction from the intermediate shared memory, using only thread 0 walking linearly. This has been changed to a simple parallel reduction algorithm.Tested changes with `python3 test/test_nn.py````Ran 3551 tests in 200.554sOK (skipped=998, expected failures=4)```Performance before and after with the script below with an RTX 3090, batch size x axis, time (sec) y axis. This GPU is also used for display graphics and such, so the measurements are pretty noisy, even with 100 samples.## Before## After ifdef removal## After Parallel SMEM reduction```pythonimport torchfrom matplotlib import pyplot as pltfrom torch.nn import functional as Ftiming = []batches= list(range(32, 4096, 32))for batch in [32] + batches: samples = [] for _ in range(100): probs = torch.rand(batch, 10).cuda() labels = torch.randint(0, 10, (batch,)).cuda() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() F.nll_loss(probs, labels) end.record() torch.cuda.synchronize() elapsed = start.elapsed_time(end) samples.append(elapsed) timing.append(sum(samples) / len(samples))timing = timing[1:]plt.plot(batches, timing)plt.show()```Pull Requestresolved:#149779Approved by:https://github.com/jeffdaily
…ecting # of GPU threads (pytorch#149548)Instead of fixing the number of GPU threads to 32 regardless of input size, this PR dynamically selects the number of threads based on the formula: clamp(2^round(log2(dim0/16)), min = 32, max = 1024). The experiments below were done on an MI300 machine for data type float32:Pull Requestresolved:pytorch#149548Approved by:https://github.com/jeffdaily,https://github.com/pruthvistony
…n. (pytorch#149779)pytorch#149548 Fixed the arbitrarily missing parallelism for NLL, but they also added an arbritrary #ifdef ROCM guard around this fix to prevent its use on CUDA gpus. There is also a problem with the way the kernel does the reduction from the intermediate shared memory, using only thread 0 walking linearly. This has been changed to a simple parallel reduction algorithm.Tested changes with `python3 test/test_nn.py````Ran 3551 tests in 200.554sOK (skipped=998, expected failures=4)```Performance before and after with the script below with an RTX 3090, batch size x axis, time (sec) y axis. This GPU is also used for display graphics and such, so the measurements are pretty noisy, even with 100 samples.## Before## After ifdef removal## After Parallel SMEM reduction```pythonimport torchfrom matplotlib import pyplot as pltfrom torch.nn import functional as Ftiming = []batches= list(range(32, 4096, 32))for batch in [32] + batches: samples = [] for _ in range(100): probs = torch.rand(batch, 10).cuda() labels = torch.randint(0, 10, (batch,)).cuda() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) start.record() F.nll_loss(probs, labels) end.record() torch.cuda.synchronize() elapsed = start.elapsed_time(end) samples.append(elapsed) timing.append(sum(samples) / len(samples))timing = timing[1:]plt.plot(batches, timing)plt.show()```Pull Requestresolved:pytorch#149779Approved by:https://github.com/jeffdaily
Uh oh!
There was an error while loading.Please reload this page.
Instead of fixing the number of GPU threads to 32 regardless of input size, this PR dynamically selects the number of threads based on the formula: clamp(2^round(log2(dim0/16)), min = 32, max = 1024). The experiments below were done on an MI300 machine for data type float32:
cc@jeffdaily@sunway513@jithunnair-amd@pruthvistony@ROCmSupport@dllehr-amd@jataylo@hongxiayang@naromero77amd