diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 6df02870d104..59b1ada4faf5 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3793,11 +3793,18 @@ def dangerous_comparison( ) and not identity_check: return False - if prefer_literal: + if prefer_literal and not ( + isinstance(left, Instance) + and not left.type.is_enum + and isinstance(right, Instance) + and not right.type.is_enum + ): # Also flag non-overlapping literals in situations like: # x: Literal['a', 'b'] # if x == 'c': # ... + # but only in situations where at least one side is an actual literal/enum type. + # see testNarrowingEqualityFlipFlop for an example where we want to skip this. left = try_getting_literal(left) right = try_getting_literal(right) diff --git a/mypy/join.py b/mypy/join.py index a8c9910e60bb..233a915df76a 100644 --- a/mypy/join.py +++ b/mypy/join.py @@ -141,11 +141,19 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType: new_type = join_types(ta, sa, self) assert new_type is not None args.append(new_type) - result: ProperType = Instance(t.type, args) + + lkv = join_last_known_values(t, s) + result: ProperType = Instance(t.type, args, last_known_value=lkv) elif t.type.bases and is_proper_subtype( t, s, subtype_context=SubtypeContext(ignore_type_params=True) ): result = self.join_instances_via_supertype(t, s) + elif s.type.bases and is_proper_subtype( + s, t, subtype_context=SubtypeContext(ignore_type_params=True) + ): + result = self.join_instances_via_supertype(s, t) + elif is_subtype(t, s, subtype_context=SubtypeContext(ignore_type_params=True)): + result = self.join_instances_via_supertype(t, s) else: # Now t is not a subtype of s, and t != s. Now s could be a subtype # of t; alternatively, we need to find a common supertype. This works @@ -636,13 +644,17 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType: def visit_literal_type(self, t: LiteralType) -> ProperType: if isinstance(self.s, LiteralType): if t == self.s: + # E.g. Literal["x"], Literal["x"] -> Literal["x"] return t if self.s.fallback.type.is_enum and t.fallback.type.is_enum: return mypy.typeops.make_simplified_union([self.s, t]) + # E.g. Literal["x"], Literal["y"] -> str return join_types(self.s.fallback, t.fallback) elif isinstance(self.s, Instance) and self.s.last_known_value == t: - return t + # E.g. Literal["x"], Literal["x"]? -> Literal["x"]? + return self.s else: + # E.g. Literal["x"], Literal["y"]? -> str return join_types(self.s, t.fallback) def visit_partial_type(self, t: PartialType) -> ProperType: @@ -782,6 +794,16 @@ def join_similar_callables(t: CallableType, s: CallableType) -> CallableType: ) +def join_last_known_values(t: Instance, s: Instance) -> LiteralType | None: + r"""Return the join of two last known values.""" + # drop last known value if they differ (e.g. join('x'?, 'y'?) -> str) + left = t.last_known_value + right = s.last_known_value + if left is None or right is None or left != right: + return None + return left + + def safe_join(t: Type, s: Type) -> Type: # This is a temporary solution to prevent crashes in combine_similar_callables() etc., # until relevant TODOs on handling arg_kinds will be addressed there. diff --git a/mypy/meet.py b/mypy/meet.py index 365544d4584f..98e2640b6225 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -81,6 +81,11 @@ def meet_types(s: Type, t: Type) -> ProperType: t = get_proper_type(t) if isinstance(s, Instance) and isinstance(t, Instance) and s.type == t.type: + # special casing for dealing with last known values + lkv = meet_last_known_values(t, s) + t = t.copy_modified(last_known_value=lkv) + s = s.copy_modified(last_known_value=lkv) + # Code in checker.py should merge any extra_items where possible, so we # should have only compatible extra_items here. We check this before # the below subtype check, so that extra_attrs will not get erased. @@ -113,6 +118,31 @@ def meet_types(s: Type, t: Type) -> ProperType: return t.accept(TypeMeetVisitor(s)) +def meet_last_known_values(t: Instance, s: Instance) -> LiteralType | None: + """Return the meet of two last_known_values.""" + left = t.last_known_value + right = s.last_known_value + + if left is None: + return right + if right is None: + return left + + lkv_meet = meet_types(left, right) + + if isinstance(lkv_meet, UninhabitedType): + return None + if isinstance(lkv_meet, LiteralType): + return lkv_meet + + msg = ( + f"Unexpected result: " + f"meet of last_known_values {left=!s} and {right=!s} " + f"resulted in {lkv_meet!s}" + ) + raise ValueError(msg) + + def narrow_declared_type(declared: Type, narrowed: Type) -> Type: """Return the declared type narrowed down to another type.""" # TODO: check infinite recursion for aliases here. @@ -900,7 +930,8 @@ def visit_instance(self, t: Instance) -> ProperType: assert isinstance(meet, UninhabitedType) meet = UnpackType(tv.tuple_fallback.copy_modified(args=[meet])) args.append(meet) - return Instance(t.type, args) + lkv = meet_last_known_values(t, self.s) + return Instance(t.type, args, last_known_value=lkv) else: if state.strict_optional: return UninhabitedType() @@ -1066,13 +1097,22 @@ def meet_tuples(self, s: TupleType, t: TupleType) -> list[Type] | None: items.append(self.meet(fi, vi)) return items + def meet_tuple_fallbacks(self, s: TupleType, t: TupleType) -> Instance: + fall_s = tuple_fallback(s) + fall_t = tuple_fallback(t) + meet_fallback = meet_types(fall_s, fall_t) + if not isinstance(meet_fallback, Instance): + # TODO: emit warning? + return fall_t + return meet_fallback + def visit_tuple_type(self, t: TupleType) -> ProperType: if isinstance(self.s, TupleType): items = self.meet_tuples(self.s, t) if items is None: return self.default(self.s) - # TODO: What if the fallbacks are different? - return TupleType(items, tuple_fallback(t)) + fallback = self.meet_tuple_fallbacks(self.s, t) + return TupleType(items, fallback=fallback) elif isinstance(self.s, Instance): # meet(Tuple[t1, t2, <...>], Tuple[s, ...]) == Tuple[meet(t1, s), meet(t2, s), <...>]. if self.s.type.fullname in TUPLE_LIKE_INSTANCE_NAMES and self.s.args: @@ -1114,8 +1154,14 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType: def visit_literal_type(self, t: LiteralType) -> ProperType: if isinstance(self.s, LiteralType) and self.s == t: return t - elif isinstance(self.s, Instance) and is_subtype(t.fallback, self.s): - return t + elif isinstance(self.s, Instance): + # if is_subtype(t.fallback, self.s): + # return t + if self.s.last_known_value is not None: + # meet(Literal["max"]?, Literal["max"]) -> Literal["max"] + # meet(Literal["sum"]?, Literal["max"]) -> Never + return meet_types(self.s.last_known_value, t) + return self.default(self.s) else: return self.default(self.s) diff --git a/mypy/solve.py b/mypy/solve.py index e3709106996c..b2ea2037652e 100644 --- a/mypy/solve.py +++ b/mypy/solve.py @@ -319,7 +319,8 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None: elif top is None: candidate = bottom elif is_subtype(bottom, top): - candidate = bottom + # Need to meet in case like Literal["x"]? <: T <: Literal["x"] + candidate = meet_types(bottom, top) else: candidate = None return candidate diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 350d57a7e4ad..7d32221e35a0 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -548,6 +548,13 @@ def visit_instance(self, left: Instance) -> bool: assert isinstance(erased, Instance) t = erased nominal = True + if self.proper_subtype and right.last_known_value is not None: + if left.last_known_value is None: + # E.g. str is not a proper subtype of Literal["x"]? + nominal = False + else: + # E.g. Literal[A]? <: Literal[B]? requires A <: B + nominal &= self._is_subtype(left.last_known_value, right.last_known_value) if right.type.has_type_var_tuple_type: # For variadic instances we simply find the correct type argument mappings, # all the heavy lifting is done by the tuple subtyping. @@ -628,8 +635,14 @@ def visit_instance(self, left: Instance) -> bool: return True if isinstance(item, Instance): return is_named_instance(item, "builtins.object") - if isinstance(right, LiteralType) and left.last_known_value is not None: - return self._is_subtype(left.last_known_value, right) + if isinstance(right, LiteralType): + if self.proper_subtype: + # Instance types like Literal["sum"]? is *assignable* to Literal["sum"], + # but is not a proper subtype of it. (Literal["sum"]? is a gradual type, + # that is a proper subtype of str, and assignable to Literal["sum"]. + return False + if left.last_known_value is not None: + return self._is_subtype(left.last_known_value, right) if isinstance(right, FunctionLike): # Special case: Instance can be a subtype of Callable / Overloaded. call = find_member("__call__", left, left, is_operator=True) @@ -964,6 +977,12 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool: def visit_literal_type(self, left: LiteralType) -> bool: if isinstance(self.right, LiteralType): return left == self.right + elif ( + isinstance(self.right, Instance) + and self.right.last_known_value is not None + and self.proper_subtype + ): + return self._is_subtype(left, self.right.last_known_value) else: return self._is_subtype(left.fallback, self.right) @@ -2138,6 +2157,11 @@ def covers_at_runtime(item: Type, supertype: Type) -> bool: item = get_proper_type(item) supertype = get_proper_type(supertype) + # Use last known value for Instance types, if available. + # This ensures that e.g. Literal["max"]? is covered by Literal["max"]. + if isinstance(item, Instance) and item.last_known_value is not None: + item = item.last_known_value + # Since runtime type checks will ignore type arguments, erase the types. if not (isinstance(supertype, FunctionLike) and supertype.is_type_obj()): supertype = erase_type(supertype) diff --git a/mypy/test/testsubtypes.py b/mypy/test/testsubtypes.py index b75c22bca7f7..5be32f628de1 100644 --- a/mypy/test/testsubtypes.py +++ b/mypy/test/testsubtypes.py @@ -1,7 +1,7 @@ from __future__ import annotations from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT -from mypy.subtypes import is_subtype +from mypy.subtypes import is_proper_subtype, is_subtype, restrict_subtype_away from mypy.test.helpers import Suite from mypy.test.typefixture import InterfaceTypeFixture, TypeFixture from mypy.types import Instance, TupleType, Type, UninhabitedType, UnpackType @@ -277,6 +277,74 @@ def test_type_var_tuple_unpacked_variable_length_tuple(self) -> None: def test_fallback_not_subtype_of_tuple(self) -> None: self.assert_not_subtype(self.fx.a, TupleType([self.fx.b], fallback=self.fx.a)) + def test_literal(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" ≲ str -> YES + # str ≲ "x" -> NO + # "x"? ≲ str -> YES + # str ≲ "x"? -> YES + self.assert_subtype(str1, str_type) + self.assert_not_subtype(str_type, str1) + self.assert_subtype(str1_inst, str_type) + self.assert_subtype(str_type, str1_inst) + + # other operand is the same literal + # "x" ≲ "x" -> YES + # "x" ≲ "x"? -> YES + # "x"? ≲ "x" -> YES + # "x"? ≲ "x"? -> YES + self.assert_subtype(str1, str1) + self.assert_subtype(str1, str1_inst) + self.assert_subtype(str1_inst, str1) + self.assert_subtype(str1_inst, str1_inst) + + # other operand is a different literal + # "x" ≲ "y" -> NO + # "x" ≲ "y"? -> YES + # "x"? ≲ "y" -> NO + # "x"? ≲ "y"? -> YES + self.assert_not_subtype(str1, str2) + self.assert_subtype(str1, str2_inst) + self.assert_not_subtype(str1_inst, str2) + self.assert_subtype(str1_inst, str2_inst) + + # check proper subtyping + # other operand is the fallback type + # "x" <: str -> YES + # str <: "x" -> NO + # "x"? <: str -> YES + # str <: "x"? -> NO + self.assert_proper_subtype(str1, str_type) + self.assert_not_proper_subtype(str_type, str1) + self.assert_proper_subtype(str1_inst, str_type) + self.assert_not_proper_subtype(str_type, str1_inst) + + # other operand is the same literal + # "x" <: "x" -> YES + # "x" <: "x"? -> YES + # "x"? <: "x" -> NO + # "x"? <: "x"? -> YES + self.assert_proper_subtype(str1, str1) + self.assert_proper_subtype(str1, str1_inst) + self.assert_not_proper_subtype(str1_inst, str1) + self.assert_proper_subtype(str1_inst, str1_inst) + + # other operand is a different literal + # "x" <: "y" -> NO + # "x" <: "y"? -> NO + # "x"? <: "y" -> NO + # "x"? <: "y"? -> NO + self.assert_not_proper_subtype(str1, str2) + self.assert_not_proper_subtype(str1, str2_inst) + self.assert_not_proper_subtype(str1_inst, str2) + self.assert_not_proper_subtype(str1_inst, str2_inst) + # IDEA: Maybe add these test cases (they are tested pretty well in type # checker tests already): # * more interface subtyping test cases @@ -287,6 +355,12 @@ def test_fallback_not_subtype_of_tuple(self) -> None: # * any type # * generic function types + def assert_proper_subtype(self, s: Type, t: Type) -> None: + assert is_proper_subtype(s, t), f"{s} not proper subtype of {t}" + + def assert_not_proper_subtype(self, s: Type, t: Type) -> None: + assert not is_proper_subtype(s, t), f"{s} not proper subtype of {t}" + def assert_subtype(self, s: Type, t: Type) -> None: assert is_subtype(s, t), f"{s} not subtype of {t}" @@ -304,3 +378,53 @@ def assert_equivalent(self, s: Type, t: Type) -> None: def assert_unrelated(self, s: Type, t: Type) -> None: self.assert_not_subtype(s, t) self.assert_not_subtype(t, s) + + +class RestrictionSuite(Suite): + # Tests for type restrictions "A - B", i.e. ``T <: A and not T <: B``. + + def setUp(self) -> None: + self.fx = TypeFixture() + + def assert_restriction(self, s: Type, t: Type, expected: Type) -> None: + actual = restrict_subtype_away(s, t) + msg = f"restrict_subtype_away({s}, {t}) == {{}} ({{}} expected)" + self.assertEqual(actual, expected, msg=msg.format(actual, expected)) + + def test_literal(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + uninhabited = self.fx.uninhabited + + # other operand is the fallback type + # "x" - str -> Never + # str - "x" -> str + # "x"? - str -> Never + # str - "x"? -> Never + self.assert_restriction(str1, str_type, uninhabited) + self.assert_restriction(str_type, str1, str_type) + self.assert_restriction(str1_inst, str_type, uninhabited) + self.assert_restriction(str_type, str1_inst, uninhabited) + + # other operand is the same literal + # "x" - "x" -> Never + # "x" - "x"? -> Never + # "x"? - "x" -> Never + # "x"? - "x"? -> Never + self.assert_restriction(str1, str1, uninhabited) + self.assert_restriction(str1, str1_inst, uninhabited) + self.assert_restriction(str1_inst, str1, uninhabited) + self.assert_restriction(str1_inst, str1_inst, uninhabited) + + # other operand is a different literal + # "x" - "y" -> "x" + # "x" - "y"? -> Never + # "x"? - "y" -> "x"? + # "x"? - "y"? -> Never + self.assert_restriction(str1, str2, str1) + self.assert_restriction(str1, str2_inst, uninhabited) + self.assert_restriction(str1_inst, str2, str1_inst) + self.assert_restriction(str1_inst, str2_inst, uninhabited) diff --git a/mypy/test/testtypes.py b/mypy/test/testtypes.py index 090796ec9f44..e82432cc8a1b 100644 --- a/mypy/test/testtypes.py +++ b/mypy/test/testtypes.py @@ -589,7 +589,7 @@ def test_simplified_union_with_literals(self) -> None: [fx.lit1_inst, fx.lit3_inst], UnionType([fx.lit1_inst, fx.lit3_inst]) ) self.assert_simplified_union([fx.lit1_inst, fx.uninhabited], fx.lit1_inst) - self.assert_simplified_union([fx.lit1, fx.lit1_inst], fx.lit1) + self.assert_simplified_union([fx.lit1, fx.lit1_inst], fx.lit1_inst) self.assert_simplified_union([fx.lit1, fx.lit2_inst], UnionType([fx.lit1, fx.lit2_inst])) self.assert_simplified_union([fx.lit1, fx.lit3_inst], UnionType([fx.lit1, fx.lit3_inst])) @@ -639,7 +639,46 @@ def test_simplified_union_with_mixed_str_literals(self) -> None: [fx.lit_str1, fx.lit_str2, fx.lit_str3_inst], UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]), ) - self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1) + self.assert_simplified_union( + [fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1_inst + ) + + def test_simplified_union_with_mixed_str_literals2(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" | str -> str + # str | "x" -> str + # "x"? | str -> str + # str | "x"? -> str + self.assert_simplified_union([str1, str_type], str_type) + self.assert_simplified_union([str_type, str1], str_type) + self.assert_simplified_union([str1_inst, str_type], str_type) + self.assert_simplified_union([str_type, str1_inst], str_type) + + # other operand is the same literal + # "x" | "x" -> "x" + # "x" | "x"? -> "x"? + # "x"? | "x" -> "x"? + # "x"? | "x"? -> "x"? + self.assert_simplified_union([str1, str1], str1) + self.assert_simplified_union([str1, str1_inst], str1_inst) + self.assert_simplified_union([str1_inst, str1], str1_inst) + self.assert_simplified_union([str1_inst, str1_inst], str1_inst) + + # other operand is a different literal + # "x" | "y" -> "x" | "y" + # "x" | "y"? -> "x" | "y"? + # "x"? | "y" -> "x"? | "y" + # "x"? | "y"? -> "x"? | "y"? + self.assert_simplified_union([str1, str2], UnionType([str1, str2])) + self.assert_simplified_union([str1, str2_inst], UnionType([str1, str2_inst])) + self.assert_simplified_union([str1_inst, str2], UnionType([str1_inst, str2])) + self.assert_simplified_union([str1_inst, str2_inst], UnionType([str1_inst, str2_inst])) def assert_simplified_union(self, original: list[Type], union: Type) -> None: assert_equal(make_simplified_union(original), union) @@ -999,6 +1038,39 @@ def test_literal_type(self) -> None: UnionType([lit2, lit3]), UnionType([lit1, lit2]), UnionType([lit2, lit3, lit1]) ) + def test_mixed_literal_types(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" , str -> str + # str , "x" -> str + # "x"?, str -> str + # str , "x"? -> str + self.assert_join(str1, str_type, str_type) + self.assert_join(str1_inst, str_type, str_type) + + # other operand is the same literal + # "x" , "x" -> "x" + # "x" , "x"? -> "x"? + # "x"?, "x" -> "x"? + # "x"?, "x"? -> "x"? + self.assert_join(str1, str1, str1) + self.assert_join(str1, str1_inst, str1_inst) + self.assert_join(str1_inst, str1_inst, str1_inst) + + # other operand is a different literal + # "x" , "y" -> str (TODO: consider using "x" | "y" (treat real literals like enum)) + # "x" , "y"? -> str + # "x"?, "y" -> str + # "x"?, "y"? -> str + self.assert_join(str1, str2, str_type) + self.assert_join(str1, str2_inst, str_type) + self.assert_join(str1_inst, str2_inst, str_type) + def test_variadic_tuple_joins(self) -> None: # These tests really test just the "arity", to be sure it is handled correctly. self.assert_join( @@ -1321,6 +1393,39 @@ def test_literal_type(self) -> None: assert is_same_type(lit1, narrow_declared_type(lit1, a)) assert is_same_type(lit2, narrow_declared_type(lit2, a)) + def test_mixed_literal_types(self) -> None: + str1 = self.fx.lit_str1 + str2 = self.fx.lit_str2 + str1_inst = self.fx.lit_str1_inst + str2_inst = self.fx.lit_str2_inst + str_type = self.fx.str_type + + # other operand is the fallback type + # "x" , str -> "x" + # str , "x" -> "x" + # "x"?, str -> "x"? + # str , "x"? -> "x"? + self.assert_meet(str1, str_type, str1) + self.assert_meet(str1_inst, str_type, str1_inst) + + # other operand is the same literal + # "x" , "x" -> "x" + # "x" , "x"? -> "x" + # "x"?, "x" -> "x" + # "x"?, "x"? -> "x"? + self.assert_meet(str1, str1, str1) + self.assert_meet(str1, str1_inst, str1) + self.assert_meet(str1_inst, str1_inst, str1_inst) + + # other operand is a different literal + # "x" , "y" -> Never + # "x" , "y"? -> Never + # "x"?, "y" -> Never + # "x"?, "y"? -> str + self.assert_meet_uninhabited(str1, str2) + self.assert_meet_uninhabited(str1, str2_inst) + self.assert_meet(str1_inst, str2_inst, str_type) + # FIX generic interfaces + ranges def assert_meet_uninhabited(self, s: Type, t: Type) -> None: diff --git a/mypy/typeops.py b/mypy/typeops.py index e3657c2fc4d5..6b31f36bcf8a 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -582,6 +582,8 @@ def make_simplified_union( * [int, Any] -> Union[int, Any] (Any types are not simplified away!) * [Any, Any] -> Any * [int, Union[bytes, str]] -> Union[int, bytes, str] + * [Literal[1]?, Literal[1]] -> Literal[1]? + * Literal["max"]?, Literal["max", "sum"] -> Literal["max"]? | Literal["sum"] Note: This must NOT be used during semantic analysis, since TypeInfos may not be fully initialized. @@ -610,13 +612,22 @@ def make_simplified_union( ): simplified_set = try_contracting_literals_in_union(simplified_set) + # Step 5: Combine Literals and Instances with LKVs, e.g. Literal[1]?, Literal[1] -> Literal[1]? + proper_items: list[ProperType] = [get_proper_type(t) for t in simplified_set] + last_known_values: list[LiteralType | None] = [ + p_t.last_known_value if isinstance(p_t, Instance) else None for p_t in proper_items + ] + simplified_set = [ + item for item, p_t in zip(simplified_set, proper_items) if p_t not in last_known_values + ] + result = get_proper_type(UnionType.make_union(simplified_set, line, column)) nitems = len(items) if nitems > 1 and ( nitems > 2 or not (type(items[0]) is NoneType or type(items[1]) is NoneType) ): - # Step 5: At last, we erase any (inconsistent) extra attributes on instances. + # Step 6: At last, we erase any (inconsistent) extra attributes on instances. # Initialize with None instead of an empty set as a micro-optimization. The set # is needed very rarely, so we try to avoid constructing it. diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 22edd12b0c4c..aa8f71a06ef2 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -4220,7 +4220,7 @@ def check_and(maybe: bool) -> None: bar = None if maybe and (foo := [1])[(bar := 0)]: reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]" - reveal_type(bar) # N: Revealed type is "builtins.int" + reveal_type(bar) # N: Revealed type is "Literal[0]?" else: reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int] | None" reveal_type(bar) # N: Revealed type is "builtins.int | None" @@ -4246,7 +4246,7 @@ def check_or(maybe: bool) -> None: reveal_type(bar) # N: Revealed type is "builtins.int | None" else: reveal_type(foo) # N: Revealed type is "builtins.list[builtins.int]" - reveal_type(bar) # N: Revealed type is "builtins.int" + reveal_type(bar) # N: Revealed type is "Literal[0]?" def check_or_nested(maybe: bool) -> None: foo = None diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 0b779f57b615..4c0ddcf06d47 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -3088,20 +3088,48 @@ z: Type[Literal[1, 2]] # E: Type[...] can't contain "Union[Literal[...], Litera [case testJoinLiteralAndInstance] from typing import Generic, TypeVar, Literal -T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +T = TypeVar("T", covariant=False, contravariant=False) +S = TypeVar("S", covariant=False, contravariant=False) -class A(Generic[T]): ... +class A_inv(Generic[T]): ... +class A_co(Generic[T_co]): ... -def f(a: A[T], t: T) -> T: ... -def g(a: T, t: A[T]) -> T: ... +def check_inv(obj: A_inv[Literal[1]]) -> None: + def f(a: A_inv[S], t: S) -> S: ... + def g(a: S, t: A_inv[S]) -> S: ... -def check(obj: A[Literal[1]]) -> None: reveal_type(f(obj, 1)) # N: Revealed type is "Literal[1]" - reveal_type(f(obj, '')) # E: Cannot infer value of type parameter "T" of "f" \ - # N: Revealed type is "Any" + reveal_type(f(obj, '')) # E: Cannot infer value of type parameter "S" of "f" \ + # N: Revealed type is "Any" reveal_type(g(1, obj)) # N: Revealed type is "Literal[1]" - reveal_type(g('', obj)) # E: Cannot infer value of type parameter "T" of "g" \ - # N: Revealed type is "Any" + reveal_type(g('', obj)) # E: Cannot infer value of type parameter "S" of "g" \ + # N: Revealed type is "Any" + +def check_co(obj: A_co[Literal[1]]) -> None: + def f(a: A_co[S], t: S) -> S: ... + def g(a: S, t: A_co[S]) -> S: ... + + reveal_type(f(obj, 1)) # N: Revealed type is "builtins.int" + reveal_type(f(obj, '')) # N: Revealed type is "builtins.object" + reveal_type(g(1, obj)) # N: Revealed type is "builtins.int" + reveal_type(g('', obj)) # N: Revealed type is "builtins.object" + +[case testJoinLiteralInstanceAndEnum] +from typing import Final, TypeVar +from enum import StrEnum + +T = TypeVar("T") +def join(a: T, b: T) -> T: ... + +class Foo(StrEnum): + A = "a" + +CONST: Final = "const" + +reveal_type(CONST) # N: Revealed type is "Literal['const']?" +reveal_type(join(Foo.A, CONST)) # N: Revealed type is "builtins.str" +reveal_type(join(CONST, Foo.A)) # N: Revealed type is "builtins.str" [builtins fixtures/tuple.pyi] [case testEnumLiteralIsNotIntLiteral] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 117d0e72ed79..e3e3f47c50a8 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -708,20 +708,21 @@ class FlipFlopStr: def test1(switch: FlipFlopStr) -> None: # Naively, we might assume the 'assert' here would narrow the type to - # Literal["state-1"]. However, doing this ends up breaking a fair number of real-world + # a strong literal 'Literal["state-1"]'. + # However, doing this ends up breaking a fair number of real-world # code (usually test cases) that looks similar to this function: e.g. checks # to make sure a field was mutated to some particular value. # # And since mypy can't really reason about state mutation, we take a conservative - # approach and avoid narrowing anything here. + # approach and only narrow to a weak literal 'Literal["state-1"]?' assert switch.state == "state-1" - reveal_type(switch.state) # N: Revealed type is "builtins.str" + reveal_type(switch.state) # N: Revealed type is "Literal['state-1']?" switch.mutate() assert switch.state == "state-2" - reveal_type(switch.state) # N: Revealed type is "builtins.str" + reveal_type(switch.state) # N: Revealed type is "Literal['state-2']?" def test2(switch: FlipFlopEnum) -> None: # This is the same thing as 'test1', except we use enums, which we allow to be narrowed @@ -755,18 +756,18 @@ from typing import Final, Literal A_final: Final = "A" A_literal: Literal["A"] -# Neither the LHS nor the RHS are explicit literals, so regrettably nothing -# is narrowed here -- see 'testNarrowingEqualityFlipFlop' for an example of -# why more precise inference here is problematic. +# narrow only to a weak literal 'Literal["A"]?', as the value may mutate to a different +# string later - see 'testNarrowingEqualityFlipFlop' for an example of +# why we cannot narrow to a strong literal 'Literal["A"]' x_str: str if x_str == "A": - reveal_type(x_str) # N: Revealed type is "builtins.str" + reveal_type(x_str) # N: Revealed type is "Literal['A']?" else: reveal_type(x_str) # N: Revealed type is "builtins.str" reveal_type(x_str) # N: Revealed type is "builtins.str" if x_str == A_final: - reveal_type(x_str) # N: Revealed type is "builtins.str" + reveal_type(x_str) # N: Revealed type is "Literal['A']?" else: reveal_type(x_str) # N: Revealed type is "builtins.str" reveal_type(x_str) # N: Revealed type is "builtins.str" @@ -2077,6 +2078,28 @@ else: reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...] | builtins.tuple[builtins.str, ...]" [builtins fixtures/len.pyi] + +[case testNarrowingNamedTupleLiteralValues] +# flags: --strict-equality --warn-unreachable +from typing import NamedTuple + +class CaptureResult(NamedTuple): + stdout: str + stderr: str + +cr = CaptureResult("out", "err") +reveal_type(cr) # N: Revealed type is "tuple[builtins.str, builtins.str, fallback=__main__.CaptureResult]" +assert cr == ("out", "err") +reveal_type(cr) # N: Revealed type is "tuple[Literal['out']?, Literal['err']?, fallback=__main__.CaptureResult]" +assert cr != ("out", "wrong") +reveal_type(cr) # N: Revealed type is "tuple[Literal['out']?, Literal['err']?, fallback=__main__.CaptureResult]" +cr = cr._replace(stderr="replaced") +reveal_type(cr) # N: Revealed type is "tuple[builtins.str, builtins.str, fallback=__main__.CaptureResult]" +assert cr == ("out", "replaced") +reveal_type(cr) # N: Revealed type is "tuple[Literal['out']?, Literal['replaced']?, fallback=__main__.CaptureResult]" +[builtins fixtures/primitives.pyi] + + [case testNarrowingLenUnionOfNamedTuples] # flags: --strict-equality --warn-unreachable from typing import NamedTuple, Union @@ -2464,7 +2487,7 @@ def foo(location: ParameterLocation): # https://github.com/python/mypy/issues/17864 def f(x: str | int) -> None: if x == "x": - reveal_type(x) # N: Revealed type is "builtins.str" + reveal_type(x) # N: Revealed type is "Literal['x']?" y = x if x in ["x"]: @@ -2853,7 +2876,7 @@ class D(tuple[T], Generic[T]): ... def check_d(arg: D[T]) -> None: if not isinstance(arg, D): return # E: Statement is unreachable - reveal_type(arg) # N: Revealed type is "tuple[T`-1, fallback=__main__.D[Any]]" + reveal_type(arg) # N: Revealed type is "tuple[T`-1, fallback=__main__.D[T`-1]]" [builtins fixtures/tuple.pyi] @@ -3572,7 +3595,7 @@ def foo(x: str | array) -> str: from __future__ import annotations def f(x: bytes | None): if x == b"asdf": - reveal_type(x) # N: Revealed type is "builtins.bytes" + reveal_type(x) # N: Revealed type is "Literal[b'asdf']?" else: reveal_type(x) # N: Revealed type is "builtins.bytes | None" [builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-optional.test b/test-data/unit/check-optional.test index 6db60275944f..6a61ce2e12bb 100644 --- a/test-data/unit/check-optional.test +++ b/test-data/unit/check-optional.test @@ -448,7 +448,7 @@ from typing import Optional def main(x: Optional[str]): if x == '': - reveal_type(x) # N: Revealed type is "builtins.str" + reveal_type(x) # N: Revealed type is "Literal['']?" else: reveal_type(x) # N: Revealed type is "builtins.str | None" if x is '': @@ -461,7 +461,7 @@ def main(x: Optional[str]): from typing import Union def main(x: Union[str, int, None]): if x == '': - reveal_type(x) # N: Revealed type is "builtins.str" + reveal_type(x) # N: Revealed type is "Literal['']?" else: reveal_type(x) # N: Revealed type is "builtins.str | builtins.int | None" if x is '': @@ -521,8 +521,8 @@ from typing import Optional def main(x: Optional[int], y: Optional[int]): if x == y == 1: - reveal_type(x) # N: Revealed type is "builtins.int" - reveal_type(y) # N: Revealed type is "builtins.int" + reveal_type(x) # N: Revealed type is "Literal[1]?" + reveal_type(y) # N: Revealed type is "Literal[1]?" else: reveal_type(x) # N: Revealed type is "builtins.int | None" reveal_type(y) # N: Revealed type is "builtins.int | None" diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index 1bd106e67d9b..ec8b19a9296f 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -1530,7 +1530,7 @@ m: str match m: case a if a := "test": - reveal_type(a) # N: Revealed type is "builtins.str" + reveal_type(a) # N: Revealed type is "Literal['test']?" [case testMatchNarrowingPatternGuard] m: object @@ -2864,7 +2864,7 @@ def int_literal() -> None: case other: other # E: Statement is unreachable -def str_literal() -> None: +def str_literal_from_literal() -> None: match 'foo': case 'a' as s: reveal_type(s) # E: Statement is unreachable @@ -2873,6 +2873,16 @@ def str_literal() -> None: case other: other # E: Statement is unreachable + +def str_literal_from_str(arg: str) -> None: + match arg: + case 'a' as s: + reveal_type(s) # N: Revealed type is "Literal['a']" + case str(i): + reveal_type(i) # N: Revealed type is "builtins.str" + case other: + other # E: Statement is unreachable + [case testMatchOperations] # flags: --warn-unreachable diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index 595ff95f44dc..9ad4c55f04ab 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -214,10 +214,10 @@ i(arg=0) # E: Unexpected keyword argument "arg" from typing import Final, NamedTuple, Optional, List if a := 2: - reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(a) # N: Revealed type is "Literal[2]?" while b := "x": - reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(b) # N: Revealed type is "Literal['x']?" l = [y2 := 1, y2 + 2, y2 + 3] reveal_type(y2) # N: Revealed type is "builtins.int" @@ -242,10 +242,10 @@ reveal_type(new_v) # N: Revealed type is "builtins.int" def f(x: int = (c := 4)) -> int: if a := 2: - reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(a) # N: Revealed type is "Literal[2]?" while b := "x": - reveal_type(b) # N: Revealed type is "builtins.str" + reveal_type(b) # N: Revealed type is "Literal['x']?" x = (y := 1) + (z := 2) reveal_type(x) # N: Revealed type is "builtins.int" @@ -284,7 +284,7 @@ def f(x: int = (c := 4)) -> int: f(x=(y7 := 3)) reveal_type(y7) # N: Revealed type is "builtins.int" - reveal_type((lambda: (y8 := 3) and y8)()) # N: Revealed type is "builtins.int" + reveal_type((lambda: (y8 := 3) and y8)()) # N: Revealed type is "Literal[3]?" y8 # E: Name "y8" is not defined y7 = 1.0 # E: Incompatible types in assignment (expression has type "float", variable has type "int") @@ -325,16 +325,16 @@ def check_binder(x: Optional[int], y: Optional[int], z: Optional[int], a: Option reveal_type(y) # N: Revealed type is "builtins.int | None" if x and (y := 1): - reveal_type(y) # N: Revealed type is "builtins.int" + reveal_type(y) # N: Revealed type is "Literal[1]?" if (a := 1) and x: - reveal_type(a) # N: Revealed type is "builtins.int" + reveal_type(a) # N: Revealed type is "Literal[1]?" if (b := 1) or x: reveal_type(b) # N: Revealed type is "builtins.int" if z := 1: - reveal_type(z) # N: Revealed type is "builtins.int" + reveal_type(z) # N: Revealed type is "Literal[1]?" def check_partial() -> None: x = None @@ -351,7 +351,7 @@ def check_narrow(x: Optional[int], s: List[int]) -> None: reveal_type(y) # N: Revealed type is "builtins.int" if (y := x) == 10: - reveal_type(y) # N: Revealed type is "builtins.int" + reveal_type(y) # N: Revealed type is "Literal[10]?" if (y := x) in s: reveal_type(y) # N: Revealed type is "builtins.int" diff --git a/test-data/unit/check-warnings.test b/test-data/unit/check-warnings.test index a2d201fa301d..25aec01566b5 100644 --- a/test-data/unit/check-warnings.test +++ b/test-data/unit/check-warnings.test @@ -49,13 +49,27 @@ from typing import cast a = 1 b = cast(object, 1) -[case testCastFromLiteralRedundant] -# flags: --warn-redundant-casts -from typing import cast +[case testLiteralRedundantCast] +# https://github.com/python/mypy/issues/19055 +# flags: --warn-redundant-cast -cast(int, 1) -[out] -main:4: error: Redundant cast to "int" +from typing import Literal, cast + +# "a" is an AnyOf[str, Literal["a"]], denoted Literal['a']? +# See: https://github.com/python/typing/issues/566 + +# This cast is not redundant because Literal["a"]? is not identical to Literal["a"] +LiteralOnly = Literal["a"] +reveal_type("a") # N: Revealed type is "Literal['a']?" +cast(LiteralOnly, "a") + +# This cast is redundant because the type is already Literal["a"] +already_literal: Literal["a"] = "a" +reveal_type(already_literal) # N: Revealed type is "Literal['a']" +cast(LiteralOnly, already_literal) # E: Redundant cast to "Literal['a']" + +LiteralUnion = Literal["a", "b"] +cast(LiteralUnion, "a") [case testCastFromUnionOfAnyOk] # flags: --warn-redundant-casts