Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commite6ef2b1

Browse files
authored
optimize shard writing (#3561)
* optimize shard writingWriting to sharded arrays was up to 10x slower for largish chunk sizesbecause the _ShardBuilder object has many calls to np.concatenate. Thiscommit coalesces these into a single concatenate call, and improves writeperformance by a factor of 10 on the benchmarking script in#3560.Added a new core.Buffer.combine APIResolves#3560Signed-off-by: Noah D. Brenowitz <nbren12@gmail.com>* remove redundant methodSigned-off-by: Noah D. Brenowitz <nbren12@gmail.com>* remove redundant np.asayarraySigned-off-by: Noah D. Brenowitz <nbren12@gmail.com>* clarify ShardBuilder APIremove inheritance, hide the index attribute and remove some indirectionSigned-off-by: Noah D. Brenowitz <nbren12@gmail.com>* Remove shard builder objectsjust use dictsSigned-off-by: Noah D. Brenowitz <nbren12@gmail.com>* fix missing chunk caseSigned-off-by: Noah D. Brenowitz <nbren12@gmail.com>* add release note---------Signed-off-by: Noah D. Brenowitz <nbren12@gmail.com>
1 parentb3e9aed commite6ef2b1

File tree

5 files changed

+87
-146
lines changed

5 files changed

+87
-146
lines changed

‎changes/3560.bugfix.md‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Improve write performance to large shards by up to 10x.

‎src/zarr/codecs/sharding.py‎

Lines changed: 64 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__importannotations
22

33
fromcollections.abcimportIterable,Mapping,MutableMapping
4-
fromdataclassesimportdataclass,field,replace
4+
fromdataclassesimportdataclass,replace
55
fromenumimportEnum
66
fromfunctoolsimportlru_cache
77
fromoperatorimportitemgetter
@@ -54,15 +54,15 @@
5454
fromzarr.registryimportget_ndbuffer_class,get_pipeline_class
5555

5656
ifTYPE_CHECKING:
57-
fromcollections.abcimportAwaitable,Callable,Iterator
57+
fromcollections.abcimportIterator
5858
fromtypingimportSelf
5959

6060
fromzarr.core.commonimportJSON
6161
fromzarr.core.dtype.wrapperimportTBaseDType,TBaseScalar,ZDType
6262

6363
MAX_UINT_64=2**64-1
64-
ShardMapping=Mapping[tuple[int, ...],Buffer]
65-
ShardMutableMapping=MutableMapping[tuple[int, ...],Buffer]
64+
ShardMapping=Mapping[tuple[int, ...],Buffer|None]
65+
ShardMutableMapping=MutableMapping[tuple[int, ...],Buffer|None]
6666

6767

6868
classShardingCodecIndexLocation(Enum):
@@ -219,114 +219,6 @@ def __len__(self) -> int:
219219
def__iter__(self)->Iterator[tuple[int, ...]]:
220220
returnc_order_iter(self.index.offsets_and_lengths.shape[:-1])
221221

222-
defis_empty(self)->bool:
223-
returnself.index.is_all_empty()
224-
225-
226-
class_ShardBuilder(_ShardReader,ShardMutableMapping):
227-
buf:Buffer
228-
index:_ShardIndex
229-
230-
@classmethod
231-
defmerge_with_morton_order(
232-
cls,
233-
chunks_per_shard:tuple[int, ...],
234-
tombstones:set[tuple[int, ...]],
235-
*shard_dicts:ShardMapping,
236-
)->_ShardBuilder:
237-
obj=cls.create_empty(chunks_per_shard)
238-
forchunk_coordsinmorton_order_iter(chunks_per_shard):
239-
ifchunk_coordsintombstones:
240-
continue
241-
forshard_dictinshard_dicts:
242-
maybe_value=shard_dict.get(chunk_coords,None)
243-
ifmaybe_valueisnotNone:
244-
obj[chunk_coords]=maybe_value
245-
break
246-
returnobj
247-
248-
@classmethod
249-
defcreate_empty(
250-
cls,chunks_per_shard:tuple[int, ...],buffer_prototype:BufferPrototype|None=None
251-
)->_ShardBuilder:
252-
ifbuffer_prototypeisNone:
253-
buffer_prototype=default_buffer_prototype()
254-
obj=cls()
255-
obj.buf=buffer_prototype.buffer.create_zero_length()
256-
obj.index=_ShardIndex.create_empty(chunks_per_shard)
257-
returnobj
258-
259-
def__setitem__(self,chunk_coords:tuple[int, ...],value:Buffer)->None:
260-
chunk_start=len(self.buf)
261-
chunk_length=len(value)
262-
self.buf+=value
263-
self.index.set_chunk_slice(chunk_coords,slice(chunk_start,chunk_start+chunk_length))
264-
265-
def__delitem__(self,chunk_coords:tuple[int, ...])->None:
266-
raiseNotImplementedError
267-
268-
asyncdeffinalize(
269-
self,
270-
index_location:ShardingCodecIndexLocation,
271-
index_encoder:Callable[[_ShardIndex],Awaitable[Buffer]],
272-
)->Buffer:
273-
index_bytes=awaitindex_encoder(self.index)
274-
ifindex_location==ShardingCodecIndexLocation.start:
275-
empty_chunks_mask=self.index.offsets_and_lengths[...,0]==MAX_UINT_64
276-
self.index.offsets_and_lengths[~empty_chunks_mask,0]+=len(index_bytes)
277-
index_bytes=awaitindex_encoder(self.index)# encode again with corrected offsets
278-
out_buf=index_bytes+self.buf
279-
else:
280-
out_buf=self.buf+index_bytes
281-
returnout_buf
282-
283-
284-
@dataclass(frozen=True)
285-
class_MergingShardBuilder(ShardMutableMapping):
286-
old_dict:_ShardReader
287-
new_dict:_ShardBuilder
288-
tombstones:set[tuple[int, ...]]=field(default_factory=set)
289-
290-
def__getitem__(self,chunk_coords:tuple[int, ...])->Buffer:
291-
chunk_bytes_maybe=self.new_dict.get(chunk_coords)
292-
ifchunk_bytes_maybeisnotNone:
293-
returnchunk_bytes_maybe
294-
returnself.old_dict[chunk_coords]
295-
296-
def__setitem__(self,chunk_coords:tuple[int, ...],value:Buffer)->None:
297-
self.new_dict[chunk_coords]=value
298-
299-
def__delitem__(self,chunk_coords:tuple[int, ...])->None:
300-
self.tombstones.add(chunk_coords)
301-
302-
def__len__(self)->int:
303-
returnself.old_dict.__len__()
304-
305-
def__iter__(self)->Iterator[tuple[int, ...]]:
306-
returnself.old_dict.__iter__()
307-
308-
defis_empty(self)->bool:
309-
full_chunk_coords_map=self.old_dict.index.get_full_chunk_map()
310-
full_chunk_coords_map=np.logical_or(
311-
full_chunk_coords_map,self.new_dict.index.get_full_chunk_map()
312-
)
313-
fortombstoneinself.tombstones:
314-
full_chunk_coords_map[tombstone]=False
315-
returnbool(np.array_equiv(full_chunk_coords_map,False))
316-
317-
asyncdeffinalize(
318-
self,
319-
index_location:ShardingCodecIndexLocation,
320-
index_encoder:Callable[[_ShardIndex],Awaitable[Buffer]],
321-
)->Buffer:
322-
shard_builder=_ShardBuilder.merge_with_morton_order(
323-
self.new_dict.index.chunks_per_shard,
324-
self.tombstones,
325-
self.new_dict,
326-
self.old_dict,
327-
)
328-
returnawaitshard_builder.finalize(index_location,index_encoder)
329-
330222

331223
@dataclass(frozen=True)
332224
classShardingCodec(
@@ -573,7 +465,7 @@ async def _encode_single(
573465
)
574466
)
575467

576-
shard_builder=_ShardBuilder.create_empty(chunks_per_shard)
468+
shard_builder=dict.fromkeys(morton_order_iter(chunks_per_shard))
577469

578470
awaitself.codec_pipeline.write(
579471
[
@@ -589,7 +481,11 @@ async def _encode_single(
589481
shard_array,
590482
)
591483

592-
returnawaitshard_builder.finalize(self.index_location,self._encode_shard_index)
484+
returnawaitself._encode_shard_dict(
485+
shard_builder,
486+
chunks_per_shard=chunks_per_shard,
487+
buffer_prototype=default_buffer_prototype(),
488+
)
593489

594490
asyncdef_encode_partial_single(
595491
self,
@@ -603,15 +499,13 @@ async def _encode_partial_single(
603499
chunks_per_shard=self._get_chunks_per_shard(shard_spec)
604500
chunk_spec=self._get_chunk_spec(shard_spec)
605501

606-
shard_dict=_MergingShardBuilder(
607-
awaitself._load_full_shard_maybe(
608-
byte_getter=byte_setter,
609-
prototype=chunk_spec.prototype,
610-
chunks_per_shard=chunks_per_shard,
611-
)
612-
or_ShardReader.create_empty(chunks_per_shard),
613-
_ShardBuilder.create_empty(chunks_per_shard),
502+
shard_reader=awaitself._load_full_shard_maybe(
503+
byte_getter=byte_setter,
504+
prototype=chunk_spec.prototype,
505+
chunks_per_shard=chunks_per_shard,
614506
)
507+
shard_reader=shard_readeror_ShardReader.create_empty(chunks_per_shard)
508+
shard_dict= {k:shard_reader.get(k)forkinmorton_order_iter(chunks_per_shard)}
615509

616510
indexer=list(
617511
get_indexer(
@@ -632,16 +526,57 @@ async def _encode_partial_single(
632526
],
633527
shard_array,
634528
)
529+
buf=awaitself._encode_shard_dict(
530+
shard_dict,
531+
chunks_per_shard=chunks_per_shard,
532+
buffer_prototype=default_buffer_prototype(),
533+
)
635534

636-
ifshard_dict.is_empty():
535+
ifbufisNone:
637536
awaitbyte_setter.delete()
638537
else:
639-
awaitbyte_setter.set(
640-
awaitshard_dict.finalize(
641-
self.index_location,
642-
self._encode_shard_index,
643-
)
644-
)
538+
awaitbyte_setter.set(buf)
539+
540+
asyncdef_encode_shard_dict(
541+
self,
542+
map:ShardMapping,
543+
chunks_per_shard:tuple[int, ...],
544+
buffer_prototype:BufferPrototype,
545+
)->Buffer|None:
546+
index=_ShardIndex.create_empty(chunks_per_shard)
547+
548+
buffers= []
549+
550+
template=buffer_prototype.buffer.create_zero_length()
551+
chunk_start=0
552+
forchunk_coordsinmorton_order_iter(chunks_per_shard):
553+
value=map.get(chunk_coords)
554+
ifvalueisNone:
555+
continue
556+
557+
iflen(value)==0:
558+
continue
559+
560+
chunk_length=len(value)
561+
buffers.append(value)
562+
index.set_chunk_slice(chunk_coords,slice(chunk_start,chunk_start+chunk_length))
563+
chunk_start+=chunk_length
564+
565+
iflen(buffers)==0:
566+
returnNone
567+
568+
index_bytes=awaitself._encode_shard_index(index)
569+
ifself.index_location==ShardingCodecIndexLocation.start:
570+
empty_chunks_mask=index.offsets_and_lengths[...,0]==MAX_UINT_64
571+
index.offsets_and_lengths[~empty_chunks_mask,0]+=len(index_bytes)
572+
index_bytes=awaitself._encode_shard_index(
573+
index
574+
)# encode again with corrected offsets
575+
buffers.insert(0,index_bytes)
576+
else:
577+
buffers.append(index_bytes)
578+
579+
returntemplate.combine(buffers)
645580

646581
def_is_total_shard(
647582
self,all_chunk_coords:set[tuple[int, ...]],chunks_per_shard:tuple[int, ...]

‎src/zarr/core/buffer/core.py‎

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
importsys
44
fromabcimportABC,abstractmethod
5+
fromcollections.abcimportIterable
56
fromtypingimport (
67
TYPE_CHECKING,
78
Any,
@@ -294,9 +295,13 @@ def __len__(self) -> int:
294295
returnself._data.size
295296

296297
@abstractmethod
298+
defcombine(self,others:Iterable[Buffer])->Self:
299+
"""Concatenate many buffers"""
300+
...
301+
297302
def__add__(self,other:Buffer)->Self:
298303
"""Concatenate two buffers"""
299-
...
304+
returnself.combine([other])
300305

301306
def__eq__(self,other:object)->bool:
302307
# Another Buffer class can override this to choose a more efficient path

‎src/zarr/core/buffer/cpu.py‎

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,13 @@ def as_numpy_array(self) -> npt.NDArray[Any]:
107107
"""
108108
returnnp.asanyarray(self._data)
109109

110-
def__add__(self,other:core.Buffer)->Self:
111-
"""Concatenate two buffers"""
112-
113-
other_array=other.as_array_like()
114-
assertother_array.dtype==np.dtype("B")
115-
returnself.__class__(
116-
np.concatenate((np.asanyarray(self._data),np.asanyarray(other_array)))
117-
)
110+
defcombine(self,others:Iterable[core.Buffer])->Self:
111+
data= [np.asanyarray(self._data)]
112+
forbufinothers:
113+
other_array=buf.as_array_like()
114+
assertother_array.dtype==np.dtype("B")
115+
data.append(np.asanyarray(other_array))
116+
returnself.__class__(np.concatenate(data))
118117

119118

120119
classNDBuffer(core.NDBuffer):

‎src/zarr/core/buffer/gpu.py‎

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,15 @@ def from_bytes(cls, bytes_like: BytesLike) -> Self:
107107
defas_numpy_array(self)->npt.NDArray[Any]:
108108
returncast("npt.NDArray[Any]",cp.asnumpy(self._data))
109109

110-
def__add__(self,other:core.Buffer)->Self:
111-
other_array=other.as_array_like()
112-
assertother_array.dtype==np.dtype("B")
113-
gpu_other=Buffer(other_array)
114-
gpu_other_array=gpu_other.as_array_like()
115-
returnself.__class__(
116-
cp.concatenate((cp.asanyarray(self._data),cp.asanyarray(gpu_other_array)))
117-
)
110+
defcombine(self,others:Iterable[core.Buffer])->Self:
111+
data= [cp.asanyarray(self._data)]
112+
forotherinothers:
113+
other_array=other.as_array_like()
114+
assertother_array.dtype==np.dtype("B")
115+
gpu_other=Buffer(other_array)
116+
gpu_other_array=gpu_other.as_array_like()
117+
data.append(cp.asanyarray(gpu_other_array))
118+
returnself.__class__(cp.concatenate(data))
118119

119120

120121
classNDBuffer(core.NDBuffer):

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp