Source code for pipefunc.testing
"""Testing utilities for the pipefunc package."""
from __future__ import annotations
import contextlib
import unittest.mock
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Generator
from pipefunc import Pipeline
[docs]
@contextlib.contextmanager
def patch(pipeline: Pipeline, func_name: str) -> Generator[unittest.mock.MagicMock, None, None]:
"""Patch a function within a Pipeline for testing purposes.
This function provides a context manager to temporarily replace the function of a
specified `~pipefunc.PipeFunc` instance within the pipeline.
Parameters
----------
pipeline
The Pipeline instance to be patched.
func_name
The name of the function to be patched. This can be either a simple function
name or a fully qualified name including the module path.
If a dot is present in `func_name`, the function will attempt to match the full
module path and function name. Otherwise, it will use only the function name.
Yields
------
mock
A MagicMock object that can be used to set return values or side effects.
Raises
------
ValueError
If no function with the given name is found in the pipeline.
Examples
--------
>>> @pipefunc(output_name="c")
... def f() -> Any:
... raise ValueError("test")
>>> pipeline = Pipeline([f])
>>> with patch(pipeline, "f") as mock:
... mock.return_value = 1
... print(pipeline("c")) # Prints 1
"""
target_func = None
for f in pipeline.functions:
if isinstance(f.func, unittest.mock.MagicMock):
continue
full_name = f"{f.func.__module__}.{f.func.__name__}"
# Check for full match if there's a dot in func_name, otherwise just use func_name
if ("." in func_name and full_name == func_name) or f.__name__ == func_name:
target_func = f
break
if target_func is None:
msg = f"No function named '{func_name}' found in the pipeline."
raise ValueError(msg)
with unittest.mock.patch.object(target_func, "func") as mock:
yield mock