- Notifications
You must be signed in to change notification settings - Fork29.5k
[Performance 2/6] Replace einops.rearrange with torch native ops#15804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
| """ | ||
| b,n,_=t.shape# Get the batch size (b) and sequence length (n) | ||
| d=t.shape[2]//h# Determine the depth per head | ||
| returnt.reshape(b,n,h,d) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
t.reshape(b,n,h,-1) should achieve similar result without having to explicitly calculated
| q=_reshape(q_in) | ||
| k=_reshape(k_in) | ||
| v=_reshape(v_in) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
This can be done in 1 line withq, k, v = (_reshape(t) for t in (q_in, k_in, v_in))
* replace rearrange to viewAUTOMATIC1111#15804 * see alsolllyasviel/stable-diffusion-webui-forge@79adfa8 * conditional use torch.rms_norm for torch 2.4 * fix RMSNorm() for clear: use torch.ones()
* replace rearrange to viewAUTOMATIC1111#15804 * see alsolllyasviel/stable-diffusion-webui-forge@79adfa8 * conditional use torch.rms_norm for torch 2.4 * fix RMSNorm() for clear: use torch.ones()
Uh oh!
There was an error while loading.Please reload this page.
Description
According tolllyasviel/stable-diffusion-webui-forge#716 (comment),
einops.rearrangecalls in crossattn is causing extra overhead. Replacing it with torch native ops can save ~55ms/it.Screenshots/videos:
TODO
There are other places where
einops.rearrangecan be replaced by torch native ops, but this one in CrossAttn is the most critical one. Instrument the usage ofeinops.rearrangeelsewhere might also yield some improvements.Checklist: