- Notifications
You must be signed in to change notification settings - Fork26.3k
fix apparent copy-paste bug in log_softmax reduced-precision fp kernel#156379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
pytorch-botbot commentedJun 18, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
🔗 Helpful Links🧪 See artifacts and rendered test results athud.pytorch.org/pr/156379
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commitd1f4041 with merge base6303cc4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cyyever commentedJun 19, 2025
Why the kineto changes? |
swolchok commentedJun 19, 2025
whoops, think I just forgot to update submodules when I pulled main |
swolchok commentedJun 20, 2025 • edited
Loading Uh oh!
There was an error while loading.Please reload this page.
edited
Uh oh!
There was an error while loading.Please reload this page.
I've been stressing out trying to figure out why I can't detect this bug in testing. Apparently, this is because the actual value of the intermediate max used doesn't matter for "normal" values; it's a numerical accuracy thing called "safe softmax" and it cancels out (see e.g. "Why Safe Softmax Doesn't Change the Result"here). I imagine we could try to contrive a very specific case where it matters, but probably we should just fix it and move on. |
| max_fvec0 =fVec::blendv(max_fvec0, data_fvec0, data_fvec0 > max_fvec0); | ||
| max_fvec1 =fVec::blendv(max_fvec1, data_fvec1, data_fvec1 > max_fvec1); | ||
| max_fvec0.store(input_max_data + d1); | ||
| max_fvec0.store(input_max_data + d1 +fVec::size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
OMG I EVEN SAW THIS BEFORE AND GOT confused why we were storing twice into max_fvec0? But hmmm why do we store twice into min_fvec and zero_fvec in lines 1028-1031 above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I spent some time last night trying to understand how this kernel works (by interrogating the LLM of my choice about it since I am below average at intuitively understanding array indexing code without lots of pictures for some reason), so I actually can answer that! Lines 1028-1031 are storing the identities for the max and sum reductions into our array of accumulators.
We're making a non-contiguous reduction vectorizable anyway by doing anarray of reductions all at once. (We can note that outer_size is basically just a batch dimension and ignore it for purposes of understanding the kernel.) We slice up the inner dimensions into chunks of length CHUNK_SIZE a; the inner loops are doing CHUNK_SIZE reductions at once. Accordingly, they have an array of CHUNK_SIZE accumulators, which is what's getting initialized in lines 1028-1031. since the CHUNK_SIZE dimension is contiguous, we can vectorize along it and get "parallelization" that way through vector arithmetic. The blocking/chunking stuff is so that the dim_size x CHUNK_SIZE "vertical panel" that we hand to each thread fits in cache, since softmax will end up reading the data 3 times (for each inner loop -- max, sum + log, and data - logsum - max).
janeyx99 left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Thanks for the explanation, though I haven't yet absorbed it all
janeyx99 commentedJun 20, 2025
@pytorchbot merge |
pytorchmergebot commentedJun 20, 2025
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in thewiki. Questions? Feedback? Please reach out to thePyTorch DevX Team |
Uh oh!
There was an error while loading.Please reload this page.
Stack fromghstack (oldest at bottom):
This looks like a bug. Check if trying to fix it breaks existing tests; if not, will look into why no test coverage caught it
cc@jgong5@mingfeima@XiaobingSuper@sanchitintel@ashokei@jingxu10@jerryzh168