"""Utilities for pretty-printing IR in a human-readable form.""" from collections import defaultdict from typing import Any, Dict, List, Union, Sequence, Tuple from typing_extensions import Final from mypyc.common import short_name from mypyc.ir.ops import ( Goto, Branch, Return, Unreachable, Assign, Integer, LoadErrorValue, GetAttr, SetAttr, LoadStatic, InitStatic, TupleGet, TupleSet, IncRef, DecRef, Call, MethodCall, Cast, Box, Unbox, RaiseStandardError, CallC, Truncate, LoadGlobal, IntOp, ComparisonOp, LoadMem, SetMem, GetElementPtr, LoadAddress, Register, Value, OpVisitor, BasicBlock, ControlOp, LoadLiteral, AssignMulti, KeepAlive, Op, Extend, ERR_NEVER ) from mypyc.ir.func_ir import FuncIR, all_values_full from mypyc.ir.module_ir import ModuleIRs from mypyc.ir.rtypes import is_bool_rprimitive, is_int_rprimitive, RType ErrorSource = Union[BasicBlock, Op] class IRPrettyPrintVisitor(OpVisitor[str]): """Internal visitor that pretty-prints ops.""" def __init__(self, names: Dict[Value, str]) -> None: # This should contain a name for all values that are shown as # registers in the output. This is not just for Register # instances -- all Ops that produce values need (generated) names. self.names = names def visit_goto(self, op: Goto) -> str: return self.format('goto %l', op.label) branch_op_names: Final = { Branch.BOOL: ('%r', 'bool'), Branch.IS_ERROR: ('is_error(%r)', ''), } def visit_branch(self, op: Branch) -> str: fmt, typ = self.branch_op_names[op.op] if op.negated: fmt = f'not {fmt}' cond = self.format(fmt, op.value) tb = '' if op.traceback_entry: tb = ' (error at %s:%d)' % op.traceback_entry fmt = f'if {cond} goto %l{tb} else goto %l' if typ: fmt += f' :: {typ}' return self.format(fmt, op.true, op.false) def visit_return(self, op: Return) -> str: return self.format('return %r', op.value) def visit_unreachable(self, op: Unreachable) -> str: return "unreachable" def visit_assign(self, op: Assign) -> str: return self.format('%r = %r', op.dest, op.src) def visit_assign_multi(self, op: AssignMulti) -> str: return self.format('%r = [%s]', op.dest, ', '.join(self.format('%r', v) for v in op.src)) def visit_load_error_value(self, op: LoadErrorValue) -> str: return self.format('%r = :: %s', op, op.type) def visit_load_literal(self, op: LoadLiteral) -> str: prefix = '' # For values that have a potential unboxed representation, make # it explicit that this is a Python object. if isinstance(op.value, int): prefix = 'object ' return self.format('%r = %s%s', op, prefix, repr(op.value)) def visit_get_attr(self, op: GetAttr) -> str: return self.format('%r = %s%r.%s', op, self.borrow_prefix(op), op.obj, op.attr) def borrow_prefix(self, op: Op) -> str: if op.is_borrowed: return 'borrow ' return '' def visit_set_attr(self, op: SetAttr) -> str: if op.is_init: assert op.error_kind == ERR_NEVER # Initialization and direct struct access can never fail return self.format('%r.%s = %r', op.obj, op.attr, op.src) else: return self.format('%r.%s = %r; %r = is_error', op.obj, op.attr, op.src, op) def visit_load_static(self, op: LoadStatic) -> str: ann = f' ({repr(op.ann)})' if op.ann else '' name = op.identifier if op.module_name is not None: name = f'{op.module_name}.{name}' return self.format('%r = %s :: %s%s', op, name, op.namespace, ann) def visit_init_static(self, op: InitStatic) -> str: name = op.identifier if op.module_name is not None: name = f'{op.module_name}.{name}' return self.format('%s = %r :: %s', name, op.value, op.namespace) def visit_tuple_get(self, op: TupleGet) -> str: return self.format('%r = %r[%d]', op, op.src, op.index) def visit_tuple_set(self, op: TupleSet) -> str: item_str = ', '.join(self.format('%r', item) for item in op.items) return self.format('%r = (%s)', op, item_str) def visit_inc_ref(self, op: IncRef) -> str: s = self.format('inc_ref %r', op.src) # TODO: Remove bool check (it's unboxed) if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type): s += f' :: {short_name(op.src.type.name)}' return s def visit_dec_ref(self, op: DecRef) -> str: s = self.format('%sdec_ref %r', 'x' if op.is_xdec else '', op.src) # TODO: Remove bool check (it's unboxed) if is_bool_rprimitive(op.src.type) or is_int_rprimitive(op.src.type): s += f' :: {short_name(op.src.type.name)}' return s def visit_call(self, op: Call) -> str: args = ', '.join(self.format('%r', arg) for arg in op.args) # TODO: Display long name? short_name = op.fn.shortname s = f'{short_name}({args})' if not op.is_void: s = self.format('%r = ', op) + s return s def visit_method_call(self, op: MethodCall) -> str: args = ', '.join(self.format('%r', arg) for arg in op.args) s = self.format('%r.%s(%s)', op.obj, op.method, args) if not op.is_void: s = self.format('%r = ', op) + s return s def visit_cast(self, op: Cast) -> str: return self.format('%r = %scast(%s, %r)', op, self.borrow_prefix(op), op.type, op.src) def visit_box(self, op: Box) -> str: return self.format('%r = box(%s, %r)', op, op.src.type, op.src) def visit_unbox(self, op: Unbox) -> str: return self.format('%r = unbox(%s, %r)', op, op.type, op.src) def visit_raise_standard_error(self, op: RaiseStandardError) -> str: if op.value is not None: if isinstance(op.value, str): return self.format('%r = raise %s(%s)', op, op.class_name, repr(op.value)) elif isinstance(op.value, Value): return self.format('%r = raise %s(%r)', op, op.class_name, op.value) else: assert False, 'value type must be either str or Value' else: return self.format('%r = raise %s', op, op.class_name) def visit_call_c(self, op: CallC) -> str: args_str = ', '.join(self.format('%r', arg) for arg in op.args) if op.is_void: return self.format('%s(%s)', op.function_name, args_str) else: return self.format('%r = %s(%s)', op, op.function_name, args_str) def visit_truncate(self, op: Truncate) -> str: return self.format("%r = truncate %r: %t to %t", op, op.src, op.src_type, op.type) def visit_extend(self, op: Extend) -> str: if op.signed: extra = ' signed' else: extra = '' return self.format("%r = extend%s %r: %t to %t", op, extra, op.src, op.src_type, op.type) def visit_load_global(self, op: LoadGlobal) -> str: ann = f' ({repr(op.ann)})' if op.ann else '' return self.format('%r = load_global %s :: static%s', op, op.identifier, ann) def visit_int_op(self, op: IntOp) -> str: return self.format('%r = %r %s %r', op, op.lhs, IntOp.op_str[op.op], op.rhs) def visit_comparison_op(self, op: ComparisonOp) -> str: if op.op in (ComparisonOp.SLT, ComparisonOp.SGT, ComparisonOp.SLE, ComparisonOp.SGE): sign_format = " :: signed" elif op.op in (ComparisonOp.ULT, ComparisonOp.UGT, ComparisonOp.ULE, ComparisonOp.UGE): sign_format = " :: unsigned" else: sign_format = "" return self.format('%r = %r %s %r%s', op, op.lhs, ComparisonOp.op_str[op.op], op.rhs, sign_format) def visit_load_mem(self, op: LoadMem) -> str: return self.format("%r = load_mem %r :: %t*", op, op.src, op.type) def visit_set_mem(self, op: SetMem) -> str: return self.format("set_mem %r, %r :: %t*", op.dest, op.src, op.dest_type) def visit_get_element_ptr(self, op: GetElementPtr) -> str: return self.format("%r = get_element_ptr %r %s :: %t", op, op.src, op.field, op.src_type) def visit_load_address(self, op: LoadAddress) -> str: if isinstance(op.src, Register): return self.format("%r = load_address %r", op, op.src) else: return self.format("%r = load_address %s", op, op.src) def visit_keep_alive(self, op: KeepAlive) -> str: return self.format('keep_alive %s' % ', '.join(self.format('%r', v) for v in op.src)) # Helpers def format(self, fmt: str, *args: Any) -> str: """Helper for formatting strings. These format sequences are supported in fmt: %s: arbitrary object converted to string using str() %r: name of IR value/register %d: int %f: float %l: BasicBlock (formatted as label 'Ln') %t: RType """ result = [] i = 0 arglist = list(args) while i < len(fmt): n = fmt.find('%', i) if n < 0: n = len(fmt) result.append(fmt[i:n]) if n < len(fmt): typespec = fmt[n + 1] arg = arglist.pop(0) if typespec == 'r': # Register/value assert isinstance(arg, Value) if isinstance(arg, Integer): result.append(str(arg.value)) else: result.append(self.names[arg]) elif typespec == 'd': # Integer result.append('%d' % arg) elif typespec == 'f': # Float result.append('%f' % arg) elif typespec == 'l': # Basic block (label) assert isinstance(arg, BasicBlock) result.append('L%s' % arg.label) elif typespec == 't': # RType assert isinstance(arg, RType) result.append(arg.name) elif typespec == 's': # String result.append(str(arg)) else: raise ValueError(f'Invalid format sequence %{typespec}') i = n + 2 else: i = n return ''.join(result) def format_registers(func_ir: FuncIR, names: Dict[Value, str]) -> List[str]: result = [] i = 0 regs = all_values_full(func_ir.arg_regs, func_ir.blocks) while i < len(regs): i0 = i group = [names[regs[i0]]] while i + 1 < len(regs) and regs[i + 1].type == regs[i0].type: i += 1 group.append(names[regs[i]]) i += 1 result.append('{} :: {}'.format(', '.join(group), regs[i0].type)) return result def format_blocks(blocks: List[BasicBlock], names: Dict[Value, str], source_to_error: Dict[ErrorSource, List[str]]) -> List[str]: """Format a list of IR basic blocks into a human-readable form.""" # First label all of the blocks for i, block in enumerate(blocks): block.label = i handler_map: Dict[BasicBlock, List[BasicBlock]] = {} for b in blocks: if b.error_handler: handler_map.setdefault(b.error_handler, []).append(b) visitor = IRPrettyPrintVisitor(names) lines = [] for i, block in enumerate(blocks): handler_msg = '' if block in handler_map: labels = sorted('L%d' % b.label for b in handler_map[block]) handler_msg = ' (handler for {})'.format(', '.join(labels)) lines.append('L%d:%s' % (block.label, handler_msg)) if block in source_to_error: for error in source_to_error[block]: lines.append(f" ERR: {error}") ops = block.ops if (isinstance(ops[-1], Goto) and i + 1 < len(blocks) and ops[-1].label == blocks[i + 1] and not source_to_error.get(ops[-1], [])): # Hide the last goto if it just goes to the next basic block, # and there are no assocatiated errors with the op. ops = ops[:-1] for op in ops: line = ' ' + op.accept(visitor) lines.append(line) if op in source_to_error: for error in source_to_error[op]: lines.append(f" ERR: {error}") if not isinstance(block.ops[-1], (Goto, Branch, Return, Unreachable)): # Each basic block needs to exit somewhere. lines.append(' [MISSING BLOCK EXIT OPCODE]') return lines def format_func(fn: FuncIR, errors: Sequence[Tuple[ErrorSource, str]] = ()) -> List[str]: lines = [] cls_prefix = fn.class_name + '.' if fn.class_name else '' lines.append('def {}{}({}):'.format(cls_prefix, fn.name, ', '.join(arg.name for arg in fn.args))) names = generate_names_for_ir(fn.arg_regs, fn.blocks) for line in format_registers(fn, names): lines.append(' ' + line) source_to_error = defaultdict(list) for source, error in errors: source_to_error[source].append(error) code = format_blocks(fn.blocks, names, source_to_error) lines.extend(code) return lines def format_modules(modules: ModuleIRs) -> List[str]: ops = [] for module in modules.values(): for fn in module.functions: ops.extend(format_func(fn)) ops.append('') return ops def generate_names_for_ir(args: List[Register], blocks: List[BasicBlock]) -> Dict[Value, str]: """Generate unique names for IR values. Give names such as 'r5' to temp values in IR which are useful when pretty-printing or generating C. Ensure generated names are unique. """ names: Dict[Value, str] = {} used_names = set() temp_index = 0 for arg in args: names[arg] = arg.name used_names.add(arg.name) for block in blocks: for op in block.ops: values = [] for source in op.sources(): if source not in names: values.append(source) if isinstance(op, (Assign, AssignMulti)): values.append(op.dest) elif isinstance(op, ControlOp) or op.is_void: continue elif op not in names: values.append(op) for value in values: if value in names: continue if isinstance(value, Register) and value.name: name = value.name elif isinstance(value, Integer): continue else: name = 'r%d' % temp_index temp_index += 1 # Append _2, _3, ... if needed to make the name unique. if name in used_names: n = 2 while True: candidate = '%s_%d' % (name, n) if candidate not in used_names: name = candidate break n += 1 names[value] = name used_names.add(name) return names