Source code for pipefunc.map._storage_array._file

# This file is part of the pipefunc package.

from __future__ import annotations

import concurrent.futures
import itertools
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any

import cloudpickle
import numpy as np

from pipefunc._utils import dump, load

from ._base import (
    StorageBase,
    iterate_shape_indices,
    normalize_key,
    register_storage,
    select_by_mask,
)

if TYPE_CHECKING:
    from collections.abc import Iterator

    from pipefunc.map._types import ShapeTuple

storage_registry: dict[str, type[StorageBase]] = {}

FILENAME_TEMPLATE = "__{:d}__.pickle"


[docs] class FileArray(StorageBase): """Array interface to a folder of files on disk. __getitem__ returns "np.ma.masked" for non-existent files. """ storage_id = "file_array" requires_serialization = True def __init__( self, folder: str | Path, shape: ShapeTuple, internal_shape: ShapeTuple | None = None, shape_mask: tuple[bool, ...] | None = None, *, filename_template: str = FILENAME_TEMPLATE, ) -> None: if internal_shape and shape_mask is None: msg = "shape_mask must be provided if internal_shape is provided" raise ValueError(msg) if internal_shape is not None and len(shape_mask) != len(shape) + len(internal_shape): # type: ignore[arg-type] msg = "shape_mask must have the same length as shape + internal_shape" raise ValueError(msg) self.folder = Path(folder).absolute() self.folder.mkdir(parents=True, exist_ok=True) self.shape = tuple(shape) self.filename_template = str(filename_template) self.shape_mask = tuple(shape_mask) if shape_mask is not None else (True,) * len(shape) self.internal_shape = tuple(internal_shape) if internal_shape is not None else () def __repr__(self) -> str: return ( f"FileArray(folder='{self.folder}', " f"shape={self.shape}, " f"internal_shape={self.internal_shape}, " f"shape_mask={self.shape_mask}, " f"filename_template={self.filename_template!r})" ) def _normalize_key( self, key: tuple[int | slice, ...], *, for_dump: bool = False, ) -> tuple[int | slice, ...]: return normalize_key( key, self.resolved_shape, self.resolved_internal_shape, self.shape_mask, for_dump=for_dump, ) def _index_to_file(self, index: int) -> Path: """Return the filename associated with the given index.""" return self.folder / self.filename_template.format(index) def _key_to_file(self, key: tuple[int, ...]) -> Path: """Return the filename associated with the given key.""" index = sum(k * s for k, s in zip(key, self.strides)) return self._index_to_file(index)
[docs] def get_from_index(self, index: int) -> Any: """Return the data associated with the given linear index.""" return load(self._index_to_file(index))
[docs] def has_index(self, index: int) -> bool: """Return whether the given linear index exists.""" return self._index_to_file(index).is_file()
def _files(self) -> Iterator[Path]: """Yield all the filenames that constitute the data in this array.""" return (self._key_to_file(x) for x in iterate_shape_indices(self.resolved_shape)) def _slice_indices( self, key: tuple[int | slice, ...], *, for_dump: bool = False, ) -> list[range]: slice_indices = [] shape_index = 0 internal_shape_index = 0 normalized_key = self._normalize_key(key, for_dump=for_dump) for k, m in zip(normalized_key, self.shape_mask): shape = self.resolved_shape if m else self.resolved_internal_shape index = shape_index if m else internal_shape_index if isinstance(k, slice): slice_indices.append(range(*k.indices(shape[index]))) else: slice_indices.append(range(k, k + 1)) if m: shape_index += 1 else: internal_shape_index += 1 return slice_indices def __getitem__(self, key: tuple[int | slice, ...]) -> Any: normalized_key = self._normalize_key(key) if any(isinstance(k, slice) for k in normalized_key): return self._get_item_sliced(key, normalized_key) external_indices = tuple(i for i, m in zip(normalized_key, self.shape_mask) if m) internal_indices = tuple(i for i, m in zip(normalized_key, self.shape_mask) if not m) file = self._key_to_file(external_indices) # type: ignore[arg-type] if not file.is_file(): return np.ma.masked sub_array = load(file) if internal_indices: sub_array = np.asarray(sub_array) return sub_array[internal_indices] return sub_array def _get_item_sliced( self, key: tuple[int | slice, ...], normalized_key: tuple[int | slice, ...], ) -> np.ma.core.MaskedArray: """Return a sliced view of the array as a masked array.""" slice_indices = self._slice_indices(key) sliced_data = [] sliced_mask = [] for index in itertools.product(*slice_indices): file_key = tuple(i for i, m in zip(index, self.shape_mask) if m) file = self._key_to_file(file_key) if file.is_file(): sub_array = load(file) internal_index = tuple(i for i, m in zip(index, self.shape_mask) if not m) if internal_index: sub_array = np.asarray(sub_array) # could be a list sliced_sub_array = sub_array[internal_index] sliced_data.append(sliced_sub_array) else: sliced_data.append(sub_array) sliced_mask.append(False) else: sliced_data.append(np.ma.masked) sliced_mask.append(True) sliced_array: np.ndarray = np.empty(len(sliced_data), dtype=object) sliced_array[:] = sliced_data mask: np.ndarray = np.array(sliced_mask, dtype=bool) sliced_array = np.ma.masked_array(sliced_array, mask=mask) new_shape = tuple( len(range_) for k, range_ in zip(normalized_key, slice_indices) if isinstance(k, slice) ) return sliced_array.reshape(new_shape)
[docs] def to_array(self, *, splat_internal: bool | None = None) -> np.ma.core.MaskedArray: """Return a masked numpy array containing all the data. The returned numpy array has dtype "object" and a mask for masking out missing data. Parameters ---------- splat_internal If True, the internal array dimensions will be splatted out. If None, it will happen if and only if `internal_shape` is provided. Returns ------- The array containing all the data. """ if splat_internal is None: splat_internal = bool(self.resolved_internal_shape) items = _load_all(map(self._index_to_file, range(self.size))) if not splat_internal: ma_arr = np.empty(self.size, dtype=object) # type: ignore[var-annotated] ma_arr[:] = items mask = self.mask_linear() return np.ma.MaskedArray(ma_arr, mask=mask, dtype=object).reshape(self.resolved_shape) if not self.resolved_internal_shape: msg = "internal_shape must be provided if splat_internal is True" raise ValueError(msg) arr = np.empty(self.full_shape, dtype=object) # type: ignore[var-annotated] full_mask = np.empty(self.full_shape, dtype=bool) # type: ignore[var-annotated] for external_index in iterate_shape_indices(self.resolved_shape): file = self._key_to_file(external_index) if file.is_file(): sub_array = load(file) sub_array = np.asarray(sub_array) # could be a list for internal_index in iterate_shape_indices(self.resolved_internal_shape): full_index = select_by_mask(self.shape_mask, external_index, internal_index) arr[full_index] = sub_array[internal_index] full_mask[full_index] = False else: for internal_index in iterate_shape_indices(self.resolved_internal_shape): full_index = select_by_mask(self.shape_mask, external_index, internal_index) arr[full_index] = np.ma.masked full_mask[full_index] = True return np.ma.MaskedArray(arr, mask=full_mask, dtype=object)
[docs] def mask_linear(self) -> list[bool]: """Return a list of booleans indicating which elements are missing.""" # We use os.listdir to check if a file exists instead of checking with # self._index_to_file(i).is_file() because this is more efficient. existing_files = set(os.listdir(self.folder)) # noqa: PTH208 return [self.filename_template.format(i) not in existing_files for i in range(self.size)]
@property def mask(self) -> np.ma.core.MaskedArray: """Return a masked numpy array containing the mask. The returned numpy array has dtype "bool" and a mask for masking out missing data. """ mask = self.mask_linear() return np.ma.MaskedArray(mask, mask=mask, dtype=bool).reshape(self.shape)
[docs] def dump(self, key: tuple[int | slice, ...], value: Any) -> None: """Dump 'value' into the file associated with 'key'. Examples -------- >>> arr = FileArray(...) >>> arr.dump((2, 1, 5), dict(a=1, b=2)) """ key = self._normalize_key(key, for_dump=True) if not any(isinstance(k, slice) for k in key): dump(value, self._key_to_file(key)) # type: ignore[arg-type] return for index in itertools.product(*self._slice_indices(key, for_dump=True)): file = self._key_to_file(index) dump(value, file)
@property def dump_in_subprocess(self) -> bool: """Indicates if the storage can be dumped in a subprocess and read by the main process.""" return True
[docs] @classmethod def from_data(cls, data: list[Any] | np.ndarray, folder: str | Path) -> FileArray: """Create a FileArray from an existing list or NumPy array. This method serializes the provided `data` (which can be a Python list or a NumPy array) into the FileArray's on-disk format within the specified `folder`. Each element of the input `data` is dumped individually. This is useful for preparing large datasets for use with `pipeline.map` in distributed environments (like SLURM), as it allows `pipefunc` to pass around a lightweight `FileArray` object instead of serializing the entire large dataset for each task. Parameters ---------- data The list or NumPy-like array to store in the FileArray. folder The directory where the FileArray will store its data files. This folder must be accessible by all worker nodes if used in a distributed setting. Returns ------- FileArray A new FileArray instance populated with the provided data. """ file_array = FileArray(folder, np.shape(data)) for index, value in np.ndenumerate(data): file_array.dump(index, value) return file_array
def _read(name: str | Path) -> bytes: """Load file contents as a bytestring.""" with open(name, "rb") as f: # noqa: PTH123 return f.read() def _load_all(filenames: Iterator[Path]) -> list[Any]: def maybe_read(f: Path) -> Any | None: return _read(f) if f.is_file() else None def maybe_load(x: str | None) -> Any | None: return cloudpickle.loads(x) if x is not None else None # Delegate file reading to the threadpool but deserialize sequentially, # as this is pure Python and CPU bound with concurrent.futures.ThreadPoolExecutor() as tex: return [maybe_load(x) for x in tex.map(maybe_read, filenames)] register_storage(FileArray)