import inspect from inspect import Parameter from types import FunctionType from typing import Callable, Dict, Optional, Tuple, Type, TypeVar, Union, get_type_hints from .typing_utils import get_args, issubtype _WrappedMethod = TypeVar("_WrappedMethod", bound=Union[FunctionType, Callable]) _WrappedMethod2 = TypeVar("_WrappedMethod2", bound=Union[FunctionType, Callable]) def _contains_unbound_typevar(t: Type) -> bool: """Recursively check if `t` or any types contained by `t` is a `TypeVar`. Examples where we return `True`: `T`, `Optional[T]`, `Tuple[Optional[T], ...]`, ... Examples where we return `False`: `int`, `Optional[str]`, ... :param t: Type to evaluate. :return: `True` if the input type contains an unbound `TypeVar`, `False` otherwise. """ # Check self if isinstance(t, TypeVar): return True # Check children for arg in get_args(t): if _contains_unbound_typevar(arg): return True return False def _issubtype(left, right): if _contains_unbound_typevar(left): return True if right is None: return True if _contains_unbound_typevar(right): return True try: return issubtype(left, right) except TypeError: # Ignore all broken cases return True def _get_type_hints(callable) -> Optional[Dict]: try: return get_type_hints(callable) except (NameError, TypeError): return None def _is_same_module(callable1: _WrappedMethod, callable2: _WrappedMethod2) -> bool: mod1 = callable1.__module__.split(".")[0] try: mod2 = callable2.__module__ except AttributeError: return False mod2 = mod2.split(".")[0] return mod1 == mod2 def ensure_signature_is_compatible( super_callable: _WrappedMethod, sub_callable: _WrappedMethod2, is_static: bool = False, ) -> None: """Ensure that the signature of `sub_callable` is compatible with the signature of `super_callable`. Guarantees that any call to `super_callable` will work on `sub_callable` by checking the following criteria: 1. The return type of `sub_callable` is a subtype of the return type of `super_callable`. 2. All parameters of `super_callable` are present in `sub_callable`, unless `sub_callable` declares `*args` or `**kwargs`. 3. All positional parameters of `super_callable` appear in the same order in `sub_callable`. 4. All parameters of `super_callable` are a subtype of the corresponding parameters of `sub_callable`. 5. All required parameters of `sub_callable` are present in `super_callable`, unless `super_callable` declares `*args` or `**kwargs`. :param super_callable: Function to check compatibility with. :param sub_callable: Function to check compatibility of. :param is_static: True if staticmethod and should check first argument. """ super_callable = _unbound_func(super_callable) sub_callable = _unbound_func(sub_callable) try: super_sig = inspect.signature(super_callable) except ValueError: return super_type_hints = _get_type_hints(super_callable) sub_sig = inspect.signature(sub_callable) sub_type_hints = _get_type_hints(sub_callable) method_name = sub_callable.__qualname__ same_main_module = _is_same_module(sub_callable, super_callable) if super_type_hints is not None and sub_type_hints is not None: ensure_return_type_compatibility(super_type_hints, sub_type_hints, method_name) ensure_all_kwargs_defined_in_sub( super_sig, sub_sig, super_type_hints, sub_type_hints, is_static, method_name ) ensure_all_positional_args_defined_in_sub( super_sig, sub_sig, super_type_hints, sub_type_hints, is_static, same_main_module, method_name, ) ensure_no_extra_args_in_sub(super_sig, sub_sig, is_static, method_name) def _unbound_func(callable: _WrappedMethod) -> _WrappedMethod: if hasattr(callable, "__self__") and hasattr(callable, "__func__"): return callable.__func__ # type: ignore return callable def ensure_all_kwargs_defined_in_sub( super_sig: inspect.Signature, sub_sig: inspect.Signature, super_type_hints: Dict, sub_type_hints: Dict, check_first_parameter: bool, method_name: str, ): sub_has_var_kwargs = any( p.kind == Parameter.VAR_KEYWORD for p in sub_sig.parameters.values() ) for super_index, (name, super_param) in enumerate(super_sig.parameters.items()): if super_index == 0 and not check_first_parameter: continue if super_param.kind == Parameter.VAR_POSITIONAL: continue if super_param.kind == Parameter.POSITIONAL_ONLY: continue if not is_param_defined_in_sub( name, True, sub_has_var_kwargs, sub_sig, super_param ): raise TypeError(f"{method_name}: `{name}` is not present.") elif name in sub_sig.parameters and super_param.kind != Parameter.VAR_KEYWORD: sub_index = list(sub_sig.parameters.keys()).index(name) sub_param = sub_sig.parameters[name] if super_param.kind != sub_param.kind and not ( super_param.kind == Parameter.KEYWORD_ONLY and sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD ): raise TypeError(f"{method_name}: `{name}` is not `{super_param.kind}`") elif super_index > sub_index and super_param.kind != Parameter.KEYWORD_ONLY: raise TypeError( f"{method_name}: `{name}` is not parameter at index `{super_index}`" ) elif ( name in super_type_hints and name in sub_type_hints and not _issubtype(super_type_hints[name], sub_type_hints[name]) ): raise TypeError( f"`{method_name}: {name} must be a supertype of `{super_param.annotation}` but is `{sub_param.annotation}`" ) def ensure_all_positional_args_defined_in_sub( super_sig: inspect.Signature, sub_sig: inspect.Signature, super_type_hints: Dict, sub_type_hints: Dict, check_first_parameter: bool, is_same_main_module: bool, method_name: str, ): sub_parameter_values = [ v for v in sub_sig.parameters.values() if v.kind not in (Parameter.KEYWORD_ONLY, Parameter.VAR_KEYWORD) ] super_parameter_values = [ v for v in super_sig.parameters.values() if v.kind not in (Parameter.KEYWORD_ONLY, Parameter.VAR_KEYWORD) ] sub_has_var_args = any( p.kind == Parameter.VAR_POSITIONAL for p in sub_parameter_values ) super_has_var_args = any( p.kind == Parameter.VAR_POSITIONAL for p in super_parameter_values ) if not sub_has_var_args and len(sub_parameter_values) < len(super_parameter_values): raise TypeError(f"{method_name}: parameter list too short") super_shift = 0 for index, sub_param in enumerate(sub_parameter_values): if index == 0 and not check_first_parameter: continue if index + super_shift >= len(super_parameter_values): if sub_param.kind == Parameter.VAR_POSITIONAL: continue if ( sub_param.kind == Parameter.POSITIONAL_ONLY and sub_param.default != Parameter.empty ): continue if sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD: continue # Assume use as keyword raise TypeError( f"{method_name}: `{sub_param.name}` positionally required in subclass but not in supertype" ) if sub_param.kind == Parameter.VAR_POSITIONAL: return super_param = super_parameter_values[index + super_shift] if super_param.kind == Parameter.VAR_POSITIONAL: super_shift -= 1 if super_param.kind == Parameter.VAR_POSITIONAL: if not sub_has_var_args: raise TypeError(f"{method_name}: `{super_param.name}` must be present") continue if ( super_param.kind != sub_param.kind and not ( super_param.kind == Parameter.POSITIONAL_ONLY and sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD ) and not (sub_param.kind == Parameter.POSITIONAL_ONLY and super_has_var_args) ): raise TypeError( f"{method_name}: `{sub_param.name}` is not `{super_param.kind}` and is `{sub_param.kind}`" ) elif ( super_param.name in super_type_hints or is_same_main_module ) and not _issubtype( super_type_hints.get(super_param.name, None), sub_type_hints.get(sub_param.name, None), ): raise TypeError( f"`{method_name}: {sub_param.name} overriding must be a supertype of `{super_param.annotation}` but is `{sub_param.annotation}`" ) def is_param_defined_in_sub( name: str, sub_has_var_args: bool, sub_has_var_kwargs: bool, sub_sig: inspect.Signature, super_param: inspect.Parameter, ) -> bool: return ( name in sub_sig.parameters or (super_param.kind == Parameter.VAR_POSITIONAL and sub_has_var_args) or (super_param.kind == Parameter.VAR_KEYWORD and sub_has_var_kwargs) or (super_param.kind == Parameter.POSITIONAL_ONLY and sub_has_var_args) or ( super_param.kind == Parameter.POSITIONAL_OR_KEYWORD and sub_has_var_args and sub_has_var_kwargs ) or (super_param.kind == Parameter.KEYWORD_ONLY and sub_has_var_kwargs) ) def ensure_no_extra_args_in_sub( super_sig: inspect.Signature, sub_sig: inspect.Signature, check_first_parameter: bool, method_name: str, ) -> None: super_params = super_sig.parameters.values() super_var_args = any(p.kind == Parameter.VAR_POSITIONAL for p in super_params) super_var_kwargs = any(p.kind == Parameter.VAR_KEYWORD for p in super_params) for sub_index, (name, sub_param) in enumerate(sub_sig.parameters.items()): if ( sub_param.kind == Parameter.POSITIONAL_ONLY and len(super_params) > sub_index and list(super_params)[sub_index].kind == Parameter.POSITIONAL_ONLY ): continue if ( name not in super_sig.parameters and sub_param.default == Parameter.empty and sub_param.kind != Parameter.VAR_POSITIONAL and sub_param.kind != Parameter.VAR_KEYWORD and not (sub_param.kind == Parameter.KEYWORD_ONLY and super_var_kwargs) and not (sub_param.kind == Parameter.POSITIONAL_ONLY and super_var_args) and not ( sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD and super_var_args ) and (sub_index > 0 or check_first_parameter) ): raise TypeError(f"{method_name}: `{name}` is not a valid parameter.") def ensure_return_type_compatibility( super_type_hints: Dict, sub_type_hints: Dict, method_name: str ): super_return = super_type_hints.get("return", None) sub_return = sub_type_hints.get("return", None) if not _issubtype(sub_return, super_return) and super_return is not None: raise TypeError( f"{method_name}: return type `{sub_return}` is not a `{super_return}`." )