Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit325e4f6

Browse files
Backport PR#3083: Add async oindex and vindex methods to AsyncArray (#3311)
Co-authored-by: Tom Nicholas <tom@earthmover.io>
1 parentb792530 commit325e4f6

File tree

8 files changed

+270
-17
lines changed

8 files changed

+270
-17
lines changed

‎changes/3083.feature.rst‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added support for async vectorized and orthogonal indexing.

‎src/zarr/core/array.py‎

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
ZarrFormat,
6262
_default_zarr_format,
6363
_warn_order_kwarg,
64+
ceildiv,
6465
concurrent_map,
6566
parse_shapelike,
6667
product,
@@ -76,6 +77,8 @@
7677
)
7778
fromzarr.core.dtype.commonimportHasEndianness,HasItemSize,HasObjectCodec
7879
fromzarr.core.indexingimport (
80+
AsyncOIndex,
81+
AsyncVIndex,
7982
BasicIndexer,
8083
BasicSelection,
8184
BlockIndex,
@@ -92,7 +95,6 @@
9295
Selection,
9396
VIndex,
9497
_iter_grid,
95-
ceildiv,
9698
check_fields,
9799
check_no_multi_fields,
98100
is_pure_fancy_indexing,
@@ -1425,6 +1427,56 @@ async def getitem(
14251427
)
14261428
returnawaitself._get_selection(indexer,prototype=prototype)
14271429

1430+
asyncdefget_orthogonal_selection(
1431+
self,
1432+
selection:OrthogonalSelection,
1433+
*,
1434+
out:NDBuffer|None=None,
1435+
fields:Fields|None=None,
1436+
prototype:BufferPrototype|None=None,
1437+
)->NDArrayLikeOrScalar:
1438+
ifprototypeisNone:
1439+
prototype=default_buffer_prototype()
1440+
indexer=OrthogonalIndexer(selection,self.shape,self.metadata.chunk_grid)
1441+
returnawaitself._get_selection(
1442+
indexer=indexer,out=out,fields=fields,prototype=prototype
1443+
)
1444+
1445+
asyncdefget_mask_selection(
1446+
self,
1447+
mask:MaskSelection,
1448+
*,
1449+
out:NDBuffer|None=None,
1450+
fields:Fields|None=None,
1451+
prototype:BufferPrototype|None=None,
1452+
)->NDArrayLikeOrScalar:
1453+
ifprototypeisNone:
1454+
prototype=default_buffer_prototype()
1455+
indexer=MaskIndexer(mask,self.shape,self.metadata.chunk_grid)
1456+
returnawaitself._get_selection(
1457+
indexer=indexer,out=out,fields=fields,prototype=prototype
1458+
)
1459+
1460+
asyncdefget_coordinate_selection(
1461+
self,
1462+
selection:CoordinateSelection,
1463+
*,
1464+
out:NDBuffer|None=None,
1465+
fields:Fields|None=None,
1466+
prototype:BufferPrototype|None=None,
1467+
)->NDArrayLikeOrScalar:
1468+
ifprototypeisNone:
1469+
prototype=default_buffer_prototype()
1470+
indexer=CoordinateIndexer(selection,self.shape,self.metadata.chunk_grid)
1471+
out_array=awaitself._get_selection(
1472+
indexer=indexer,out=out,fields=fields,prototype=prototype
1473+
)
1474+
1475+
ifhasattr(out_array,"shape"):
1476+
# restore shape
1477+
out_array=np.array(out_array).reshape(indexer.sel_shape)
1478+
returnout_array
1479+
14281480
asyncdef_save_metadata(self,metadata:ArrayMetadata,ensure_parents:bool=False)->None:
14291481
"""
14301482
Asynchronously save the array metadata.
@@ -1556,6 +1608,19 @@ async def setitem(
15561608
)
15571609
returnawaitself._set_selection(indexer,value,prototype=prototype)
15581610

1611+
@property
1612+
defoindex(self)->AsyncOIndex[T_ArrayMetadata]:
1613+
"""Shortcut for orthogonal (outer) indexing, see :func:`get_orthogonal_selection` and
1614+
:func:`set_orthogonal_selection` for documentation and examples."""
1615+
returnAsyncOIndex(self)
1616+
1617+
@property
1618+
defvindex(self)->AsyncVIndex[T_ArrayMetadata]:
1619+
"""Shortcut for vectorized (inner) indexing, see :func:`get_coordinate_selection`,
1620+
:func:`set_coordinate_selection`, :func:`get_mask_selection` and
1621+
:func:`set_mask_selection` for documentation and examples."""
1622+
returnAsyncVIndex(self)
1623+
15591624
asyncdefresize(self,new_shape:ShapeLike,delete_outside_chunks:bool=True)->None:
15601625
"""
15611626
Asynchronously resize the array to a new shape.

‎src/zarr/core/chunk_grids.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
ChunkCoords,
1919
ChunkCoordsLike,
2020
ShapeLike,
21+
ceildiv,
2122
parse_named_configuration,
2223
parse_shapelike,
2324
)
24-
fromzarr.core.indexingimportceildiv
2525

2626
ifTYPE_CHECKING:
2727
fromcollections.abcimportIterator

‎src/zarr/core/common.py‎

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
importasyncio
44
importfunctools
5+
importmath
56
importoperator
67
importwarnings
78
fromcollections.abcimportIterable,Mapping,Sequence
@@ -69,6 +70,12 @@ def product(tup: ChunkCoords) -> int:
6970
returnfunctools.reduce(operator.mul,tup,1)
7071

7172

73+
defceildiv(a:float,b:float)->int:
74+
ifa==0:
75+
return0
76+
returnmath.ceil(a/b)
77+
78+
7279
T=TypeVar("T",bound=tuple[Any, ...])
7380
V=TypeVar("V")
7481

‎src/zarr/core/indexing.py‎

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
fromtypingimport (
1313
TYPE_CHECKING,
1414
Any,
15+
Generic,
1516
Literal,
1617
NamedTuple,
1718
Protocol,
@@ -25,14 +26,16 @@
2526
importnumpyasnp
2627
importnumpy.typingasnpt
2728

28-
fromzarr.core.commonimportproduct
29+
fromzarr.core.commonimportceildiv,product
30+
fromzarr.core.metadataimportT_ArrayMetadata
2931

3032
ifTYPE_CHECKING:
31-
fromzarr.core.arrayimportArray
33+
fromzarr.core.arrayimportArray,AsyncArray
3234
fromzarr.core.bufferimportNDArrayLikeOrScalar
3335
fromzarr.core.chunk_gridsimportChunkGrid
3436
fromzarr.core.commonimportChunkCoords
3537

38+
3639
IntSequence=list[int]|npt.NDArray[np.intp]
3740
ArrayOfIntOrBool=npt.NDArray[np.intp]|npt.NDArray[np.bool_]
3841
BasicSelector=int|slice|EllipsisType
@@ -93,12 +96,6 @@ class Indexer(Protocol):
9396
def__iter__(self)->Iterator[ChunkProjection]: ...
9497

9598

96-
defceildiv(a:float,b:float)->int:
97-
ifa==0:
98-
return0
99-
returnmath.ceil(a/b)
100-
101-
10299
_ArrayIndexingOrder:TypeAlias=Literal["lexicographic"]
103100

104101

@@ -960,6 +957,25 @@ def __setitem__(self, selection: OrthogonalSelection, value: npt.ArrayLike) -> N
960957
)
961958

962959

960+
@dataclass(frozen=True)
961+
classAsyncOIndex(Generic[T_ArrayMetadata]):
962+
array:AsyncArray[T_ArrayMetadata]
963+
964+
asyncdefgetitem(self,selection:OrthogonalSelection|Array)->NDArrayLikeOrScalar:
965+
fromzarr.core.arrayimportArray
966+
967+
# if input is a Zarr array, we materialize it now.
968+
ifisinstance(selection,Array):
969+
selection=_zarr_array_to_int_or_bool_array(selection)
970+
971+
fields,new_selection=pop_fields(selection)
972+
new_selection=ensure_tuple(new_selection)
973+
new_selection=replace_lists(new_selection)
974+
returnawaitself.array.get_orthogonal_selection(
975+
cast(OrthogonalSelection,new_selection),fields=fields
976+
)
977+
978+
963979
@dataclass(frozen=True)
964980
classBlockIndexer(Indexer):
965981
dim_indexers:list[SliceDimIndexer]
@@ -1268,6 +1284,32 @@ def __setitem__(
12681284
raiseVindexInvalidSelectionError(new_selection)
12691285

12701286

1287+
@dataclass(frozen=True)
1288+
classAsyncVIndex(Generic[T_ArrayMetadata]):
1289+
array:AsyncArray[T_ArrayMetadata]
1290+
1291+
# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool
1292+
asyncdefgetitem(
1293+
self,selection:CoordinateSelection|MaskSelection|Array
1294+
)->NDArrayLikeOrScalar:
1295+
# TODO deduplicate these internals with the sync version of getitem
1296+
# TODO requires solving this circular sync issue: https://github.com/zarr-developers/zarr-python/pull/3083#discussion_r2230737448
1297+
fromzarr.core.arrayimportArray
1298+
1299+
# if input is a Zarr array, we materialize it now.
1300+
ifisinstance(selection,Array):
1301+
selection=_zarr_array_to_int_or_bool_array(selection)
1302+
fields,new_selection=pop_fields(selection)
1303+
new_selection=ensure_tuple(new_selection)
1304+
new_selection=replace_lists(new_selection)
1305+
ifis_coordinate_selection(new_selection,self.array.shape):
1306+
returnawaitself.array.get_coordinate_selection(new_selection,fields=fields)
1307+
elifis_mask_selection(new_selection,self.array.shape):
1308+
returnawaitself.array.get_mask_selection(new_selection,fields=fields)
1309+
else:
1310+
raiseVindexInvalidSelectionError(new_selection)
1311+
1312+
12711313
defcheck_fields(fields:Fields|None,dtype:np.dtype[Any])->np.dtype[Any]:
12721314
# early out
12731315
iffieldsisNone:

‎tests/test_array.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
fromzarr.core.bufferimportNDArrayLike,NDArrayLikeOrScalar,default_buffer_prototype
4242
fromzarr.core.chunk_gridsimport_auto_partition
4343
fromzarr.core.chunk_key_encodingsimportChunkKeyEncodingParams
44-
fromzarr.core.commonimportJSON,ZarrFormat
44+
fromzarr.core.commonimportJSON,ZarrFormat,ceildiv
4545
fromzarr.core.dtypeimport (
4646
DateTime64,
4747
Float32,
@@ -59,7 +59,7 @@
5959
fromzarr.core.dtype.npy.commonimportNUMPY_ENDIANNESS_STR,endianness_from_numpy_str
6060
fromzarr.core.dtype.npy.stringimportUTF8Base
6161
fromzarr.core.groupimportAsyncGroup
62-
fromzarr.core.indexingimportBasicIndexer,ceildiv
62+
fromzarr.core.indexingimportBasicIndexer
6363
fromzarr.core.metadata.v2importArrayV2Metadata
6464
fromzarr.core.metadata.v3importArrayV3Metadata
6565
fromzarr.core.syncimportsync

‎tests/test_indexing.py‎

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,3 +1994,110 @@ def test_iter_chunk_regions():
19941994
assert_array_equal(a[region],np.ones_like(a[region]))
19951995
a[region]=0
19961996
assert_array_equal(a[region],np.zeros_like(a[region]))
1997+
1998+
1999+
classTestAsync:
2000+
@pytest.mark.parametrize(
2001+
("indexer","expected"),
2002+
[
2003+
# int
2004+
((0,),np.array([1,2])),
2005+
((1,),np.array([3,4])),
2006+
((0,1),np.array(2)),
2007+
# slice
2008+
((slice(None),),np.array([[1,2], [3,4]])),
2009+
((slice(0,1),),np.array([[1,2]])),
2010+
((slice(1,2),),np.array([[3,4]])),
2011+
((slice(0,2),),np.array([[1,2], [3,4]])),
2012+
((slice(0,0),),np.empty(shape=(0,2),dtype="i8")),
2013+
# ellipsis
2014+
((...,),np.array([[1,2], [3,4]])),
2015+
((0, ...),np.array([1,2])),
2016+
((...,0),np.array([1,3])),
2017+
((0,1, ...),np.array(2)),
2018+
# combined
2019+
((0,slice(None)),np.array([1,2])),
2020+
((slice(None),0),np.array([1,3])),
2021+
((slice(None),slice(None)),np.array([[1,2], [3,4]])),
2022+
# array of ints
2023+
(([0]),np.array([[1,2]])),
2024+
(([1]),np.array([[3,4]])),
2025+
(([0], [1]),np.array(2)),
2026+
(([0,1], [0]),np.array([[1], [3]])),
2027+
(([0,1], [0,1]),np.array([[1,2], [3,4]])),
2028+
# boolean array
2029+
(np.array([True,True]),np.array([[1,2], [3,4]])),
2030+
(np.array([True,False]),np.array([[1,2]])),
2031+
(np.array([False,True]),np.array([[3,4]])),
2032+
(np.array([False,False]),np.empty(shape=(0,2),dtype="i8")),
2033+
],
2034+
)
2035+
@pytest.mark.asyncio
2036+
asyncdeftest_async_oindex(self,store,indexer,expected):
2037+
z=zarr.create_array(store=store,shape=(2,2),chunks=(1,1),zarr_format=3,dtype="i8")
2038+
z[...]=np.array([[1,2], [3,4]])
2039+
async_zarr=z._async_array
2040+
2041+
result=awaitasync_zarr.oindex.getitem(indexer)
2042+
assert_array_equal(result,expected)
2043+
2044+
@pytest.mark.asyncio
2045+
asyncdeftest_async_oindex_with_zarr_array(self,store):
2046+
z1=zarr.create_array(store=store,shape=(2,2),chunks=(1,1),zarr_format=3,dtype="i8")
2047+
z1[...]=np.array([[1,2], [3,4]])
2048+
async_zarr=z1._async_array
2049+
2050+
# create boolean zarr array to index with
2051+
z2=zarr.create_array(
2052+
store=store,name="z2",shape=(2,),chunks=(1,),zarr_format=3,dtype="?"
2053+
)
2054+
z2[...]=np.array([True,False])
2055+
2056+
result=awaitasync_zarr.oindex.getitem(z2)
2057+
expected=np.array([[1,2]])
2058+
assert_array_equal(result,expected)
2059+
2060+
@pytest.mark.parametrize(
2061+
("indexer","expected"),
2062+
[
2063+
(([0], [0]),np.array(1)),
2064+
(([0,1], [0,1]),np.array([1,4])),
2065+
(np.array([[False,True], [False,True]]),np.array([2,4])),
2066+
],
2067+
)
2068+
@pytest.mark.asyncio
2069+
asyncdeftest_async_vindex(self,store,indexer,expected):
2070+
z=zarr.create_array(store=store,shape=(2,2),chunks=(1,1),zarr_format=3,dtype="i8")
2071+
z[...]=np.array([[1,2], [3,4]])
2072+
async_zarr=z._async_array
2073+
2074+
result=awaitasync_zarr.vindex.getitem(indexer)
2075+
assert_array_equal(result,expected)
2076+
2077+
@pytest.mark.asyncio
2078+
asyncdeftest_async_vindex_with_zarr_array(self,store):
2079+
z1=zarr.create_array(store=store,shape=(2,2),chunks=(1,1),zarr_format=3,dtype="i8")
2080+
z1[...]=np.array([[1,2], [3,4]])
2081+
async_zarr=z1._async_array
2082+
2083+
# create boolean zarr array to index with
2084+
z2=zarr.create_array(
2085+
store=store,name="z2",shape=(2,2),chunks=(1,1),zarr_format=3,dtype="?"
2086+
)
2087+
z2[...]=np.array([[False,True], [False,True]])
2088+
2089+
result=awaitasync_zarr.vindex.getitem(z2)
2090+
expected=np.array([2,4])
2091+
assert_array_equal(result,expected)
2092+
2093+
@pytest.mark.asyncio
2094+
asyncdeftest_async_invalid_indexer(self,store):
2095+
z=zarr.create_array(store=store,shape=(2,2),chunks=(1,1),zarr_format=3,dtype="i8")
2096+
z[...]=np.array([[1,2], [3,4]])
2097+
async_zarr=z._async_array
2098+
2099+
withpytest.raises(IndexError):
2100+
awaitasync_zarr.vindex.getitem("invalid_indexer")
2101+
2102+
withpytest.raises(IndexError):
2103+
awaitasync_zarr.oindex.getitem("invalid_indexer")

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp