189 lines
6.2 KiB
Python
189 lines
6.2 KiB
Python
|
"""Helpers for AST (Abstract Syntax Tree)."""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import ast
|
||
|
from typing import overload
|
||
|
|
||
|
OPERATORS: dict[type[ast.AST], str] = {
|
||
|
ast.Add: "+",
|
||
|
ast.And: "and",
|
||
|
ast.BitAnd: "&",
|
||
|
ast.BitOr: "|",
|
||
|
ast.BitXor: "^",
|
||
|
ast.Div: "/",
|
||
|
ast.FloorDiv: "//",
|
||
|
ast.Invert: "~",
|
||
|
ast.LShift: "<<",
|
||
|
ast.MatMult: "@",
|
||
|
ast.Mult: "*",
|
||
|
ast.Mod: "%",
|
||
|
ast.Not: "not",
|
||
|
ast.Pow: "**",
|
||
|
ast.Or: "or",
|
||
|
ast.RShift: ">>",
|
||
|
ast.Sub: "-",
|
||
|
ast.UAdd: "+",
|
||
|
ast.USub: "-",
|
||
|
}
|
||
|
|
||
|
|
||
|
@overload
|
||
|
def unparse(node: None, code: str = '') -> None:
|
||
|
...
|
||
|
|
||
|
|
||
|
@overload
|
||
|
def unparse(node: ast.AST, code: str = '') -> str:
|
||
|
...
|
||
|
|
||
|
|
||
|
def unparse(node: ast.AST | None, code: str = '') -> str | None:
|
||
|
"""Unparse an AST to string."""
|
||
|
if node is None:
|
||
|
return None
|
||
|
elif isinstance(node, str):
|
||
|
return node
|
||
|
return _UnparseVisitor(code).visit(node)
|
||
|
|
||
|
|
||
|
# a greatly cut-down version of `ast._Unparser`
|
||
|
class _UnparseVisitor(ast.NodeVisitor):
|
||
|
def __init__(self, code: str = '') -> None:
|
||
|
self.code = code
|
||
|
|
||
|
def _visit_op(self, node: ast.AST) -> str:
|
||
|
return OPERATORS[node.__class__]
|
||
|
for _op in OPERATORS:
|
||
|
locals()[f'visit_{_op.__name__}'] = _visit_op
|
||
|
|
||
|
def visit_arg(self, node: ast.arg) -> str:
|
||
|
if node.annotation:
|
||
|
return f"{node.arg}: {self.visit(node.annotation)}"
|
||
|
else:
|
||
|
return node.arg
|
||
|
|
||
|
def _visit_arg_with_default(self, arg: ast.arg, default: ast.AST | None) -> str:
|
||
|
"""Unparse a single argument to a string."""
|
||
|
name = self.visit(arg)
|
||
|
if default:
|
||
|
if arg.annotation:
|
||
|
name += " = %s" % self.visit(default)
|
||
|
else:
|
||
|
name += "=%s" % self.visit(default)
|
||
|
return name
|
||
|
|
||
|
def visit_arguments(self, node: ast.arguments) -> str:
|
||
|
defaults: list[ast.expr | None] = list(node.defaults)
|
||
|
positionals = len(node.args)
|
||
|
posonlyargs = len(node.posonlyargs)
|
||
|
positionals += posonlyargs
|
||
|
for _ in range(len(defaults), positionals):
|
||
|
defaults.insert(0, None)
|
||
|
|
||
|
kw_defaults: list[ast.expr | None] = list(node.kw_defaults)
|
||
|
for _ in range(len(kw_defaults), len(node.kwonlyargs)):
|
||
|
kw_defaults.insert(0, None)
|
||
|
|
||
|
args: list[str] = []
|
||
|
for i, arg in enumerate(node.posonlyargs):
|
||
|
args.append(self._visit_arg_with_default(arg, defaults[i]))
|
||
|
|
||
|
if node.posonlyargs:
|
||
|
args.append('/')
|
||
|
|
||
|
for i, arg in enumerate(node.args):
|
||
|
args.append(self._visit_arg_with_default(arg, defaults[i + posonlyargs]))
|
||
|
|
||
|
if node.vararg:
|
||
|
args.append("*" + self.visit(node.vararg))
|
||
|
|
||
|
if node.kwonlyargs and not node.vararg:
|
||
|
args.append('*')
|
||
|
for i, arg in enumerate(node.kwonlyargs):
|
||
|
args.append(self._visit_arg_with_default(arg, kw_defaults[i]))
|
||
|
|
||
|
if node.kwarg:
|
||
|
args.append("**" + self.visit(node.kwarg))
|
||
|
|
||
|
return ", ".join(args)
|
||
|
|
||
|
def visit_Attribute(self, node: ast.Attribute) -> str:
|
||
|
return f"{self.visit(node.value)}.{node.attr}"
|
||
|
|
||
|
def visit_BinOp(self, node: ast.BinOp) -> str:
|
||
|
# Special case ``**`` to not have surrounding spaces.
|
||
|
if isinstance(node.op, ast.Pow):
|
||
|
return "".join(map(self.visit, (node.left, node.op, node.right)))
|
||
|
return " ".join(self.visit(e) for e in [node.left, node.op, node.right])
|
||
|
|
||
|
def visit_BoolOp(self, node: ast.BoolOp) -> str:
|
||
|
op = " %s " % self.visit(node.op)
|
||
|
return op.join(self.visit(e) for e in node.values)
|
||
|
|
||
|
def visit_Call(self, node: ast.Call) -> str:
|
||
|
args = ', '.join([self.visit(e) for e in node.args]
|
||
|
+ [f"{k.arg}={self.visit(k.value)}" for k in node.keywords])
|
||
|
return f"{self.visit(node.func)}({args})"
|
||
|
|
||
|
def visit_Constant(self, node: ast.Constant) -> str:
|
||
|
if node.value is Ellipsis:
|
||
|
return "..."
|
||
|
elif isinstance(node.value, (int, float, complex)):
|
||
|
if self.code:
|
||
|
return ast.get_source_segment(self.code, node) or repr(node.value)
|
||
|
else:
|
||
|
return repr(node.value)
|
||
|
else:
|
||
|
return repr(node.value)
|
||
|
|
||
|
def visit_Dict(self, node: ast.Dict) -> str:
|
||
|
keys = (self.visit(k) for k in node.keys if k is not None)
|
||
|
values = (self.visit(v) for v in node.values)
|
||
|
items = (k + ": " + v for k, v in zip(keys, values))
|
||
|
return "{" + ", ".join(items) + "}"
|
||
|
|
||
|
def visit_Lambda(self, node: ast.Lambda) -> str:
|
||
|
return "lambda %s: ..." % self.visit(node.args)
|
||
|
|
||
|
def visit_List(self, node: ast.List) -> str:
|
||
|
return "[" + ", ".join(self.visit(e) for e in node.elts) + "]"
|
||
|
|
||
|
def visit_Name(self, node: ast.Name) -> str:
|
||
|
return node.id
|
||
|
|
||
|
def visit_Set(self, node: ast.Set) -> str:
|
||
|
return "{" + ", ".join(self.visit(e) for e in node.elts) + "}"
|
||
|
|
||
|
def visit_Subscript(self, node: ast.Subscript) -> str:
|
||
|
def is_simple_tuple(value: ast.expr) -> bool:
|
||
|
return (
|
||
|
isinstance(value, ast.Tuple)
|
||
|
and bool(value.elts)
|
||
|
and not any(isinstance(elt, ast.Starred) for elt in value.elts)
|
||
|
)
|
||
|
|
||
|
if is_simple_tuple(node.slice):
|
||
|
elts = ", ".join(self.visit(e)
|
||
|
for e in node.slice.elts) # type: ignore[attr-defined]
|
||
|
return f"{self.visit(node.value)}[{elts}]"
|
||
|
return f"{self.visit(node.value)}[{self.visit(node.slice)}]"
|
||
|
|
||
|
def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
|
||
|
# UnaryOp is one of {UAdd, USub, Invert, Not}, which refer to ``+x``,
|
||
|
# ``-x``, ``~x``, and ``not x``. Only Not needs a space.
|
||
|
if isinstance(node.op, ast.Not):
|
||
|
return f"{self.visit(node.op)} {self.visit(node.operand)}"
|
||
|
return f"{self.visit(node.op)}{self.visit(node.operand)}"
|
||
|
|
||
|
def visit_Tuple(self, node: ast.Tuple) -> str:
|
||
|
if len(node.elts) == 0:
|
||
|
return "()"
|
||
|
elif len(node.elts) == 1:
|
||
|
return "(%s,)" % self.visit(node.elts[0])
|
||
|
else:
|
||
|
return "(" + ", ".join(self.visit(e) for e in node.elts) + ")"
|
||
|
|
||
|
def generic_visit(self, node):
|
||
|
raise NotImplementedError('Unable to parse %s object' % type(node).__name__)
|