Skip to content

Commit 95dabb2

Browse files
author
Peng Ren
committed
Added basic update support
1 parent 4b18840 commit 95dabb2

File tree

10 files changed

+813
-11
lines changed

10 files changed

+813
-11
lines changed

pymongosql/executor.py

Lines changed: 97 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88

99
from .error import DatabaseError, OperationalError, ProgrammingError, SqlSyntaxError
1010
from .helper import SQLHelper
11+
from .sql.delete_builder import DeleteExecutionPlan
1112
from .sql.insert_builder import InsertExecutionPlan
1213
from .sql.parser import SQLParser
1314
from .sql.query_builder import QueryExecutionPlan
15+
from .sql.update_builder import UpdateExecutionPlan
1416

1517
_logger = logging.getLogger(__name__)
1618

@@ -263,10 +265,6 @@ def execute(
263265
class DeleteExecution(ExecutionStrategy):
264266
"""Strategy for executing DELETE statements."""
265267

266-
def __init__(self) -> None:
267-
"""Initialize DELETE execution strategy."""
268-
self._execution_plan: Optional[Any] = None
269-
270268
@property
271269
def execution_plan(self) -> Any:
272270
return self._execution_plan
@@ -275,8 +273,6 @@ def supports(self, context: ExecutionContext) -> bool:
275273
return context.query.lstrip().upper().startswith("DELETE")
276274

277275
def _parse_sql(self, sql: str) -> Any:
278-
from .sql.delete_builder import DeleteExecutionPlan
279-
280276
try:
281277
parser = SQLParser(sql)
282278
plan = parser.get_execution_plan()
@@ -340,10 +336,104 @@ def execute(
340336
return self._execute_execution_plan(self._execution_plan, connection.database, parameters)
341337

342338

339+
class UpdateExecution(ExecutionStrategy):
340+
"""Strategy for executing UPDATE statements."""
341+
342+
@property
343+
def execution_plan(self) -> Any:
344+
return self._execution_plan
345+
346+
def supports(self, context: ExecutionContext) -> bool:
347+
return context.query.lstrip().upper().startswith("UPDATE")
348+
349+
def _parse_sql(self, sql: str) -> Any:
350+
try:
351+
parser = SQLParser(sql)
352+
plan = parser.get_execution_plan()
353+
354+
if not isinstance(plan, UpdateExecutionPlan):
355+
raise SqlSyntaxError("Expected UPDATE execution plan")
356+
357+
if not plan.validate():
358+
raise SqlSyntaxError("Generated update plan is invalid")
359+
360+
return plan
361+
except SqlSyntaxError:
362+
raise
363+
except Exception as e:
364+
_logger.error(f"SQL parsing failed: {e}")
365+
raise SqlSyntaxError(f"Failed to parse SQL: {e}")
366+
367+
def _execute_execution_plan(
368+
self,
369+
execution_plan: Any,
370+
db: Any,
371+
parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
372+
) -> Optional[Dict[str, Any]]:
373+
try:
374+
if not execution_plan.collection:
375+
raise ProgrammingError("No collection specified in update")
376+
377+
if not execution_plan.update_fields:
378+
raise ProgrammingError("No fields to update specified")
379+
380+
filter_conditions = execution_plan.filter_conditions or {}
381+
update_fields = execution_plan.update_fields or {}
382+
383+
# Replace placeholders if parameters provided
384+
# Note: We need to replace both update_fields and filter_conditions in one pass
385+
# to maintain correct parameter ordering (SET clause first, then WHERE clause)
386+
if parameters:
387+
# Combine structures for replacement in correct order
388+
combined = {"update_fields": update_fields, "filter_conditions": filter_conditions}
389+
replaced = SQLHelper.replace_placeholders_generic(combined, parameters, execution_plan.parameter_style)
390+
update_fields = replaced["update_fields"]
391+
filter_conditions = replaced["filter_conditions"]
392+
393+
# MongoDB update command format
394+
# https://www.mongodb.com/docs/manual/reference/command/update/
395+
command = {
396+
"update": execution_plan.collection,
397+
"updates": [
398+
{
399+
"q": filter_conditions, # query filter
400+
"u": {"$set": update_fields}, # update document using $set operator
401+
"multi": True, # update all matching documents (like SQL UPDATE)
402+
"upsert": False, # don't insert if no match
403+
}
404+
],
405+
}
406+
407+
_logger.debug(f"Executing MongoDB update command: {command}")
408+
409+
return db.command(command)
410+
except PyMongoError as e:
411+
_logger.error(f"MongoDB update failed: {e}")
412+
raise DatabaseError(f"Update execution failed: {e}")
413+
except (ProgrammingError, DatabaseError, OperationalError):
414+
# Re-raise our own errors without wrapping
415+
raise
416+
except Exception as e:
417+
_logger.error(f"Unexpected error during update execution: {e}")
418+
raise OperationalError(f"Update execution error: {e}")
419+
420+
def execute(
421+
self,
422+
context: ExecutionContext,
423+
connection: Any,
424+
parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
425+
) -> Optional[Dict[str, Any]]:
426+
_logger.debug(f"Using update execution for query: {context.query[:100]}")
427+
428+
self._execution_plan = self._parse_sql(context.query)
429+
430+
return self._execute_execution_plan(self._execution_plan, connection.database, parameters)
431+
432+
343433
class ExecutionPlanFactory:
344434
"""Factory for creating appropriate execution strategy based on query context"""
345435

346-
_strategies = [DeleteExecution(), InsertExecution(), StandardQueryExecution()]
436+
_strategies = [StandardQueryExecution(), InsertExecution(), UpdateExecution(), DeleteExecution()]
347437

348438
@classmethod
349439
def get_strategy(cls, context: ExecutionContext) -> ExecutionStrategy:

pymongosql/result_set.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,21 @@ def errors(self) -> List[Dict[str, str]]:
198198

199199
@property
200200
def rowcount(self) -> int:
201-
"""Return number of rows fetched so far (not total available)"""
201+
"""Return number of rows fetched/affected"""
202+
# Check for write operation results (UPDATE, DELETE, INSERT)
203+
if hasattr(self, "_insert_result") and self._insert_result:
204+
# INSERT operation - return number of inserted documents
205+
return self._insert_result.get("n", 0)
206+
207+
# Check command result for write operations
208+
if self._command_result:
209+
# For UPDATE/DELETE operations, check 'n' (modified count) or 'nModified'
210+
if "n" in self._command_result:
211+
return self._command_result.get("n", 0)
212+
if "nModified" in self._command_result:
213+
return self._command_result.get("nModified", 0)
214+
215+
# For SELECT/QUERY operations, return number of fetched rows
202216
return self._total_fetched
203217

204218
@property

pymongosql/sql/ast.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from .partiql.PartiQLParserVisitor import PartiQLParserVisitor
1515
from .query_builder import QueryExecutionPlan
1616
from .query_handler import QueryParseResult
17+
from .update_builder import UpdateExecutionPlan
18+
from .update_handler import UpdateParseResult
1719

1820
_logger = logging.getLogger(__name__)
1921

@@ -38,6 +40,7 @@ def __init__(self) -> None:
3840
self._parse_result = QueryParseResult.for_visitor()
3941
self._insert_parse_result = InsertParseResult.for_visitor()
4042
self._delete_parse_result = DeleteParseResult.for_visitor()
43+
self._update_parse_result = UpdateParseResult.for_visitor()
4144
# Track current statement kind generically so UPDATE/DELETE can reuse this
4245
self._current_operation: str = "select" # expected values: select | insert | update | delete
4346
self._handlers = self._initialize_handlers()
@@ -50,6 +53,7 @@ def _initialize_handlers(self) -> Dict[str, BaseHandler]:
5053
"from": HandlerFactory.get_visitor_handler("from"),
5154
"where": HandlerFactory.get_visitor_handler("where"),
5255
"insert": HandlerFactory.get_visitor_handler("insert"),
56+
"update": HandlerFactory.get_visitor_handler("update"),
5357
"delete": HandlerFactory.get_visitor_handler("delete"),
5458
}
5559

@@ -58,12 +62,16 @@ def parse_result(self) -> QueryParseResult:
5862
"""Get the current parse result"""
5963
return self._parse_result
6064

61-
def parse_to_execution_plan(self) -> Union[QueryExecutionPlan, InsertExecutionPlan, DeleteExecutionPlan]:
65+
def parse_to_execution_plan(
66+
self,
67+
) -> Union[QueryExecutionPlan, InsertExecutionPlan, DeleteExecutionPlan, UpdateExecutionPlan]:
6268
"""Convert the parse result to an execution plan using BuilderFactory."""
6369
if self._current_operation == "insert":
6470
return self._build_insert_plan()
6571
elif self._current_operation == "delete":
6672
return self._build_delete_plan()
73+
elif self._current_operation == "update":
74+
return self._build_update_plan()
6775

6876
return self._build_query_plan()
6977

@@ -110,6 +118,23 @@ def _build_delete_plan(self) -> DeleteExecutionPlan:
110118

111119
return builder.build()
112120

121+
def _build_update_plan(self) -> UpdateExecutionPlan:
122+
"""Build an UPDATE execution plan from UPDATE parsing."""
123+
_logger.debug(
124+
f"Building UPDATE plan with collection: {self._update_parse_result.collection}, "
125+
f"update_fields: {self._update_parse_result.update_fields}, "
126+
f"filters: {self._update_parse_result.filter_conditions}"
127+
)
128+
builder = BuilderFactory.create_update_builder().collection(self._update_parse_result.collection)
129+
130+
if self._update_parse_result.update_fields:
131+
builder.update_fields(self._update_parse_result.update_fields)
132+
133+
if self._update_parse_result.filter_conditions:
134+
builder.filter_conditions(self._update_parse_result.filter_conditions)
135+
136+
return builder.build()
137+
113138
def visitRoot(self, ctx: PartiQLParser.RootContext) -> Any:
114139
"""Visit root node and process child nodes"""
115140
_logger.debug("Starting to parse SQL query")
@@ -212,6 +237,12 @@ def visitWhereClause(self, ctx: PartiQLParser.WhereClauseContext) -> Any:
212237
if handler:
213238
return handler.handle_where_clause(ctx, self._delete_parse_result)
214239
return {}
240+
# For UPDATE, use the update handler
241+
elif self._current_operation == "update":
242+
handler = self._handlers.get("update")
243+
if handler:
244+
return handler.handle_where_clause(ctx, self._update_parse_result)
245+
return {}
215246
else:
216247
# For other operations, use the where handler
217248
handler = self._handlers["where"]
@@ -293,3 +324,29 @@ def visitOffsetByClause(self, ctx: PartiQLParser.OffsetByClauseContext) -> Any:
293324
except Exception as e:
294325
_logger.warning(f"Error processing OFFSET clause: {e}")
295326
return self.visitChildren(ctx)
327+
328+
def visitUpdateClause(self, ctx: PartiQLParser.UpdateClauseContext) -> Any:
329+
"""Handle UPDATE clause to extract collection/table name."""
330+
_logger.debug("Processing UPDATE clause")
331+
self._current_operation = "update"
332+
# Reset update parse result for this statement
333+
self._update_parse_result = UpdateParseResult.for_visitor()
334+
335+
handler = self._handlers.get("update")
336+
if handler:
337+
handler.handle_visitor(ctx, self._update_parse_result)
338+
339+
# Visit children to process SET and WHERE clauses
340+
return self.visitChildren(ctx)
341+
342+
def visitSetCommand(self, ctx: PartiQLParser.SetCommandContext) -> Any:
343+
"""Handle SET command for UPDATE statements."""
344+
_logger.debug("Processing SET command")
345+
346+
if self._current_operation == "update":
347+
handler = self._handlers.get("update")
348+
if handler:
349+
handler.handle_set_command(ctx, self._update_parse_result)
350+
return None
351+
352+
return self.visitChildren(ctx)

pymongosql/sql/builder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ def create_delete_builder():
5757

5858
return MongoDeleteBuilder()
5959

60+
@staticmethod
61+
def create_update_builder():
62+
"""Create a builder for UPDATE queries"""
63+
# Local import to avoid circular dependency during module import
64+
from .update_builder import MongoUpdateBuilder
65+
66+
return MongoUpdateBuilder()
67+
6068

6169
__all__ = [
6270
"ExecutionPlan",

pymongosql/sql/handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,13 +809,15 @@ def _initialize_visitor_handlers(cls):
809809
from .delete_handler import DeleteHandler
810810
from .insert_handler import InsertHandler
811811
from .query_handler import FromHandler, SelectHandler, WhereHandler
812+
from .update_handler import UpdateHandler
812813

813814
cls._visitor_handlers = {
814815
"select": SelectHandler(),
815816
"from": FromHandler(),
816817
"where": WhereHandler(),
817818
"insert": InsertHandler(),
818819
"delete": DeleteHandler(),
820+
"update": UpdateHandler(),
819821
}
820822
return cls._visitor_handlers
821823

pymongosql/sql/parser.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .delete_builder import DeleteExecutionPlan
1212
from .insert_builder import InsertExecutionPlan
1313
from .query_builder import QueryExecutionPlan
14+
from .update_builder import UpdateExecutionPlan
1415

1516
_logger = logging.getLogger(__name__)
1617

@@ -128,8 +129,10 @@ def _validate_ast(self) -> None:
128129

129130
_logger.debug("AST validation successful")
130131

131-
def get_execution_plan(self) -> Union[QueryExecutionPlan, InsertExecutionPlan, DeleteExecutionPlan]:
132-
"""Parse SQL and return an execution plan (SELECT, INSERT, or DELETE)."""
132+
def get_execution_plan(
133+
self,
134+
) -> Union[QueryExecutionPlan, InsertExecutionPlan, DeleteExecutionPlan, UpdateExecutionPlan]:
135+
"""Parse SQL and return an execution plan (SELECT, INSERT, DELETE, or UPDATE)."""
133136
if self._ast is None:
134137
raise SqlSyntaxError("No AST available - parsing may have failed")
135138

0 commit comments

Comments
 (0)