Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit5932c98

Browse files
committed
TYP: np.argmin changes
1 parent4e628f4 commit5932c98

File tree

7 files changed

+36
-25
lines changed

7 files changed

+36
-25
lines changed

‎numpy/__init__.pyi

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ _CharacterItemT_co = TypeVar("_CharacterItemT_co", bound=_CharLike_co, default=_
826826
_TD64ItemT_co=TypeVar("_TD64ItemT_co",bound=dt.timedelta|int|None,default=dt.timedelta|int|None,covariant=True)
827827
_DT64ItemT_co=TypeVar("_DT64ItemT_co",bound=dt.date|int|None,default=dt.date|int|None,covariant=True)
828828
_TD64UnitT=TypeVar("_TD64UnitT",bound=_TD64Unit,default=_TD64Unit)
829+
_BoolOrIntArrayT=TypeVar("_BoolOrIntArrayT",bound=NDArray[integer|np.bool])
829830

830831
### Type Aliases (for internal use only)
831832

@@ -1704,18 +1705,18 @@ class _ArrayOrScalarCommon:
17041705
@overload# axis=index, out=None (default)
17051706
defargmax(self,/,axis:SupportsIndex,out:None=None,*,keepdims:builtins.bool=False)->Any: ...
17061707
@overload# axis=index, out=ndarray
1707-
defargmax(self,/,axis:SupportsIndex|None,out:_ArrayT,*,keepdims:builtins.bool=False)->_ArrayT: ...
1708+
defargmax(self,/,axis:SupportsIndex|None,out:_BoolOrIntArrayT,*,keepdims:builtins.bool=False)->_BoolOrIntArrayT: ...
17081709
@overload
1709-
defargmax(self,/,axis:SupportsIndex|None=None,*,out:_ArrayT,keepdims:builtins.bool=False)->_ArrayT: ...
1710+
defargmax(self,/,axis:SupportsIndex|None=None,*,out:_BoolOrIntArrayT,keepdims:builtins.bool=False)->_BoolOrIntArrayT: ...
17101711

17111712
@overload# axis=None (default), out=None (default), keepdims=False (default)
17121713
defargmin(self,/,axis:None=None,out:None=None,*,keepdims:L[False]=False)->intp: ...
17131714
@overload# axis=index, out=None (default)
17141715
defargmin(self,/,axis:SupportsIndex,out:None=None,*,keepdims:builtins.bool=False)->Any: ...
17151716
@overload# axis=index, out=ndarray
1716-
defargmin(self,/,axis:SupportsIndex|None,out:_ArrayT,*,keepdims:builtins.bool=False)->_ArrayT: ...
1717+
defargmin(self,/,axis:SupportsIndex|None,out:_BoolOrIntArrayT,*,keepdims:builtins.bool=False)->_BoolOrIntArrayT: ...
17171718
@overload
1718-
defargmin(self,/,axis:SupportsIndex|None=None,*,out:_ArrayT,keepdims:builtins.bool=False)->_ArrayT: ...
1719+
defargmin(self,/,axis:SupportsIndex|None=None,*,out:_BoolOrIntArrayT,keepdims:builtins.bool=False)->_BoolOrIntArrayT: ...
17191720

17201721
@overload# out=None (default)
17211722
defround(self,/,decimals:SupportsIndex=0,out:None=None)->Self: ...
@@ -5364,19 +5365,19 @@ class matrix(ndarray[_2DShapeT_co, _DTypeT_co]):
53645365
@overload
53655366
defargmax(self,axis:_ShapeLike,out:None=None)->matrix[_2D,dtype[intp]]: ...
53665367
@overload
5367-
defargmax(self,axis:_ShapeLike|None,out:_ArrayT)->_ArrayT: ...
5368+
defargmax(self,axis:_ShapeLike|None,out:_BoolOrIntArrayT)->_BoolOrIntArrayT: ...
53685369
@overload
5369-
defargmax(self,axis:_ShapeLike|None=None,*,out:_ArrayT)->_ArrayT: ...# pyright: ignore[reportIncompatibleMethodOverride]
5370+
defargmax(self,axis:_ShapeLike|None=None,*,out:_BoolOrIntArrayT)->_BoolOrIntArrayT: ...# pyright: ignore[reportIncompatibleMethodOverride]
53705371

53715372
# keep in sync with `argmax`
53725373
@overload# type: ignore[override]
53735374
defargmin(self:NDArray[_ScalarT],axis:None=None,out:None=None)->intp: ...
53745375
@overload
53755376
defargmin(self,axis:_ShapeLike,out:None=None)->matrix[_2D,dtype[intp]]: ...
53765377
@overload
5377-
defargmin(self,axis:_ShapeLike|None,out:_ArrayT)->_ArrayT: ...
5378+
defargmin(self,axis:_ShapeLike|None,out:_BoolOrIntArrayT)->_BoolOrIntArrayT: ...
53785379
@overload
5379-
defargmin(self,axis:_ShapeLike|None=None,*,out:_ArrayT)->_ArrayT: ...# pyright: ignore[reportIncompatibleMethodOverride]
5380+
defargmin(self,axis:_ShapeLike|None=None,*,out:_BoolOrIntArrayT)->_BoolOrIntArrayT: ...# pyright: ignore[reportIncompatibleMethodOverride]
53805381

53815382
#the second overload handles the (rare) case that the matrix is not 2-d
53825383
@overload

‎numpy/_core/fromnumeric.pyi

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ _NumberOrObjectT = TypeVar("_NumberOrObjectT", bound=np.number | np.object_)
111111
_ArrayT=TypeVar("_ArrayT",bound=np.ndarray[Any,Any])
112112
_ShapeT=TypeVar("_ShapeT",bound=tuple[int, ...])
113113
_ShapeT_co=TypeVar("_ShapeT_co",bound=tuple[int, ...],covariant=True)
114+
_BoolOrIntArrayT=TypeVar("_BoolOrIntArrayT",bound=NDArray[np.integer|np.bool])
114115

115116
@type_check_only
116117
class_SupportsShape(Protocol[_ShapeT_co]):
@@ -418,18 +419,18 @@ def argmax(
418419
defargmax(
419420
a:ArrayLike,
420421
axis:SupportsIndex|None,
421-
out:_ArrayT,
422+
out:_BoolOrIntArrayT,
422423
*,
423424
keepdims:bool= ...,
424-
)->_ArrayT: ...
425+
)->_BoolOrIntArrayT: ...
425426
@overload
426427
defargmax(
427428
a:ArrayLike,
428429
axis:SupportsIndex|None= ...,
429430
*,
430-
out:_ArrayT,
431+
out:_BoolOrIntArrayT,
431432
keepdims:bool= ...,
432-
)->_ArrayT: ...
433+
)->_BoolOrIntArrayT: ...
433434

434435
@overload
435436
defargmin(
@@ -451,18 +452,18 @@ def argmin(
451452
defargmin(
452453
a:ArrayLike,
453454
axis:SupportsIndex|None,
454-
out:_ArrayT,
455+
out:_BoolOrIntArrayT,
455456
*,
456457
keepdims:bool= ...,
457-
)->_ArrayT: ...
458+
)->_BoolOrIntArrayT: ...
458459
@overload
459460
defargmin(
460461
a:ArrayLike,
461462
axis:SupportsIndex|None= ...,
462463
*,
463-
out:_ArrayT,
464+
out:_BoolOrIntArrayT,
464465
keepdims:bool= ...,
465-
)->_ArrayT: ...
466+
)->_BoolOrIntArrayT: ...
466467

467468
@overload
468469
defsearchsorted(

‎numpy/typing/tests/data/fail/fromnumeric.pyi

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ A = np.array(True, ndmin=2, dtype=bool)
77
A.setflags(write=False)
88
AR_U:npt.NDArray[np.str_]
99
AR_M:npt.NDArray[np.datetime64]
10+
AR_f4:npt.NDArray[np.float32]
1011

1112
a=np.bool(True)
1213

@@ -50,9 +51,11 @@ np.argsort(A, order=range(5)) # type: ignore[arg-type]
5051

5152
np.argmax(A,axis="bob")# type: ignore[call-overload]
5253
np.argmax(A,kind="bob")# type: ignore[call-overload]
54+
np.argmax(A,out=AR_f4)# type: ignore[arg-type]
5355

5456
np.argmin(A,axis="bob")# type: ignore[call-overload]
5557
np.argmin(A,kind="bob")# type: ignore[call-overload]
58+
np.argmin(A,out=AR_f4)# type: ignore[arg-type]
5659

5760
np.searchsorted(A[0],0,side="bob")# type: ignore[call-overload]
5861
np.searchsorted(A[0],0,sorter=1.0)# type: ignore[call-overload]

‎numpy/typing/tests/data/pass/ndarray_misc.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
importnumpy.typingasnpt
1616

1717
classSubClass(npt.NDArray[np.float64]): ...
18-
18+
classIntSubClass(npt.NDArray[np.intp]): ...
1919

2020
i4=np.int32(1)
2121
A:np.ndarray[Any,np.dtype[np.int32]]=np.array([[1]],dtype=np.int32)
2222
B0=np.empty((),dtype=np.int32).view(SubClass)
2323
B1=np.empty((1,),dtype=np.int32).view(SubClass)
2424
B2=np.empty((1,1),dtype=np.int32).view(SubClass)
25+
B_int0:IntSubClass=np.empty((),dtype=np.intp).view(IntSubClass)
2526
C:np.ndarray[Any,np.dtype[np.int32]]=np.array([0,1,2],dtype=np.int32)
2627
D=np.ones(3).view(SubClass)
2728

@@ -42,12 +43,12 @@ class SubClass(npt.NDArray[np.float64]): ...
4243
i4.argmax()
4344
A.argmax()
4445
A.argmax(axis=0)
45-
A.argmax(out=B0)
46+
A.argmax(out=B_int0)
4647

4748
i4.argmin()
4849
A.argmin()
4950
A.argmin(axis=0)
50-
A.argmin(out=B0)
51+
A.argmin(out=B_int0)
5152

5253
i4.argsort()
5354
A.argsort()

‎numpy/typing/tests/data/reveal/fromnumeric.pyi

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ f4: np.float32
2525
i8:np.int64
2626
f:float
2727

28+
# integer‑dtype subclass for argmin/argmax
29+
classNDArrayIntSubclass(npt.NDArray[np.intp]): ...
30+
AR_sub_i:NDArrayIntSubclass
31+
2832
assert_type(np.take(b,0),np.bool)
2933
assert_type(np.take(f4,0),np.float32)
3034
assert_type(np.take(f,0),Any)
@@ -89,13 +93,13 @@ assert_type(np.argmax(AR_b), np.intp)
8993
assert_type(np.argmax(AR_f4),np.intp)
9094
assert_type(np.argmax(AR_b,axis=0),Any)
9195
assert_type(np.argmax(AR_f4,axis=0),Any)
92-
assert_type(np.argmax(AR_f4,out=AR_subclass),NDArraySubclass)
96+
assert_type(np.argmax(AR_f4,out=AR_sub_i),NDArrayIntSubclass)
9397

9498
assert_type(np.argmin(AR_b),np.intp)
9599
assert_type(np.argmin(AR_f4),np.intp)
96100
assert_type(np.argmin(AR_b,axis=0),Any)
97101
assert_type(np.argmin(AR_f4,axis=0),Any)
98-
assert_type(np.argmin(AR_f4,out=AR_subclass),NDArraySubclass)
102+
assert_type(np.argmin(AR_f4,out=AR_sub_i),NDArrayIntSubclass)
99103

100104
assert_type(np.searchsorted(AR_b[0],0),np.intp)
101105
assert_type(np.searchsorted(AR_f4[0],0),np.intp)

‎numpy/typing/tests/data/reveal/matrix.pyi

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ _Shape2D: TypeAlias = tuple[int, int]
77

88
mat:np.matrix[_Shape2D,np.dtype[np.int64]]
99
ar_f8:npt.NDArray[np.float64]
10+
ar_ip:npt.NDArray[np.intp]
1011

1112
assert_type(mat*5,np.matrix[_Shape2D,Any])
1213
assert_type(5*mat,np.matrix[_Shape2D,Any])
@@ -50,8 +51,8 @@ assert_type(mat.any(out=ar_f8), npt.NDArray[np.float64])
5051
assert_type(mat.all(out=ar_f8),npt.NDArray[np.float64])
5152
assert_type(mat.max(out=ar_f8),npt.NDArray[np.float64])
5253
assert_type(mat.min(out=ar_f8),npt.NDArray[np.float64])
53-
assert_type(mat.argmax(out=ar_f8),npt.NDArray[np.float64])
54-
assert_type(mat.argmin(out=ar_f8),npt.NDArray[np.float64])
54+
assert_type(mat.argmax(out=ar_ip),npt.NDArray[np.intp])
55+
assert_type(mat.argmin(out=ar_ip),npt.NDArray[np.intp])
5556
assert_type(mat.ptp(out=ar_f8),npt.NDArray[np.float64])
5657

5758
assert_type(mat.T,np.matrix[_Shape2D,np.dtype[np.int64]])

‎numpy/typing/tests/data/reveal/ndarray_misc.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,12 @@ assert_type(AR_f8.any(out=B), SubClass)
5858
assert_type(f8.argmax(),np.intp)
5959
assert_type(AR_f8.argmax(),np.intp)
6060
assert_type(AR_f8.argmax(axis=0),Any)
61-
assert_type(AR_f8.argmax(out=B),SubClass)
61+
assert_type(AR_f8.argmax(out=AR_i8),npt.NDArray[np.intp])
6262

6363
assert_type(f8.argmin(),np.intp)
6464
assert_type(AR_f8.argmin(),np.intp)
6565
assert_type(AR_f8.argmin(axis=0),Any)
66-
assert_type(AR_f8.argmin(out=B),SubClass)
66+
assert_type(AR_f8.argmin(out=AR_i8),npt.NDArray[np.intp])
6767

6868
assert_type(f8.argsort(),npt.NDArray[Any])
6969
assert_type(AR_f8.argsort(),npt.NDArray[Any])

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp