Source code for pipefunc.lazy
"""Provides the `pipefunc.lazy` module, which contains functions for lazy evaluation."""
from __future__ import annotations
import contextlib
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, NamedTuple
import networkx as nx
from pipefunc._utils import format_function_call
from pipefunc.cache import SimpleCache
if TYPE_CHECKING:
from collections.abc import Callable, Generator
class _LazyFunction:
"""Lazy function wrapper for deferred evaluation of a function."""
__slots__ = ["_evaluated", "_id", "_result", "args", "func", "kwargs"]
_counter = 0
def __init__(
self,
func: Callable[..., Any],
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
) -> None:
self.func = func
self.args = args
self.kwargs = kwargs or {}
self._result = None
self._evaluated = False
self._id = _LazyFunction._counter
_LazyFunction._counter += 1
if _TASK_GRAPH is not None:
_TASK_GRAPH.graph.add_node(self._id, lazy_func=self)
_TASK_GRAPH.mapping[self._id] = self
def add_edge(arg: Any) -> None:
if isinstance(arg, _LazyFunction):
_TASK_GRAPH.graph.add_edge(arg._id, self._id)
elif isinstance(arg, Iterable):
for item in arg:
if isinstance(item, _LazyFunction):
_TASK_GRAPH.graph.add_edge(item._id, self._id)
for arg in self.args:
add_edge(arg)
if kwargs is not None:
for arg in kwargs.values():
add_edge(arg)
def evaluate(self) -> Any:
"""Evaluate the lazy function and return the result."""
if self._evaluated:
return self._result
args = evaluate_lazy(self.args)
kwargs = evaluate_lazy(self.kwargs)
result = self.func(*args, **kwargs)
self._result = result
self._evaluated = True
return result
def __repr__(self) -> str:
from pipefunc._pipefunc import PipeFunc
func = str(self.func.__name__) if isinstance(self.func, PipeFunc) else str(self.func)
return format_function_call(func, self.args, self.kwargs)
[docs]
class TaskGraph(NamedTuple):
"""A named tuple representing a task graph."""
graph: nx.DiGraph
mapping: dict[int, _LazyFunction]
cache: SimpleCache
[docs]
@contextlib.contextmanager
def construct_dag() -> Generator[TaskGraph, None, None]:
"""Create a directed acyclic graph (DAG) for a pipeline."""
global _TASK_GRAPH
_TASK_GRAPH = TaskGraph(nx.DiGraph(), {}, SimpleCache())
try:
yield _TASK_GRAPH
finally:
_TASK_GRAPH = None
_TASK_GRAPH: TaskGraph | None = None
[docs]
def task_graph() -> TaskGraph | None:
"""Return the task graph."""
return _TASK_GRAPH
[docs]
def evaluate_lazy(x: Any) -> Any:
"""Evaluate a lazy object."""
if isinstance(x, _LazyFunction):
return x.evaluate()
if isinstance(x, dict):
return {k: evaluate_lazy(v) for k, v in x.items()}
if isinstance(x, tuple):
return tuple(evaluate_lazy(v) for v in x)
if isinstance(x, list):
return [evaluate_lazy(v) for v in x]
if isinstance(x, set):
return {evaluate_lazy(v) for v in x}
return x