Uh oh!
There was an error while loading.Please reload this page.
- Notifications
You must be signed in to change notification settings - Fork10.9k
ENH: Improve np.linalg.det performance#28649
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.
Already on GitHub?Sign in to your account
Uh oh!
There was an error while loading.Please reload this page.
Conversation
Uh oh!
There was an error while loading.Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
So farasv
isn't showing much perf improvement on this branch on x86_64 Linux (i9-13900K
):
asv continuous -E virtualenv -e -b "time_det.*" main linalg_refactor
BENCHMARKS NOT SIGNIFICANTLY CHANGED.
It may be because you're only making the first of the series of proposed changes.
Regardless of the performance changes, I suppose this is a reduction in lines of code, so maybe "ok" on its own anyway.
numpy/linalg/_linalg.py Outdated
@@ -152,8 +152,8 @@ def _commonType(*arrays): | |||
for a in arrays: | |||
type_ = a.dtype.type | |||
if issubclass(type_, inexact): | |||
if isComplexType(type_): | |||
is_complex = True | |||
is_complex = is_complex or isComplexType(type_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Is this really worth doing? It takes my brain longer to process and shouldn't matter much performance-wise?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I'm also not sure this routine is worth it, but if one goes for it, I'd start with something like
types = set(a.dtype.type for a in arrays)
which will generally reduce the number already, and then do
is_complex = any(isComplexType(type_) for type_ in types)
and something along similar lines forresult_type
(but with a check whether one can really not simply use the built-innp.result_type
- not obvious to me).
But I'd do it in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Interesting idea! It might not work out because the common use case is with len(arrays)
just 1 or 2 and the set overhead is too large. I will remove this change here and test out in a new PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I like the cleanup, though I'm not surprised it doesn't have that much of an effect on speed. A suggestion in-line for an extra (if very minor) performance boost.
Also, would suggest to do just the removal of_assert_stacked_square
here.
numpy/linalg/_linalg.py Outdated
@@ -152,8 +152,8 @@ def _commonType(*arrays): | |||
for a in arrays: | |||
type_ = a.dtype.type | |||
if issubclass(type_, inexact): | |||
if isComplexType(type_): | |||
is_complex = True | |||
is_complex = is_complex or isComplexType(type_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I'm also not sure this routine is worth it, but if one goes for it, I'd start with something like
types = set(a.dtype.type for a in arrays)
which will generally reduce the number already, and then do
is_complex = any(isComplexType(type_) for type_ in types)
and something along similar lines forresult_type
(but with a check whether one can really not simply use the built-innp.result_type
- not obvious to me).
But I'd do it in a separate PR.
@@ -1320,11 +1317,6 @@ def eigvalsh(a, UPLO='L'): | |||
w = gufunc(a, signature=signature) | |||
return w.astype(_realType(result_t), copy=False) | |||
def _convertarray(a): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Nice catch that this is not actually used!
numpy/linalg/_linalg.py Outdated
@@ -197,6 +197,9 @@ def _assert_stacked_2d(*arrays): | |||
def _assert_stacked_square(*arrays): | |||
for a in arrays: | |||
if a.ndim < 2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
I like the combination. If one really wants to get out the most, it could be
try: m, n = a.shape[-2:]except ValueError: riase LinAlgError(f"{a.ndim}-dimensional...") from Noneif m != n: ...
Using that these daystry/except
has no cost if no exception is raised.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Sorry, pushed this after submitting the review - above is the most relevant comment! (and not very relevant at that!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
The_assert_stacked_square
is about 10% faster using your suggestion, I updated the PR with this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others.Learn more.
Looks good to me, thanks! Let's get it in.
422ca44
intonumpy:mainUh oh!
There was an error while loading.Please reload this page.
Thanks for reviewing! Next PR is#28686 |
…8649)* ENH: Improve np.linalg.det performance* Update numpy/linalg/_linalg.py* revert change to complex detection* use suggestion* whitespace* add more small array benchmarks* trigger build
We can improve performance of
np.linalg.det
for small arrays by up to 40% with 3 changes:_assert_stacked_2d
check into_assert_stacked_square
_commonType
by using a cacher.astype(...)
for scalar arguments making a copy of the data (and internally converting to an array and back)In this PR we perform the first step.