from __future__ import annotations
import functools
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Literal
from pipefunc import PipeFunc, Pipeline
from pipefunc._utils import assert_complete_kwargs, is_installed, is_running_in_ipynb, requires
if TYPE_CHECKING:
from collections.abc import Callable
import ipywidgets
from pipefunc._pipefunc import PipeFunc
[docs]
class VariantPipeline:
"""A pipeline container that supports multiple implementations (variants) of functions.
`VariantPipeline` allows you to define multiple implementations of functions and
select which variant to use at runtime. This is particularly useful for:
- A/B testing different implementations
- Experimenting with algorithm variations
- Managing multiple processing options
- Creating configurable pipelines
The pipeline can have multiple variant groups, where each group contains alternative
implementations of a function. Functions can be assigned to variant groups using
the ``variant`` parameter which can be a single string (for the default group) or
a dictionary mapping group names to variant names.
All parameters below (except ``functions`` and ``default_variant``) are simply passed to
the `~pipefunc.Pipeline` constructor when creating a new pipeline with the selected
variant(s) using the `with_variant` method.
Parameters
----------
functions
List of `PipeFunc` instances.
default_variant
Default variant to use if none is specified in `with_variant`.
Either a single variant name or a dictionary mapping variant
groups to variants.
lazy
Flag indicating whether the pipeline should be lazy.
debug
Flag indicating whether debug information should be printed.
If ``None``, the value of each PipeFunc's debug attribute is used.
print_error
Flag indicating whether errors raised during the function execution should
be printed.
If ``None``, the value of each PipeFunc's print_error attribute is used.
profile
Flag indicating whether profiling information should be collected.
If ``None``, the value of each PipeFunc's profile attribute is used.
Profiling is only available for sequential execution.
cache_type
The type of cache to use. See the notes below for more *important* information.
cache_kwargs
Keyword arguments passed to the cache constructor.
validate_type_annotations
Flag indicating whether type validation should be performed. If ``True``,
the type annotations of the functions are validated during the pipeline
initialization. If ``False``, the type annotations are not validated.
scope
If provided, *all* parameter names and output names of the pipeline functions will
be prefixed with the specified scope followed by a dot (``'.'``), e.g., parameter
``x`` with scope ``foo`` becomes ``foo.x``. This allows multiple functions in a
pipeline to have parameters with the same name without conflict. To be selective
about which parameters and outputs to include in the scope, use the
`Pipeline.update_scope` method.
When providing parameter values for pipelines that have scopes, they can
be provided either as a dictionary for the scope, or by using the
``f'{scope}.{name}'`` notation. For example,
a `Pipeline` instance with scope "foo" and "bar", the parameters
can be provided as:
``pipeline(output_name, foo=dict(a=1, b=2), bar=dict(a=3, b=4))`` or
``pipeline(output_name, **{"foo.a": 1, "foo.b": 2, "bar.a": 3, "bar.b": 4})``.
default_resources
Default resources to use for the pipeline functions. If ``None``,
the resources are not set. Either a dict or a `pipefunc.resources.Resources`
instance can be provided. If provided, the resources in the `PipeFunc`
instances are updated with the default resources.
name
A name for the pipeline. If provided, it will be used to generate e.g., docs
and MCP server descriptions.
description
A description of the pipeline. If provided, it will be used to generate e.g., docs
and MCP server descriptions.
Examples
--------
Simple variant selection:
>>> @pipefunc(output_name="c", variant="add")
... def f(a, b):
... return a + b
...
>>> @pipefunc(output_name="c", variant="sub")
... def f_alt(a, b):
... return a - b
...
>>> @pipefunc(output_name="d")
... def g(b, c):
... return b * c
...
>>> pipeline = VariantPipeline([f, f_alt, g], default_variant="add")
>>> pipeline_add = pipeline.with_variant() # Uses default variant
>>> pipeline_sub = pipeline.with_variant(select="sub")
>>> pipeline_add(a=2, b=3) # (2 + 3) * 3 = 15
15
>>> pipeline_sub(a=2, b=3) # (2 - 3) * 3 = -3
-3
Multiple variant groups:
>>> @pipefunc(output_name="c", variant={"method": "add"})
... def f1(a, b):
... return a + b
...
>>> @pipefunc(output_name="c", variant={"method": "sub"})
... def f2(a, b):
... return a - b
...
>>> @pipefunc(output_name="d", variant={"analysis": "mul"})
... def g1(b, c):
... return b * c
...
>>> @pipefunc(output_name="d", variant={"analysis": "div"})
... def g2(b, c):
... return b / c
...
>>> pipeline = VariantPipeline(
... [f1, f2, g1, g2],
... default_variant={"method": "add", "analysis": "mul"}
... )
>>> # Select specific variants for each group
>>> pipeline_sub_div = pipeline.with_variant(
... select={"method": "sub", "analysis": "div"}
... )
Notes
-----
- Functions without variants can be included in the pipeline and will be used
regardless of variant selection.
- When using ``with_variant()``, if all variants are resolved, a regular `~pipefunc.Pipeline`
is returned. If some variants remain unselected, another `VariantPipeline` is
returned.
- The ``default_variant`` can be a single string (if there's only one variant group)
or a dictionary mapping variant groups to their default variants.
- Variants in the same group can have different output names, allowing for
flexible pipeline structures.
See Also
--------
pipefunc.Pipeline
The base pipeline class.
pipefunc.PipeFunc
Function wrapper that supports variants.
"""
def __init__(
self,
functions: list[PipeFunc],
*,
default_variant: str | dict[str | None, str] | None = None,
lazy: bool = False,
debug: bool | None = None,
print_error: bool | None = None,
profile: bool | None = None,
cache_type: Literal["lru", "hybrid", "disk", "simple"] | None = None,
cache_kwargs: dict[str, Any] | None = None,
validate_type_annotations: bool = True,
scope: str | None = None,
default_resources: dict[str, Any] | None = None,
name: str | None = None,
description: str | None = None,
) -> None:
"""Initialize a VariantPipeline."""
self.functions = functions
self.default_variant = default_variant
self.lazy = lazy
self.debug = debug
self.print_error = print_error
self.profile = profile
self.cache_type = cache_type
self.cache_kwargs = cache_kwargs
self.validate_type_annotations = validate_type_annotations
self.scope = scope
self.default_resources = default_resources
self.name = name
self.description = description
if not self.variants_mapping():
msg = "No variants found in the pipeline. Use a regular `Pipeline` instead."
raise ValueError(msg)
[docs]
def variants_mapping(self) -> dict[str | None, set[str]]:
"""Return a dictionary of variant groups and their variants."""
variant_groups: dict[str | None, set[str]] = {}
for function in self.functions:
for group, variant in function.variant.items():
variants = variant_groups.setdefault(group, set())
variants.add(variant)
return variant_groups
def _variants_mapping_inverse(self) -> dict[str, set[str | None]]:
"""Return a dictionary of variants and their variant groups."""
variants: dict[str, set[str | None]] = {}
for function in self.functions:
for group, variant in function.variant.items():
groups = variants.setdefault(variant, set())
groups.add(group)
return variants
[docs]
def with_variant(
self,
select: str | dict[str | None, str] | None = None,
**kwargs: Any,
) -> Pipeline | VariantPipeline:
"""Create a new Pipeline or VariantPipeline with the specified variant selected.
Parameters
----------
select
Name of the variant to select. If not provided, `default_variant` is used.
If `select` is a string, it selects a single variant if no ambiguity exists.
If `select` is a dictionary, it selects a variant for each variant group, where
the keys are variant group names and the values are variant names.
If a partial dictionary is provided (not covering all variant groups) and
default_variant is a dictionary, it will merge the defaults with the selection.
kwargs
Keyword arguments for changing the parameters for a Pipeline or VariantPipeline.
Returns
-------
A new Pipeline or VariantPipeline with the selected variant(s).
If variants remain, a VariantPipeline is returned.
If no variants remain, a Pipeline is returned.
Raises
------
ValueError
If the specified variant is ambiguous or unknown, or if an invalid variant type is provided.
TypeError
If `select` is not a string or a dictionary.
"""
if select is None:
if self.default_variant is None:
msg = "No variant selected and no default variant provided."
raise ValueError(msg)
select = self.default_variant
if isinstance(select, str):
select = self._resolve_single_variant(select)
elif not isinstance(select, dict):
msg = f"Invalid variant type: `{type(select)}`. Expected `str` or `dict`."
raise TypeError(msg)
if isinstance(self.default_variant, dict):
select = self.default_variant | select
assert isinstance(select, dict)
_validate_variants_exist(self.variants_mapping(), select)
new_functions = self._select_functions(select)
variants_remain = self._check_remaining_variants(new_functions)
if variants_remain:
return self.copy(functions=new_functions, **kwargs)
# No variants left, return a regular Pipeline
return Pipeline(
new_functions, # type: ignore[arg-type]
lazy=kwargs.get("lazy", self.lazy),
debug=kwargs.get("debug", self.debug),
print_error=kwargs.get("print_error", self.print_error),
profile=kwargs.get("profile", self.profile),
cache_type=kwargs.get("cache_type", self.cache_type),
cache_kwargs=kwargs.get("cache_kwargs", self.cache_kwargs),
validate_type_annotations=kwargs.get(
"validate_type_annotations",
self.validate_type_annotations,
),
scope=kwargs.get("scope", self.scope),
default_resources=kwargs.get("default_resources", self.default_resources),
)
def _resolve_single_variant(self, select: str) -> dict[str | None, str]:
"""Resolve a single variant string to a dictionary."""
inv = self._variants_mapping_inverse()
group = inv.get(select, set())
if len(group) > 1:
msg = f"Ambiguous variant: `{select}`, could be in either `{group}`"
raise ValueError(msg)
if not group:
msg = f"Unknown variant: `{select}`, choose one of: `{', '.join(inv)}`"
raise ValueError(msg)
return {group.pop(): select}
def _select_functions(self, select: dict[str | None, str]) -> list[PipeFunc]:
"""Select functions based on the given variant selection."""
new_functions: list[PipeFunc] = []
for function in self.functions:
# For functions with no variants, always include them
if not function.variant:
new_functions.append(function)
continue
# Check if function matches the selected variants
include = True
# Check variants dict
for group, variant in function.variant.items():
if group in select and select[group] != variant:
include = False
break
if include:
new_functions.append(function)
return new_functions
def _check_remaining_variants(self, functions: list[PipeFunc]) -> bool:
"""Check if any variants remain after selection."""
left_over = defaultdict(set)
for function in functions:
for group, variant in function.variant.items():
left_over[group].add(variant)
return any(len(variants) > 1 for variants in left_over.values())
[docs]
def copy(self, **kwargs: Any) -> VariantPipeline:
"""Return a copy of the VariantPipeline.
Parameters
----------
kwargs
Keyword arguments passed to the `VariantPipeline` constructor instead of the
original values.
"""
original_kwargs = {
"functions": self.functions,
"lazy": self.lazy,
"debug": self.debug,
"print_error": self.print_error,
"profile": self.profile,
"cache_type": self.cache_type,
"cache_kwargs": self.cache_kwargs,
"validate_type_annotations": self.validate_type_annotations,
"scope": self.scope,
"default_resources": self.default_resources,
"default_variant": self.default_variant,
"name": self.name,
"description": self.description,
}
assert_complete_kwargs(original_kwargs, VariantPipeline.__init__, skip={"self"})
original_kwargs.update(kwargs)
return VariantPipeline(**original_kwargs) # type: ignore[arg-type]
[docs]
@classmethod
def from_pipelines(
cls,
*variant_pipeline: tuple[str, str, Pipeline] | tuple[str, Pipeline],
) -> VariantPipeline:
"""Create a new `VariantPipeline` from multiple `Pipeline` instances.
This method constructs a `VariantPipeline` by combining functions from
multiple `Pipeline` instances, identifying common functions and assigning
variants based on the input tuples.
Each input tuple can either be a 2-tuple or a 3-tuple.
- A 2-tuple contains: ``(variant_name, pipeline)``.
- A 3-tuple contains: ``(variant_group, variant_name, pipeline)``.
Functions that are identical across all input pipelines (as determined by
the `is_identical_pipefunc` function) are considered "common" and are added to
the resulting `VariantPipeline` without any variant information.
Functions that are unique to a specific pipeline are added with their
corresponding variant information (if provided in the input tuple).
Parameters
----------
*variant_pipeline
Variable number of tuples, where each tuple represents a pipeline and its
associated variant information. Each tuple can be either:
- `(variant_name, pipeline)`: Specifies the variant name for all functions
in the pipeline. The variant group will be set to `None` (default group).
- `(variant_group, variant_name, pipeline)`: Specifies both the variant
group and variant name for all functions in the pipeline.
Returns
-------
A new `VariantPipeline` instance containing the combined functions from
the input pipelines, with appropriate variant assignments.
Examples
--------
>>> @pipefunc(output_name="x")
... def f(a, b):
... return a + b
...
>>> @pipefunc(output_name="y")
... def g(x, c):
... return x * c
...
>>> pipeline1 = Pipeline([f, g])
>>> pipeline2 = Pipeline([f, g.copy(func=lambda x, c: x / c)])
>>> variant_pipeline = VariantPipeline.from_pipelines(
... ("add_mul", pipeline1),
... ("add_div", pipeline2)
... )
>>> add_mul_pipeline = variant_pipeline.with_variant(select="add_mul")
>>> add_div_pipeline = variant_pipeline.with_variant(select="add_div")
>>> add_mul_pipeline(a=1, b=2, c=3) # (1 + 2) * 3 = 9
9
>>> add_div_pipeline(a=1, b=2, c=3) # (1 + 2) / 3 = 1.0
1.0
Notes
-----
- The `is_identical_pipefunc` function is used to determine if two `PipeFunc`
instances are identical.
- If multiple pipelines contain the same function but with different variant
information, the function will be included multiple times in the
resulting `VariantPipeline`, each with its respective variant assignment.
"""
if len(variant_pipeline) < 2: # noqa: PLR2004
msg = "At least 2 pipelines must be provided."
raise ValueError(msg)
all_funcs: list[list[PipeFunc]] = []
variant_info: list[tuple[str | None, str]] = []
for item in variant_pipeline:
if len(item) == 3: # noqa: PLR2004
variant_group, variant, pipeline = item
else:
variant, pipeline = item
variant_group = None
all_funcs.append(pipeline.functions)
variant_info.append((variant_group, variant))
# Find common functions using is_identical_pipefunc
common_funcs: list[PipeFunc] = []
for func in all_funcs[0]:
is_common = True
for other_funcs in all_funcs[1:]:
if not _pipefunc_in_list(func, other_funcs):
is_common = False
break
if is_common and not _pipefunc_in_list(func, common_funcs):
common_funcs.append(func)
functions: list[PipeFunc] = common_funcs[:]
# Add unique functions with variant information
for i, funcs in enumerate(all_funcs):
variant_group, variant = variant_info[i]
# Create the variants parameter based on variant_group
variants_param = {variant_group: variant} if variant_group is not None else variant
unique_funcs = [
func.copy(variant=variants_param)
for func in funcs
if not _pipefunc_in_list(func, common_funcs)
]
functions.extend(unique_funcs)
return cls(functions)
[docs]
def visualize(self, **kwargs: Any) -> Any:
"""Visualize the VariantPipeline with interactive variant selection.
Parameters
----------
kwargs
Additional keyword arguments passed to the `pipefunc.Pipeline.visualize` method.
Returns
-------
The output of the widget.
"""
requires("ipywidgets", reason="show_progress", extras="ipywidgets")
return _create_variant_selection_widget(
self,
_update_visualization, # type: ignore[arg-type]
**kwargs,
)
def _repr_mimebundle_(
self,
include: set[str] | None = None,
exclude: set[str] | None = None,
) -> dict[str, str]: # pragma: no cover
"""Display the VariantPipeline widget or a text representation.
Also displays a rich table of information if `rich` is installed.
"""
if is_running_in_ipynb() and is_installed("rich") and is_installed("ipywidgets"):
widget = _create_variant_selection_widget(
self,
_update_repr_mimebundle, # type: ignore[arg-type]
)
return widget._repr_mimebundle_(include=include, exclude=exclude)
# Return a plaintext representation of the object
return {"text/plain": repr(self)}
def __getattr__(self, name: str) -> None:
if name in Pipeline.__dict__:
unresolved = {
group: variants
for group, variants in self.variants_mapping().items()
if len(variants) > 1
}
if unresolved:
parts = []
if None in unresolved:
parts.append(f"variants {unresolved[None]}")
parts.extend(
f"variant group `{g} = {v}`" for g, v in unresolved.items() if g is not None
)
variants_info = f" The {' and '.join(parts)} are not yet resolved."
else:
variants_info = ""
msg = (
"This is a `VariantPipeline`, not a `Pipeline`."
f"{variants_info}"
" Use `VariantPipeline.with_variant(...)` to instanciate a Pipeline first."
f" Then access `Pipeline.{name}` again."
)
raise AttributeError(msg)
default_msg = f"'VariantPipeline' object has no attribute '{name}'"
raise AttributeError(default_msg)
def _validate_variants_exist(
variants_mapping: dict[str | None, set[str]],
selection: dict[str | None, str],
) -> None:
"""Validate that the specified variants exist."""
for group, variant_name in selection.items():
if group not in variants_mapping:
msg = f"Unknown variant group: `{group}`."
if variants_mapping:
groups = (str(k) for k in variants_mapping)
msg += f" Use one of: `{', '.join(groups)}`"
raise ValueError(msg)
if variant_name not in variants_mapping[group]:
msg = (
f"Unknown variant: `{variant_name}` in group `{group}`."
f" Use one of: `{', '.join(variants_mapping[group])}`"
)
raise ValueError(msg)
def _pipefunc_in_list(func: PipeFunc, funcs: list[PipeFunc]) -> bool:
"""Check if a PipeFunc instance is in a list of PipeFunc instances."""
return any(is_identical_pipefunc(func, f) for f in funcs)
def is_identical_pipefunc(first: PipeFunc, second: PipeFunc) -> bool:
"""Check if two PipeFunc instances are identical.
Note: This is not implemented as PipeFunc.__eq__ to avoid
hashing issues.
"""
cls = type(first)
for attr, value in first.__dict__.items():
if isinstance(getattr(cls, attr, None), functools.cached_property):
continue
if attr == "_pipelines":
continue
if value != second.__dict__[attr]:
return False
return True
def _create_variant_selection_widget(
vp: VariantPipeline,
update_func: Callable[[Pipeline, ipywidgets.Output, Any], None],
**kwargs: Any,
) -> ipywidgets.VBox:
"""Create a widget for interactive variant selection.
Parameters
----------
vp
The VariantPipeline.
update_func
The function to call when the selected variant changes.
kwargs
Additional keyword arguments passed to the `Pipeline.visualize` method
in `update_visualization`.
Returns
-------
A widget containing dropdown menus for variant selection.
"""
import ipywidgets
dropdowns: dict[str | None, ipywidgets.Dropdown] = {}
output = ipywidgets.Output()
default = _ensure_dict(vp.default_variant)
def wrapped_update_func(_change: dict | None = None) -> None:
"""Update the output with the selected variants."""
selected_variants = {group: dropdowns[group].value for group in vp.variants_mapping()}
pipeline = vp.with_variant(select=selected_variants)
assert isinstance(pipeline, Pipeline)
update_func(pipeline, output, **kwargs) # type: ignore[call-arg]
for group, variants in vp.variants_mapping().items():
options = list(variants)
dropdown = ipywidgets.Dropdown(
options=options,
value=default.get(group, options[0]),
description=f"{group}:",
disabled=False,
)
dropdown.observe(wrapped_update_func, names="value")
dropdowns[group] = dropdown
# Initial update
wrapped_update_func()
return ipywidgets.VBox([*dropdowns.values(), output])
def _ensure_dict(default_variant: str | dict[str | None, str] | None) -> dict[str | None, str]:
"""Ensure that the default_variant is a dictionary."""
if default_variant is None:
return {}
if isinstance(default_variant, str):
return {None: default_variant}
return default_variant
def _update_visualization(
pipeline: Pipeline,
output: ipywidgets.Output,
**kwargs: Any,
) -> None:
"""Update the visualization with the selected variants."""
from IPython.display import display
with output:
output.clear_output()
backend = kwargs.pop("backend", "graphviz")
viz = pipeline.visualize(backend=backend, **kwargs)
if viz is not None:
display(viz)
def _update_repr_mimebundle(
pipeline: Pipeline,
output: ipywidgets.Output,
**kwargs: Any,
) -> None: # pragma: no cover
"""Update the displayed output with the selected variant's mimebundle."""
from IPython.display import display
with output:
output.clear_output(wait=True)
display(pipeline, **kwargs)