Source code for pipefunc.map._storage_array._base

"""Implements the base class and helpers for file/memory-based arrays."""

from __future__ import annotations

import abc
import functools
import itertools
from typing import TYPE_CHECKING, Any

from pipefunc._utils import prod
from pipefunc.map._mapspec import shape_to_strides
from pipefunc.map._shapes import shape_is_resolved

if TYPE_CHECKING:
    from collections.abc import Iterator
    from pathlib import Path

    import numpy as np

    from pipefunc.map._types import ShapeTuple


storage_registry = {}


def iterate_shape_indices(shape: tuple[int, ...]) -> Iterator[tuple[int, ...]]:
    """Iterate over all indices of a given shape."""
    return itertools.product(*map(range, shape))


def select_by_mask(
    mask: tuple[bool, ...],
    tuple1: tuple[Any, ...],
    tuple2: tuple[Any, ...],
) -> tuple[Any, ...]:
    """Select elements from two tuples based on a mask."""
    result = []
    index1, index2 = 0, 0
    for m in mask:
        if m:
            result.append(tuple1[index1])
            index1 += 1
        else:
            result.append(tuple2[index2])
            index2 += 1
    return tuple(result)


[docs] class StorageBase(abc.ABC): """Base class for file-based arrays.""" shape: ShapeTuple internal_shape: ShapeTuple shape_mask: tuple[bool, ...] storage_id: str requires_serialization: bool _is_resolved: bool = False @abc.abstractmethod def __init__( self, folder: str | Path | None, shape: ShapeTuple, internal_shape: ShapeTuple | None = None, shape_mask: tuple[bool, ...] | None = None, ) -> None: ...
[docs] @functools.cached_property def resolved_shape(self) -> tuple[int, ...]: """Return the resolved shape of the array.""" # This cached property (and resolved_internal_shape) only exist to help mypy. # For performance reasons, we assume this is only called once the shape is resolved. assert shape_is_resolved(self.shape) return self.shape
[docs] @functools.cached_property def resolved_internal_shape(self) -> tuple[int, ...]: # See comment in `resolved_shapes`. assert shape_is_resolved(self.internal_shape) return self.internal_shape
[docs] def full_shape_is_resolved(self) -> bool: """Return whether the shape is resolved.""" # This function is called many times, so we cache the result if self._is_resolved: return True self._is_resolved = all(isinstance(s, int) for s in self.shape + self.internal_shape) return self._is_resolved
[docs] @abc.abstractmethod def get_from_index(self, index: int) -> Any: ...
[docs] @abc.abstractmethod def has_index(self, index: int) -> bool: ...
@abc.abstractmethod def __getitem__(self, key: tuple[int | slice, ...]) -> Any: ...
[docs] @abc.abstractmethod def to_array(self, *, splat_internal: bool | None = None) -> np.ma.core.MaskedArray: ...
@property @abc.abstractmethod def mask(self) -> np.ma.core.MaskedArray: ...
[docs] @abc.abstractmethod def mask_linear(self) -> list[bool]: ...
[docs] @abc.abstractmethod def dump(self, key: tuple[int | slice, ...], value: Any) -> None: ...
@property @abc.abstractmethod def dump_in_subprocess(self) -> bool: """Indicates if the storage can be dumped in a subprocess and read by the main process.""" @property def size(self) -> int: """Return number of elements in the array.""" return prod(self.resolved_shape) @property def rank(self) -> int: """Return the rank of the array.""" return len(self.resolved_shape)
[docs] @functools.cached_property def full_shape(self) -> tuple[int, ...]: """Return the full shape of the array.""" full_shape = select_by_mask( self.shape_mask, self.resolved_shape, self.resolved_internal_shape, ) assert shape_is_resolved(full_shape) return full_shape
[docs] @functools.cached_property def strides(self) -> tuple[int, ...]: """Return the strides of the array.""" return shape_to_strides(self.resolved_shape)
[docs] def persist(self) -> None: # noqa: B027 """Save a memory-based storage to disk."""
[docs] def register_storage(cls: type[StorageBase], storage_id: str | None = None) -> None: """Register a StorageBase class. Parameters ---------- cls Storage class that should be registered. storage_id Storage identifier, defaults to the `storage_id` attribute of the class. Notes ----- This function maintains a mapping from storage identifiers to storage classes. When a storage class is registered, it will replace any class previously registered under the same storage identifier, if present. """ if storage_id is None: storage_id = cls.storage_id storage_registry[storage_id] = cls
def normalize_key( key: tuple[int | slice, ...], shape: tuple[int, ...], internal_shape: tuple[int, ...], shape_mask: tuple[bool, ...], *, for_dump: bool = False, ) -> tuple[int | slice, ...]: if not isinstance(key, tuple): key = (key,) expected_rank = sum(shape_mask) if for_dump else len(shape_mask) if len(key) != expected_rank: msg = ( f"Too many indices for array: array is {expected_rank}-dimensional, " f"but {len(key)} were indexed" ) raise IndexError(msg) normalized_key: list[int | slice] = [] shape_index = 0 internal_shape_index = 0 if for_dump: # internal_shape is not involved when dumping shape_mask = (True,) * len(key) for axis, (mask, k) in enumerate(zip(shape_mask, key, strict=True)): if mask: axis_size = shape[shape_index] shape_index += 1 else: axis_size = internal_shape[internal_shape_index] internal_shape_index += 1 if isinstance(k, slice): normalized_key.append(k) else: normalized_k = k if k >= 0 else (k + axis_size) if not (0 <= normalized_k < axis_size): msg = f"Index {k} is out of bounds for axis {axis} with size {axis_size}" raise IndexError(msg) normalized_key.append(normalized_k) return tuple(normalized_key) def get_storage_class(storage: str) -> type[StorageBase]: """Get the storage class by its identifier. See `pipefunc.map.storage_registry` for available storage classes. Parameters ---------- storage The storage class identifier. Returns ------- The storage class. Raises ------ ValueError If the storage class is not found. """ if storage not in storage_registry: available = ", ".join(storage_registry.keys()) msg = f"Storage class `{storage}` not found, only `{available}` available." raise ValueError(msg) return storage_registry[storage]