usse/funda-scraper/venv/lib/python3.10/site-packages/mypyc/irbuild/constant_fold.py

100 lines
3.0 KiB
Python

"""Constant folding of IR values.
For example, 3 + 5 can be constant folded into 8.
"""
from typing import Optional, Union
from typing_extensions import Final
from mypy.nodes import Expression, IntExpr, StrExpr, OpExpr, UnaryExpr, NameExpr, MemberExpr, Var
from mypyc.irbuild.builder import IRBuilder
# All possible result types of constant folding
ConstantValue = Union[int, str]
CONST_TYPES: Final = (int, str)
def constant_fold_expr(builder: IRBuilder, expr: Expression) -> Optional[ConstantValue]:
"""Return the constant value of an expression for supported operations.
Return None otherwise.
"""
if isinstance(expr, IntExpr):
return expr.value
if isinstance(expr, StrExpr):
return expr.value
elif isinstance(expr, NameExpr):
node = expr.node
if isinstance(node, Var) and node.is_final:
value = node.final_value
if isinstance(value, (CONST_TYPES)):
return value
elif isinstance(expr, MemberExpr):
final = builder.get_final_ref(expr)
if final is not None:
fn, final_var, native = final
if final_var.is_final:
value = final_var.final_value
if isinstance(value, (CONST_TYPES)):
return value
elif isinstance(expr, OpExpr):
left = constant_fold_expr(builder, expr.left)
right = constant_fold_expr(builder, expr.right)
if isinstance(left, int) and isinstance(right, int):
return constant_fold_binary_int_op(expr.op, left, right)
elif isinstance(left, str) and isinstance(right, str):
return constant_fold_binary_str_op(expr.op, left, right)
elif isinstance(expr, UnaryExpr):
value = constant_fold_expr(builder, expr.expr)
if isinstance(value, int):
return constant_fold_unary_int_op(expr.op, value)
return None
def constant_fold_binary_int_op(op: str, left: int, right: int) -> Optional[int]:
if op == '+':
return left + right
if op == '-':
return left - right
elif op == '*':
return left * right
elif op == '//':
if right != 0:
return left // right
elif op == '%':
if right != 0:
return left % right
elif op == '&':
return left & right
elif op == '|':
return left | right
elif op == '^':
return left ^ right
elif op == '<<':
if right >= 0:
return left << right
elif op == '>>':
if right >= 0:
return left >> right
elif op == '**':
if right >= 0:
return left ** right
return None
def constant_fold_unary_int_op(op: str, value: int) -> Optional[int]:
if op == '-':
return -value
elif op == '~':
return ~value
elif op == '+':
return value
return None
def constant_fold_binary_str_op(op: str, left: str, right: str) -> Optional[str]:
if op == '+':
return left + right
return None