203 lines
6.9 KiB
Python
203 lines
6.9 KiB
Python
|
from typing import List, Optional, Union
|
||
|
|
||
|
from mypy.nodes import (
|
||
|
ARG_POS, MDEF, Argument, Block, CallExpr, ClassDef, Expression, SYMBOL_FUNCBASE_TYPES,
|
||
|
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var, JsonDict,
|
||
|
)
|
||
|
from mypy.plugin import CheckerPluginInterface, ClassDefContext, SemanticAnalyzerPluginInterface
|
||
|
from mypy.semanal import set_callable_name, ALLOW_INCOMPATIBLE_OVERRIDE
|
||
|
from mypy.types import (
|
||
|
CallableType, Overloaded, Type, TypeVarType, deserialize_type, get_proper_type,
|
||
|
)
|
||
|
from mypy.typevars import fill_typevars
|
||
|
from mypy.util import get_unique_redefinition_name
|
||
|
from mypy.typeops import try_getting_str_literals # noqa: F401 # Part of public API
|
||
|
from mypy.fixup import TypeFixer
|
||
|
|
||
|
|
||
|
def _get_decorator_bool_argument(
|
||
|
ctx: ClassDefContext,
|
||
|
name: str,
|
||
|
default: bool,
|
||
|
) -> bool:
|
||
|
"""Return the bool argument for the decorator.
|
||
|
|
||
|
This handles both @decorator(...) and @decorator.
|
||
|
"""
|
||
|
if isinstance(ctx.reason, CallExpr):
|
||
|
return _get_bool_argument(ctx, ctx.reason, name, default)
|
||
|
else:
|
||
|
return default
|
||
|
|
||
|
|
||
|
def _get_bool_argument(ctx: ClassDefContext, expr: CallExpr,
|
||
|
name: str, default: bool) -> bool:
|
||
|
"""Return the boolean value for an argument to a call or the
|
||
|
default if it's not found.
|
||
|
"""
|
||
|
attr_value = _get_argument(expr, name)
|
||
|
if attr_value:
|
||
|
ret = ctx.api.parse_bool(attr_value)
|
||
|
if ret is None:
|
||
|
ctx.api.fail(f'"{name}" argument must be True or False.', expr)
|
||
|
return default
|
||
|
return ret
|
||
|
return default
|
||
|
|
||
|
|
||
|
def _get_argument(call: CallExpr, name: str) -> Optional[Expression]:
|
||
|
"""Return the expression for the specific argument."""
|
||
|
# To do this we use the CallableType of the callee to find the FormalArgument,
|
||
|
# then walk the actual CallExpr looking for the appropriate argument.
|
||
|
#
|
||
|
# Note: I'm not hard-coding the index so that in the future we can support other
|
||
|
# attrib and class makers.
|
||
|
if not isinstance(call.callee, RefExpr):
|
||
|
return None
|
||
|
|
||
|
callee_type = None
|
||
|
callee_node = call.callee.node
|
||
|
if (isinstance(callee_node, (Var, SYMBOL_FUNCBASE_TYPES))
|
||
|
and callee_node.type):
|
||
|
callee_node_type = get_proper_type(callee_node.type)
|
||
|
if isinstance(callee_node_type, Overloaded):
|
||
|
# We take the last overload.
|
||
|
callee_type = callee_node_type.items[-1]
|
||
|
elif isinstance(callee_node_type, CallableType):
|
||
|
callee_type = callee_node_type
|
||
|
|
||
|
if not callee_type:
|
||
|
return None
|
||
|
|
||
|
argument = callee_type.argument_by_name(name)
|
||
|
if not argument:
|
||
|
return None
|
||
|
assert argument.name
|
||
|
|
||
|
for i, (attr_name, attr_value) in enumerate(zip(call.arg_names, call.args)):
|
||
|
if argument.pos is not None and not attr_name and i == argument.pos:
|
||
|
return attr_value
|
||
|
if attr_name == argument.name:
|
||
|
return attr_value
|
||
|
return None
|
||
|
|
||
|
|
||
|
def add_method(
|
||
|
ctx: ClassDefContext,
|
||
|
name: str,
|
||
|
args: List[Argument],
|
||
|
return_type: Type,
|
||
|
self_type: Optional[Type] = None,
|
||
|
tvar_def: Optional[TypeVarType] = None,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Adds a new method to a class.
|
||
|
Deprecated, use add_method_to_class() instead.
|
||
|
"""
|
||
|
add_method_to_class(ctx.api, ctx.cls,
|
||
|
name=name,
|
||
|
args=args,
|
||
|
return_type=return_type,
|
||
|
self_type=self_type,
|
||
|
tvar_def=tvar_def)
|
||
|
|
||
|
|
||
|
def add_method_to_class(
|
||
|
api: Union[SemanticAnalyzerPluginInterface, CheckerPluginInterface],
|
||
|
cls: ClassDef,
|
||
|
name: str,
|
||
|
args: List[Argument],
|
||
|
return_type: Type,
|
||
|
self_type: Optional[Type] = None,
|
||
|
tvar_def: Optional[TypeVarType] = None,
|
||
|
) -> None:
|
||
|
"""Adds a new method to a class definition."""
|
||
|
info = cls.info
|
||
|
|
||
|
# First remove any previously generated methods with the same name
|
||
|
# to avoid clashes and problems in the semantic analyzer.
|
||
|
if name in info.names:
|
||
|
sym = info.names[name]
|
||
|
if sym.plugin_generated and isinstance(sym.node, FuncDef):
|
||
|
cls.defs.body.remove(sym.node)
|
||
|
|
||
|
self_type = self_type or fill_typevars(info)
|
||
|
if isinstance(api, SemanticAnalyzerPluginInterface):
|
||
|
function_type = api.named_type('builtins.function')
|
||
|
else:
|
||
|
function_type = api.named_generic_type('builtins.function', [])
|
||
|
|
||
|
args = [Argument(Var('self'), self_type, None, ARG_POS)] + args
|
||
|
arg_types, arg_names, arg_kinds = [], [], []
|
||
|
for arg in args:
|
||
|
assert arg.type_annotation, 'All arguments must be fully typed.'
|
||
|
arg_types.append(arg.type_annotation)
|
||
|
arg_names.append(arg.variable.name)
|
||
|
arg_kinds.append(arg.kind)
|
||
|
|
||
|
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
|
||
|
if tvar_def:
|
||
|
signature.variables = [tvar_def]
|
||
|
|
||
|
func = FuncDef(name, args, Block([PassStmt()]))
|
||
|
func.info = info
|
||
|
func.type = set_callable_name(signature, func)
|
||
|
func._fullname = info.fullname + '.' + name
|
||
|
func.line = info.line
|
||
|
|
||
|
# NOTE: we would like the plugin generated node to dominate, but we still
|
||
|
# need to keep any existing definitions so they get semantically analyzed.
|
||
|
if name in info.names:
|
||
|
# Get a nice unique name instead.
|
||
|
r_name = get_unique_redefinition_name(name, info.names)
|
||
|
info.names[r_name] = info.names[name]
|
||
|
|
||
|
info.names[name] = SymbolTableNode(MDEF, func, plugin_generated=True)
|
||
|
info.defn.defs.body.append(func)
|
||
|
|
||
|
|
||
|
def add_attribute_to_class(
|
||
|
api: SemanticAnalyzerPluginInterface,
|
||
|
cls: ClassDef,
|
||
|
name: str,
|
||
|
typ: Type,
|
||
|
final: bool = False,
|
||
|
no_serialize: bool = False,
|
||
|
override_allow_incompatible: bool = False,
|
||
|
) -> None:
|
||
|
"""
|
||
|
Adds a new attribute to a class definition.
|
||
|
This currently only generates the symbol table entry and no corresponding AssignmentStatement
|
||
|
"""
|
||
|
info = cls.info
|
||
|
|
||
|
# NOTE: we would like the plugin generated node to dominate, but we still
|
||
|
# need to keep any existing definitions so they get semantically analyzed.
|
||
|
if name in info.names:
|
||
|
# Get a nice unique name instead.
|
||
|
r_name = get_unique_redefinition_name(name, info.names)
|
||
|
info.names[r_name] = info.names[name]
|
||
|
|
||
|
node = Var(name, typ)
|
||
|
node.info = info
|
||
|
node.is_final = final
|
||
|
if name in ALLOW_INCOMPATIBLE_OVERRIDE:
|
||
|
node.allow_incompatible_override = True
|
||
|
else:
|
||
|
node.allow_incompatible_override = override_allow_incompatible
|
||
|
node._fullname = info.fullname + '.' + name
|
||
|
info.names[name] = SymbolTableNode(
|
||
|
MDEF,
|
||
|
node,
|
||
|
plugin_generated=True,
|
||
|
no_serialize=no_serialize,
|
||
|
)
|
||
|
|
||
|
|
||
|
def deserialize_and_fixup_type(
|
||
|
data: Union[str, JsonDict], api: SemanticAnalyzerPluginInterface
|
||
|
) -> Type:
|
||
|
typ = deserialize_type(data)
|
||
|
typ.accept(TypeFixer(api.modules, allow_missing=False))
|
||
|
return typ
|