Skip to content

Commit 577cc54

Browse files
author
Peng Ren
committed
Added transaction support
1 parent fd1e31d commit 577cc54

File tree

5 files changed

+644
-27
lines changed

5 files changed

+644
-27
lines changed

pymongosql/connection.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from .common import BaseCursor
1313
from .cursor import Cursor
14-
from .error import DatabaseError, NotSupportedError, OperationalError
14+
from .error import DatabaseError, OperationalError
1515
from .helper import ConnectionHelper
1616

1717
_logger = logging.getLogger(__name__)
@@ -212,8 +212,8 @@ def in_transaction(self) -> bool:
212212
return self._in_transaction
213213

214214
@in_transaction.setter
215-
def in_transaction(self, value: bool) -> bool:
216-
self._in_transaction = False
215+
def in_transaction(self, value: bool) -> None:
216+
self._in_transaction = value
217217

218218
@property
219219
def host(self) -> str:
@@ -407,24 +407,54 @@ def _with_transaction(self, callback, **kwargs):
407407
return self._session.with_transaction(callback, **kwargs)
408408

409409
def begin(self) -> None:
410-
"""Begin transaction (DB-API 2.0 standard method)"""
410+
"""Begin transaction (DB-API 2.0 standard method)
411+
412+
Starts an explicit transaction. After calling begin(), operations
413+
are executed within the transaction context until commit() or
414+
rollback() is called. Requires MongoDB 4.0+ for multi-document
415+
transactions on replica sets or sharded clusters.
416+
417+
Example:
418+
conn.begin()
419+
try:
420+
cursor.execute("INSERT INTO users VALUES (...)")
421+
cursor.execute("UPDATE accounts SET balance = balance - 100")
422+
conn.commit()
423+
except Exception:
424+
conn.rollback()
425+
426+
Raises:
427+
OperationalError: If unable to start transaction
428+
"""
411429
self._start_transaction()
412430

413431
def commit(self) -> None:
414-
"""Commit transaction (DB-API 2.0 standard method)"""
432+
"""Commit transaction (DB-API 2.0 standard method)
433+
434+
Commits the current transaction to the database. All operations
435+
executed since begin() will be atomically persisted. If no
436+
transaction is active, this is a no-op (DB-API 2.0 compliant).
437+
438+
Raises:
439+
OperationalError: If commit fails
440+
"""
415441
if self._session and self._session.in_transaction:
416442
self._commit_transaction()
417-
else:
418-
# Fallback for non-session based operations
419-
self._in_transaction = False
420-
self._autocommit = True
443+
# If no transaction, this is a no-op (DB-API 2.0 compliant)
421444

422445
def rollback(self) -> None:
423-
"""Rollback transaction (DB-API 2.0 standard method)"""
446+
"""Rollback transaction (DB-API 2.0 standard method)
447+
448+
Rolls back (aborts) the current transaction, undoing all operations
449+
executed since begin(). If no transaction is active, this is a no-op
450+
(DB-API 2.0 compliant).
451+
452+
Raises:
453+
OperationalError: If rollback fails
454+
"""
424455
if self._session and self._session.in_transaction:
425456
self._abort_transaction()
426-
else:
427-
raise NotSupportedError("MongoDB doesn't support rollback without an active transaction")
457+
# If no transaction, this is a no-op (DB-API 2.0 compliant)
428458

429459
def test_connection(self) -> bool:
430460
"""Test if the connection is alive"""

pymongosql/executor.py

Lines changed: 59 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,23 @@ def _replace_placeholders(self, obj: Any, parameters: Sequence[Any]) -> Any:
101101
def _execute_execution_plan(
102102
self,
103103
execution_plan: QueryExecutionPlan,
104-
db: Any,
104+
connection: Any = None,
105105
parameters: Optional[Sequence[Any]] = None,
106106
) -> Optional[Dict[str, Any]]:
107-
"""Execute a QueryExecutionPlan against MongoDB using db.command"""
107+
"""Execute a QueryExecutionPlan against MongoDB using db.command
108+
109+
Args:
110+
execution_plan: QueryExecutionPlan to execute
111+
connection: Connection object (for session and database access)
112+
parameters: Parameters for placeholder replacement
113+
"""
108114
try:
115+
# Get database from connection
116+
if not connection:
117+
raise OperationalError("No connection provided")
118+
119+
db = connection.database
120+
109121
# Get database
110122
if not execution_plan.collection:
111123
raise ProgrammingError("No collection specified in query")
@@ -144,8 +156,11 @@ def _execute_execution_plan(
144156

145157
_logger.debug(f"Executing MongoDB command: {find_command}")
146158

147-
# Execute find command directly
148-
result = db.command(find_command)
159+
# Execute find command with session if in transaction
160+
if connection and connection.session and connection.session.in_transaction:
161+
result = db.command(find_command, session=connection.session)
162+
else:
163+
result = db.command(find_command)
149164

150165
# Create command result
151166
return result
@@ -182,7 +197,7 @@ def execute(
182197
# Parse the query
183198
self._execution_plan = self._parse_sql(processed_query)
184199

185-
return self._execute_execution_plan(self._execution_plan, connection.database, processed_params)
200+
return self._execute_execution_plan(self._execution_plan, connection, processed_params)
186201

187202

188203
class InsertExecution(ExecutionStrategy):
@@ -224,10 +239,16 @@ def _replace_placeholders(
224239
def _execute_execution_plan(
225240
self,
226241
execution_plan: InsertExecutionPlan,
227-
db: Any,
242+
connection: Any = None,
228243
parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
229244
) -> Optional[Dict[str, Any]]:
230245
try:
246+
# Get database from connection
247+
if not connection:
248+
raise OperationalError("No connection provided")
249+
250+
db = connection.database
251+
231252
if not execution_plan.collection:
232253
raise ProgrammingError("No collection specified in insert")
233254

@@ -238,7 +259,11 @@ def _execute_execution_plan(
238259

239260
_logger.debug(f"Executing MongoDB insert command: {command}")
240261

241-
return db.command(command)
262+
# Execute with session if in transaction
263+
if connection and connection.session and connection.session.in_transaction:
264+
return db.command(command, session=connection.session)
265+
else:
266+
return db.command(command)
242267
except PyMongoError as e:
243268
_logger.error(f"MongoDB insert failed: {e}")
244269
raise DatabaseError(f"Insert execution failed: {e}")
@@ -259,7 +284,7 @@ def execute(
259284

260285
self._execution_plan = self._parse_sql(context.query)
261286

262-
return self._execute_execution_plan(self._execution_plan, connection.database, parameters)
287+
return self._execute_execution_plan(self._execution_plan, connection, parameters)
263288

264289

265290
class DeleteExecution(ExecutionStrategy):
@@ -293,10 +318,16 @@ def _parse_sql(self, sql: str) -> Any:
293318
def _execute_execution_plan(
294319
self,
295320
execution_plan: Any,
296-
db: Any,
321+
connection: Any = None,
297322
parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
298323
) -> Optional[Dict[str, Any]]:
299324
try:
325+
# Get database from connection
326+
if not connection:
327+
raise OperationalError("No connection provided")
328+
329+
db = connection.database
330+
300331
if not execution_plan.collection:
301332
raise ProgrammingError("No collection specified in delete")
302333

@@ -312,7 +343,11 @@ def _execute_execution_plan(
312343

313344
_logger.debug(f"Executing MongoDB delete command: {command}")
314345

315-
return db.command(command)
346+
# Execute with session if in transaction
347+
if connection and connection.session and connection.session.in_transaction:
348+
return db.command(command, session=connection.session)
349+
else:
350+
return db.command(command)
316351
except PyMongoError as e:
317352
_logger.error(f"MongoDB delete failed: {e}")
318353
raise DatabaseError(f"Delete execution failed: {e}")
@@ -333,7 +368,7 @@ def execute(
333368

334369
self._execution_plan = self._parse_sql(context.query)
335370

336-
return self._execute_execution_plan(self._execution_plan, connection.database, parameters)
371+
return self._execute_execution_plan(self._execution_plan, connection, parameters)
337372

338373

339374
class UpdateExecution(ExecutionStrategy):
@@ -367,10 +402,16 @@ def _parse_sql(self, sql: str) -> Any:
367402
def _execute_execution_plan(
368403
self,
369404
execution_plan: Any,
370-
db: Any,
405+
connection: Any = None,
371406
parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
372407
) -> Optional[Dict[str, Any]]:
373408
try:
409+
# Get database from connection
410+
if not connection:
411+
raise OperationalError("No connection provided")
412+
413+
db = connection.database
414+
374415
if not execution_plan.collection:
375416
raise ProgrammingError("No collection specified in update")
376417

@@ -406,7 +447,11 @@ def _execute_execution_plan(
406447

407448
_logger.debug(f"Executing MongoDB update command: {command}")
408449

409-
return db.command(command)
450+
# Execute with session if in transaction
451+
if connection and connection.session and connection.session.in_transaction:
452+
return db.command(command, session=connection.session)
453+
else:
454+
return db.command(command)
410455
except PyMongoError as e:
411456
_logger.error(f"MongoDB update failed: {e}")
412457
raise DatabaseError(f"Update execution failed: {e}")
@@ -427,7 +472,7 @@ def execute(
427472

428473
self._execution_plan = self._parse_sql(context.query)
429474

430-
return self._execute_execution_plan(self._execution_plan, connection.database, parameters)
475+
return self._execute_execution_plan(self._execution_plan, connection, parameters)
431476

432477

433478
class ExecutionPlanFactory:

pymongosql/superset_mongodb/executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def execute(
6363
_logger.debug(f"Stage 1: Executing MongoDB subquery: {mongo_query}")
6464

6565
mongo_execution_plan = self._parse_sql(mongo_query)
66-
mongo_result = self._execute_execution_plan(mongo_execution_plan, connection.database)
66+
mongo_result = self._execute_execution_plan(mongo_execution_plan, connection)
6767

6868
# Extract result set from MongoDB
6969
mongo_result_set = ResultSet(

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ testpaths = ["tests"]
8989
python_files = ["test_*.py"]
9090
python_classes = ["Test*"]
9191
python_functions = ["test_*"]
92+
markers = [
93+
"transactional: marks tests that require MongoDB transaction support (requires replica set or sharded cluster)",
94+
]
9295

9396
[tool.coverage.run]
9497
source = ["pymongosql"]

0 commit comments

Comments
 (0)