"""Simple type inference for decorated functions during semantic analysis.""" from typing import Optional from mypy.nodes import Expression, Decorator, CallExpr, FuncDef, RefExpr, Var, ARG_POS from mypy.types import ( Type, CallableType, AnyType, TypeOfAny, TypeVarType, ProperType, get_proper_type ) from mypy.typeops import function_type from mypy.typevars import has_no_typevars from mypy.semanal_shared import SemanticAnalyzerInterface def infer_decorator_signature_if_simple(dec: Decorator, analyzer: SemanticAnalyzerInterface) -> None: """Try to infer the type of the decorated function. This lets us resolve additional references to decorated functions during type checking. Otherwise the type might not be available when we need it, since module top levels can't be deferred. This basically uses a simple special-purpose type inference engine just for decorators. """ if dec.var.is_property: # Decorators are expected to have a callable type (it's a little odd). if dec.func.type is None: dec.var.type = CallableType( [AnyType(TypeOfAny.special_form)], [ARG_POS], [None], AnyType(TypeOfAny.special_form), analyzer.named_type('builtins.function'), name=dec.var.name) elif isinstance(dec.func.type, CallableType): dec.var.type = dec.func.type return decorator_preserves_type = True for expr in dec.decorators: preserve_type = False if isinstance(expr, RefExpr) and isinstance(expr.node, FuncDef): if expr.node.type and is_identity_signature(expr.node.type): preserve_type = True if not preserve_type: decorator_preserves_type = False break if decorator_preserves_type: # No non-identity decorators left. We can trivially infer the type # of the function here. dec.var.type = function_type(dec.func, analyzer.named_type('builtins.function')) if dec.decorators: return_type = calculate_return_type(dec.decorators[0]) if return_type and isinstance(return_type, AnyType): # The outermost decorator will return Any so we know the type of the # decorated function. dec.var.type = AnyType(TypeOfAny.from_another_any, source_any=return_type) sig = find_fixed_callable_return(dec.decorators[0]) if sig: # The outermost decorator always returns the same kind of function, # so we know that this is the type of the decorated function. orig_sig = function_type(dec.func, analyzer.named_type('builtins.function')) sig.name = orig_sig.items[0].name dec.var.type = sig def is_identity_signature(sig: Type) -> bool: """Is type a callable of form T -> T (where T is a type variable)?""" sig = get_proper_type(sig) if isinstance(sig, CallableType) and sig.arg_kinds == [ARG_POS]: if isinstance(sig.arg_types[0], TypeVarType) and isinstance(sig.ret_type, TypeVarType): return sig.arg_types[0].id == sig.ret_type.id return False def calculate_return_type(expr: Expression) -> Optional[ProperType]: """Return the return type if we can calculate it. This only uses information available during semantic analysis so this will sometimes return None because of insufficient information (as type inference hasn't run yet). """ if isinstance(expr, RefExpr): if isinstance(expr.node, FuncDef): typ = expr.node.type if typ is None: # No signature -> default to Any. return AnyType(TypeOfAny.unannotated) # Explicit Any return? if isinstance(typ, CallableType): return get_proper_type(typ.ret_type) return None elif isinstance(expr.node, Var): return get_proper_type(expr.node.type) elif isinstance(expr, CallExpr): return calculate_return_type(expr.callee) return None def find_fixed_callable_return(expr: Expression) -> Optional[CallableType]: """Return the return type, if expression refers to a callable that returns a callable. But only do this if the return type has no type variables. Return None otherwise. This approximates things a lot as this is supposed to be called before type checking when full type information is not available yet. """ if isinstance(expr, RefExpr): if isinstance(expr.node, FuncDef): typ = expr.node.type if typ: if isinstance(typ, CallableType) and has_no_typevars(typ.ret_type): ret_type = get_proper_type(typ.ret_type) if isinstance(ret_type, CallableType): return ret_type elif isinstance(expr, CallExpr): t = find_fixed_callable_return(expr.callee) if t: ret_type = get_proper_type(t.ret_type) if isinstance(ret_type, CallableType): return ret_type return None