Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Removed ROCM ifdef that governs thread count + smem parallel reduction.#149779

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Closed
5had3z wants to merge2 commits intopytorch:mainfrom5had3z:nll_remove_threads_ifdef

Conversation

@5had3z
Copy link
Contributor

@5had3z5had3z commentedMar 21, 2025
edited by pytorch-botbot
Loading

#149548 Fixed the arbitrarily missing parallelism for NLL, but they also added an arbritrary #ifdef ROCM guard around this fix to prevent its use on CUDA gpus. There is also a problem with the way the kernel does the reduction from the intermediate shared memory, using only thread 0 walking linearly. This has been changed to a simple parallel reduction algorithm.

Tested changes withpython3 test/test_nn.py

Ran 3551 tests in 200.554sOK (skipped=998, expected failures=4)

Performance before and after with the script below with an RTX 3090, batch size x axis, time (sec) y axis. This GPU is also used for display graphics and such, so the measurements are pretty noisy, even with 100 samples.

Before

before_nll

After ifdef removal

after_nll

After Parallel SMEM reduction

after_reduction

importtorchfrommatplotlibimportpyplotaspltfromtorch.nnimportfunctionalasFtiming= []batches=list(range(32,4096,32))forbatchin [32]+batches:samples= []for_inrange(100):probs=torch.rand(batch,10).cuda()labels=torch.randint(0,10, (batch,)).cuda()start=torch.cuda.Event(enable_timing=True)end=torch.cuda.Event(enable_timing=True)start.record()F.nll_loss(probs,labels)end.record()torch.cuda.synchronize()elapsed=start.elapsed_time(end)samples.append(elapsed)timing.append(sum(samples)/len(samples))timing=timing[1:]plt.plot(batches,timing)plt.show()

cc@jeffdaily@sunway513@jithunnair-amd@pruthvistony@ROCmSupport@dllehr-amd@jataylo@hongxiayang@naromero77amd

@pytorch-bot
Copy link

pytorch-botbot commentedMar 21, 2025
edited
Loading

🔗 Helpful Links

🧪 See artifacts and rendered test results athud.pytorch.org/pr/149779

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 2 Unrelated Failures

As of commitf485bbb with merge base14f0cd7 (image):

NEW FAILURE - The following job has failed:

UNSTABLE - The following jobs are marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@linux-foundation-easycla
Copy link

linux-foundation-easyclabot commentedMar 21, 2025
edited
Loading

CLA Signed

The committers listed above are authorized under a signed CLA.

@pytorch-botpytorch-botbot added module: rocmAMD GPU support for Pytorch release notes: cudarelease notes category labelsMar 21, 2025
@5had3z
Copy link
ContributorAuthor

Those failures don't seem to be related to my changes? The NN test also works fine when I call directly withpython3 test/test_nn.py TestNNDeviceTypeCUDA.test_variable_sequence_cuda_float32, so maybe there is a race condition or something.

TestNNDeviceTypeCUDA.test_variable_sequence_cuda_float32

WIN: benchmark ('sum_floordiv_regression', 'compile_time_instruction_count') failed, actual result 971748003 is -5.29% lower than expected 1026000000 ±1.50% please update the expected results.

I'm also curious as to the purpose ofNLLLoss2d.cu? torch.nn.NLLLoss2d is an alias for torch.nn.NLLLoss, so this seems to just be dead code?

@colesburycolesbury added the triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module labelMar 24, 2025
@jeffdailyjeffdaily added the ciflow/rocmTrigger "default" config CI on ROCm labelMar 27, 2025
@jeffdaily
Copy link
Collaborator

@5had3z I'm attempting to rerun the failed CUDA UT job. The CUDA benchmark job failure is unrelated. Also, I forgot to trigger ROCm CI so I've done that now. I approved your changes since they LGTM but pending CI passing.

@jeffdaily
Copy link
Collaborator

@5had3z rerunning job got the same error. Let's try a rebase this time and see if we just happened to get an unlucky random set of inputs for that test.

5had3z reacted with thumbs up emoji

@jeffdaily
Copy link
Collaborator

@pytorchbot rebase

pytorch-bot[bot] reacted with thumbs up emoji

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job ontorefs/remotes/origin/viable/strict. Check the current statushere

Signed-off-by: Bryce Ferenczi <frenzi@hotmail.com.au>
Signed-off-by: Bryce Ferenczi <frenzi@hotmail.com.au>
@pytorchmergebot
Copy link
Collaborator

Successfully rebasednll_remove_threads_ifdef ontorefs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, viagit checkout nll_remove_threads_ifdef && git pull --rebase)

@pytorch-botpytorch-botbot removed the ciflow/rocmTrigger "default" config CI on ROCm labelMar 27, 2025
@cyyever
Copy link
Collaborator

@pytorchmergebot merge -f "Unrelated failures"

pytorch-bot[bot] and 5had3z reacted with thumbs up emoji

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag,bypassing any CI checks (ETA: 1-5 minutes). Please use-f as last resort and instead consider-i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in thewiki.

Questions? Feedback? Please reach out to thePyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

amathewc pushed a commit to amathewc/pytorch that referenced this pull requestApr 17, 2025
…n. (pytorch#149779)pytorch#149548 Fixed the arbitrarily missing parallelism for NLL, but they also added an arbritrary #ifdef ROCM guard around this fix to prevent its use on CUDA gpus. There is also a problem with the way the kernel does the reduction from the intermediate shared memory, using only thread 0 walking linearly. This has been changed to a simple parallel reduction algorithm.Tested changes with `python3 test/test_nn.py````Ran 3551 tests in 200.554sOK (skipped=998, expected failures=4)```Performance before and after with the script below with an RTX 3090, batch size x axis, time (sec) y axis. This GPU is also used for display graphics and such, so the measurements are pretty noisy, even with 100 samples.## Before![before_nll](https://github.com/user-attachments/assets/c19044aa-7bc2-4223-b560-9be7acedef35)## After ifdef removal![after_nll](https://github.com/user-attachments/assets/4672f5ca-93b0-4c34-a257-81b2ab364995)## After Parallel SMEM reduction![after_reduction](https://github.com/user-attachments/assets/9607b68c-7d9d-4ee0-9f99-8989d134e4fd)```pythonimport torchfrom matplotlib import pyplot as pltfrom torch.nn import functional as Ftiming = []batches=  list(range(32, 4096, 32))for batch in [32] + batches:    samples = []    for _ in range(100):        probs = torch.rand(batch, 10).cuda()        labels = torch.randint(0, 10, (batch,)).cuda()        start = torch.cuda.Event(enable_timing=True)        end = torch.cuda.Event(enable_timing=True)        start.record()        F.nll_loss(probs, labels)        end.record()        torch.cuda.synchronize()        elapsed = start.elapsed_time(end)        samples.append(elapsed)    timing.append(sum(samples) / len(samples))timing = timing[1:]plt.plot(batches, timing)plt.show()```Pull Requestresolved:pytorch#149779Approved by:https://github.com/jeffdaily
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

@jeffdailyjeffdailyjeffdaily approved these changes

@eqyeqyAwaiting requested review from eqyeqy is a code owner

@syed-ahmedsyed-ahmedAwaiting requested review from syed-ahmedsyed-ahmed is a code owner

Assignees

No one assigned

Labels

Mergedmodule: rocmAMD GPU support for Pytorchopen sourcerelease notes: cudarelease notes categorytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

6 participants

@5had3z@jeffdaily@pytorchmergebot@cyyever@colesbury@pytorchbot

[8]ページ先頭

©2009-2025 Movatter.jp