Source code for pipefunc.map._storage_array._zarr

"""Provides `zarr` integration for `pipefunc`."""

from __future__ import annotations

import itertools
import multiprocessing.managers
import warnings
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast

import cloudpickle
import numpy as np
import zarr  # noqa: TC002
from numcodecs.abc import Codec
from numcodecs.compat import ensure_contiguous_ndarray
from numcodecs.registry import register_codec
from zarr.abc.store import Store
from zarr.api import synchronous as zs
from zarr.core.buffer.core import default_buffer_prototype
from zarr.core.sync import sync
from zarr.dtype import VariableLengthBytes
from zarr.errors import UnstableSpecificationWarning
from zarr.storage import LocalStore, MemoryStore

from pipefunc._utils import prod

from ._base import StorageBase, normalize_key, register_storage, select_by_mask

if TYPE_CHECKING:
    from pipefunc.map._types import ShapeTuple


_FILL_VALUE = b""


def _open_or_create_array(
    store: Store,
    *,
    name: str,
    shape: tuple[int, ...],
    chunks: tuple[int, ...],
    dtype: Any,
    fill_value: Any,
) -> zarr.Array:
    """Open an array (creating it if missing) via the public Zarr API.

    Uses `zarr.api.synchronous.open_array(..., mode='a', ...)` to avoid
    reimplementing open-or-create behavior.
    """
    # Suppress Zarr's UnstableSpecificationWarning for VariableLengthBytes
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", UnstableSpecificationWarning)
        array = zs.open_array(
            store=store,
            path=name,
            mode="a",
            shape=shape,
            dtype=dtype,
            chunks=chunks,
            fill_value=fill_value,
        )
        # Maintain previous safety check: validate existing shape if already created
        if array.shape != shape:
            msg = f"Existing array '{name}' has unexpected shape {array.shape}, expected {shape}."
            raise ValueError(msg)
        return array


def _encode_scalar(codec: CloudPickleCodec, value: Any) -> bytes:
    if isinstance(value, np.ndarray) and value.shape == ():
        value = value.item()
    return codec.encode(value)


def _encode_array(codec: CloudPickleCodec, value: np.ndarray) -> np.ndarray:
    encoded = np.empty(value.shape, dtype=object)
    for idx in np.ndindex(value.shape):
        encoded[idx] = _encode_scalar(codec, value[idx])
    return encoded


def _decode_scalar(codec: CloudPickleCodec, value: Any) -> Any:
    """Decode a single stored element.

    For ``VariableLengthBytes``: slicing returns an ``ndarray`` (``dtype=object``)
    whose elements are Python ``bytes``; scalar indexing returns a 0-D
    ``numpy.ndarray``. Handle both forms and avoid extra branching.
    """
    if isinstance(value, np.ndarray) and value.shape == ():
        value = value.item()
    assert isinstance(value, (bytes, bytearray, np.bytes_))
    buffer = np.frombuffer(value, dtype="uint8")
    return codec.decode(buffer)


def _decode_array(codec: CloudPickleCodec, value: np.ndarray) -> np.ndarray:
    decoded = np.empty(value.shape, dtype=object)
    for idx in np.ndindex(value.shape):
        decoded[idx] = _decode_scalar(codec, value[idx])
    return decoded


def _decode_with_mask(
    codec: CloudPickleCodec,
    data: Any,
    mask: Any,
) -> tuple[np.ndarray, np.ndarray]:
    data_array = np.asarray(data, dtype=object)
    mask_array = np.asarray(mask, dtype=bool)
    if (mask_array.shape == () and data_array.shape != ()) or mask_array.shape != data_array.shape:
        mask_array = np.broadcast_to(mask_array, data_array.shape)

    decoded = np.empty(data_array.shape, dtype=object)
    for idx in np.ndindex(data_array.shape if data_array.shape else (1,)):
        index = idx if data_array.shape else ()
        if mask_array[index]:
            decoded[index] = None
        else:
            decoded[index] = _decode_scalar(codec, data_array[index])

    return decoded, mask_array


def _copy_store(source: Store, destination: Store) -> None:
    """Copy all keys from ``source`` into ``destination`` using the async Store API.

    Always overwrites existing keys in the destination ("replace" semantics).
    """

    async def _copy() -> None:
        # Use public Store API (lazy-open in get/set) without relying on privates
        prototype = default_buffer_prototype()
        async for key in source.list():
            buffer = await source.get(key, prototype=prototype)
            if buffer is not None:
                await destination.set(key, buffer)

    sync(_copy())


[docs] class ZarrFileArray(StorageBase): """Array interface to a Zarr store. Only exists if the `zarr` package is installed! """ storage_id = "zarr_file_array" requires_serialization = True def __init__( self, folder: str | Path | None, shape: ShapeTuple, internal_shape: ShapeTuple | None = None, shape_mask: tuple[bool, ...] | None = None, *, store: Store | str | Path | None = None, object_codec: Any = None, ) -> None: """Initialize the ZarrFileArray.""" if internal_shape and shape_mask is None: msg = "shape_mask must be provided if internal_shape is provided" raise ValueError(msg) if internal_shape is not None and len(shape_mask) != len(shape) + len(internal_shape): # type: ignore[arg-type] msg = "shape_mask must have the same length as shape + internal_shape" raise ValueError(msg) self.folder = Path(folder) if folder is not None else None if isinstance(store, Store): self.store = store else: if store is None and self.folder is None: msg = "Either a `store` or `folder` must be provided" raise ValueError(msg) root = Path(store) if store is not None else self.folder assert root is not None self.store = LocalStore(str(root)) self.object_codec = ( object_codec if object_codec is not None else CloudPickleCodec() ) # Used for encoding and in __repr__ self.shape = tuple(shape) self.shape_mask = tuple(shape_mask) if shape_mask is not None else (True,) * len(shape) self.internal_shape = tuple(internal_shape) if internal_shape is not None else () raw_chunks = select_by_mask(self.shape_mask, (1,) * len(self.shape), self.internal_shape) chunks: tuple[int, ...] = tuple(int(value) for value in raw_chunks) self.array = _open_or_create_array( self.store, name="array", shape=self.full_shape, chunks=chunks, dtype=VariableLengthBytes(), fill_value=_FILL_VALUE, ) self._mask = _open_or_create_array( self.store, name="mask", shape=self.resolved_shape, # Zarr v3 requires an empty tuple for 0-D arrays; using chunks=(1,) # raises ValueError. Use chunks=() for scalars. chunks=() if len(self.resolved_shape) == 0 else (1,) * len(self.resolved_shape), dtype=bool, fill_value=True, ) def __repr__(self) -> str: folder = f"'{self.folder}'" if self.folder is not None else self.folder return ( f"ZarrFileArray(folder={folder}, " f"shape={self.shape}, " f"internal_shape={self.internal_shape}, " f"shape_mask={self.shape_mask}, " f"store={self.store}, " f"object_codec={self.object_codec})" ) @property def size(self) -> int: """Return number of elements in the array.""" return prod(self.resolved_shape) @property def rank(self) -> int: """Return the rank of the array.""" return len(self.resolved_shape)
[docs] def get_from_index(self, index: int) -> Any: """Return the data associated with the given linear index.""" np_index = np.unravel_index(index, self.resolved_shape) full_index = select_by_mask( self.shape_mask, np_index, (slice(None),) * len(self.internal_shape), ) if self._mask[np_index]: return np.ma.masked data = self.array[full_index] # In Zarr v3 array indexing returns a NumPy ndarray (scalar # indices yield a 0-D ndarray); unwrap to a Python scalar for parity decoded = np.asarray(_decode_array(self.object_codec, data), dtype=object) return decoded.item() if decoded.shape == () else decoded
[docs] def has_index(self, index: int) -> bool: """Return whether the given linear index exists.""" np_index = np.unravel_index(index, self.resolved_shape) return not self._mask[np_index]
def __getitem__(self, key: tuple[int | slice, ...]) -> Any: """Return the data associated with the given key.""" data = self.array[key] if self.internal_shape: mask_key = tuple(x for x, m in zip(key, self.shape_mask) if m) if any(isinstance(k, slice) for k in key): mask = self._mask[mask_key] assert len(key) == len(self.shape_mask) slices = tuple( slice(None) if m else None for k, m in zip(key, self.shape_mask) if isinstance(k, slice) ) tile_shape = tuple( 1 if isinstance(sl, slice) else s for sl, s in zip(slices, np.asarray(data).shape) ) mask = mask[slices] mask = np.tile(mask, tile_shape) else: mask = self._mask[mask_key] else: mask = self._mask[key] decoded, mask_array = _decode_with_mask(self.object_codec, data, mask) if decoded.shape == (): if mask_array.item(): return np.ma.masked decoded_value = decoded.item() if isinstance(decoded_value, list): return np.array(decoded_value, dtype=object) return decoded_value return np.ma.masked_array(decoded, mask=mask_array, dtype=object)
[docs] def to_array(self, *, splat_internal: bool | None = None) -> np.ma.core.MaskedArray: """Return the array as a NumPy masked array.""" if splat_internal and not self.internal_shape: msg = "internal_shape must be provided if splat_internal is True" raise ValueError(msg) if splat_internal is None: splat_internal = True if not splat_internal: msg = "splat_internal must be True" raise NotImplementedError(msg) data = self.array[:] mask = self._mask[:] slc = select_by_mask( self.shape_mask, (slice(None),) * len(self.shape), (None,) * len(self.internal_shape), # Adds axes with size 1 ) tile_shape = select_by_mask(self.shape_mask, (1,) * len(self.shape), self.internal_shape) mask = np.tile(mask[slc], tile_shape) decoded, mask_array = _decode_with_mask(self.object_codec, data, mask) return np.ma.MaskedArray(decoded, mask=mask_array, dtype=object)
@property def mask(self) -> np.ma.core.MaskedArray: """Return the mask associated with the array.""" mask = self._mask[:] return np.ma.MaskedArray(mask, dtype=bool)
[docs] def mask_linear(self) -> list[bool]: """Return a list of booleans indicating which elements are missing.""" return list(self._mask[:].flat)
[docs] def dump(self, key: tuple[int | slice, ...], value: Any) -> None: """Dump 'value' into the location associated with 'key'. Examples -------- >>> arr = ZarrFileArray(...) >>> arr.dump((2, 1, 5), dict(a=1, b=2)) """ if any(isinstance(k, slice) for k in key): for external_index in itertools.product(*self._slice_indices(key)): if self.internal_shape: block = np.asarray(value, dtype=object) # in case it's a list if block.shape != self.internal_shape: msg = "Value has incorrect internal_shape" raise ValueError(msg) full_index = select_by_mask( self.shape_mask, external_index, (slice(None),) * len(self.internal_shape), ) encoded = _encode_array(self.object_codec, block) self.array[full_index] = encoded else: encoded_value = _encode_scalar(self.object_codec, value) self._store_scalar_encoded(external_index, encoded_value) self._mask[external_index] = False return if self.internal_shape: block = np.asarray(value, dtype=object) # in case it's a list if block.shape != self.internal_shape: msg = "Value has incorrect internal_shape" raise ValueError(msg) assert len(key) == len(self.shape) full_index = select_by_mask( self.shape_mask, key, (slice(None),) * len(self.internal_shape), ) self.array[full_index] = _encode_array(self.object_codec, block) else: encoded_value = _encode_scalar(self.object_codec, value) assert all(isinstance(k, int) for k in key) indices = cast(tuple[int, ...], key) self._store_scalar_encoded(indices, encoded_value) self._mask[key] = False
def _slice_indices(self, key: tuple[int | slice, ...]) -> list[range]: slice_indices: list[range] = [] for size, k in zip(self.resolved_shape, key, strict=False): if isinstance(k, slice): slice_indices.append(range(*k.indices(size))) else: slice_indices.append(range(k, k + 1)) return slice_indices @property def dump_in_subprocess(self) -> bool: """Indicates if the storage can be dumped in a subprocess and read by the main process.""" return True def _store_scalar_encoded(self, indices: tuple[int, ...], encoded: bytes) -> None: """Store a single serialized value at the provided indices.""" normalized_indices = cast( tuple[int, ...], normalize_key( indices, self.resolved_shape, self.resolved_internal_shape, self.shape_mask, for_dump=True, ), ) slice_key = tuple(slice(i, i + 1) for i in normalized_indices) shaped = np.array([encoded], dtype=object).reshape(*(1,) * len(indices)) self.array[slice_key] = shaped
class _SharedDictStore(MemoryStore): """Custom Store subclass using a shared dictionary.""" def __init__(self, shared_dict: multiprocessing.managers.DictProxy | None = None) -> None: """Initialize the _SharedDictStore. Parameters ---------- shared_dict Shared dictionary to use as the underlying storage, by default None If None, a new shared dictionary will be created. """ if shared_dict is None: shared_dict = multiprocessing.Manager().dict() super().__init__(store_dict=shared_dict)
[docs] class ZarrMemoryArray(ZarrFileArray): """Array interface to an in-memory Zarr store. Only exists if the `zarr` package is installed! """ storage_id = "zarr_memory" requires_serialization = False def __init__( self, folder: str | Path | None, shape: tuple[int, ...], internal_shape: tuple[int, ...] | None = None, shape_mask: tuple[bool, ...] | None = None, *, store: Store | None = None, object_codec: Any = None, ) -> None: """Initialize the ZarrMemoryArray.""" if store is None: store = MemoryStore() super().__init__( folder=folder, shape=shape, internal_shape=internal_shape, shape_mask=shape_mask, store=store, object_codec=object_codec, ) self.load() @property def persistent_store(self) -> Store | None: """Return the persistent store.""" if self.folder is None: # pragma: no cover return None return LocalStore(self.folder)
[docs] def persist(self) -> None: """Persist the memory storage to disk.""" persistent = self.persistent_store if persistent is None: return _copy_store(self.store, persistent)
[docs] def load(self) -> None: """Load the memory storage from disk.""" persistent = self.persistent_store if persistent is None: # pragma: no cover return folder = self.folder if folder is None or not folder.exists(): # pragma: no cover return _copy_store(persistent, self.store)
@property def dump_in_subprocess(self) -> bool: """Indicates if the storage can be dumped in a subprocess and read by the main process.""" return False
[docs] class ZarrSharedMemoryArray(ZarrMemoryArray): """Array interface to a shared memory Zarr store. Only exists if the `zarr` package is installed! """ storage_id = "zarr_shared_memory" requires_serialization = True def __init__( self, folder: str | Path | None, shape: tuple[int, ...], internal_shape: tuple[int, ...] | None = None, shape_mask: tuple[bool, ...] | None = None, *, store: Store | None = None, object_codec: Any = None, ) -> None: """Initialize the ZarrSharedMemoryArray.""" if store is None: store = _SharedDictStore() super().__init__( folder=folder, shape=shape, internal_shape=internal_shape, shape_mask=shape_mask, store=store, object_codec=object_codec, ) @property def dump_in_subprocess(self) -> bool: """Indicates if the storage can be dumped in a subprocess and read by the main process.""" return True
class CloudPickleCodec(Codec): """Codec to encode data as cloudpickled bytes. Useful for encoding an array of Python objects. Parameters ---------- protocol The protocol used to pickle data. Examples -------- >>> from pipefunc.map._storage._zarr import CloudPickleCodec >>> import numpy as np >>> x = np.array(['foo', 'bar', 'baz'], dtype='object') >>> f = CloudPickleCodec() >>> f.decode(f.encode(x)) array(['foo', 'bar', 'baz'], dtype=object) """ codec_id = "cloudpickle" def __init__(self, protocol: int = cloudpickle.DEFAULT_PROTOCOL) -> None: """Initialize the CloudPickleCodec codec. Parameters ---------- protocol The protocol used to pickle data, by default `cloudpickle.DEFAULT_PROTOCOL` """ self.protocol = protocol def encode(self, buf: Any) -> bytes: """Encode the input buffer using CloudPickleCodec. Parameters ---------- buf The input buffer to encode. Returns ------- The cloudpickled data. """ return cloudpickle.dumps(buf, protocol=self.protocol) def decode(self, buf: np.ndarray, out: np.ndarray | None = None) -> Any: """Decode the input buffer using CloudPickleCodec. Parameters ---------- buf The cloudpickled data. out The output array to store the decoded data, by default None Returns ------- The decoded data. """ buf = ensure_contiguous_ndarray(buf) dec = cloudpickle.loads(buf) if out is not None: np.copyto(out, dec) return out return dec def get_config(self) -> dict[str, Any]: """Get the configuration of the codec. Returns ------- The configuration of the codec. """ return {"id": self.codec_id, "protocol": self.protocol} def __repr__(self) -> str: """Return a string representation of the codec.""" return f"CloudPickleCodec(protocol={self.protocol})" register_codec(CloudPickleCodec) register_storage(ZarrFileArray) register_storage(ZarrMemoryArray) register_storage(ZarrSharedMemoryArray)