"""Provides `pipefunc.helpers` module with various tools."""
from __future__ import annotations
import asyncio
import importlib.util
import inspect
import os
import warnings
from collections import Counter
from pathlib import Path
from typing import TYPE_CHECKING, Any
from pipefunc._utils import at_least_tuple, dump, is_running_in_ipynb, load, requires
from pipefunc.map._storage_array._file import FileArray
if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from ipywidgets import Widget
from pipefunc import PipeFunc
from pipefunc._widgets.output_tabs import OutputTabs
from pipefunc.map._result import ResultDict
from pipefunc.map._run import AsyncMap
__all__ = [
"FileArray", # To keep in the same namespace as FileValue
"FileValue",
"chain",
"collect_kwargs",
"gather_maps",
"get_attribute_factory",
"launch_maps",
]
class _ReturnsKwargs:
def __call__(self, **kwargs: Any) -> dict[str, Any]:
"""Returns keyword arguments it receives as a dictionary."""
return kwargs
[docs]
def collect_kwargs(
parameters: tuple[str, ...],
*,
annotations: tuple[type, ...] | None = None,
function_name: str = "call",
) -> Callable[..., dict[str, Any]]:
"""Returns a callable with a signature as specified in ``parameters`` which returns a dict.
Parameters
----------
parameters
Tuple of names, these names will be used for the function parameters.
annotations
Optionally, provide type-annotations for the ``parameters``. Must be
the same length as ``parameters`` or ``None``.
function_name
The ``__name__`` that is assigned to the returned callable.
Returns
-------
Callable that returns the parameters in a dictionary.
Examples
--------
This creates ``def yolo(a: int, b: list[int]) -> dict[str, Any]``:
>>> f = collect_kwargs(("a", "b"), annotations=(int, list[int]), function_name="yolo")
>>> f(a=1, b=2)
{"a": 1, "b": 2}
"""
cls = _ReturnsKwargs()
sig = inspect.signature(cls.__call__)
if annotations is None:
annotations = (inspect.Parameter.empty,) * len(parameters)
elif len(parameters) != len(annotations):
msg = f"`parameters` and `annotations` should have equal length ({len(parameters)}!={len(annotations)})"
raise ValueError(msg)
new_params = [
inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation)
for name, annotation in zip(parameters, annotations)
]
new_sig = sig.replace(parameters=new_params)
def _wrapped(*args: Any, **kwargs: Any) -> Any:
bound = new_sig.bind(*args, **kwargs)
bound.apply_defaults()
return cls(**bound.arguments)
_wrapped.__signature__ = new_sig # type: ignore[attr-defined]
_wrapped.__name__ = function_name
return _wrapped
[docs]
def get_attribute_factory(
attribute_name: str,
parameter_name: str,
parameter_annotation: type = inspect.Parameter.empty,
return_annotation: type = inspect.Parameter.empty,
function_name: str = "get_attribute",
) -> Callable[[Any], Any]:
"""Returns a callable that retrieves an attribute from its input parameter.
Parameters
----------
attribute_name
The name of the attribute to access.
parameter_name
The name of the input parameter.
parameter_annotation
Optional, type annotation for the input parameter.
return_annotation
Optional, type annotation for the return value.
function_name
The ``__name__`` that is assigned to the returned callable.
Returns
-------
Callable that returns an attribute of its input parameter.
Examples
--------
This creates ``def get_data(obj: MyClass) -> int``:
>>> class MyClass:
... def __init__(self, data: int) -> None:
... self.data = data
>>> f = get_attribute_factory("data", parameter_name="obj", parameter_annotation=MyClass, return_annotation=int, function_name="get_data")
>>> f(MyClass(data=123))
123
"""
param = inspect.Parameter(
parameter_name,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
annotation=parameter_annotation,
)
sig = inspect.Signature(parameters=[param], return_annotation=return_annotation)
def _wrapped(*args: Any, **kwargs: Any) -> Any:
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
obj = bound.arguments[parameter_name]
return getattr(obj, attribute_name)
_wrapped.__signature__ = sig # type: ignore[attr-defined]
_wrapped.__name__ = function_name
return _wrapped
[docs]
class FileValue:
"""A reference to a value stored in a file.
This class provides a way to store and load values from files, which is useful
for passing large objects between processes without serializing them directly.
Parameters
----------
path
Path to the file containing the serialized value.
Examples
--------
>>> ref = FileValue.from_data([1, 2, 3], Path("data.pkl"))
>>> ref.load()
[1, 2, 3]
"""
def __init__(self, path: str | Path) -> None:
self.path = Path(path).absolute()
[docs]
def load(self) -> Any:
"""Load the stored data from disk."""
return load(self.path)
[docs]
@classmethod
def from_data(cls, data: Any, path: Path) -> FileValue:
"""Serializes data to the given file path and returns a FileValue to it.
This is useful for preparing a single large, non-iterable object
for use with `pipeline.map` in distributed environments.
The object is stored once on disk, and the lightweight FileValue
can be passed to tasks, which then load the data on demand.
Parameters
----------
data
The Python object to serialize and store.
path
The full file path (including filename) where the data will be stored.
This path must be accessible by all worker nodes if used in
a distributed setting.
Returns
-------
FileValue
A new FileValue instance pointing to the stored data.
"""
path.parent.mkdir(parents=True, exist_ok=True)
dump(data, path)
return cls(path=path)
def _setup_automatic_tab_updates(index_output: int, tabs: OutputTabs, async_map: AsyncMap) -> None:
def create_callback() -> Callable[[asyncio.Task[ResultDict]], None]:
def callback(task: asyncio.Task[ResultDict]) -> None:
if task.exception() is not None:
tabs.set_tab_status(index_output, "failed")
else:
tabs.set_tab_status(index_output, "completed")
return callback
# Set initial status to running and add callbacks
tabs.set_tab_status(index_output, "running")
async_map.task.add_done_callback(create_callback())
[docs]
async def gather_maps(
*async_maps: AsyncMap,
max_concurrent: int = 1,
max_completed_tabs: int | None = None,
_tabs: OutputTabs | None = None,
) -> list[ResultDict]:
"""Run AsyncMap objects with a limit on simultaneous executions.
Parameters
----------
async_maps
`AsyncMap` objects created with ``pipeline.map_async(..., start=False)``.
max_concurrent
Maximum number of concurrent jobs
max_completed_tabs
Maximum number of completed tabs to show. If ``None``, all completed tabs
are shown. Only used if ``display_widgets=True``.
Returns
-------
List of results from each AsyncMap's task
"""
_validate_async_maps(async_maps)
for async_map in async_maps:
if async_map._task is not None:
msg = "`pipeline.map_async(..., start=False)` must be called before `launch_maps`."
raise RuntimeError(msg)
if _tabs is None: # Prefer to get from the caller (in sync context), otherwise create it here
_tabs = _maybe_output_tabs(async_maps, max_completed_tabs)
else:
_tabs._max_completed_tabs = max_completed_tabs
semaphore = asyncio.Semaphore(max_concurrent)
async def run_with_semaphore(index: int, async_map: AsyncMap) -> ResultDict:
async with semaphore:
if _tabs is not None and async_map._display_widgets:
# Cannot use output_context here, because it is not thread-safe
# See https://github.com/jupyter-widgets/ipywidgets/issues/3993
from pipefunc._widgets.progress_ipywidgets import IPyWidgetsProgressTracker
# Disable `display` on the first call to `start`
async_map._display_widgets = False
async_map.start()
widgets = []
if async_map.status_widget is not None: # pragma: no cover
widgets.append(async_map.status_widget.widget)
if isinstance(async_map.progress, IPyWidgetsProgressTracker):
widgets.append(async_map.progress._style())
widgets.append(async_map.progress._widgets)
elif async_map.progress is not None: # pragma: no cover
msg = "Only `show_progress='ipywidgets'` is supported in this tab widget."
widgets.append(msg)
if async_map.multi_run_manager is not None: # pragma: no cover
widgets.append(async_map.multi_run_manager.info())
for widget in widgets:
_register_widget(widget)
_tabs.output(index).append_display_data(widget)
if widgets:
_tabs.show_output(index)
_setup_automatic_tab_updates(index, _tabs, async_map)
else:
async_map.start()
return await async_map.task
tasks = [run_with_semaphore(index, async_map) for index, async_map in enumerate(async_maps)]
return await asyncio.gather(*tasks)
def _maybe_output_tabs(
async_maps: Sequence[AsyncMap],
max_completed_tabs: int | None,
) -> OutputTabs | None:
display_widgets = any(async_map._display_widgets for async_map in async_maps)
has_ipywidgets = importlib.util.find_spec("ipywidgets") is not None
if has_ipywidgets and display_widgets and is_running_in_ipynb():
requires("ipywidgets", reason="tab_widget=True", extras="widgets")
from pipefunc._widgets.output_tabs import OutputTabs
if max_completed_tabs and os.environ.get("VSCODE_PID") is not None: # pragma: no cover
warnings.warn(
"`max_completed_tabs` is buggy in VS Code Jupyter notebook environment.",
stacklevel=2,
)
tabs = OutputTabs(len(async_maps), max_completed_tabs)
tabs.display()
return tabs
return None
def _register_widget(widget: Widget) -> None: # pragma: no cover
"""Register widget in VS Code to work around widget rendering bug.
This is a workaround for VS Code Jupyter notebook environment where
widgets created and immediately used in append_display_data() without
being displayed first don't get properly registered in the widget state.
See: https://github.com/microsoft/vscode-jupyter/issues/16739
"""
if os.environ.get("VSCODE_PID") is None:
return
from IPython.display import display
from ipywidgets import Output
with Output():
display(widget)
[docs]
def launch_maps(
*async_maps: AsyncMap,
max_concurrent: int = 1,
max_completed_tabs: int | None = None,
) -> asyncio.Task[list[ResultDict]]:
"""Launch a collection of map operations to run concurrently in the background.
This is a user-friendly, non-blocking wrapper around ``gather_maps``.
It immediately returns an ``asyncio.Task`` object, which can be awaited
later to retrieve the results. This is ideal for use in interactive
environments like Jupyter notebooks.
Parameters
----------
async_maps
`AsyncMap` objects created with ``pipeline.map_async(..., start=False)``.
max_concurrent
Maximum number of map operations to run at the same time.
max_completed_tabs
Maximum number of completed tabs to show. If ``None``, all completed tabs
are shown. Only used if ``display_widgets=True``.
Returns
-------
asyncio.Task
A task handle representing the background execution of the maps.
``await`` this task to get the list of results.
Examples
--------
>>> # In a Jupyter notebook cell:
>>> task = launch_maps(runners, max_concurrent=2)
>>> # In a later cell:
>>> results = await task
>>> print("Computation finished!")
"""
_validate_async_maps(async_maps)
tabs = _maybe_output_tabs(async_maps, max_completed_tabs)
coro = gather_maps(
*async_maps,
max_concurrent=max_concurrent,
max_completed_tabs=max_completed_tabs,
_tabs=tabs,
)
return asyncio.create_task(coro)
[docs]
def chain(
functions: Sequence[PipeFunc | Callable],
*,
copy: bool = True,
) -> list[PipeFunc]:
"""Return a new list of PipeFuncs connected linearly by applying minimal renames.
The i+1-th function's first parameter is renamed to the i-th function's output name,
creating a linear data flow. Other parameters (including additional inputs) are untouched.
Parameters
----------
functions
Sequence of PipeFuncs (or callables). Callables are wrapped as PipeFuncs with
``output_name=f.__name__``.
copy
If True (default), return copies of the input PipeFuncs; original instances are
not modified.
Returns
-------
list[PipeFunc]
New PipeFunc objects with renames applied so the data flows linearly.
Notes
-----
- If a downstream function already has an *unbound* parameter matching an upstream
output name, no rename is applied (prefer existing matches).
- When no explicit match exists, the first parameter is renamed to the upstream output.
The first parameter must not be bound; if it is, a ValueError is raised.
- If a function has zero parameters (and is not the first in the chain), a ValueError
is raised.
"""
from pipefunc import PipeFunc as _PipeFunc # local import to avoid cyclic in typing
if not functions:
msg = "chain requires at least one function"
raise ValueError(msg)
# Normalize to PipeFunc instances
pfs: list[_PipeFunc] = []
for f in functions:
pf = (
f
if isinstance(f, _PipeFunc)
else _PipeFunc(f, output_name=getattr(f, "__name__", "output"))
)
pfs.append(pf.copy() if copy else pf)
# Nothing to connect if only one
if len(pfs) == 1:
return pfs
# Apply renames to connect each pair
upstream = pfs[0]
for downstream in pfs[1:]:
upstream_outputs = at_least_tuple(upstream.output_name)
free_params = [p for p in downstream.parameters if p not in downstream.bound]
# Prefer existing matches among free parameters
if any(name in free_params for name in upstream_outputs):
upstream = downstream
continue
# No explicit match - validate and rename first parameter
if not downstream.parameters:
msg = f"Function {downstream} has no parameters to receive upstream value."
raise ValueError(msg)
if not free_params:
msg = f"All parameters of {downstream} are bound; cannot auto-select input parameter."
raise ValueError(msg)
# Require first parameter to be non-bound for auto-selection
first_param = downstream.parameters[0]
if first_param in downstream.bound:
upstream_out = upstream_outputs[0]
msg = (
f"chain: First parameter '{first_param}' of {downstream.output_name} is bound.\n"
f"Solution: Either reorder parameters to put the data-flow parameter first,\n"
f"or rename a parameter to '{upstream_out}' to create an explicit match."
)
raise ValueError(msg)
# Rename first parameter to upstream output
desired_name = upstream_outputs[0]
downstream.update_renames({first_param: desired_name}, update_from="current")
upstream = downstream
return pfs
def _validate_async_maps(async_maps: Sequence[AsyncMap]) -> None:
caller_name = inspect.stack()[1].function
_validate_async_maps_length(async_maps, caller_name)
_validate_unique_run_folders(async_maps, caller_name)
_validate_slurm_executor_names(async_maps, caller_name)
def _validate_async_maps_length(async_maps: Sequence[AsyncMap], caller_name: str) -> None:
if len(async_maps) == 0:
msg = f"`{caller_name}` requires at least one `AsyncMap` object."
raise ValueError(msg)
if len(async_maps) == 1 and isinstance(async_maps[0], tuple | list):
msg = (
f"It seems you passed a list or tuple of `AsyncMap` objects as a single argument to `{caller_name}`. "
"Instead, you should unpack the sequence into individual arguments. "
f"For example, use `{caller_name}(*my_async_maps)` instead of `{caller_name}(my_async_maps)`."
)
raise ValueError(msg)
def _validate_unique_run_folders(async_maps: Sequence[AsyncMap], caller_name: str) -> None:
run_folders = [
am.run_info.run_folder for am in async_maps if am.run_info.run_folder is not None
]
if len(run_folders) != len(set(run_folders)):
msg = (
f"All `run_folder`s must be unique among the provided `AsyncMap` objects in `{caller_name}` "
"unless they are None."
)
raise ValueError(msg)
def _validate_slurm_executor_names(async_maps: Sequence[AsyncMap], caller_name: str) -> None:
from pipefunc.map._adaptive_scheduler_slurm_executor import is_slurm_executor
cnt: Counter[str] = Counter()
for am in async_maps:
executors = am._prepared.executor
if executors is None:
continue
for v in executors.values():
if is_slurm_executor(v):
cnt[v.name] += 1
violations = [name for name, count in cnt.items() if count > 1]
if violations:
msg = (
f"All `map_async`s provided to `{caller_name}` that use a `SlurmExecutor`"
" must have instances with a unique `name`."
f" Currently, the following names are used multiple times: {violations}."
" Use `SlurmExecutor(name=...)` to set a unique name."
)
raise ValueError(msg)