Shofel2_T124_python/venv/lib/python3.10/site-packages/overrides/signature.py

305 lines
11 KiB
Python
Raw Normal View History

2024-05-25 16:45:07 +00:00
import inspect
from inspect import Parameter
from types import FunctionType
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar, Union, get_type_hints
from .typing_utils import get_args, issubtype
_WrappedMethod = TypeVar("_WrappedMethod", bound=Union[FunctionType, Callable])
_WrappedMethod2 = TypeVar("_WrappedMethod2", bound=Union[FunctionType, Callable])
def _contains_unbound_typevar(t: Type) -> bool:
"""Recursively check if `t` or any types contained by `t` is a `TypeVar`.
Examples where we return `True`: `T`, `Optional[T]`, `Tuple[Optional[T], ...]`, ...
Examples where we return `False`: `int`, `Optional[str]`, ...
:param t: Type to evaluate.
:return: `True` if the input type contains an unbound `TypeVar`, `False` otherwise.
"""
# Check self
if isinstance(t, TypeVar):
return True
# Check children
for arg in get_args(t):
if _contains_unbound_typevar(arg):
return True
return False
def _issubtype(left, right):
if _contains_unbound_typevar(left):
return True
if right is None:
return True
if _contains_unbound_typevar(right):
return True
try:
return issubtype(left, right)
except TypeError:
# Ignore all broken cases
return True
def _get_type_hints(callable) -> Optional[Dict]:
try:
return get_type_hints(callable)
except (NameError, TypeError):
return None
def _is_same_module(callable1: _WrappedMethod, callable2: _WrappedMethod2) -> bool:
mod1 = callable1.__module__.split(".")[0]
try:
mod2 = callable2.__module__
except AttributeError:
return False
mod2 = mod2.split(".")[0]
return mod1 == mod2
def ensure_signature_is_compatible(
super_callable: _WrappedMethod,
sub_callable: _WrappedMethod2,
is_static: bool = False,
) -> None:
"""Ensure that the signature of `sub_callable` is compatible with the signature of `super_callable`.
Guarantees that any call to `super_callable` will work on `sub_callable` by checking the following criteria:
1. The return type of `sub_callable` is a subtype of the return type of `super_callable`.
2. All parameters of `super_callable` are present in `sub_callable`, unless `sub_callable`
declares `*args` or `**kwargs`.
3. All positional parameters of `super_callable` appear in the same order in `sub_callable`.
4. All parameters of `super_callable` are a subtype of the corresponding parameters of `sub_callable`.
5. All required parameters of `sub_callable` are present in `super_callable`, unless `super_callable`
declares `*args` or `**kwargs`.
:param super_callable: Function to check compatibility with.
:param sub_callable: Function to check compatibility of.
:param is_static: True if staticmethod and should check first argument.
"""
super_callable = _unbound_func(super_callable)
sub_callable = _unbound_func(sub_callable)
try:
super_sig = inspect.signature(super_callable)
except ValueError:
return
super_type_hints = _get_type_hints(super_callable)
sub_sig = inspect.signature(sub_callable)
sub_type_hints = _get_type_hints(sub_callable)
method_name = sub_callable.__qualname__
same_main_module = _is_same_module(sub_callable, super_callable)
if super_type_hints is not None and sub_type_hints is not None:
ensure_return_type_compatibility(super_type_hints, sub_type_hints, method_name)
ensure_all_kwargs_defined_in_sub(
super_sig, sub_sig, super_type_hints, sub_type_hints, is_static, method_name
)
ensure_all_positional_args_defined_in_sub(
super_sig,
sub_sig,
super_type_hints,
sub_type_hints,
is_static,
same_main_module,
method_name,
)
ensure_no_extra_args_in_sub(super_sig, sub_sig, is_static, method_name)
def _unbound_func(callable: _WrappedMethod) -> _WrappedMethod:
if hasattr(callable, "__self__") and hasattr(callable, "__func__"):
return callable.__func__ # type: ignore
return callable
def ensure_all_kwargs_defined_in_sub(
super_sig: inspect.Signature,
sub_sig: inspect.Signature,
super_type_hints: Dict,
sub_type_hints: Dict,
check_first_parameter: bool,
method_name: str,
):
sub_has_var_kwargs = any(
p.kind == Parameter.VAR_KEYWORD for p in sub_sig.parameters.values()
)
for super_index, (name, super_param) in enumerate(super_sig.parameters.items()):
if super_index == 0 and not check_first_parameter:
continue
if super_param.kind == Parameter.VAR_POSITIONAL:
continue
if super_param.kind == Parameter.POSITIONAL_ONLY:
continue
if not is_param_defined_in_sub(
name, True, sub_has_var_kwargs, sub_sig, super_param
):
raise TypeError(f"{method_name}: `{name}` is not present.")
elif name in sub_sig.parameters and super_param.kind != Parameter.VAR_KEYWORD:
sub_index = list(sub_sig.parameters.keys()).index(name)
sub_param = sub_sig.parameters[name]
if super_param.kind != sub_param.kind and not (
super_param.kind == Parameter.KEYWORD_ONLY
and sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD
):
raise TypeError(f"{method_name}: `{name}` is not `{super_param.kind}`")
elif super_index > sub_index and super_param.kind != Parameter.KEYWORD_ONLY:
raise TypeError(
f"{method_name}: `{name}` is not parameter at index `{super_index}`"
)
elif (
name in super_type_hints
and name in sub_type_hints
and not _issubtype(super_type_hints[name], sub_type_hints[name])
):
raise TypeError(
f"`{method_name}: {name} must be a supertype of `{super_param.annotation}` but is `{sub_param.annotation}`"
)
def ensure_all_positional_args_defined_in_sub(
super_sig: inspect.Signature,
sub_sig: inspect.Signature,
super_type_hints: Dict,
sub_type_hints: Dict,
check_first_parameter: bool,
is_same_main_module: bool,
method_name: str,
):
sub_parameter_values = [
v
for v in sub_sig.parameters.values()
if v.kind not in (Parameter.KEYWORD_ONLY, Parameter.VAR_KEYWORD)
]
super_parameter_values = [
v
for v in super_sig.parameters.values()
if v.kind not in (Parameter.KEYWORD_ONLY, Parameter.VAR_KEYWORD)
]
sub_has_var_args = any(
p.kind == Parameter.VAR_POSITIONAL for p in sub_parameter_values
)
super_has_var_args = any(
p.kind == Parameter.VAR_POSITIONAL for p in super_parameter_values
)
if not sub_has_var_args and len(sub_parameter_values) < len(super_parameter_values):
raise TypeError(f"{method_name}: parameter list too short")
super_shift = 0
for index, sub_param in enumerate(sub_parameter_values):
if index == 0 and not check_first_parameter:
continue
if index + super_shift >= len(super_parameter_values):
if sub_param.kind == Parameter.VAR_POSITIONAL:
continue
if (
sub_param.kind == Parameter.POSITIONAL_ONLY
and sub_param.default != Parameter.empty
):
continue
if sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD:
continue # Assume use as keyword
raise TypeError(
f"{method_name}: `{sub_param.name}` positionally required in subclass but not in supertype"
)
if sub_param.kind == Parameter.VAR_POSITIONAL:
return
super_param = super_parameter_values[index + super_shift]
if super_param.kind == Parameter.VAR_POSITIONAL:
super_shift -= 1
if super_param.kind == Parameter.VAR_POSITIONAL:
if not sub_has_var_args:
raise TypeError(f"{method_name}: `{super_param.name}` must be present")
continue
if (
super_param.kind != sub_param.kind
and not (
super_param.kind == Parameter.POSITIONAL_ONLY
and sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD
)
and not (sub_param.kind == Parameter.POSITIONAL_ONLY and super_has_var_args)
):
raise TypeError(
f"{method_name}: `{sub_param.name}` is not `{super_param.kind}` and is `{sub_param.kind}`"
)
elif (
super_param.name in super_type_hints or is_same_main_module
) and not _issubtype(
super_type_hints.get(super_param.name, None),
sub_type_hints.get(sub_param.name, None),
):
raise TypeError(
f"`{method_name}: {sub_param.name} overriding must be a supertype of `{super_param.annotation}` but is `{sub_param.annotation}`"
)
def is_param_defined_in_sub(
name: str,
sub_has_var_args: bool,
sub_has_var_kwargs: bool,
sub_sig: inspect.Signature,
super_param: inspect.Parameter,
) -> bool:
return (
name in sub_sig.parameters
or (super_param.kind == Parameter.VAR_POSITIONAL and sub_has_var_args)
or (super_param.kind == Parameter.VAR_KEYWORD and sub_has_var_kwargs)
or (super_param.kind == Parameter.POSITIONAL_ONLY and sub_has_var_args)
or (
super_param.kind == Parameter.POSITIONAL_OR_KEYWORD
and sub_has_var_args
and sub_has_var_kwargs
)
or (super_param.kind == Parameter.KEYWORD_ONLY and sub_has_var_kwargs)
)
def ensure_no_extra_args_in_sub(
super_sig: inspect.Signature,
sub_sig: inspect.Signature,
check_first_parameter: bool,
method_name: str,
) -> None:
super_params = super_sig.parameters.values()
super_var_args = any(p.kind == Parameter.VAR_POSITIONAL for p in super_params)
super_var_kwargs = any(p.kind == Parameter.VAR_KEYWORD for p in super_params)
for sub_index, (name, sub_param) in enumerate(sub_sig.parameters.items()):
if (
sub_param.kind == Parameter.POSITIONAL_ONLY
and len(super_params) > sub_index
and list(super_params)[sub_index].kind == Parameter.POSITIONAL_ONLY
):
continue
if (
name not in super_sig.parameters
and sub_param.default == Parameter.empty
and sub_param.kind != Parameter.VAR_POSITIONAL
and sub_param.kind != Parameter.VAR_KEYWORD
and not (sub_param.kind == Parameter.KEYWORD_ONLY and super_var_kwargs)
and not (sub_param.kind == Parameter.POSITIONAL_ONLY and super_var_args)
and not (
sub_param.kind == Parameter.POSITIONAL_OR_KEYWORD and super_var_args
)
and (sub_index > 0 or check_first_parameter)
):
raise TypeError(f"{method_name}: `{name}` is not a valid parameter.")
def ensure_return_type_compatibility(
super_type_hints: Dict, sub_type_hints: Dict, method_name: str
):
super_return = super_type_hints.get("return", None)
sub_return = sub_type_hints.get("return", None)
if not _issubtype(sub_return, super_return) and super_return is not None:
raise TypeError(
f"{method_name}: return type `{sub_return}` is not a `{super_return}`."
)