"""Test cases for type inference helper functions.""" from typing import List, Optional, Tuple, Union, Dict, Set from mypy.test.helpers import Suite, assert_equal from mypy.argmap import map_actuals_to_formals from mypy.checker import group_comparison_operands, DisjointDict from mypy.literals import Key from mypy.nodes import ArgKind, ARG_POS, ARG_OPT, ARG_STAR, ARG_STAR2, ARG_NAMED, NameExpr from mypy.types import AnyType, TupleType, Type, TypeOfAny from mypy.test.typefixture import TypeFixture class MapActualsToFormalsSuite(Suite): """Test cases for argmap.map_actuals_to_formals.""" def test_basic(self) -> None: self.assert_map([], [], []) def test_positional_only(self) -> None: self.assert_map([ARG_POS], [ARG_POS], [[0]]) self.assert_map([ARG_POS, ARG_POS], [ARG_POS, ARG_POS], [[0], [1]]) def test_optional(self) -> None: self.assert_map([], [ARG_OPT], [[]]) self.assert_map([ARG_POS], [ARG_OPT], [[0]]) self.assert_map([ARG_POS], [ARG_OPT, ARG_OPT], [[0], []]) def test_callee_star(self) -> None: self.assert_map([], [ARG_STAR], [[]]) self.assert_map([ARG_POS], [ARG_STAR], [[0]]) self.assert_map([ARG_POS, ARG_POS], [ARG_STAR], [[0, 1]]) def test_caller_star(self) -> None: self.assert_map([ARG_STAR], [ARG_STAR], [[0]]) self.assert_map([ARG_POS, ARG_STAR], [ARG_STAR], [[0, 1]]) self.assert_map([ARG_STAR], [ARG_POS, ARG_STAR], [[0], [0]]) self.assert_map([ARG_STAR], [ARG_OPT, ARG_STAR], [[0], [0]]) def test_too_many_caller_args(self) -> None: self.assert_map([ARG_POS], [], []) self.assert_map([ARG_STAR], [], []) self.assert_map([ARG_STAR], [ARG_POS], [[0]]) def test_tuple_star(self) -> None: any_type = AnyType(TypeOfAny.special_form) self.assert_vararg_map( [ARG_STAR], [ARG_POS], [[0]], self.tuple(any_type)) self.assert_vararg_map( [ARG_STAR], [ARG_POS, ARG_POS], [[0], [0]], self.tuple(any_type, any_type)) self.assert_vararg_map( [ARG_STAR], [ARG_POS, ARG_OPT, ARG_OPT], [[0], [0], []], self.tuple(any_type, any_type)) def tuple(self, *args: Type) -> TupleType: return TupleType(list(args), TypeFixture().std_tuple) def test_named_args(self) -> None: self.assert_map( ['x'], [(ARG_POS, 'x')], [[0]]) self.assert_map( ['y', 'x'], [(ARG_POS, 'x'), (ARG_POS, 'y')], [[1], [0]]) def test_some_named_args(self) -> None: self.assert_map( ['y'], [(ARG_OPT, 'x'), (ARG_OPT, 'y'), (ARG_OPT, 'z')], [[], [0], []]) def test_missing_named_arg(self) -> None: self.assert_map( ['y'], [(ARG_OPT, 'x')], [[]]) def test_duplicate_named_arg(self) -> None: self.assert_map( ['x', 'x'], [(ARG_OPT, 'x')], [[0, 1]]) def test_varargs_and_bare_asterisk(self) -> None: self.assert_map( [ARG_STAR], [ARG_STAR, (ARG_NAMED, 'x')], [[0], []]) self.assert_map( [ARG_STAR, 'x'], [ARG_STAR, (ARG_NAMED, 'x')], [[0], [1]]) def test_keyword_varargs(self) -> None: self.assert_map( ['x'], [ARG_STAR2], [[0]]) self.assert_map( ['x', ARG_STAR2], [ARG_STAR2], [[0, 1]]) self.assert_map( ['x', ARG_STAR2], [(ARG_POS, 'x'), ARG_STAR2], [[0], [1]]) self.assert_map( [ARG_POS, ARG_STAR2], [(ARG_POS, 'x'), ARG_STAR2], [[0], [1]]) def test_both_kinds_of_varargs(self) -> None: self.assert_map( [ARG_STAR, ARG_STAR2], [(ARG_POS, 'x'), (ARG_POS, 'y')], [[0, 1], [0, 1]]) def test_special_cases(self) -> None: self.assert_map([ARG_STAR], [ARG_STAR, ARG_STAR2], [[0], []]) self.assert_map([ARG_STAR, ARG_STAR2], [ARG_STAR, ARG_STAR2], [[0], [1]]) self.assert_map([ARG_STAR2], [(ARG_POS, 'x'), ARG_STAR2], [[0], [0]]) self.assert_map([ARG_STAR2], [ARG_STAR2], [[0]]) def assert_map(self, caller_kinds_: List[Union[ArgKind, str]], callee_kinds_: List[Union[ArgKind, Tuple[ArgKind, str]]], expected: List[List[int]], ) -> None: caller_kinds, caller_names = expand_caller_kinds(caller_kinds_) callee_kinds, callee_names = expand_callee_kinds(callee_kinds_) result = map_actuals_to_formals( caller_kinds, caller_names, callee_kinds, callee_names, lambda i: AnyType(TypeOfAny.special_form)) assert_equal(result, expected) def assert_vararg_map(self, caller_kinds: List[ArgKind], callee_kinds: List[ArgKind], expected: List[List[int]], vararg_type: Type, ) -> None: result = map_actuals_to_formals( caller_kinds, [], callee_kinds, [], lambda i: vararg_type) assert_equal(result, expected) def expand_caller_kinds(kinds_or_names: List[Union[ArgKind, str]] ) -> Tuple[List[ArgKind], List[Optional[str]]]: kinds = [] names: List[Optional[str]] = [] for k in kinds_or_names: if isinstance(k, str): kinds.append(ARG_NAMED) names.append(k) else: kinds.append(k) names.append(None) return kinds, names def expand_callee_kinds(kinds_and_names: List[Union[ArgKind, Tuple[ArgKind, str]]] ) -> Tuple[List[ArgKind], List[Optional[str]]]: kinds = [] names: List[Optional[str]] = [] for v in kinds_and_names: if isinstance(v, tuple): kinds.append(v[0]) names.append(v[1]) else: kinds.append(v) names.append(None) return kinds, names class OperandDisjointDictSuite(Suite): """Test cases for checker.DisjointDict, which is used for type inference with operands.""" def new(self) -> DisjointDict[int, str]: return DisjointDict() def test_independent_maps(self) -> None: d = self.new() d.add_mapping({0, 1}, {"group1"}) d.add_mapping({2, 3, 4}, {"group2"}) d.add_mapping({5, 6, 7}, {"group3"}) self.assertEqual(d.items(), [ ({0, 1}, {"group1"}), ({2, 3, 4}, {"group2"}), ({5, 6, 7}, {"group3"}), ]) def test_partial_merging(self) -> None: d = self.new() d.add_mapping({0, 1}, {"group1"}) d.add_mapping({1, 2}, {"group2"}) d.add_mapping({3, 4}, {"group3"}) d.add_mapping({5, 0}, {"group4"}) d.add_mapping({5, 6}, {"group5"}) d.add_mapping({4, 7}, {"group6"}) self.assertEqual(d.items(), [ ({0, 1, 2, 5, 6}, {"group1", "group2", "group4", "group5"}), ({3, 4, 7}, {"group3", "group6"}), ]) def test_full_merging(self) -> None: d = self.new() d.add_mapping({0, 1, 2}, {"a"}) d.add_mapping({3, 4, 2}, {"b"}) d.add_mapping({10, 11, 12}, {"c"}) d.add_mapping({13, 14, 15}, {"d"}) d.add_mapping({14, 10, 16}, {"e"}) d.add_mapping({0, 10}, {"f"}) self.assertEqual(d.items(), [ ({0, 1, 2, 3, 4, 10, 11, 12, 13, 14, 15, 16}, {"a", "b", "c", "d", "e", "f"}), ]) def test_merge_with_multiple_overlaps(self) -> None: d = self.new() d.add_mapping({0, 1, 2}, {"a"}) d.add_mapping({3, 4, 5}, {"b"}) d.add_mapping({1, 2, 4, 5}, {"c"}) d.add_mapping({6, 1, 2, 4, 5}, {"d"}) d.add_mapping({6, 1, 2, 4, 5}, {"e"}) self.assertEqual(d.items(), [ ({0, 1, 2, 3, 4, 5, 6}, {"a", "b", "c", "d", "e"}), ]) class OperandComparisonGroupingSuite(Suite): """Test cases for checker.group_comparison_operands.""" def literal_keymap(self, assignable_operands: Dict[int, NameExpr]) -> Dict[int, Key]: output: Dict[int, Key] = {} for index, expr in assignable_operands.items(): output[index] = ('FakeExpr', expr.name) return output def test_basic_cases(self) -> None: # Note: the grouping function doesn't actually inspect the input exprs, so we # just default to using NameExprs for simplicity. x0 = NameExpr('x0') x1 = NameExpr('x1') x2 = NameExpr('x2') x3 = NameExpr('x3') x4 = NameExpr('x4') basic_input = [('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('==', x3, x4)] none_assignable = self.literal_keymap({}) all_assignable = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4}) for assignable in [none_assignable, all_assignable]: self.assertEqual( group_comparison_operands(basic_input, assignable, set()), [('==', [0, 1]), ('==', [1, 2]), ('<', [2, 3]), ('==', [3, 4])], ) self.assertEqual( group_comparison_operands(basic_input, assignable, {'=='}), [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4])], ) self.assertEqual( group_comparison_operands(basic_input, assignable, {'<'}), [('==', [0, 1]), ('==', [1, 2]), ('<', [2, 3]), ('==', [3, 4])], ) self.assertEqual( group_comparison_operands(basic_input, assignable, {'==', '<'}), [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4])], ) def test_multiple_groups(self) -> None: x0 = NameExpr('x0') x1 = NameExpr('x1') x2 = NameExpr('x2') x3 = NameExpr('x3') x4 = NameExpr('x4') x5 = NameExpr('x5') self.assertEqual( group_comparison_operands( [('==', x0, x1), ('==', x1, x2), ('is', x2, x3), ('is', x3, x4)], self.literal_keymap({}), {'==', 'is'}, ), [('==', [0, 1, 2]), ('is', [2, 3, 4])], ) self.assertEqual( group_comparison_operands( [('==', x0, x1), ('==', x1, x2), ('==', x2, x3), ('==', x3, x4)], self.literal_keymap({}), {'==', 'is'}, ), [('==', [0, 1, 2, 3, 4])], ) self.assertEqual( group_comparison_operands( [('is', x0, x1), ('==', x1, x2), ('==', x2, x3), ('==', x3, x4)], self.literal_keymap({}), {'==', 'is'}, ), [('is', [0, 1]), ('==', [1, 2, 3, 4])], ) self.assertEqual( group_comparison_operands( [('is', x0, x1), ('is', x1, x2), ('<', x2, x3), ('==', x3, x4), ('==', x4, x5)], self.literal_keymap({}), {'==', 'is'}, ), [('is', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4, 5])], ) def test_multiple_groups_coalescing(self) -> None: x0 = NameExpr('x0') x1 = NameExpr('x1') x2 = NameExpr('x2') x3 = NameExpr('x3') x4 = NameExpr('x4') nothing_combined = [('==', [0, 1, 2]), ('<', [2, 3]), ('==', [3, 4, 5])] everything_combined = [('==', [0, 1, 2, 3, 4, 5]), ('<', [2, 3])] # Note: We do 'x4 == x0' at the very end! two_groups = [ ('==', x0, x1), ('==', x1, x2), ('<', x2, x3), ('==', x3, x4), ('==', x4, x0), ] self.assertEqual( group_comparison_operands( two_groups, self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x4, 5: x0}), {'=='}, ), everything_combined, "All vars are assignable, everything is combined" ) self.assertEqual( group_comparison_operands( two_groups, self.literal_keymap({1: x1, 2: x2, 3: x3, 4: x4}), {'=='}, ), nothing_combined, "x0 is unassignable, so no combining" ) self.assertEqual( group_comparison_operands( two_groups, self.literal_keymap({0: x0, 1: x1, 3: x3, 5: x0}), {'=='}, ), everything_combined, "Some vars are unassignable but x0 is, so we combine" ) self.assertEqual( group_comparison_operands( two_groups, self.literal_keymap({0: x0, 5: x0}), {'=='}, ), everything_combined, "All vars are unassignable but x0 is, so we combine" ) def test_multiple_groups_different_operators(self) -> None: x0 = NameExpr('x0') x1 = NameExpr('x1') x2 = NameExpr('x2') x3 = NameExpr('x3') groups = [('==', x0, x1), ('==', x1, x2), ('is', x2, x3), ('is', x3, x0)] keymap = self.literal_keymap({0: x0, 1: x1, 2: x2, 3: x3, 4: x0}) self.assertEqual( group_comparison_operands(groups, keymap, {'==', 'is'}), [('==', [0, 1, 2]), ('is', [2, 3, 4])], "Different operators can never be combined" ) def test_single_pair(self) -> None: x0 = NameExpr('x0') x1 = NameExpr('x1') single_comparison = [('==', x0, x1)] expected_output = [('==', [0, 1])] assignable_combinations: List[Dict[int, NameExpr]] = [ {}, {0: x0}, {1: x1}, {0: x0, 1: x1}, ] to_group_by: List[Set[str]] = [set(), {"=="}, {"is"}] for combo in assignable_combinations: for operators in to_group_by: keymap = self.literal_keymap(combo) self.assertEqual( group_comparison_operands(single_comparison, keymap, operators), expected_output, ) def test_empty_pair_list(self) -> None: # This case should never occur in practice -- ComparisionExprs # always contain at least one comparison. But in case it does... self.assertEqual(group_comparison_operands([], {}, set()), []) self.assertEqual(group_comparison_operands([], {}, {'=='}), [])