11from __future__import annotations
22
33from collections .abc import Iterable ,Mapping ,MutableMapping
4- from dataclasses import dataclass ,field , replace
4+ from dataclasses import dataclass ,replace
55from enum import Enum
66from functools import lru_cache
77from operator import itemgetter
5454from zarr .registry import get_ndbuffer_class ,get_pipeline_class
5555
5656if TYPE_CHECKING :
57- from collections .abc import Awaitable , Callable , Iterator
57+ from collections .abc import Iterator
5858from typing import Self
5959
6060from zarr .core .common import JSON
6161from zarr .core .dtype .wrapper import TBaseDType ,TBaseScalar ,ZDType
6262
6363MAX_UINT_64 = 2 ** 64 - 1
64- ShardMapping = Mapping [tuple [int , ...],Buffer ]
65- ShardMutableMapping = MutableMapping [tuple [int , ...],Buffer ]
64+ ShardMapping = Mapping [tuple [int , ...],Buffer | None ]
65+ ShardMutableMapping = MutableMapping [tuple [int , ...],Buffer | None ]
6666
6767
6868class ShardingCodecIndexLocation (Enum ):
@@ -219,114 +219,6 @@ def __len__(self) -> int:
219219def __iter__ (self )-> Iterator [tuple [int , ...]]:
220220return c_order_iter (self .index .offsets_and_lengths .shape [:- 1 ])
221221
222- def is_empty (self )-> bool :
223- return self .index .is_all_empty ()
224-
225-
226- class _ShardBuilder (_ShardReader ,ShardMutableMapping ):
227- buf :Buffer
228- index :_ShardIndex
229-
230- @classmethod
231- def merge_with_morton_order (
232- cls ,
233- chunks_per_shard :tuple [int , ...],
234- tombstones :set [tuple [int , ...]],
235- * shard_dicts :ShardMapping ,
236- )-> _ShardBuilder :
237- obj = cls .create_empty (chunks_per_shard )
238- for chunk_coords in morton_order_iter (chunks_per_shard ):
239- if chunk_coords in tombstones :
240- continue
241- for shard_dict in shard_dicts :
242- maybe_value = shard_dict .get (chunk_coords ,None )
243- if maybe_value is not None :
244- obj [chunk_coords ]= maybe_value
245- break
246- return obj
247-
248- @classmethod
249- def create_empty (
250- cls ,chunks_per_shard :tuple [int , ...],buffer_prototype :BufferPrototype | None = None
251- )-> _ShardBuilder :
252- if buffer_prototype is None :
253- buffer_prototype = default_buffer_prototype ()
254- obj = cls ()
255- obj .buf = buffer_prototype .buffer .create_zero_length ()
256- obj .index = _ShardIndex .create_empty (chunks_per_shard )
257- return obj
258-
259- def __setitem__ (self ,chunk_coords :tuple [int , ...],value :Buffer )-> None :
260- chunk_start = len (self .buf )
261- chunk_length = len (value )
262- self .buf += value
263- self .index .set_chunk_slice (chunk_coords ,slice (chunk_start ,chunk_start + chunk_length ))
264-
265- def __delitem__ (self ,chunk_coords :tuple [int , ...])-> None :
266- raise NotImplementedError
267-
268- async def finalize (
269- self ,
270- index_location :ShardingCodecIndexLocation ,
271- index_encoder :Callable [[_ShardIndex ],Awaitable [Buffer ]],
272- )-> Buffer :
273- index_bytes = await index_encoder (self .index )
274- if index_location == ShardingCodecIndexLocation .start :
275- empty_chunks_mask = self .index .offsets_and_lengths [...,0 ]== MAX_UINT_64
276- self .index .offsets_and_lengths [~ empty_chunks_mask ,0 ]+= len (index_bytes )
277- index_bytes = await index_encoder (self .index )# encode again with corrected offsets
278- out_buf = index_bytes + self .buf
279- else :
280- out_buf = self .buf + index_bytes
281- return out_buf
282-
283-
284- @dataclass (frozen = True )
285- class _MergingShardBuilder (ShardMutableMapping ):
286- old_dict :_ShardReader
287- new_dict :_ShardBuilder
288- tombstones :set [tuple [int , ...]]= field (default_factory = set )
289-
290- def __getitem__ (self ,chunk_coords :tuple [int , ...])-> Buffer :
291- chunk_bytes_maybe = self .new_dict .get (chunk_coords )
292- if chunk_bytes_maybe is not None :
293- return chunk_bytes_maybe
294- return self .old_dict [chunk_coords ]
295-
296- def __setitem__ (self ,chunk_coords :tuple [int , ...],value :Buffer )-> None :
297- self .new_dict [chunk_coords ]= value
298-
299- def __delitem__ (self ,chunk_coords :tuple [int , ...])-> None :
300- self .tombstones .add (chunk_coords )
301-
302- def __len__ (self )-> int :
303- return self .old_dict .__len__ ()
304-
305- def __iter__ (self )-> Iterator [tuple [int , ...]]:
306- return self .old_dict .__iter__ ()
307-
308- def is_empty (self )-> bool :
309- full_chunk_coords_map = self .old_dict .index .get_full_chunk_map ()
310- full_chunk_coords_map = np .logical_or (
311- full_chunk_coords_map ,self .new_dict .index .get_full_chunk_map ()
312- )
313- for tombstone in self .tombstones :
314- full_chunk_coords_map [tombstone ]= False
315- return bool (np .array_equiv (full_chunk_coords_map ,False ))
316-
317- async def finalize (
318- self ,
319- index_location :ShardingCodecIndexLocation ,
320- index_encoder :Callable [[_ShardIndex ],Awaitable [Buffer ]],
321- )-> Buffer :
322- shard_builder = _ShardBuilder .merge_with_morton_order (
323- self .new_dict .index .chunks_per_shard ,
324- self .tombstones ,
325- self .new_dict ,
326- self .old_dict ,
327- )
328- return await shard_builder .finalize (index_location ,index_encoder )
329-
330222
331223@dataclass (frozen = True )
332224class ShardingCodec (
@@ -573,7 +465,7 @@ async def _encode_single(
573465 )
574466 )
575467
576- shard_builder = _ShardBuilder . create_empty ( chunks_per_shard )
468+ shard_builder = dict . fromkeys ( morton_order_iter ( chunks_per_shard ) )
577469
578470await self .codec_pipeline .write (
579471 [
@@ -589,7 +481,11 @@ async def _encode_single(
589481shard_array ,
590482 )
591483
592- return await shard_builder .finalize (self .index_location ,self ._encode_shard_index )
484+ return await self ._encode_shard_dict (
485+ shard_builder ,
486+ chunks_per_shard = chunks_per_shard ,
487+ buffer_prototype = default_buffer_prototype (),
488+ )
593489
594490async def _encode_partial_single (
595491self ,
@@ -603,15 +499,13 @@ async def _encode_partial_single(
603499chunks_per_shard = self ._get_chunks_per_shard (shard_spec )
604500chunk_spec = self ._get_chunk_spec (shard_spec )
605501
606- shard_dict = _MergingShardBuilder (
607- await self ._load_full_shard_maybe (
608- byte_getter = byte_setter ,
609- prototype = chunk_spec .prototype ,
610- chunks_per_shard = chunks_per_shard ,
611- )
612- or _ShardReader .create_empty (chunks_per_shard ),
613- _ShardBuilder .create_empty (chunks_per_shard ),
502+ shard_reader = await self ._load_full_shard_maybe (
503+ byte_getter = byte_setter ,
504+ prototype = chunk_spec .prototype ,
505+ chunks_per_shard = chunks_per_shard ,
614506 )
507+ shard_reader = shard_reader or _ShardReader .create_empty (chunks_per_shard )
508+ shard_dict = {k :shard_reader .get (k )for k in morton_order_iter (chunks_per_shard )}
615509
616510indexer = list (
617511get_indexer (
@@ -632,16 +526,57 @@ async def _encode_partial_single(
632526 ],
633527shard_array ,
634528 )
529+ buf = await self ._encode_shard_dict (
530+ shard_dict ,
531+ chunks_per_shard = chunks_per_shard ,
532+ buffer_prototype = default_buffer_prototype (),
533+ )
635534
636- if shard_dict . is_empty () :
535+ if buf is None :
637536await byte_setter .delete ()
638537else :
639- await byte_setter .set (
640- await shard_dict .finalize (
641- self .index_location ,
642- self ._encode_shard_index ,
643- )
644- )
538+ await byte_setter .set (buf )
539+
540+ async def _encode_shard_dict (
541+ self ,
542+ map :ShardMapping ,
543+ chunks_per_shard :tuple [int , ...],
544+ buffer_prototype :BufferPrototype ,
545+ )-> Buffer | None :
546+ index = _ShardIndex .create_empty (chunks_per_shard )
547+
548+ buffers = []
549+
550+ template = buffer_prototype .buffer .create_zero_length ()
551+ chunk_start = 0
552+ for chunk_coords in morton_order_iter (chunks_per_shard ):
553+ value = map .get (chunk_coords )
554+ if value is None :
555+ continue
556+
557+ if len (value )== 0 :
558+ continue
559+
560+ chunk_length = len (value )
561+ buffers .append (value )
562+ index .set_chunk_slice (chunk_coords ,slice (chunk_start ,chunk_start + chunk_length ))
563+ chunk_start += chunk_length
564+
565+ if len (buffers )== 0 :
566+ return None
567+
568+ index_bytes = await self ._encode_shard_index (index )
569+ if self .index_location == ShardingCodecIndexLocation .start :
570+ empty_chunks_mask = index .offsets_and_lengths [...,0 ]== MAX_UINT_64
571+ index .offsets_and_lengths [~ empty_chunks_mask ,0 ]+= len (index_bytes )
572+ index_bytes = await self ._encode_shard_index (
573+ index
574+ )# encode again with corrected offsets
575+ buffers .insert (0 ,index_bytes )
576+ else :
577+ buffers .append (index_bytes )
578+
579+ return template .combine (buffers )
645580
646581def _is_total_shard (
647582self ,all_chunk_coords :set [tuple [int , ...]],chunks_per_shard :tuple [int , ...]