Source code for pipefunc.exceptions

"""Define error-related classes for `pipefunc`."""

from __future__ import annotations

import datetime
import getpass
import os
import platform
import traceback
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal

import cloudpickle

from pipefunc._utils import get_local_ip

Reason = Literal["input_is_error", "array_contains_errors"]

if TYPE_CHECKING:
    from collections.abc import Callable
    from pathlib import Path

    from pipefunc._error_handling import ErrorInfo


class UnusedParametersError(ValueError):
    """Exception raised when unused parameters are provided to a function."""


def _timestamp() -> str:
    return datetime.datetime.now(tz=datetime.timezone.utc).isoformat()


[docs] @dataclass class ErrorSnapshot: """A snapshot that represents an error in a function call.""" function: Callable[..., Any] exception: Exception args: tuple[Any, ...] kwargs: dict[str, Any] traceback: str = field(init=False) timestamp: str = field(default_factory=_timestamp) user: str = field(default_factory=getpass.getuser) machine: str = field(default_factory=platform.node) ip_address: str = field(default_factory=get_local_ip) current_directory: str = field(default_factory=os.getcwd) def __post_init__(self) -> None: """Initialize the error snapshot with a formatted traceback.""" tb = traceback.format_exception( type(self.exception), self.exception, self.exception.__traceback__, ) self.traceback = "".join(tb) def __repr__(self) -> str: """Return a concise representation for use in arrays and containers.""" func_name = getattr(self.function, "__name__", "?") exc_type = type(self.exception).__name__ return f"ErrorSnapshot({func_name!r}, {exc_type}: {self.exception})" def __str__(self) -> str: """Return a detailed string representation of the error snapshot.""" args_repr = ", ".join(repr(a) for a in self.args) kwargs_repr = ", ".join(f"{k}={v!r}" for k, v in self.kwargs.items()) func_name = f"{self.function.__module__}.{self.function.__qualname__}" return ( "ErrorSnapshot:\n" "--------------\n" f"- 🛠 Function: {func_name}\n" f"- 🚨 Exception type: {type(self.exception).__name__}\n" f"- 💥 Exception message: {self.exception}\n" f"- 📋 Args: ({args_repr})\n" f"- 🗂 Kwargs: {{{kwargs_repr}}}\n" f"- 🕒 Timestamp: {self.timestamp}\n" f"- 👤 User: {self.user}\n" f"- 💻 Machine: {self.machine}\n" f"- 📡 IP Address: {self.ip_address}\n" f"- 📂 Current Directory: {self.current_directory}\n" "\n" "🔁 Reproduce the error by calling `error_snapshot.reproduce()`.\n" "📄 Or see the full stored traceback using `error_snapshot.traceback`.\n" "🔍 Inspect `error_snapshot.args` and `error_snapshot.kwargs`.\n" "💾 Or save the error to a file using `error_snapshot.save_to_file(filename)`" " and load it using `ErrorSnapshot.load_from_file(filename)`." )
[docs] def reproduce(self) -> Any | None: """Attempt to recreate the error by calling the function with stored arguments.""" return self.function(*self.args, **self.kwargs)
[docs] def save_to_file(self, filename: str | Path) -> None: """Save the error snapshot to a file using cloudpickle.""" with open(filename, "wb") as f: # noqa: PTH123 cloudpickle.dump(self, f)
[docs] @classmethod def load_from_file(cls, filename: str | Path) -> ErrorSnapshot: """Load an error snapshot from a file using cloudpickle.""" with open(filename, "rb") as f: # noqa: PTH123 return cloudpickle.load(f)
def _ipython_display_(self) -> None: # pragma: no cover from IPython.display import HTML, display display(HTML(f"<pre>{self}</pre>")) def __getstate__(self) -> dict[str, Any]: """Custom pickling to handle function references using cloudpickle.""" from pipefunc._error_handling import cloudpickle_function_state return cloudpickle_function_state(self.__dict__.copy(), "function") def __setstate__(self, state: dict[str, Any]) -> None: """Custom unpickling to restore function references.""" from pipefunc._error_handling import cloudunpickle_function_state self.__dict__.update(cloudunpickle_function_state(state, "function"))
@dataclass class PropagatedErrorSnapshot: """Represents a function that was skipped due to upstream errors.""" error_info: dict[str, ErrorInfo] # parameter -> error details skipped_function: Callable[..., Any] reason: Reason # normalized reason label attempted_kwargs: dict[str, Any] # kwargs that were not errors timestamp: str = field(default_factory=_timestamp) def __repr__(self) -> str: """Return a concise representation for use in arrays and containers.""" func_name = getattr(self.skipped_function, "__name__", str(self.skipped_function)) return f"PropagatedErrorSnapshot({func_name!r}, reason={self.reason!r})" def __str__(self) -> str: """Return a detailed string representation of the propagated error snapshot.""" func_name = getattr(self.skipped_function, "__name__", str(self.skipped_function)) error_summary = [] for param, info in self.error_info.items(): if info.type == "full": error_summary.append(f"{param} (complete failure)") else: error_summary.append(f"{param} ({info.error_count} errors in array)") return ( f"PropagatedErrorSnapshot: Function '{func_name}' was skipped\n" f"Reason: {self.reason}\n" f"Errors in: {', '.join(error_summary)}" ) def __getstate__(self) -> dict[str, Any]: """Custom pickling to handle function references using cloudpickle.""" from pipefunc._error_handling import cloudpickle_function_state state = cloudpickle_function_state(self.__dict__.copy(), "skipped_function") # Also handle nested ErrorSnapshots in error_info state["error_info"] = self._pickle_error_info(self.error_info) return state def __setstate__(self, state: dict[str, Any]) -> None: """Custom unpickling to restore function references.""" from pipefunc._error_handling import cloudunpickle_function_state state = cloudunpickle_function_state(state, "skipped_function") # Restore error_info state["error_info"] = self._unpickle_error_info(state["error_info"]) self.__dict__.update(state) def _pickle_error_info( self, error_info: dict[str, ErrorInfo], ) -> dict[str, dict[str, Any]]: """Helper to pickle error_info dict that may contain ErrorSnapshots.""" pickled_info = {} for param, info in error_info.items(): # Convert ErrorInfo to dict for pickling info_dict = { "type": info.type, "shape": info.shape, "error_indices": info.error_indices, "error_count": info.error_count, } if info.type == "full" and info.error is not None: # The error might be an ErrorSnapshot or PropagatedErrorSnapshot # Let their own __getstate__ handle it info_dict["error"] = cloudpickle.dumps(info.error) pickled_info[param] = info_dict return pickled_info def _unpickle_error_info( self, pickled_info: dict[str, dict[str, Any]], ) -> dict[str, ErrorInfo]: """Helper to unpickle error_info dict.""" from pipefunc._error_handling import ErrorInfo error_info = {} for param, info_dict in pickled_info.items(): if info_dict["type"] == "full" and "error" in info_dict: serialized_error = info_dict["error"] if isinstance(serialized_error, bytes): error = cloudpickle.loads(serialized_error) else: # NOTE: v0.88 stores serialized bytes, while # older snapshots stored the ErrorSnapshot directly. error = serialized_error error_info[param] = ErrorInfo.from_full_error(error) else: error_info[param] = ErrorInfo( type=info_dict["type"], shape=info_dict.get("shape"), error_indices=info_dict.get("error_indices"), error_count=info_dict.get("error_count"), ) return error_info def get_root_causes(self) -> list[ErrorSnapshot]: """Extract all original ErrorSnapshot objects (for full-error inputs). For array-containing errors (reductions), this currently returns an empty list. Downstream code can still rely on `reason` and `error_info` metadata to understand which parameters contained errors. """ root_causes: list[ErrorSnapshot] = [] for info in self.error_info.values(): if info.type == "full" and info.error is not None: if isinstance(info.error, PropagatedErrorSnapshot): root_causes.extend(info.error.get_root_causes()) elif isinstance(info.error, ErrorSnapshot): root_causes.append(info.error) return root_causes