Source code for pipefunc.map._storage_array._dict

"""Implements a `dict` based `StorageBase` class."""

from __future__ import annotations

import itertools
import multiprocessing
from pathlib import Path
from typing import TYPE_CHECKING, Any

import numpy as np

from pipefunc._utils import dump, load

from ._base import (
    StorageBase,
    iterate_shape_indices,
    normalize_key,
    register_storage,
    select_by_mask,
)

if TYPE_CHECKING:
    from collections.abc import MutableMapping
    from multiprocessing.managers import DictProxy

    from pipefunc.map._types import ShapeTuple


[docs] class DictArray(StorageBase): """A `numpy.ndarray` backed by a `dict` with internal structure.""" storage_id = "dict" requires_serialization = False def __init__( self, folder: str | Path | None, shape: ShapeTuple, internal_shape: ShapeTuple | None = None, shape_mask: tuple[bool, ...] | None = None, *, mapping: MutableMapping[tuple[int, ...], Any] | None = None, ) -> None: """Create a `numpy.ndarray` backed by a `dict`.""" if internal_shape and shape_mask is None: msg = "shape_mask must be provided if internal_shape is provided" raise ValueError(msg) if internal_shape is not None and len(shape_mask) != len(shape) + len(internal_shape): # type: ignore[arg-type] msg = "shape_mask must have the same length as shape + internal_shape" raise ValueError(msg) self.folder = Path(folder) if folder is not None else folder self.shape = tuple(shape) self.shape_mask = tuple(shape_mask) if shape_mask is not None else (True,) * len(shape) self.internal_shape = tuple(internal_shape) if internal_shape is not None else () if mapping is None: mapping = {} self._dict: dict[tuple[int, ...], Any] = mapping # type: ignore[assignment] self.load() def __repr__(self) -> str: folder = f"'{self.folder}'" if self.folder is not None else self.folder return ( f"DictArray(folder={folder}, " f"shape={self.shape}, " f"internal_shape={self.internal_shape}, " f"shape_mask={self.shape_mask}, " f"mapping={self._dict})" )
[docs] def get_from_index(self, index: int) -> Any: """Return the data associated with the given linear index.""" np_index = np.unravel_index(index, self.resolved_shape) return self._dict[np_index] # type: ignore[index]
[docs] def has_index(self, index: int) -> bool: """Return whether the given linear index exists.""" np_index = np.unravel_index(index, self.resolved_shape) return np_index in self._dict
def _internal_mask(self) -> np.ma.MaskedArray: if self.internal_shape: return np.ma.empty(self.internal_shape, dtype=object) return np.ma.masked def __getitem__(self, key: tuple[int | slice, ...]) -> Any: """Return the data associated with the given key.""" key = normalize_key(key, self.resolved_shape, self.resolved_internal_shape, self.shape_mask) assert len(key) == len(self.full_shape) if any(isinstance(k, slice) for k in key): shape = tuple( len(range(*k.indices(s))) if isinstance(k, slice) else 1 for s, k in zip(self.full_shape, key) ) data: np.ndarray = np.empty(shape, dtype=object) for i, index in enumerate( itertools.product(*self._slice_indices(key, self.full_shape)), ): external_key = tuple(x for x, m in zip(index, self.shape_mask) if m) if self.internal_shape: internal_key = tuple(x for x, m in zip(index, self.shape_mask) if not m) if external_key in self._dict: arr = np.asarray(self._dict[external_key]) value = arr[internal_key] else: value = self._internal_mask()[internal_key] else: # noqa: PLR5501 if external_key in self._dict: value = self._dict[external_key] else: value = self._internal_mask() j = np.unravel_index(i, shape) data[j] = value new_shape = tuple( len(range(*k.indices(s))) for s, k in zip(self.full_shape, key) if isinstance(k, slice) ) return data.reshape(new_shape) external_key = tuple(x for x, m in zip(key, self.shape_mask) if m) # type: ignore[misc] internal_key = tuple(x for x, m in zip(key, self.shape_mask) if not m) # type: ignore[misc] if external_key in self._dict: data = self._dict[external_key] else: return self._internal_mask() if internal_key: arr = np.asarray(data) return arr[internal_key] return data def _slice_indices(self, key: tuple[int | slice, ...], shape: tuple[int, ...]) -> list[range]: assert len(key) == len(shape) slice_indices = [] for size, k in zip(shape, key): if isinstance(k, slice): slice_indices.append(range(*k.indices(size))) else: slice_indices.append(range(k, k + 1)) return slice_indices
[docs] def to_array(self, *, splat_internal: bool | None = None) -> np.ma.core.MaskedArray: """Return the array as a NumPy masked array.""" if splat_internal is None: splat_internal = bool(self.resolved_internal_shape) if not splat_internal: data: np.ndarray = _masked_empty(self.resolved_shape) mask: np.ndarray = np.full(self.resolved_shape, fill_value=True, dtype=bool) for external_index, value in self._dict.items(): data[external_index] = value mask[external_index] = False return np.ma.MaskedArray(data, mask=mask, dtype=object) if not self.resolved_internal_shape: msg = "internal_shape must be provided if splat_internal is True" raise ValueError(msg) data = _masked_empty(self.full_shape) mask = np.full(self.full_shape, fill_value=True, dtype=bool) for external_index, value in self._dict.items(): value_array = np.asarray(value) if value_array.shape == self.resolved_internal_shape: # Normal case - shapes match full_index = select_by_mask( self.shape_mask, external_index, (slice(None),) * len(self.resolved_internal_shape), ) data[full_index] = value_array mask[full_index] = False else: for internal_index in iterate_shape_indices(self.resolved_internal_shape): full_index = select_by_mask(self.shape_mask, external_index, internal_index) data[full_index] = value_array[internal_index] mask[full_index] = False return np.ma.MaskedArray(data, mask=mask, dtype=object)
@property def mask(self) -> np.ma.core.MaskedArray: """Return the mask associated with the array.""" mask: np.ndarray = np.full(self.resolved_shape, fill_value=True, dtype=bool) for external_index in self._dict: mask[external_index] = False return np.ma.MaskedArray(mask, mask=mask, dtype=bool)
[docs] def mask_linear(self) -> list[bool]: """Return a list of booleans indicating which elements are missing.""" return list(self.mask.data[:].flat)
[docs] def dump(self, key: tuple[int | slice, ...], value: Any) -> None: """Dump 'value' into the location associated with 'key'. Examples -------- >>> arr = DictArray(...) >>> arr.dump((2, 1, 5), dict(a=1, b=2)) """ key = normalize_key( key, self.resolved_shape, self.resolved_internal_shape, self.shape_mask, for_dump=True, ) if any(isinstance(k, slice) for k in key): for external_index in itertools.product(*self._slice_indices(key, self.resolved_shape)): if self.resolved_internal_shape: value = np.asarray(value) # in case it's a list assert value.shape == self.resolved_internal_shape self._dict[external_index] = value else: self._dict[external_index] = value return self._dict[key] = value # type: ignore[index]
def _path(self) -> Path: assert self.folder is not None return self.folder / "dict_array.cloudpickle"
[docs] def persist(self) -> None: """Persist the dict storage to disk.""" if self.folder is None: # pragma: no cover return path = self._path() path.parent.mkdir(parents=True, exist_ok=True) dct = self._dict if isinstance(self._dict, dict) else dict(self._dict) dump(dct, path)
[docs] def load(self) -> None: """Load the dict storage from disk.""" if self.folder is None: # pragma: no cover return path = self._path() if not path.exists(): return self._dict = load(path)
@property def dump_in_subprocess(self) -> bool: """Indicates if the storage can be dumped in a subprocess and read by the main process.""" return False
def _masked_empty(shape: tuple[int, ...]) -> np.ndarray: # This is a workaround for the fact that setting `x[:] = np.ma.masked` # sets the elements to 0.0. x: np.ndarray = np.empty((1,), dtype=object) x[0] = np.ma.masked return np.tile(x, shape)
[docs] class SharedMemoryDictArray(DictArray): """Array interface to a shared memory dict store.""" storage_id = "shared_memory_dict" requires_serialization = True def __init__( self, folder: str | Path | None, shape: tuple[int, ...], internal_shape: tuple[int, ...] | None = None, shape_mask: tuple[bool, ...] | None = None, *, mapping: DictProxy[tuple[int, ...], Any] | None = None, ) -> None: """Initialize the SharedMemoryDictArray.""" if mapping is None: manager = multiprocessing.Manager() mapping = manager.dict() super().__init__( folder=folder, shape=shape, internal_shape=internal_shape, shape_mask=shape_mask, mapping=mapping, ) @property def dump_in_subprocess(self) -> bool: """Indicates if the storage can be dumped in a subprocess and read by the main process.""" return True
register_storage(DictArray) register_storage(SharedMemoryDictArray)