Source code for pipefunc.typing

"""Custom type hinting utilities for pipefunc."""

import re
import sys
import warnings
from collections.abc import Callable, Iterable
from types import UnionType
from typing import (
    Annotated,
    Any,
    ForwardRef,
    Generic,
    NamedTuple,
    TypeVar,
    Union,
    get_args,
    get_origin,
    get_type_hints,
)

import numpy as np


[docs] class NoAnnotation: """Marker class for missing type annotations."""
T = TypeVar("T")
[docs] class ArrayElementType(Generic[T]): """Marker class for the element type of an annotated numpy array."""
[docs] class Array(Generic[T], np.ndarray[Any, np.dtype[np.object_]]): """Annotated numpy array type hint with element type.""" # NOTE: Ideally we would do something like this: # `Array = Annotated[np.ndarray[Any, np.dtype[object]], ArrayElementType[T]]` # however, Annotated doesn't support generics in metadata, see: # https://discuss.python.org/t/generics-in-metadata-of-annotated/62059 def __class_getitem__(cls, item: T) -> Any: """Return an annotated numpy array with the provided element type.""" return Annotated[ np.ndarray[Any, np.dtype[np.object_]], ArrayElementType[item], # type: ignore[valid-type] ]
[docs] class TypeCheckMemo(NamedTuple): """Named tuple to store memoization data for type checking.""" globals: dict[str, Any] | None locals: dict[str, Any] | None self_type: type | None = None
def _evaluate_forwardref(ref: ForwardRef, memo: TypeCheckMemo) -> Any: # pragma: no cover """Evaluate a forward reference using the provided memo.""" if sys.version_info < (3, 13): return ref._evaluate(memo.globals, memo.locals, recursive_guard=frozenset()) return ref._evaluate(memo.globals, memo.locals, recursive_guard=frozenset(), type_params={}) def _resolve_type(type_: Any, memo: TypeCheckMemo) -> Any: """Resolve forward references in a type hint.""" if isinstance(type_, str): return _evaluate_forwardref(ForwardRef(type_), memo) if isinstance(type_, ForwardRef): return _evaluate_forwardref(type_, memo) origin = get_origin(type_) if origin: args = get_args(type_) resolved_args = tuple(_resolve_type(arg, memo) for arg in args) if origin in {Union, UnionType}: # Handle both Union and new | syntax return Union[resolved_args] # noqa: UP007 return origin[resolved_args] # Ensure correct subscripting for generic types return type_ def _check_identical_or_any(incoming_type: type[Any], required_type: type[Any]) -> bool: """Check if types are identical or if required_type is Any.""" for t in (incoming_type, required_type): if isinstance(t, Unresolvable): warnings.warn( f"⚠️ Unresolvable type hint: `{t.type_str}`. Skipping type comparison.", stacklevel=3, ) return True return ( incoming_type == required_type or required_type is Any or incoming_type is NoAnnotation or required_type is NoAnnotation ) def _all_types_compatible( incoming_args: tuple[Any, ...], required_args: tuple[Any, ...], memo: TypeCheckMemo, ) -> bool: """Helper function to check if all incoming types are compatible with any required type.""" return all( any(is_type_compatible(t1, t2, memo) for t2 in required_args) for t1 in incoming_args ) def _handle_union_types( incoming_type: type[Any], required_type: type[Any], memo: TypeCheckMemo, ) -> bool | None: """Handle compatibility logic for Union types with directional consideration.""" if (isinstance(incoming_type, UnionType) or get_origin(incoming_type) is Union) and ( isinstance(required_type, UnionType) or get_origin(required_type) is Union ): incoming_type_args = get_args(incoming_type) required_type_args = get_args(required_type) return _all_types_compatible(incoming_type_args, required_type_args, memo) if isinstance(incoming_type, UnionType) or get_origin(incoming_type) is Union: return all(is_type_compatible(t, required_type, memo) for t in get_args(incoming_type)) if isinstance(required_type, UnionType) or get_origin(required_type) is Union: return any(is_type_compatible(incoming_type, t, memo) for t in get_args(required_type)) return None def _extract_array_element_type(metadata: Iterable[Any]) -> Any | None: """Extract the ArrayElementType from the metadata if it exists.""" return next((get_args(t)[0] for t in metadata if get_origin(t) is ArrayElementType), None) def _compare_annotated_types( incoming_type: type[Any], required_type: type[Any], memo: TypeCheckMemo, ) -> bool: """Compare Annotated types including metadata.""" incoming_primary, *incoming_metadata = get_args(incoming_type) required_primary, *required_metadata = get_args(required_type) # Recursively check the primary types if not is_type_compatible(incoming_primary, required_primary, memo): return False # Compare metadata (extras) incoming_array_element_type = _extract_array_element_type(incoming_metadata) required_array_element_type = _extract_array_element_type(required_metadata) if incoming_array_element_type is not None and required_array_element_type is not None: return is_type_compatible(incoming_array_element_type, required_array_element_type, memo) return True def _compare_single_annotated_type( annotated_type: type[Any], other_type: type[Any], memo: TypeCheckMemo, ) -> bool: """Handle cases where only one of the types is Annotated.""" primary_type, *_ = get_args(annotated_type) return is_type_compatible(primary_type, other_type, memo) def _compare_generic_type_origins(incoming_origin: type[Any], required_origin: type[Any]) -> bool: """Compare the origins of generic types for compatibility.""" if isinstance(incoming_origin, type) and isinstance(required_origin, type): return issubclass(incoming_origin, required_origin) return incoming_origin == required_origin def _compare_generic_type_args( incoming_args: tuple[Any, ...], required_args: tuple[Any, ...], memo: TypeCheckMemo, ) -> bool: """Compare the arguments of generic types for compatibility.""" if not required_args or not incoming_args: return True return all(is_type_compatible(t1, t2, memo) for t1, t2 in zip(incoming_args, required_args)) def _handle_generic_types( incoming_type: type[Any], required_type: type[Any], memo: TypeCheckMemo, ) -> bool | None: incoming_origin = get_origin(incoming_type) or incoming_type required_origin = get_origin(required_type) or required_type # Handle Annotated types if incoming_origin is Annotated and required_origin is Annotated: return _compare_annotated_types(incoming_type, required_type, memo) if incoming_origin is Annotated: return _compare_single_annotated_type(incoming_type, required_type, memo) if required_origin is Annotated: return _compare_single_annotated_type(required_type, incoming_type, memo) # Handle generic types if incoming_origin and required_origin: if not _compare_generic_type_origins(incoming_origin, required_origin): return False incoming_args = get_args(incoming_type) required_args = get_args(required_type) return _compare_generic_type_args(incoming_args, required_args, memo) return None
[docs] def is_type_compatible( incoming_type: Any, required_type: Any, memo: TypeCheckMemo | None = None, ) -> bool: """Check if the incoming type is compatible with the required type, resolving forward references.""" if memo is None: # for testing purposes memo = TypeCheckMemo(globals={}, locals={}) incoming_type = _resolve_type(incoming_type, memo) required_type = _resolve_type(required_type, memo) if isinstance(incoming_type, TypeVar): # TODO: the incoming type needs to be resolved to a concrete type # using the types of the arguments passed to the function. This might # require a more complex implementation. For now, we just return True. return True if _check_identical_or_any(incoming_type, required_type): return True if (result := _is_typevar_compatible(incoming_type, required_type, memo)) is not None: return result if (result := _handle_union_types(incoming_type, required_type, memo)) is not None: return result if (result := _handle_generic_types(incoming_type, required_type, memo)) is not None: return result return False
def _is_typevar_compatible( incoming_type: Any, required_type: Any, memo: TypeCheckMemo, ) -> bool | None: """Check if the required type is a TypeVar and is compatible with incoming type.""" if not isinstance(required_type, TypeVar): return None if not required_type.__constraints__ and not required_type.__bound__: return True if required_type.__constraints__ and any( is_type_compatible(incoming_type, c, memo) for c in required_type.__constraints__ ): return True return required_type.__bound__ and is_type_compatible( incoming_type, required_type.__bound__, memo, )
[docs] def is_object_array_type(tp: Any) -> bool: """Check if the given type is similar to `Array[T]`. Specifically, this function checks if the type is either: 1. `Annotated[numpy.ndarray[Any, numpy.dtype[numpy.object_]], T]` 2. `numpy.ndarray[Any, numpy.dtype[numpy.object_]]` """ if get_origin(tp) is np.ndarray: # Base case: directly an np.ndarray[Any, np.dtype[np.object_]] return get_args(tp) == (Any, np.dtype[np.object_]) if get_origin(tp) is Annotated: # Recursive case: strip the Annotated and check the first argument array_type, _ = get_args(tp) return is_object_array_type(array_type) return False
[docs] class Unresolvable: # noqa: PLW1641 """Class to represent an unresolvable type hint.""" def __init__(self, type_str: str) -> None: """Initialize the Unresolvable instance.""" self.type_str = type_str def __repr__(self) -> str: """Return a string representation of the Unresolvable instance.""" return f"Unresolvable[{self.type_str}]" def __eq__(self, other: object) -> bool: """Check equality between two Unresolvable instances.""" if isinstance(other, Unresolvable): return self.type_str == other.type_str return False
[docs] def safe_get_type_hints( func: Callable[..., Any], include_extras: bool = False, # noqa: FBT002 ) -> dict[str, Any]: """Safely get type hints for a function, resolving forward references.""" try: hints = get_type_hints(func, include_extras=include_extras) except Exception: # noqa: BLE001 hints = func.__annotations__ _globals = getattr(func, "__globals__", {}) memo = TypeCheckMemo(globals=_globals, locals=None) resolved_hints = {} for arg, hint in hints.items(): processed_hint = type(None) if hint is None else hint try: resolved = _resolve_type(processed_hint, memo) if resolved is None: resolved = type(None) resolved_hints[arg] = resolved except (NameError, Exception): resolved_hints[arg] = Unresolvable(str(processed_hint)) return resolved_hints
def _args_as_string(args: Iterable[Any]) -> str: return ", ".join(type_as_string(arg) for arg in args)
[docs] def type_as_string(type_: Any) -> str: # noqa: PLR0911 """Get a string representation of a type.""" if isinstance(type_, str): return _clean_type_string(type_) # Handle forward references if isinstance(type_, Unresolvable): return _clean_type_string(type_.type_str) if isinstance(type_, ForwardRef): return _clean_type_string(type_.__forward_arg__) if isinstance(type_, TypeVar): return _clean_type_string(type_.__name__) if isinstance(type_, list): # e.g., the arg list in `Callable[[here], Any]` return f"[{_args_as_string(type_)}]" origin = get_origin(type_) if origin is not None: args = get_args(type_) if is_object_array_type(type_): element_type = _extract_array_element_type(args[1:]) _Array = Array.__name__ # noqa: N806 return f"{_Array}[{type_as_string(element_type)}]" if element_type else _Array return f"{_clean_type_string(origin.__name__)}[{_args_as_string(args)}]" if hasattr(type_, "__name__"): return _clean_type_string(type_.__name__) # Fall back to string representation if all else fails return _clean_type_string(str(type_))
def _clean_type_string(type_str: str) -> str: # Remove 'typing.' prefix type_str = re.sub(r"\btyping\.", "", type_str) # Remove 'collections.abc.' prefix type_str = re.sub(r"\bcollections\.abc\.", "", type_str) # Replace 'UnionType' with 'Union' type_str = re.sub(r"\bUnionType\b", "Union", type_str) return type_str # noqa: RET504