diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8288b676b52e..50d123a4703a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -5436,6 +5436,15 @@ def visit_dict_expr(self, e: DictExpr) -> Type: expected_types.append( self.chk.named_generic_type("_typeshed.SupportsKeysAndGetItem", [kt, vt]) ) + # If this DictExpr came from a dict() call translation, validate that + # any unpacked dict has string keys (keywords must be strings) + if e.from_dict_call: + value_type = self.accept(value) + if not self.is_valid_keyword_var_arg(value_type): + is_mapping = is_subtype( + value_type, self.chk.named_type("_typeshed.SupportsKeysAndGetItem") + ) + self.msg.invalid_keyword_var_arg(value_type, is_mapping, value) else: tup = TupleExpr([key, value]) if key.line >= 0: diff --git a/mypy/nodes.py b/mypy/nodes.py index 4168b2e00f15..58ec788a2142 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2672,15 +2672,17 @@ def accept(self, visitor: ExpressionVisitor[T]) -> T: class DictExpr(Expression): """Dictionary literal expression {key: value, ...}.""" - __slots__ = ("items",) + __slots__ = ("items", "from_dict_call") __match_args__ = ("items",) items: list[tuple[Expression | None, Expression]] + from_dict_call: bool # True if this came from a dict(...) call translation def __init__(self, items: list[tuple[Expression | None, Expression]]) -> None: super().__init__() self.items = items + self.from_dict_call = False def accept(self, visitor: ExpressionVisitor[T]) -> T: return visitor.visit_dict_expr(self) diff --git a/mypy/semanal.py b/mypy/semanal.py index f38a71cb16e3..97a89f7d5109 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -6058,6 +6058,7 @@ def translate_dict_call(self, call: CallExpr) -> DictExpr | None: ] ) expr.set_line(call) + expr.from_dict_call = True expr.accept(self) return expr diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 1acda7079cc8..31b778a72f12 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2582,3 +2582,11 @@ def last_known_value() -> None: x, y, z = xy # E: Unpacking a string is disallowed reveal_type(z) # N: Revealed type is "builtins.str" [builtins fixtures/primitives.pyi] + + +[case testDictUnpackNonStringKey] +def f() -> None: + dict(**{10: 20}) # E: Argument after ** must have string keys + dict(**{**{1: 1}}) # E: Argument after ** must have string keys + dict(**{b'a': 1}) # E: Argument after ** must have string keys +[builtins fixtures/dict.pyi]