"""Provides `pipefunc.cache` module with cache classes for memoization and caching."""
from __future__ import annotations
import abc
import array
import collections
import functools
import hashlib
import pickle
import sys
import time
import warnings
from contextlib import nullcontext, suppress
from multiprocessing import Manager
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal
import cloudpickle
if TYPE_CHECKING:
from collections.abc import Callable, Hashable, Iterable
from multiprocessing.managers import DictProxy, ListProxy, SyncManager
class _CacheBase(abc.ABC):
@abc.abstractmethod
def get(self, key: Hashable) -> Any:
raise NotImplementedError
@abc.abstractmethod
def put(self, key: Hashable, value: Any) -> None:
raise NotImplementedError
@abc.abstractmethod
def __contains__(self, key: Hashable) -> bool:
raise NotImplementedError
@abc.abstractmethod
def __len__(self) -> int:
raise NotImplementedError
@abc.abstractmethod
def clear(self) -> None:
raise NotImplementedError
[docs]
class HybridCache(_CacheBase):
"""A hybrid cache implementation.
This uses a combination of Least Frequently Used (LFU) and
Least Computationally Expensive (LCE) strategies for invalidating cache entries.
The cache invalidation strategy calculates a score for each entry based on its
access frequency and computation duration. The entry with the lowest score will
be invalidated when the cache reaches its maximum size.
Attributes
----------
max_size
The maximum number of entries the cache can store.
access_weight
The weight given to the access frequency in the score calculation.
duration_weight
The weight given to the computation duration in the score calculation.
allow_cloudpickle
Use cloudpickle for storing the data in memory if using shared memory.
shared
Whether the cache should be shared between multiple processes.
"""
def __init__(
self,
max_size: int = 128,
access_weight: float = 0.5,
duration_weight: float = 0.5,
*,
allow_cloudpickle: bool = True,
shared: bool = True,
) -> None:
"""Initialize the HybridCache instance."""
if shared:
manager = Manager()
self._cache_dict = manager.dict()
self._access_counts = manager.dict()
self._computation_durations = manager.dict()
self._cache_lock = manager.Lock()
else:
self._cache_dict = {} # type: ignore[assignment]
self._access_counts = {} # type: ignore[assignment]
self._computation_durations = {} # type: ignore[assignment]
self._cache_lock = nullcontext() # type: ignore[assignment]
self.max_size: int = max_size
self.access_weight: float = access_weight
self.duration_weight: float = duration_weight
self.shared: bool = shared
self._allow_cloudpickle: bool = allow_cloudpickle
@property
def cache(self) -> dict[Hashable, Any]:
"""Return the cache entries."""
if not self.shared:
assert isinstance(self._cache_dict, dict)
return self._cache_dict
with self._cache_lock:
return {k: _maybe_load(v, self._allow_cloudpickle) for k, v in self._cache_dict.items()}
@property
def access_counts(self) -> dict[Hashable, int]:
"""Return the access counts of the cache entries."""
if not self.shared:
assert isinstance(self._access_counts, dict)
return self._access_counts
with self._cache_lock:
return dict(self._access_counts.items())
@property
def computation_durations(self) -> dict[Hashable, float]:
"""Return the computation durations of the cache entries."""
if not self.shared:
assert isinstance(self._computation_durations, dict)
return self._computation_durations
with self._cache_lock:
return dict(self._computation_durations.items())
[docs]
def get(self, key: Hashable) -> Any | None:
"""Retrieve a value from the cache by its key.
If the key is present in the cache, its access count is incremented.
Parameters
----------
key
The key associated with the value in the cache.
Returns
-------
The value associated with the key if the key is present in the cache,
otherwise None.
"""
if key not in self._cache_dict:
return None
with self._cache_lock:
self._access_counts[key] += 1
value = self._cache_dict[key]
if self._allow_cloudpickle and self.shared:
value = cloudpickle.loads(value)
return value
[docs]
def put(self, key: Hashable, value: Any, duration: float) -> None: # type: ignore[override]
"""Add a value to the cache with its associated key and computation duration.
If the cache is full, the entry with the lowest score based on the access
frequency and computation duration will be invalidated.
Parameters
----------
key
The key associated with the value.
value
The value to store in the cache.
duration
The duration of the computation that generated the value.
"""
if self._allow_cloudpickle and self.shared:
value = cloudpickle.dumps(value)
with self._cache_lock:
if len(self._cache_dict) >= self.max_size:
self._expire()
self._cache_dict[key] = value
self._access_counts[key] = 1
self._computation_durations[key] = duration
def _expire(self) -> None:
"""Invalidate the entry with the lowest score based on the access frequency."""
# Calculate normalized access frequencies and computation durations
total_access_count = sum(self._access_counts.values())
total_duration = sum(self._computation_durations.values())
normalized_access_counts = {
k: v / total_access_count for k, v in self._access_counts.items()
}
normalized_durations = {
k: v / total_duration for k, v in self._computation_durations.items()
}
# Calculate scores using a weighted sum
scores = {
k: self.access_weight * normalized_access_counts[k]
+ self.duration_weight * normalized_durations[k]
for k in self._access_counts
}
# Find the key with the lowest score
lowest_score_key = min(scores, key=lambda k: scores[k])
del self._cache_dict[lowest_score_key]
del self._access_counts[lowest_score_key]
del self._computation_durations[lowest_score_key]
[docs]
def clear(self) -> None:
"""Clear the cache."""
with self._cache_lock:
self._cache_dict.clear()
self._access_counts.clear()
self._computation_durations.clear()
def __contains__(self, key: Hashable) -> bool:
"""Check if a key is present in the cache.
Parameters
----------
key
The key to check for in the cache.
Returns
-------
True if the key is present in the cache, otherwise False.
"""
return key in self._cache_dict
def __str__(self) -> str:
"""Return a string representation of the HybridCache.
The string representation includes information about the cache, access counts,
and computation durations for each key.
Returns
-------
A string representation of the HybridCache.
"""
cache_str = f"Cache: {self._cache_dict}\n"
access_counts_str = f"Access Counts: {self._access_counts}\n"
computation_durations_str = f"Computation Durations: {self._computation_durations}\n"
return cache_str + access_counts_str + computation_durations_str
def __len__(self) -> int:
"""Return the number of entries in the cache."""
return len(self._cache_dict)
def __getstate__(self) -> dict[str, Any]:
"""Prepare the object for pickling."""
state = self.__dict__.copy()
if self.shared:
# Convert shared structures to regular ones
state["_cache_dict"] = _dict_to_regular(self._cache_dict)
state["_access_counts"] = _dict_to_regular(self._access_counts)
state["_computation_durations"] = _dict_to_regular(self._computation_durations)
# Remove unpicklable lock
state.pop("_cache_lock", None)
return state
def __setstate__(self, state: dict[str, Any]) -> None:
"""Restore the object after unpickling."""
self.__dict__.update(state)
if not self.shared:
return
manager = Manager()
self._cache_dict = _create_shared_dict(manager, self._cache_dict) # type: ignore[arg-type]
self._access_counts = _create_shared_dict(manager, self._access_counts) # type: ignore[arg-type]
self._computation_durations = _create_shared_dict(manager, self._computation_durations) # type: ignore[arg-type]
self._cache_lock = manager.Lock()
def _maybe_load(value: bytes | str, allow_cloudpickle: bool) -> Any:
return cloudpickle.loads(value) if allow_cloudpickle else value
[docs]
class LRUCache(_CacheBase):
"""A shared memory LRU cache implementation.
Parameters
----------
max_size
Cache size of the LRU cache, by default 128.
allow_cloudpickle
Use cloudpickle for storing the data in memory if using shared memory.
shared
Whether the cache should be shared between multiple processes.
"""
def __init__(
self,
*,
max_size: int = 128,
allow_cloudpickle: bool = True,
shared: bool = True,
) -> None:
"""Initialize the cache."""
self.max_size = max_size
self.shared = shared
self._allow_cloudpickle = allow_cloudpickle
if max_size == 0: # pragma: no cover
msg = "max_size must be greater than 0"
raise ValueError(msg)
if shared:
manager = Manager()
self._cache_dict = manager.dict()
self._cache_queue = manager.list()
self._cache_lock = manager.Lock()
else:
self._cache_dict = {} # type: ignore[assignment]
self._cache_queue = [] # type: ignore[assignment]
self._cache_lock = nullcontext() # type: ignore[assignment]
[docs]
def get(self, key: Hashable) -> Any:
"""Get a value from the cache by key."""
if key not in self._cache_dict:
return None
with self._cache_lock:
value = self._cache_dict[key]
# Move key to back of queue
self._cache_queue.remove(key)
self._cache_queue.append(key)
if self._allow_cloudpickle and self.shared:
return cloudpickle.loads(value)
return value
[docs]
def put(self, key: Hashable, value: Any) -> None:
"""Insert a key value pair into the cache."""
if self._allow_cloudpickle and self.shared:
value = cloudpickle.dumps(value)
with self._cache_lock:
self._cache_dict[key] = value
cache_size = len(self._cache_queue)
if cache_size < self.max_size:
self._cache_queue.append(key)
else:
key_to_evict = self._cache_queue.pop(0)
self._cache_dict.pop(key_to_evict)
self._cache_queue.append(key)
def __contains__(self, key: Hashable) -> bool:
"""Check if a key is present in the cache."""
return key in self._cache_dict
@property
def cache(self) -> dict:
"""Returns a copy of the cache."""
if not self.shared:
assert isinstance(self._cache_dict, dict)
return self._cache_dict
with self._cache_lock:
return {k: _maybe_load(v, self._allow_cloudpickle) for k, v in self._cache_dict.items()}
def __len__(self) -> int:
"""Return the number of entries in the cache."""
return len(self._cache_dict)
[docs]
def clear(self) -> None:
"""Clear the cache."""
with self._cache_lock:
keys = list(self._cache_dict.keys())
for key in keys:
del self._cache_dict[key]
del self._cache_queue[:]
def __getstate__(self) -> dict[str, Any]:
"""Prepare the object for pickling."""
state = self.__dict__.copy()
if self.shared:
# Convert shared structures to regular ones
state["_cache_dict"] = _dict_to_regular(self._cache_dict)
state["_cache_queue"] = _list_to_regular(self._cache_queue)
# Remove unpicklable lock
state.pop("_cache_lock", None)
return state
def __setstate__(self, state: dict[str, Any]) -> None:
"""Restore the object after unpickling."""
self.__dict__.update(state)
if not self.shared:
return
manager = Manager()
self._cache_dict = _create_shared_dict(manager, self._cache_dict) # type: ignore[arg-type]
self._cache_queue = _create_shared_list(manager, self._cache_queue) # type: ignore[arg-type]
self._cache_lock = manager.Lock()
[docs]
class SimpleCache(_CacheBase):
"""A simple cache without any eviction strategy."""
def __init__(self) -> None:
"""Initialize the cache."""
self._cache_dict: dict[Hashable, Any] = {}
[docs]
def get(self, key: Hashable) -> Any:
"""Get a value from the cache by key."""
return self._cache_dict.get(key)
[docs]
def put(self, key: Hashable, value: Any) -> None:
"""Insert a key value pair into the cache."""
self._cache_dict[key] = value
def __contains__(self, key: Hashable) -> bool:
"""Check if a key is present in the cache."""
return key in self._cache_dict
@property
def cache(self) -> dict:
"""Returns a copy of the cache."""
return self._cache_dict
def __len__(self) -> int:
"""Return the number of entries in the cache."""
return len(self._cache_dict)
[docs]
def clear(self) -> None:
"""Clear the cache."""
keys = list(self._cache_dict.keys())
for key in keys:
del self._cache_dict[key]
[docs]
class DiskCache(_CacheBase):
"""Disk cache implementation using pickle or cloudpickle for serialization.
Parameters
----------
cache_dir
The directory where the cache files are stored.
max_size
The maximum number of cache files to store. If None, no limit is set.
use_cloudpickle
Use cloudpickle for storing the data in memory.
with_lru_cache
Use an in-memory LRU cache to prevent reading from disk too often.
lru_cache_size
The maximum size of the in-memory LRU cache. Only used if with_lru_cache is True.
lru_shared
Whether the in-memory LRU cache should be shared between multiple processes.
permissions
The file permissions to set for the cache files.
If None, the default permissions are used.
Some examples:
- 0o660 (read/write for owner and group, no access for others)
- 0o644 (read/write for owner, read-only for group and others)
- 0o777 (read/write/execute for everyone - generally not recommended)
- 0o600 (read/write for owner, no access for group and others)
- None (use the system's default umask)
"""
def __init__(
self,
cache_dir: str | Path,
max_size: int | None = None,
*,
use_cloudpickle: bool = True,
with_lru_cache: bool = True,
lru_cache_size: int = 128,
lru_shared: bool = True,
permissions: int | None = None,
) -> None:
self.cache_dir = Path(cache_dir)
self.max_size = max_size
self.use_cloudpickle = use_cloudpickle
self.with_lru_cache = with_lru_cache
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.permissions = permissions
if self.with_lru_cache:
self.lru_cache = LRUCache(
max_size=lru_cache_size,
allow_cloudpickle=use_cloudpickle,
shared=lru_shared,
)
def _get_file_path(self, key: Hashable) -> Path:
key_hash = _pickle_key(key)
return self.cache_dir / f"{key_hash}.pkl"
[docs]
def get(self, key: Hashable) -> Any:
"""Get a value from the cache by key."""
if self.with_lru_cache and key in self.lru_cache:
return self.lru_cache.get(key)
file_path = self._get_file_path(key)
if file_path.exists():
with file_path.open("rb") as f:
value = (
cloudpickle.load(f) if self.use_cloudpickle else pickle.load(f) # noqa: S301
)
if self.with_lru_cache:
self.lru_cache.put(key, value)
return value
return None
[docs]
def put(self, key: Hashable, value: Any) -> None:
"""Insert a key value pair into the cache."""
file_path = self._get_file_path(key)
with file_path.open("wb") as f:
if self.use_cloudpickle:
cloudpickle.dump(value, f)
else:
pickle.dump(value, f)
if self.permissions is not None:
file_path.chmod(self.permissions) # Set permissions after writing
if self.with_lru_cache:
self.lru_cache.put(key, value)
self._evict_if_needed()
def _all_files(self) -> list[Path]:
return list(self.cache_dir.glob("*.pkl"))
def _evict_if_needed(self) -> None:
if self.max_size is not None:
files = self._all_files()
for _ in range(len(files) - self.max_size):
oldest_file = min(files, key=lambda f: f.stat().st_ctime_ns)
try:
oldest_file.unlink()
except PermissionError: # pragma: no cover
warnings.warn(
f"Permission denied when trying to delete {oldest_file}.",
RuntimeWarning,
stacklevel=2,
)
def __contains__(self, key: Hashable) -> bool:
"""Check if a key is present in the cache."""
if self.with_lru_cache and key in self.lru_cache:
return True
file_path = self._get_file_path(key)
return file_path.exists()
def __len__(self) -> int:
"""Return the number of cache files."""
files = self._all_files()
return len(files)
[docs]
def clear(self) -> None:
"""Clear the cache by deleting all cache files."""
for file_path in self._all_files():
with suppress(PermissionError, FileNotFoundError):
file_path.unlink()
if self.with_lru_cache:
self.lru_cache.clear()
@property
def cache(self) -> dict:
"""Returns a copy of the cache, but only if with_lru_cache is True."""
if not self.with_lru_cache: # pragma: no cover
msg = "LRU cache is not enabled."
raise AttributeError(msg)
return self.lru_cache.cache
@property
def shared(self) -> bool:
"""Return whether the cache is shared."""
return self.lru_cache.shared if self.with_lru_cache else True
def _pickle_key(obj: Any) -> str:
# Based on the implementation of `diskcache` although that also
# does pickle_tools.optimize which we don't need here
data = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
return hashlib.md5(data).hexdigest() # noqa: S324
def _cloudpickle_key(obj: Any) -> str:
data = cloudpickle.dumps(obj)
return hashlib.md5(data).hexdigest() # noqa: S324
[docs]
def memoize(
cache: HybridCache | LRUCache | SimpleCache | DiskCache | None = None,
key_func: Callable[..., Hashable] | None = None,
*,
fallback_to_pickle: bool = True,
unhashable_action: Literal["error", "warning", "ignore"] = "error",
) -> Callable:
"""A flexible memoization decorator that works with different cache types.
Parameters
----------
cache
An instance of a cache class (_CacheBase). If None, a SimpleCache is used.
key_func
A function to generate cache keys. If None, the default key generation which
attempts to make all arguments hashable.
fallback_to_pickle
If ``True``, unhashable objects will be pickled to bytes using `cloudpickle` as a last resort.
If ``False``, an exception will be raised for unhashable objects.
Only used if ``key_func`` is None.
unhashable_action
Determines the behavior when encountering unhashable objects:
- "error": Raise an UnhashableError (default).
- "warning": Log a warning and skip caching for that call.
- "ignore": Silently skip caching for that call.
Only used if ``key_func`` is None.
Returns
-------
Decorated function with memoization.
Raises
------
UnhashableError
If the object cannot be made hashable and ``fallback_to_pickle`` is ``False``.
Notes
-----
This function creates a hashable representation of both positional and keyword
arguments, allowing for effective caching of function calls with various
argument types.
"""
if cache is None:
cache = SimpleCache()
def decorator(func: Callable) -> Callable:
@functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if key_func:
key = key_func(*args, **kwargs)
else:
key = try_to_hashable( # type: ignore[assignment]
(args, kwargs),
fallback_to_pickle,
unhashable_action,
func.__name__,
)
if key is UnhashableError:
return func(*args, **kwargs)
if key in cache:
return cache.get(key)
if isinstance(cache, HybridCache):
t_start = time.monotonic()
result = func(*args, **kwargs)
if isinstance(cache, HybridCache):
# For HybridCache, we need to provide a duration
# Here, we're using a default duration of 1.0
cache.put(key, result, time.monotonic() - t_start)
else:
cache.put(key, result)
return result
wrapper.cache = cache # type: ignore[attr-defined]
return wrapper
return decorator
[docs]
def try_to_hashable(
obj: Any,
fallback_to_pickle: bool = True, # noqa: FBT002
unhashable_action: Literal["error", "warning", "ignore"] = "error",
where: str = "function",
) -> Hashable | type[UnhashableError]:
"""Try to convert an object to a hashable representation.
Wrapper around ``to_hashable`` that allows for different actions when encountering
unhashable objects.
Parameters
----------
obj
The object to convert.
fallback_to_pickle
If ``True``, unhashable objects will be pickled to bytes using `cloudpickle` as a last resort.
If ``False``, an exception will be raised for unhashable objects.
unhashable_action
Determines the behavior when encountering unhashable objects:
- ``"error"``: Raise an `UnhashableError` (default).
- ``"warning"``: Log a warning and skip caching for that call.
- ``"ignore"``: Silently skip caching for that call. Returns `UnhashableError`.
where
The location where the unhashable object was encountered.
Used for warning or error messages.
Returns
-------
A hashable representation of the input object.
Raises
------
UnhashableError
If the object cannot be made hashable and ``fallback_to_pickle`` is ``False``.
Notes
-----
This function attempts to create a hashable representation of any input object.
It handles most built-in Python types and some common third-party types like
numpy arrays and pandas Series/DataFrames.
"""
try:
return to_hashable(obj, fallback_to_pickle=fallback_to_pickle)
except UnhashableError:
if unhashable_action == "error":
raise
if unhashable_action == "warning":
warnings.warn(
f"Unhashable arguments in '{where}'. Skipping cache.",
UserWarning,
stacklevel=3,
)
return UnhashableError
def _hashable_iterable(
iterable: Iterable,
fallback_to_pickle: bool,
*,
sort: bool = False,
) -> tuple:
items = sorted(iterable) if sort else iterable
return tuple(to_hashable(item, fallback_to_pickle) for item in items)
def _hashable_mapping(
mapping: dict,
fallback_to_pickle: bool,
*,
sort: bool = False,
) -> tuple:
items = sorted(mapping.items()) if sort else mapping.items()
return tuple((k, to_hashable(v, fallback_to_pickle)) for k, v in items)
# Unique string added to hashable representations to avoid hash collisions
_HASH_MARKER = "__CONVERTED__"
[docs]
def to_hashable( # noqa: C901, PLR0911, PLR0912
obj: Any,
fallback_to_pickle: bool = True, # noqa: FBT002
) -> Any:
"""Convert any object to a hashable representation if not hashable yet.
Parameters
----------
obj
The object to convert.
fallback_to_pickle
If ``True``, unhashable objects will be pickled to bytes using `cloudpickle` as a last resort.
If ``False``, an exception will be raised for unhashable objects.
Returns
-------
A hashable representation of the input object.
Raises
------
UnhashableError
If the object cannot be made hashable and fallback_to_pickle is False.
Notes
-----
This function attempts to create a hashable representation of any input object.
It handles most built-in Python types and some common third-party types like
numpy arrays and pandas Series/DataFrames.
"""
try:
hash(obj)
except Exception: # noqa: BLE001, S110
pass
else:
return obj
tp: type | str = type(obj)
try:
hash(tp)
except Exception: # noqa: BLE001
tp = tp.__name__ # type: ignore[union-attr]
m = _HASH_MARKER
if isinstance(obj, collections.OrderedDict):
return (m, tp, _hashable_mapping(obj, fallback_to_pickle))
if isinstance(obj, collections.defaultdict):
data = (
to_hashable(obj.default_factory, fallback_to_pickle),
_hashable_mapping(obj, fallback_to_pickle, sort=True),
)
return (m, tp, data)
if isinstance(obj, collections.Counter):
return (m, tp, tuple(sorted(obj.items())))
if isinstance(obj, dict):
return (m, tp, _hashable_mapping(obj, fallback_to_pickle, sort=True))
if isinstance(obj, set | frozenset):
return (m, tp, _hashable_iterable(obj, fallback_to_pickle, sort=True))
if isinstance(obj, list | tuple):
return (m, tp, _hashable_iterable(obj, fallback_to_pickle))
if isinstance(obj, collections.deque):
return (m, tp, (obj.maxlen, _hashable_iterable(obj, fallback_to_pickle)))
if isinstance(obj, bytearray):
return (m, tp, tuple(obj))
if isinstance(obj, array.array):
return (m, tp, (obj.typecode, tuple(obj)))
# Handle numpy arrays
if "numpy" in sys.modules and isinstance(obj, sys.modules["numpy"].ndarray):
return (m, tp, (obj.shape, obj.dtype.str, tuple(obj.flatten())))
# Handle pandas Series and DataFrames
if "pandas" in sys.modules:
if isinstance(obj, sys.modules["pandas"].Series):
return (m, tp, (obj.name, to_hashable(obj.to_dict(), fallback_to_pickle)))
if isinstance(obj, sys.modules["pandas"].DataFrame):
return (m, tp, to_hashable(obj.to_dict("list"), fallback_to_pickle))
# Handle polars Series and DataFrames
if "polars" in sys.modules:
polars = sys.modules["polars"]
if isinstance(obj, polars.Series):
# Include dtype to distinguish Series with different dtypes but same values
hsh = (
obj.name,
str(obj.dtype),
to_hashable(obj.to_list(), fallback_to_pickle),
)
return (m, tp, hsh)
if isinstance(obj, polars.DataFrame):
hsh = to_hashable(obj.to_dict(as_series=False), fallback_to_pickle)
return (m, tp, hsh)
if fallback_to_pickle:
try:
return (m, tp, _cloudpickle_key(obj))
except Exception as e:
raise UnhashableError(obj) from e
raise UnhashableError(obj)
[docs]
class UnhashableError(TypeError):
"""Exception raised for objects that cannot be made hashable."""
def __init__(self, obj: Any) -> None:
self.obj = obj
self.message = (
f"Object of type {type(obj)} cannot be hashed using `pipefunc.cache.to_hashable`."
)
super().__init__(self.message)
# Helper functions for pickling
def _dict_to_regular(shared_dict: DictProxy) -> dict:
"""Convert a shared dictionary to a regular dictionary."""
return dict(shared_dict.items())
def _list_to_regular(shared_list: ListProxy) -> list:
"""Convert a shared list to a regular list."""
return list(shared_list)
def _create_shared_dict(manager: SyncManager, regular_dict: dict) -> DictProxy:
"""Create a shared dictionary and populate it."""
shared_dict = manager.dict()
shared_dict.update(regular_dict)
return shared_dict
def _create_shared_list(manager: SyncManager, regular_list: list) -> ListProxy:
"""Create a shared list and populate it."""
shared_list = manager.list()
shared_list.extend(regular_list)
return shared_list