Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions sqlmodel/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Annotated,
Any,
ForwardRef,
Literal,
Optional,
TypeVar,
Union,
Expand Down Expand Up @@ -64,6 +65,36 @@ def _is_union_type(t: Any) -> bool:
return t is UnionType or t is Union


def get_literal_annotation_info(
annotation: Any,
) -> Optional[tuple[type[Any], tuple[Any, ...]]]:
if annotation is None or get_origin(annotation) is None:
return None
origin = get_origin(annotation)
if origin is Annotated:
return get_literal_annotation_info(get_args(annotation)[0])
if _is_union_type(origin):
bases = get_args(annotation)
if len(bases) > 2:
raise ValueError("Cannot have a Union with more than 2 members")
if bases[0] is not NoneType and bases[1] is not NoneType:
raise ValueError("Cannot have a Union without None")
use_type = bases[0] if bases[0] is not NoneType else bases[1]
return get_literal_annotation_info(use_type)
if origin is Literal:
literal_args = get_args(annotation)
if not literal_args:
return None
if all(isinstance(arg, bool) for arg in literal_args): # all bools
base_type: type[Any] = bool
elif all(isinstance(arg, int) for arg in literal_args): # all ints
base_type = int
else:
base_type = str
return base_type, tuple(literal_args)
return None


finish_init: ContextVar[bool] = ContextVar("finish_init", default=True)


Expand Down Expand Up @@ -189,6 +220,12 @@ def get_sa_type_from_type_annotation(annotation: Any) -> Any:
# Optional unions are allowed
use_type = bases[0] if bases[0] is not NoneType else bases[1]
return get_sa_type_from_type_annotation(use_type)
if origin is Literal:
literal_info = get_literal_annotation_info(annotation)
if literal_info is None:
raise ValueError("Literal without values is not supported")
base_type, _ = literal_info
return base_type
return origin


Expand Down
27 changes: 27 additions & 0 deletions sqlmodel/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pydantic.fields import FieldInfo as PydanticFieldInfo
from sqlalchemy import (
Boolean,
CheckConstraint,
Column,
Date,
DateTime,
Expand Down Expand Up @@ -63,6 +64,7 @@
finish_init,
get_annotations,
get_field_metadata,
get_literal_annotation_info,
get_model_fields,
get_relationship_to,
get_sa_type_from_field,
Expand Down Expand Up @@ -678,6 +680,31 @@ def __init__(
# Ref: https://github.com/sqlalchemy/sqlalchemy/commit/428ea01f00a9cc7f85e435018565eb6da7af1b77
# Tag: 1.4.36
DeclarativeMeta.__init__(cls, classname, bases, dict_, **kw)
table = getattr(cls, "__table__", None)
if table is not None:
# Attach Literal-based value constraints at the database level
for field_name, field in get_model_fields(cls).items():
annotation = getattr(field, "annotation", None)
literal_info = get_literal_annotation_info(annotation)
if literal_info is None:
continue
base_type, values = literal_info
assert base_type in (str, int, bool)
column = table.c.get(field_name)
if column is None:
continue
if base_type is int:
coerced_values = tuple(int(v) for v in values)
elif base_type is bool:
coerced_values = tuple(bool(v) for v in values)
else:
coerced_values = tuple(str(v) for v in values)
constraint_name = f"ck_{table.name}_{field_name}_literal"
constraint = CheckConstraint(
column.in_(coerced_values),
name=constraint_name,
)
table.append_constraint(constraint)
else:
ModelMetaclass.__init__(cls, classname, bases, dict_, **kw)

Expand Down
147 changes: 146 additions & 1 deletion tests/test_main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Annotated, Optional
from typing import Annotated, Literal, Optional, Union

import pytest
from sqlalchemy import text
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import RelationshipProperty
from sqlmodel import Field, Relationship, Session, SQLModel, create_engine, select
Expand Down Expand Up @@ -216,3 +217,147 @@ class Hero(SQLModel, table=True):
assert len(foreign_keys) == 1
assert foreign_keys[0].ondelete == "CASCADE"
assert team_id_column.nullable is False


def test_literal_valid_values(clear_sqlmodel, caplog):
"""Test https://github.com/fastapi/sqlmodel/issues/57"""

class Model(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
all_str: Literal["a", "b", "c"]
mixed: Literal["yes", "no", 1, 0]
all_int: Literal[1, 2, 3]
int_bool: Literal[0, 1, True, False]
all_bool: Literal[True, False]

obj = Model(
all_str="a",
mixed="yes",
all_int=1,
int_bool=True,
all_bool=False,
)

engine = create_engine("sqlite://", echo=True)

SQLModel.metadata.create_all(engine)

# Check DDL
assert "all_str VARCHAR NOT NULL" in caplog.text
assert "mixed VARCHAR NOT NULL" in caplog.text
assert "all_int INTEGER NOT NULL" in caplog.text
assert "int_bool INTEGER NOT NULL" in caplog.text
assert "all_bool BOOLEAN NOT NULL" in caplog.text

# Check query
with Session(engine) as session:
session.add(obj)
session.commit()
session.refresh(obj)
assert isinstance(obj.all_str, str)
assert obj.all_str == "a"
assert isinstance(obj.mixed, str)
assert obj.mixed == "yes"
assert isinstance(obj.all_int, int)
assert obj.all_int == 1
assert isinstance(obj.int_bool, int)
assert obj.int_bool == 1
assert isinstance(obj.all_bool, bool)
assert obj.all_bool is False


def test_literal_constraints_invalid_values(clear_sqlmodel):
"""DB should reject values that are not part of the Literal choices."""

class Model(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
all_str: Literal["a", "b", "c"]
mixed: Literal["yes", "no", 1, 0]
all_int: Literal[1, 2, 3]
int_bool: Literal[0, 1, True, False]
all_bool: Literal[True, False]

engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)

# Helper to attempt a raw insert that bypasses Pydantic validation so we
# can verify that the database-level CHECK constraints are enforced.
def insert_raw(values: dict[str, object]) -> None:
stmt = text(
"INSERT INTO model (all_str, mixed, all_int, int_bool, all_bool) "
"VALUES (:all_str, :mixed, :all_int, :int_bool, :all_bool)"
).bindparams(**values)
with pytest.raises(IntegrityError):
with Session(engine) as session:
session.exec(stmt)
session.commit()

# Invalid string literal for all_str
insert_raw(
{
"all_str": "z", # invalid, not in {"a","b","c"}
"mixed": "yes",
"all_int": 1,
"int_bool": 1,
"all_bool": 0,
}
)

# Invalid int literal for all_int
insert_raw(
{
"all_str": "a",
"mixed": "yes",
"all_int": 5, # invalid, not in {1,2,3}
"int_bool": 1,
"all_bool": 0,
}
)

# Invalid bool literal for all_bool
insert_raw(
{
"all_str": "a",
"mixed": "yes",
"all_int": 1,
"int_bool": 1,
"all_bool": 2, # invalid boolean value
}
)


def test_literal_optional_and_union_constraints(clear_sqlmodel):
"""Literals inside Optional/Union should also be enforced at the DB level."""

class Model(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
opt_str: Optional[Literal["x", "y"]] = None
union_int: Union[Literal[10, 20], None] = None

engine = create_engine("sqlite://")
SQLModel.metadata.create_all(engine)

# Valid values should be accepted
obj = Model(opt_str="x", union_int=10)
with Session(engine) as session:
session.add(obj)
session.commit()
session.refresh(obj)
assert obj.opt_str == "x"
assert obj.union_int == 10

# Invalid values should be rejected by the database
def insert_raw(values: dict[str, object]) -> None:
stmt = text(
"INSERT INTO model (opt_str, union_int) VALUES (:opt_str, :union_int)"
).bindparams(**values)
with pytest.raises(IntegrityError):
with Session(engine) as session:
session.exec(stmt)
session.commit()

# opt_str not in {"x", "y"}
insert_raw({"opt_str": "z", "union_int": 10})

# union_int not in {10, 20}
insert_raw({"opt_str": "x", "union_int": 30})
Loading