Skip to content
Draft
9 changes: 8 additions & 1 deletion mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3793,11 +3793,18 @@ def dangerous_comparison(
) and not identity_check:
return False

if prefer_literal:
if prefer_literal and not (
isinstance(left, Instance)
and not left.type.is_enum
and isinstance(right, Instance)
and not right.type.is_enum
):
# Also flag non-overlapping literals in situations like:
# x: Literal['a', 'b']
# if x == 'c':
# ...
# but only in situations where at least one side is an actual literal/enum type.
# see testNarrowingEqualityFlipFlop for an example where we want to skip this.
left = try_getting_literal(left)
right = try_getting_literal(right)

Expand Down
26 changes: 24 additions & 2 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,19 @@ def join_instances(self, t: Instance, s: Instance) -> ProperType:
new_type = join_types(ta, sa, self)
assert new_type is not None
args.append(new_type)
result: ProperType = Instance(t.type, args)

lkv = join_last_known_values(t, s)
result: ProperType = Instance(t.type, args, last_known_value=lkv)
elif t.type.bases and is_proper_subtype(
t, s, subtype_context=SubtypeContext(ignore_type_params=True)
):
result = self.join_instances_via_supertype(t, s)
elif s.type.bases and is_proper_subtype(
s, t, subtype_context=SubtypeContext(ignore_type_params=True)
):
result = self.join_instances_via_supertype(s, t)
elif is_subtype(t, s, subtype_context=SubtypeContext(ignore_type_params=True)):
result = self.join_instances_via_supertype(t, s)
else:
# Now t is not a subtype of s, and t != s. Now s could be a subtype
# of t; alternatively, we need to find a common supertype. This works
Expand Down Expand Up @@ -636,13 +644,17 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
def visit_literal_type(self, t: LiteralType) -> ProperType:
if isinstance(self.s, LiteralType):
if t == self.s:
# E.g. Literal["x"], Literal["x"] -> Literal["x"]
return t
if self.s.fallback.type.is_enum and t.fallback.type.is_enum:
return mypy.typeops.make_simplified_union([self.s, t])
# E.g. Literal["x"], Literal["y"] -> str
return join_types(self.s.fallback, t.fallback)
elif isinstance(self.s, Instance) and self.s.last_known_value == t:
return t
# E.g. Literal["x"], Literal["x"]? -> Literal["x"]?
return self.s
else:
# E.g. Literal["x"], Literal["y"]? -> str
return join_types(self.s, t.fallback)

def visit_partial_type(self, t: PartialType) -> ProperType:
Expand Down Expand Up @@ -782,6 +794,16 @@ def join_similar_callables(t: CallableType, s: CallableType) -> CallableType:
)


def join_last_known_values(t: Instance, s: Instance) -> LiteralType | None:
r"""Return the join of two last known values."""
# drop last known value if they differ (e.g. join('x'?, 'y'?) -> str)
left = t.last_known_value
right = s.last_known_value
if left is None or right is None or left != right:
return None
return left


def safe_join(t: Type, s: Type) -> Type:
# This is a temporary solution to prevent crashes in combine_similar_callables() etc.,
# until relevant TODOs on handling arg_kinds will be addressed there.
Expand Down
56 changes: 51 additions & 5 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,11 @@ def meet_types(s: Type, t: Type) -> ProperType:
t = get_proper_type(t)

if isinstance(s, Instance) and isinstance(t, Instance) and s.type == t.type:
# special casing for dealing with last known values
lkv = meet_last_known_values(t, s)
t = t.copy_modified(last_known_value=lkv)
s = s.copy_modified(last_known_value=lkv)

# Code in checker.py should merge any extra_items where possible, so we
# should have only compatible extra_items here. We check this before
# the below subtype check, so that extra_attrs will not get erased.
Expand Down Expand Up @@ -113,6 +118,31 @@ def meet_types(s: Type, t: Type) -> ProperType:
return t.accept(TypeMeetVisitor(s))


def meet_last_known_values(t: Instance, s: Instance) -> LiteralType | None:
"""Return the meet of two last_known_values."""
left = t.last_known_value
right = s.last_known_value

if left is None:
return right
if right is None:
return left

lkv_meet = meet_types(left, right)

if isinstance(lkv_meet, UninhabitedType):
return None
if isinstance(lkv_meet, LiteralType):
return lkv_meet

msg = (
f"Unexpected result: "
f"meet of last_known_values {left=!s} and {right=!s} "
f"resulted in {lkv_meet!s}"
)
raise ValueError(msg)


def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
"""Return the declared type narrowed down to another type."""
# TODO: check infinite recursion for aliases here.
Expand Down Expand Up @@ -900,7 +930,8 @@ def visit_instance(self, t: Instance) -> ProperType:
assert isinstance(meet, UninhabitedType)
meet = UnpackType(tv.tuple_fallback.copy_modified(args=[meet]))
args.append(meet)
return Instance(t.type, args)
lkv = meet_last_known_values(t, self.s)
return Instance(t.type, args, last_known_value=lkv)
else:
if state.strict_optional:
return UninhabitedType()
Expand Down Expand Up @@ -1066,13 +1097,22 @@ def meet_tuples(self, s: TupleType, t: TupleType) -> list[Type] | None:
items.append(self.meet(fi, vi))
return items

def meet_tuple_fallbacks(self, s: TupleType, t: TupleType) -> Instance:
fall_s = tuple_fallback(s)
fall_t = tuple_fallback(t)
meet_fallback = meet_types(fall_s, fall_t)
if not isinstance(meet_fallback, Instance):
# TODO: emit warning?
return fall_t
return meet_fallback

def visit_tuple_type(self, t: TupleType) -> ProperType:
if isinstance(self.s, TupleType):
items = self.meet_tuples(self.s, t)
if items is None:
return self.default(self.s)
# TODO: What if the fallbacks are different?
return TupleType(items, tuple_fallback(t))
fallback = self.meet_tuple_fallbacks(self.s, t)
return TupleType(items, fallback=fallback)
elif isinstance(self.s, Instance):
# meet(Tuple[t1, t2, <...>], Tuple[s, ...]) == Tuple[meet(t1, s), meet(t2, s), <...>].
if self.s.type.fullname in TUPLE_LIKE_INSTANCE_NAMES and self.s.args:
Expand Down Expand Up @@ -1114,8 +1154,14 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
def visit_literal_type(self, t: LiteralType) -> ProperType:
if isinstance(self.s, LiteralType) and self.s == t:
return t
elif isinstance(self.s, Instance) and is_subtype(t.fallback, self.s):
return t
elif isinstance(self.s, Instance):
# if is_subtype(t.fallback, self.s):
# return t
if self.s.last_known_value is not None:
# meet(Literal["max"]?, Literal["max"]) -> Literal["max"]
# meet(Literal["sum"]?, Literal["max"]) -> Never
return meet_types(self.s.last_known_value, t)
return self.default(self.s)
else:
return self.default(self.s)

Expand Down
3 changes: 2 additions & 1 deletion mypy/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def solve_one(lowers: Iterable[Type], uppers: Iterable[Type]) -> Type | None:
elif top is None:
candidate = bottom
elif is_subtype(bottom, top):
candidate = bottom
# Need to meet in case like Literal["x"]? <: T <: Literal["x"]
candidate = meet_types(bottom, top)
else:
candidate = None
return candidate
Expand Down
28 changes: 26 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,13 @@ def visit_instance(self, left: Instance) -> bool:
assert isinstance(erased, Instance)
t = erased
nominal = True
if self.proper_subtype and right.last_known_value is not None:
if left.last_known_value is None:
# E.g. str is not a proper subtype of Literal["x"]?
nominal = False
else:
# E.g. Literal[A]? <: Literal[B]? requires A <: B
nominal &= self._is_subtype(left.last_known_value, right.last_known_value)
if right.type.has_type_var_tuple_type:
# For variadic instances we simply find the correct type argument mappings,
# all the heavy lifting is done by the tuple subtyping.
Expand Down Expand Up @@ -628,8 +635,14 @@ def visit_instance(self, left: Instance) -> bool:
return True
if isinstance(item, Instance):
return is_named_instance(item, "builtins.object")
if isinstance(right, LiteralType) and left.last_known_value is not None:
return self._is_subtype(left.last_known_value, right)
if isinstance(right, LiteralType):
if self.proper_subtype:
# Instance types like Literal["sum"]? is *assignable* to Literal["sum"],
# but is not a proper subtype of it. (Literal["sum"]? is a gradual type,
# that is a proper subtype of str, and assignable to Literal["sum"].
return False
if left.last_known_value is not None:
return self._is_subtype(left.last_known_value, right)
if isinstance(right, FunctionLike):
# Special case: Instance can be a subtype of Callable / Overloaded.
call = find_member("__call__", left, left, is_operator=True)
Expand Down Expand Up @@ -964,6 +977,12 @@ def visit_typeddict_type(self, left: TypedDictType) -> bool:
def visit_literal_type(self, left: LiteralType) -> bool:
if isinstance(self.right, LiteralType):
return left == self.right
elif (
isinstance(self.right, Instance)
and self.right.last_known_value is not None
and self.proper_subtype
):
return self._is_subtype(left, self.right.last_known_value)
else:
return self._is_subtype(left.fallback, self.right)

Expand Down Expand Up @@ -2138,6 +2157,11 @@ def covers_at_runtime(item: Type, supertype: Type) -> bool:
item = get_proper_type(item)
supertype = get_proper_type(supertype)

# Use last known value for Instance types, if available.
# This ensures that e.g. Literal["max"]? is covered by Literal["max"].
if isinstance(item, Instance) and item.last_known_value is not None:
item = item.last_known_value

# Since runtime type checks will ignore type arguments, erase the types.
if not (isinstance(supertype, FunctionLike) and supertype.is_type_obj()):
supertype = erase_type(supertype)
Expand Down
126 changes: 125 additions & 1 deletion mypy/test/testsubtypes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
from mypy.subtypes import is_subtype
from mypy.subtypes import is_proper_subtype, is_subtype, restrict_subtype_away
from mypy.test.helpers import Suite
from mypy.test.typefixture import InterfaceTypeFixture, TypeFixture
from mypy.types import Instance, TupleType, Type, UninhabitedType, UnpackType
Expand Down Expand Up @@ -277,6 +277,74 @@ def test_type_var_tuple_unpacked_variable_length_tuple(self) -> None:
def test_fallback_not_subtype_of_tuple(self) -> None:
self.assert_not_subtype(self.fx.a, TupleType([self.fx.b], fallback=self.fx.a))

def test_literal(self) -> None:
str1 = self.fx.lit_str1
str2 = self.fx.lit_str2
str1_inst = self.fx.lit_str1_inst
str2_inst = self.fx.lit_str2_inst
str_type = self.fx.str_type

# other operand is the fallback type
# "x" ≲ str -> YES
# str ≲ "x" -> NO
# "x"? ≲ str -> YES
# str ≲ "x"? -> YES
self.assert_subtype(str1, str_type)
self.assert_not_subtype(str_type, str1)
self.assert_subtype(str1_inst, str_type)
self.assert_subtype(str_type, str1_inst)

# other operand is the same literal
# "x" ≲ "x" -> YES
# "x" ≲ "x"? -> YES
# "x"? ≲ "x" -> YES
# "x"? ≲ "x"? -> YES
self.assert_subtype(str1, str1)
self.assert_subtype(str1, str1_inst)
self.assert_subtype(str1_inst, str1)
self.assert_subtype(str1_inst, str1_inst)

# other operand is a different literal
# "x" ≲ "y" -> NO
# "x" ≲ "y"? -> YES
# "x"? ≲ "y" -> NO
# "x"? ≲ "y"? -> YES
self.assert_not_subtype(str1, str2)
self.assert_subtype(str1, str2_inst)
self.assert_not_subtype(str1_inst, str2)
self.assert_subtype(str1_inst, str2_inst)

# check proper subtyping
# other operand is the fallback type
# "x" <: str -> YES
# str <: "x" -> NO
# "x"? <: str -> YES
# str <: "x"? -> NO
self.assert_proper_subtype(str1, str_type)
self.assert_not_proper_subtype(str_type, str1)
self.assert_proper_subtype(str1_inst, str_type)
self.assert_not_proper_subtype(str_type, str1_inst)

# other operand is the same literal
# "x" <: "x" -> YES
# "x" <: "x"? -> YES
# "x"? <: "x" -> NO
# "x"? <: "x"? -> YES
self.assert_proper_subtype(str1, str1)
self.assert_proper_subtype(str1, str1_inst)
self.assert_not_proper_subtype(str1_inst, str1)
self.assert_proper_subtype(str1_inst, str1_inst)

# other operand is a different literal
# "x" <: "y" -> NO
# "x" <: "y"? -> NO
# "x"? <: "y" -> NO
# "x"? <: "y"? -> NO
self.assert_not_proper_subtype(str1, str2)
self.assert_not_proper_subtype(str1, str2_inst)
self.assert_not_proper_subtype(str1_inst, str2)
self.assert_not_proper_subtype(str1_inst, str2_inst)

# IDEA: Maybe add these test cases (they are tested pretty well in type
# checker tests already):
# * more interface subtyping test cases
Expand All @@ -287,6 +355,12 @@ def test_fallback_not_subtype_of_tuple(self) -> None:
# * any type
# * generic function types

def assert_proper_subtype(self, s: Type, t: Type) -> None:
assert is_proper_subtype(s, t), f"{s} not proper subtype of {t}"

def assert_not_proper_subtype(self, s: Type, t: Type) -> None:
assert not is_proper_subtype(s, t), f"{s} not proper subtype of {t}"

def assert_subtype(self, s: Type, t: Type) -> None:
assert is_subtype(s, t), f"{s} not subtype of {t}"

Expand All @@ -304,3 +378,53 @@ def assert_equivalent(self, s: Type, t: Type) -> None:
def assert_unrelated(self, s: Type, t: Type) -> None:
self.assert_not_subtype(s, t)
self.assert_not_subtype(t, s)


class RestrictionSuite(Suite):
# Tests for type restrictions "A - B", i.e. ``T <: A and not T <: B``.

def setUp(self) -> None:
self.fx = TypeFixture()

def assert_restriction(self, s: Type, t: Type, expected: Type) -> None:
actual = restrict_subtype_away(s, t)
msg = f"restrict_subtype_away({s}, {t}) == {{}} ({{}} expected)"
self.assertEqual(actual, expected, msg=msg.format(actual, expected))

def test_literal(self) -> None:
str1 = self.fx.lit_str1
str2 = self.fx.lit_str2
str1_inst = self.fx.lit_str1_inst
str2_inst = self.fx.lit_str2_inst
str_type = self.fx.str_type
uninhabited = self.fx.uninhabited

# other operand is the fallback type
# "x" - str -> Never
# str - "x" -> str
# "x"? - str -> Never
# str - "x"? -> Never
self.assert_restriction(str1, str_type, uninhabited)
self.assert_restriction(str_type, str1, str_type)
self.assert_restriction(str1_inst, str_type, uninhabited)
self.assert_restriction(str_type, str1_inst, uninhabited)

# other operand is the same literal
# "x" - "x" -> Never
# "x" - "x"? -> Never
# "x"? - "x" -> Never
# "x"? - "x"? -> Never
self.assert_restriction(str1, str1, uninhabited)
self.assert_restriction(str1, str1_inst, uninhabited)
self.assert_restriction(str1_inst, str1, uninhabited)
self.assert_restriction(str1_inst, str1_inst, uninhabited)

# other operand is a different literal
# "x" - "y" -> "x"
# "x" - "y"? -> Never
# "x"? - "y" -> "x"?
# "x"? - "y"? -> Never
self.assert_restriction(str1, str2, str1)
self.assert_restriction(str1, str2_inst, uninhabited)
self.assert_restriction(str1_inst, str2, str1_inst)
self.assert_restriction(str1_inst, str2_inst, uninhabited)
Loading