Skip to content
1 change: 0 additions & 1 deletion sqlalchemy_bind_manager/_bind_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def __init_bind(self, name: str, config: SQLAlchemyConfig):

engine_options: dict = config.engine_options or {}
engine_options.setdefault("echo", False)
engine_options.setdefault("future", True)

session_options: dict = config.session_options or {}
session_options.setdefault("expire_on_commit", False)
Expand Down
9 changes: 0 additions & 9 deletions sqlalchemy_bind_manager/_repository/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,6 @@
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
#
from typing import (
Any,
Iterable,
Expand Down
2 changes: 1 addition & 1 deletion sqlalchemy_bind_manager/_repository/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ async def cursor_paginated_find(
).scalar() or 0
result_items = [
x for x in (await session.execute(paginated_stmt)).scalars()
] or []
]

return CursorPaginatedResultPresenter.build_result(
result_items=result_items,
Expand Down
20 changes: 8 additions & 12 deletions sqlalchemy_bind_manager/_repository/base_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
# DEALINGS IN THE SOFTWARE.

from abc import ABC
from functools import partial
from typing import (
Any,
Callable,
Expand All @@ -33,7 +32,7 @@
Union,
)

from sqlalchemy import asc, desc, func, inspect, select
from sqlalchemy import asc, desc, func, select
from sqlalchemy.orm import Mapper, aliased, class_mapper, lazyload
from sqlalchemy.orm.exc import UnmappedClassError
from sqlalchemy.sql import Select
Expand All @@ -43,6 +42,7 @@
from .common import (
MODEL,
CursorReference,
get_model_pk_name,
)


Expand Down Expand Up @@ -131,9 +131,9 @@ def _filter_order_by(
:param order_by: a list of columns, or tuples (column, direction)
:return: The filtered query
"""
_partial_registry: Dict[Literal["asc", "desc"], Callable] = {
"desc": partial(desc),
"asc": partial(asc),
_order_funcs: Dict[Literal["asc", "desc"], Callable] = {
"desc": desc,
"asc": asc,
}

for value in order_by:
Expand All @@ -143,7 +143,7 @@ def _filter_order_by(
else:
self._validate_mapped_property(value[0])
stmt = stmt.order_by(
_partial_registry[value[1]](getattr(self._model, value[0]))
_order_funcs[value[1]](getattr(self._model, value[0]))
)

return stmt
Expand Down Expand Up @@ -344,14 +344,10 @@ def _model_pk(self) -> str:

:return:
"""
primary_keys = inspect(self._model).primary_key # type: ignore
if len(primary_keys) > 1:
raise NotImplementedError("Composite primary keys are not supported.")

return primary_keys[0].name
return get_model_pk_name(self._model)

def _fail_if_invalid_models(self, objects: Iterable[MODEL]) -> None:
if [x for x in objects if not isinstance(x, self._model)]:
if any(not isinstance(x, self._model) for x in objects):
raise InvalidModelError(
"Cannot handle models not belonging to this repository"
)
16 changes: 15 additions & 1 deletion sqlalchemy_bind_manager/_repository/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,29 @@
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

from typing import Generic, List, TypeVar, Union
from typing import Generic, List, Type, TypeVar, Union
from uuid import UUID

from pydantic import BaseModel, StrictInt, StrictStr
from sqlalchemy import inspect

MODEL = TypeVar("MODEL")
PRIMARY_KEY = Union[str, int, tuple, dict, UUID]


def get_model_pk_name(model_class: Type) -> str:
"""Retrieves the primary key column name from a SQLAlchemy model class.

:param model_class: A SQLAlchemy model class
:return: The name of the primary key column
:raises NotImplementedError: If the model has composite primary keys
"""
primary_keys = inspect(model_class).primary_key # type: ignore
if len(primary_keys) > 1:
raise NotImplementedError("Composite primary keys are not supported.")
return primary_keys[0].name


class PageInfo(BaseModel):
"""
Paginated query metadata.
Expand Down
13 changes: 2 additions & 11 deletions sqlalchemy_bind_manager/_repository/result_presenters.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
from math import ceil
from typing import List, Union

from sqlalchemy import inspect

from .common import (
MODEL,
CursorPageInfo,
CursorPaginatedResult,
CursorReference,
PageInfo,
PaginatedResult,
get_model_pk_name,
)


Expand Down Expand Up @@ -93,7 +92,7 @@ def _build_no_cursor_result(
has_next_page = len(result_items) > items_per_page
if has_next_page:
result_items = result_items[0:items_per_page]
reference_column = _pk_from_result_object(result_items[0])
reference_column = get_model_pk_name(type(result_items[0]))

return CursorPaginatedResult(
items=result_items,
Expand Down Expand Up @@ -237,11 +236,3 @@ def build_result(
has_previous_page=has_previous_page,
),
)


def _pk_from_result_object(model) -> str:
primary_keys = inspect(type(model)).primary_key # type: ignore
if len(primary_keys) > 1:
raise NotImplementedError("Composite primary keys are not supported.")

return primary_keys[0].name
10 changes: 8 additions & 2 deletions sqlalchemy_bind_manager/_session_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# DEALINGS IN THE SOFTWARE.

import asyncio
import logging
from contextlib import asynccontextmanager, contextmanager
from typing import AsyncIterator, Iterator

Expand All @@ -34,6 +35,8 @@
)
from sqlalchemy_bind_manager.exceptions import UnsupportedBindError

logger = logging.getLogger(__name__)


class SessionHandler:
scoped_session: scoped_session
Expand All @@ -45,8 +48,11 @@ def __init__(self, bind: SQLAlchemyBind):
self.scoped_session = scoped_session(bind.session_class)

def __del__(self):
if getattr(self, "scoped_session", None):
self.scoped_session.remove()
try:
if getattr(self, "scoped_session", None):
self.scoped_session.remove()
except Exception:
logger.debug("Failed to remove scoped session", exc_info=True)

@contextmanager
def get_session(self, read_only: bool = False) -> Iterator[Session]:
Expand Down
6 changes: 3 additions & 3 deletions tests/repository/result_presenters/test_composite_pk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

import pytest

from sqlalchemy_bind_manager._repository.result_presenters import _pk_from_result_object
from sqlalchemy_bind_manager._repository.common import get_model_pk_name


def test_exception_raised_if_multiple_primary_keys():
with (
patch(
"sqlalchemy_bind_manager._repository.result_presenters.inspect",
"sqlalchemy_bind_manager._repository.common.inspect",
return_value=Mock(primary_key=["1", "2"]),
),
pytest.raises(NotImplementedError),
):
_pk_from_result_object("irrelevant")
get_model_pk_name(str)
15 changes: 15 additions & 0 deletions tests/session_handler/test_session_lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ def test_sync_session_is_removed_on_cleanup(sa_manager):
mocked_remove.assert_called_once()


def test_sync_session_cleanup_handles_exception(sa_manager):
"""Test that __del__ gracefully handles exceptions from scoped_session.remove()."""
sh = SessionHandler(sa_manager.get_bind("sync"))

with patch.object(
sh.scoped_session,
"remove",
side_effect=Exception("Connection already closed"),
) as mocked_remove:
# This should not raise - the exception should be caught and logged
sh.__del__()

mocked_remove.assert_called_once()


@pytest.mark.parametrize("read_only_flag", [True, False])
async def test_commit_is_called_only_if_not_read_only(
read_only_flag,
Expand Down
Loading