diff --git a/README.md b/README.md
index e21251a..a8cb8e7 100644
--- a/README.md
+++ b/README.md
@@ -16,10 +16,10 @@ PyMongoSQL is a Python [DB API 2.0 (PEP 249)](https://www.python.org/dev/peps/pe
PyMongoSQL implements the DB API 2.0 interfaces to provide SQL-like access to MongoDB, built on PartiQL syntax for querying semi-structured data. The project aims to:
-- Bridge the gap between SQL and NoSQL by providing SQL capabilities for MongoDB's nested document structures
-- Support standard SQL DQL (Data Query Language) operations including SELECT statements with WHERE, ORDER BY, and LIMIT clauses on nested and hierarchical data
-- Provide seamless integration with existing Python applications that expect DB API 2.0 compliance
-- Enable easy migration from traditional SQL databases to MongoDB without rewriting queries for document traversal
+- **Bridge SQL and NoSQL**: Provide SQL capabilities for MongoDB's nested document structures
+- **Standard SQL Operations**: Support DQL (SELECT) and DML (INSERT, UPDATE, DELETE) operations with WHERE, ORDER BY, and LIMIT clauses
+- **Seamless Integration**: Full compatibility with Python applications expecting DB API 2.0 compliance
+- **Easy Migration**: Enable migration from traditional SQL databases to MongoDB without rewriting application code
## Features
@@ -28,6 +28,7 @@ PyMongoSQL implements the DB API 2.0 interfaces to provide SQL-like access to Mo
- **Nested Structure Support**: Query and filter deeply nested fields and arrays within MongoDB documents using standard SQL syntax
- **SQLAlchemy Integration**: Complete ORM and Core support with dedicated MongoDB dialect
- **SQL Query Support**: SELECT statements with WHERE conditions, field selection, and aliases
+- **DML Support**: Full support for INSERT, UPDATE, and DELETE operations using PartiQL syntax
- **Connection String Support**: MongoDB URI format for easy configuration
## Requirements
@@ -184,16 +185,18 @@ Parameters are substituted into the MongoDB filter during execution, providing p
## Supported SQL Features
### SELECT Statements
-- Field selection: `SELECT name, age FROM users`
-- Wildcards: `SELECT * FROM products`
-- **Field aliases**: `SELECT name as user_name, age as user_age FROM users`
+
+- **Field selection**: `SELECT name, age FROM users`
+- **Wildcards**: `SELECT * FROM products`
+- **Field aliases**: `SELECT name AS user_name, age AS user_age FROM users`
- **Nested fields**: `SELECT profile.name, profile.age FROM users`
- **Array access**: `SELECT items[0], items[1].name FROM orders`
### WHERE Clauses
-- Equality: `WHERE name = 'John'`
-- Comparisons: `WHERE age > 25`, `WHERE price <= 100.0`
-- Logical operators: `WHERE age > 18 AND status = 'active'`
+
+- **Equality**: `WHERE name = 'John'`
+- **Comparisons**: `WHERE age > 25`, `WHERE price <= 100.0`
+- **Logical operators**: `WHERE age > 18 AND status = 'active'`, `WHERE age < 30 OR role = 'admin'`
- **Nested field filtering**: `WHERE profile.status = 'active'`
- **Array filtering**: `WHERE items[0].price > 100`
@@ -206,9 +209,140 @@ Parameters are substituted into the MongoDB filter during execution, providing p
> **Note**: Avoid SQL reserved words (`user`, `data`, `value`, `count`, etc.) as unquoted field names. Use alternatives or bracket notation for arrays.
### Sorting and Limiting
-- ORDER BY: `ORDER BY name ASC, age DESC`
-- LIMIT: `LIMIT 10`
-- Combined: `ORDER BY created_at DESC LIMIT 5`
+
+- **ORDER BY**: `ORDER BY name ASC, age DESC`
+- **LIMIT**: `LIMIT 10`
+- **Combined**: `ORDER BY created_at DESC LIMIT 5`
+
+### INSERT Statements
+
+PyMongoSQL supports inserting documents into MongoDB collections using PartiQL-style object and bag literals.
+
+**Single Document**
+
+```python
+cursor.execute(
+ "INSERT INTO Music {'title': 'Song A', 'artist': 'Alice', 'year': 2021}"
+)
+```
+
+**Multiple Documents (Bag Syntax)**
+
+```python
+cursor.execute(
+ "INSERT INTO Music << {'title': 'Song B', 'artist': 'Bob'}, {'title': 'Song C', 'artist': 'Charlie'} >>"
+)
+```
+
+**Parameterized INSERT**
+
+```python
+# Positional parameters using ? placeholders
+cursor.execute(
+ "INSERT INTO Music {'title': ?, 'artist': ?, 'year': ?}",
+ ["Song D", "Diana", 2020]
+)
+```
+
+> **Note**: For parameterized INSERT, use positional parameters (`?`). Named placeholders (`:name`) are supported for SELECT, UPDATE, and DELETE queries.
+
+### UPDATE Statements
+
+PyMongoSQL supports updating documents in MongoDB collections using standard SQL UPDATE syntax.
+
+**Update All Documents**
+
+```python
+cursor.execute("UPDATE Music SET available = false")
+```
+
+**Update with WHERE Clause**
+
+```python
+cursor.execute("UPDATE Music SET price = 14.99 WHERE year < 2020")
+```
+
+**Update Multiple Fields**
+
+```python
+cursor.execute(
+ "UPDATE Music SET price = 19.99, available = true WHERE artist = 'Alice'"
+)
+```
+
+**Update with Logical Operators**
+
+```python
+cursor.execute(
+ "UPDATE Music SET price = 9.99 WHERE year = 2020 AND stock > 5"
+)
+```
+
+**Parameterized UPDATE**
+
+```python
+# Positional parameters using ? placeholders
+cursor.execute(
+ "UPDATE Music SET price = ?, stock = ? WHERE artist = ?",
+ [24.99, 50, "Bob"]
+)
+```
+
+**Update Nested Fields**
+
+```python
+cursor.execute(
+ "UPDATE Music SET details.publisher = 'XYZ Records' WHERE title = 'Song A'"
+)
+```
+
+**Check Updated Row Count**
+
+```python
+cursor.execute("UPDATE Music SET available = false WHERE year = 2020")
+print(f"Updated {cursor.rowcount} documents")
+```
+
+### DELETE Statements
+
+PyMongoSQL supports deleting documents from MongoDB collections using standard SQL DELETE syntax.
+
+**Delete All Documents**
+
+```python
+cursor.execute("DELETE FROM Music")
+```
+
+**Delete with WHERE Clause**
+
+```python
+cursor.execute("DELETE FROM Music WHERE year < 2020")
+```
+
+**Delete with Logical Operators**
+
+```python
+cursor.execute(
+ "DELETE FROM Music WHERE year = 2019 AND available = false"
+)
+```
+
+**Parameterized DELETE**
+
+```python
+# Positional parameters using ? placeholders
+cursor.execute(
+ "DELETE FROM Music WHERE artist = ? AND year < ?",
+ ["Charlie", 2021]
+)
+```
+
+**Check Deleted Row Count**
+
+```python
+cursor.execute("DELETE FROM Music WHERE available = false")
+print(f"Deleted {cursor.rowcount} documents")
+```
## Apache Superset Integration
@@ -231,16 +365,18 @@ PyMongoSQL can be used as a database driver in Apache Superset for querying and
This allows seamless integration between MongoDB data and Superset's BI capabilities without requiring data migration to traditional SQL databases.
-
Limitations & Roadmap
+## Limitations & Roadmap
-**Note**: Currently PyMongoSQL focuses on Data Query Language (DQL) operations. The following SQL features are **not yet supported** but are planned for future releases:
+**Note**: PyMongoSQL currently supports DQL (Data Query Language) and DML (Data Manipulation Language) operations. The following SQL features are **not yet supported** but are planned for future releases:
-- **DML Operations** (Data Manipulation Language)
- - `INSERT`, `UPDATE`, `DELETE`
- **DDL Operations** (Data Definition Language)
- `CREATE TABLE/COLLECTION`, `DROP TABLE/COLLECTION`
- `CREATE INDEX`, `DROP INDEX`
- `LIST TABLES/COLLECTIONS`
+ - `ALTER TABLE/COLLECTION`
+- **Advanced DML Operations**
+ - `MERGE`, `UPSERT`
+ - Transactions and multi-document operations
These features are on our development roadmap and contributions are welcome!
diff --git a/pymongosql/__init__.py b/pymongosql/__init__.py
index ae4b9fb..fae40bd 100644
--- a/pymongosql/__init__.py
+++ b/pymongosql/__init__.py
@@ -6,7 +6,7 @@
if TYPE_CHECKING:
from .connection import Connection
-__version__: str = "0.2.5"
+__version__: str = "0.3.0"
# Globals https://www.python.org/dev/peps/pep-0249/#globals
apilevel: str = "2.0"
diff --git a/pymongosql/cursor.py b/pymongosql/cursor.py
index 45060e2..4c8bdb7 100644
--- a/pymongosql/cursor.py
+++ b/pymongosql/cursor.py
@@ -6,7 +6,7 @@
from .error import DatabaseError, OperationalError, ProgrammingError, SqlSyntaxError
from .executor import ExecutionContext, ExecutionPlanFactory
from .result_set import DictResultSet, ResultSet
-from .sql.builder import ExecutionPlan
+from .sql.query_builder import QueryExecutionPlan
if TYPE_CHECKING:
from .connection import Connection
@@ -29,7 +29,7 @@ def __init__(self, connection: "Connection", mode: str = "standard", **kwargs) -
self._kwargs = kwargs
self._result_set: Optional[ResultSet] = None
self._result_set_class = ResultSet
- self._current_execution_plan: Optional[ExecutionPlan] = None
+ self._current_execution_plan: Optional[Any] = None
self._is_closed = False
@property
@@ -103,12 +103,32 @@ def execute(self: _T, operation: str, parameters: Optional[Any] = None) -> _T:
self._current_execution_plan = strategy.execution_plan
# Create result set from command result
- self._result_set = self._result_set_class(
- command_result=result,
- execution_plan=self._current_execution_plan,
- database=self.connection.database,
- **self._kwargs,
- )
+ # For SELECT/QUERY operations, use the execution plan directly
+ if isinstance(self._current_execution_plan, QueryExecutionPlan):
+ execution_plan_for_rs = self._current_execution_plan
+ self._result_set = self._result_set_class(
+ command_result=result,
+ execution_plan=execution_plan_for_rs,
+ database=self.connection.database,
+ **self._kwargs,
+ )
+ else:
+ # For INSERT and other non-query operations, create a minimal synthetic result
+ # since INSERT commands don't return a cursor structure
+ stub_plan = QueryExecutionPlan(collection=self._current_execution_plan.collection)
+ self._result_set = self._result_set_class(
+ command_result={
+ "cursor": {
+ "id": 0,
+ "firstBatch": [],
+ }
+ },
+ execution_plan=stub_plan,
+ database=self.connection.database,
+ **self._kwargs,
+ )
+ # Store the actual insert result for reference
+ self._result_set._insert_result = result
return self
diff --git a/pymongosql/executor.py b/pymongosql/executor.py
index db1ac90..b9aae8a 100644
--- a/pymongosql/executor.py
+++ b/pymongosql/executor.py
@@ -7,8 +7,12 @@
from pymongo.errors import PyMongoError
from .error import DatabaseError, OperationalError, ProgrammingError, SqlSyntaxError
-from .sql.builder import ExecutionPlan
+from .helper import SQLHelper
+from .sql.delete_builder import DeleteExecutionPlan
+from .sql.insert_builder import InsertExecutionPlan
from .sql.parser import SQLParser
+from .sql.query_builder import QueryExecutionPlan
+from .sql.update_builder import UpdateExecutionPlan
_logger = logging.getLogger(__name__)
@@ -30,7 +34,7 @@ class ExecutionStrategy(ABC):
@property
@abstractmethod
- def execution_plan(self) -> ExecutionPlan:
+ def execution_plan(self) -> Union[QueryExecutionPlan, InsertExecutionPlan]:
"""Name of the execution plan"""
pass
@@ -60,20 +64,21 @@ def supports(self, context: ExecutionContext) -> bool:
pass
-class StandardExecution(ExecutionStrategy):
+class StandardQueryExecution(ExecutionStrategy):
"""Standard execution strategy for simple SELECT queries without subqueries"""
@property
- def execution_plan(self) -> ExecutionPlan:
+ def execution_plan(self) -> QueryExecutionPlan:
"""Return standard execution plan"""
return self._execution_plan
def supports(self, context: ExecutionContext) -> bool:
"""Support simple queries without subqueries"""
- return "standard" in context.execution_mode.lower()
+ normalized = context.query.lstrip().upper()
+ return "standard" in context.execution_mode.lower() and normalized.startswith("SELECT")
- def _parse_sql(self, sql: str) -> ExecutionPlan:
- """Parse SQL statement and return ExecutionPlan"""
+ def _parse_sql(self, sql: str) -> QueryExecutionPlan:
+ """Parse SQL statement and return QueryExecutionPlan"""
try:
parser = SQLParser(sql)
execution_plan = parser.get_execution_plan()
@@ -91,37 +96,15 @@ def _parse_sql(self, sql: str) -> ExecutionPlan:
def _replace_placeholders(self, obj: Any, parameters: Sequence[Any]) -> Any:
"""Recursively replace ? placeholders with parameter values in filter/projection dicts"""
- param_index = [0] # Use list to allow modification in nested function
-
- def replace_recursive(value: Any) -> Any:
- if isinstance(value, str):
- # Replace ? with the next parameter value
- if value == "?":
- if param_index[0] < len(parameters):
- result = parameters[param_index[0]]
- param_index[0] += 1
- return result
- else:
- raise ProgrammingError(
- f"Not enough parameters provided: expected at least {param_index[0] + 1}"
- )
- return value
- elif isinstance(value, dict):
- return {k: replace_recursive(v) for k, v in value.items()}
- elif isinstance(value, list):
- return [replace_recursive(item) for item in value]
- else:
- return value
-
- return replace_recursive(obj)
+ return SQLHelper.replace_placeholders_generic(obj, parameters, "qmark")
def _execute_execution_plan(
self,
- execution_plan: ExecutionPlan,
+ execution_plan: QueryExecutionPlan,
db: Any,
parameters: Optional[Sequence[Any]] = None,
) -> Optional[Dict[str, Any]]:
- """Execute an ExecutionPlan against MongoDB using db.command"""
+ """Execute a QueryExecutionPlan against MongoDB using db.command"""
try:
# Get database
if not execution_plan.collection:
@@ -202,10 +185,255 @@ def execute(
return self._execute_execution_plan(self._execution_plan, connection.database, processed_params)
+class InsertExecution(ExecutionStrategy):
+ """Execution strategy for INSERT statements."""
+
+ @property
+ def execution_plan(self) -> InsertExecutionPlan:
+ return self._execution_plan
+
+ def supports(self, context: ExecutionContext) -> bool:
+ return context.query.lstrip().upper().startswith("INSERT")
+
+ def _parse_sql(self, sql: str) -> InsertExecutionPlan:
+ try:
+ parser = SQLParser(sql)
+ plan = parser.get_execution_plan()
+
+ if not isinstance(plan, InsertExecutionPlan):
+ raise SqlSyntaxError("Expected INSERT execution plan")
+
+ if not plan.validate():
+ raise SqlSyntaxError("Generated insert plan is invalid")
+
+ return plan
+ except SqlSyntaxError:
+ raise
+ except Exception as e:
+ _logger.error(f"SQL parsing failed: {e}")
+ raise SqlSyntaxError(f"Failed to parse SQL: {e}")
+
+ def _replace_placeholders(
+ self,
+ documents: Sequence[Dict[str, Any]],
+ parameters: Optional[Union[Sequence[Any], Dict[str, Any]]],
+ style: Optional[str],
+ ) -> Sequence[Dict[str, Any]]:
+ return SQLHelper.replace_placeholders_generic(documents, parameters, style)
+
+ def _execute_execution_plan(
+ self,
+ execution_plan: InsertExecutionPlan,
+ db: Any,
+ parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
+ ) -> Optional[Dict[str, Any]]:
+ try:
+ if not execution_plan.collection:
+ raise ProgrammingError("No collection specified in insert")
+
+ docs = execution_plan.insert_documents or []
+ docs = self._replace_placeholders(docs, parameters, execution_plan.parameter_style)
+
+ command = {"insert": execution_plan.collection, "documents": docs}
+
+ _logger.debug(f"Executing MongoDB insert command: {command}")
+
+ return db.command(command)
+ except PyMongoError as e:
+ _logger.error(f"MongoDB insert failed: {e}")
+ raise DatabaseError(f"Insert execution failed: {e}")
+ except (ProgrammingError, DatabaseError, OperationalError):
+ # Re-raise our own errors without wrapping
+ raise
+ except Exception as e:
+ _logger.error(f"Unexpected error during insert execution: {e}")
+ raise OperationalError(f"Insert execution error: {e}")
+
+ def execute(
+ self,
+ context: ExecutionContext,
+ connection: Any,
+ parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
+ ) -> Optional[Dict[str, Any]]:
+ _logger.debug(f"Using insert execution for query: {context.query[:100]}")
+
+ self._execution_plan = self._parse_sql(context.query)
+
+ return self._execute_execution_plan(self._execution_plan, connection.database, parameters)
+
+
+class DeleteExecution(ExecutionStrategy):
+ """Strategy for executing DELETE statements."""
+
+ @property
+ def execution_plan(self) -> Any:
+ return self._execution_plan
+
+ def supports(self, context: ExecutionContext) -> bool:
+ return context.query.lstrip().upper().startswith("DELETE")
+
+ def _parse_sql(self, sql: str) -> Any:
+ try:
+ parser = SQLParser(sql)
+ plan = parser.get_execution_plan()
+
+ if not isinstance(plan, DeleteExecutionPlan):
+ raise SqlSyntaxError("Expected DELETE execution plan")
+
+ if not plan.validate():
+ raise SqlSyntaxError("Generated delete plan is invalid")
+
+ return plan
+ except SqlSyntaxError:
+ raise
+ except Exception as e:
+ _logger.error(f"SQL parsing failed: {e}")
+ raise SqlSyntaxError(f"Failed to parse SQL: {e}")
+
+ def _execute_execution_plan(
+ self,
+ execution_plan: Any,
+ db: Any,
+ parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
+ ) -> Optional[Dict[str, Any]]:
+ try:
+ if not execution_plan.collection:
+ raise ProgrammingError("No collection specified in delete")
+
+ filter_conditions = execution_plan.filter_conditions or {}
+
+ # Replace placeholders in filter if parameters provided
+ if parameters and filter_conditions:
+ filter_conditions = SQLHelper.replace_placeholders_generic(
+ filter_conditions, parameters, execution_plan.parameter_style
+ )
+
+ command = {"delete": execution_plan.collection, "deletes": [{"q": filter_conditions, "limit": 0}]}
+
+ _logger.debug(f"Executing MongoDB delete command: {command}")
+
+ return db.command(command)
+ except PyMongoError as e:
+ _logger.error(f"MongoDB delete failed: {e}")
+ raise DatabaseError(f"Delete execution failed: {e}")
+ except (ProgrammingError, DatabaseError, OperationalError):
+ # Re-raise our own errors without wrapping
+ raise
+ except Exception as e:
+ _logger.error(f"Unexpected error during delete execution: {e}")
+ raise OperationalError(f"Delete execution error: {e}")
+
+ def execute(
+ self,
+ context: ExecutionContext,
+ connection: Any,
+ parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
+ ) -> Optional[Dict[str, Any]]:
+ _logger.debug(f"Using delete execution for query: {context.query[:100]}")
+
+ self._execution_plan = self._parse_sql(context.query)
+
+ return self._execute_execution_plan(self._execution_plan, connection.database, parameters)
+
+
+class UpdateExecution(ExecutionStrategy):
+ """Strategy for executing UPDATE statements."""
+
+ @property
+ def execution_plan(self) -> Any:
+ return self._execution_plan
+
+ def supports(self, context: ExecutionContext) -> bool:
+ return context.query.lstrip().upper().startswith("UPDATE")
+
+ def _parse_sql(self, sql: str) -> Any:
+ try:
+ parser = SQLParser(sql)
+ plan = parser.get_execution_plan()
+
+ if not isinstance(plan, UpdateExecutionPlan):
+ raise SqlSyntaxError("Expected UPDATE execution plan")
+
+ if not plan.validate():
+ raise SqlSyntaxError("Generated update plan is invalid")
+
+ return plan
+ except SqlSyntaxError:
+ raise
+ except Exception as e:
+ _logger.error(f"SQL parsing failed: {e}")
+ raise SqlSyntaxError(f"Failed to parse SQL: {e}")
+
+ def _execute_execution_plan(
+ self,
+ execution_plan: Any,
+ db: Any,
+ parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
+ ) -> Optional[Dict[str, Any]]:
+ try:
+ if not execution_plan.collection:
+ raise ProgrammingError("No collection specified in update")
+
+ if not execution_plan.update_fields:
+ raise ProgrammingError("No fields to update specified")
+
+ filter_conditions = execution_plan.filter_conditions or {}
+ update_fields = execution_plan.update_fields or {}
+
+ # Replace placeholders if parameters provided
+ # Note: We need to replace both update_fields and filter_conditions in one pass
+ # to maintain correct parameter ordering (SET clause first, then WHERE clause)
+ if parameters:
+ # Combine structures for replacement in correct order
+ combined = {"update_fields": update_fields, "filter_conditions": filter_conditions}
+ replaced = SQLHelper.replace_placeholders_generic(combined, parameters, execution_plan.parameter_style)
+ update_fields = replaced["update_fields"]
+ filter_conditions = replaced["filter_conditions"]
+
+ # MongoDB update command format
+ # https://www.mongodb.com/docs/manual/reference/command/update/
+ command = {
+ "update": execution_plan.collection,
+ "updates": [
+ {
+ "q": filter_conditions, # query filter
+ "u": {"$set": update_fields}, # update document using $set operator
+ "multi": True, # update all matching documents (like SQL UPDATE)
+ "upsert": False, # don't insert if no match
+ }
+ ],
+ }
+
+ _logger.debug(f"Executing MongoDB update command: {command}")
+
+ return db.command(command)
+ except PyMongoError as e:
+ _logger.error(f"MongoDB update failed: {e}")
+ raise DatabaseError(f"Update execution failed: {e}")
+ except (ProgrammingError, DatabaseError, OperationalError):
+ # Re-raise our own errors without wrapping
+ raise
+ except Exception as e:
+ _logger.error(f"Unexpected error during update execution: {e}")
+ raise OperationalError(f"Update execution error: {e}")
+
+ def execute(
+ self,
+ context: ExecutionContext,
+ connection: Any,
+ parameters: Optional[Union[Sequence[Any], Dict[str, Any]]] = None,
+ ) -> Optional[Dict[str, Any]]:
+ _logger.debug(f"Using update execution for query: {context.query[:100]}")
+
+ self._execution_plan = self._parse_sql(context.query)
+
+ return self._execute_execution_plan(self._execution_plan, connection.database, parameters)
+
+
class ExecutionPlanFactory:
"""Factory for creating appropriate execution strategy based on query context"""
- _strategies = [StandardExecution()]
+ _strategies = [StandardQueryExecution(), InsertExecution(), UpdateExecution(), DeleteExecution()]
@classmethod
def get_strategy(cls, context: ExecutionContext) -> ExecutionStrategy:
@@ -216,7 +444,7 @@ def get_strategy(cls, context: ExecutionContext) -> ExecutionStrategy:
return strategy
# Fallback to standard execution
- return StandardExecution()
+ return StandardQueryExecution()
@classmethod
def register_strategy(cls, strategy: ExecutionStrategy) -> None:
diff --git a/pymongosql/helper.py b/pymongosql/helper.py
index 38a3610..6c1d2cb 100644
--- a/pymongosql/helper.py
+++ b/pymongosql/helper.py
@@ -6,9 +6,11 @@
"""
import logging
-from typing import Optional, Tuple
+from typing import Any, Optional, Sequence, Tuple
from urllib.parse import parse_qs, urlparse
+from .error import ProgrammingError
+
_logger = logging.getLogger(__name__)
@@ -95,3 +97,54 @@ def parse_connection_string(connection_string: Optional[str]) -> Tuple[Optional[
except Exception as e:
_logger.error(f"Failed to parse connection string: {e}")
raise ValueError(f"Invalid connection string format: {e}")
+
+
+class SQLHelper:
+ """SQL-related helper utilities."""
+
+ @staticmethod
+ def replace_placeholders_generic(value: Any, parameters: Any, style: Optional[str]) -> Any:
+ """Recursively replace placeholders in nested structures for qmark or named styles."""
+ if style is None or parameters is None:
+ return value
+
+ if style == "qmark":
+ if not isinstance(parameters, Sequence) or isinstance(parameters, (str, bytes, dict)):
+ raise ProgrammingError("Positional parameters must be provided as a sequence")
+
+ idx = [0]
+
+ def replace(val: Any) -> Any:
+ if isinstance(val, str) and val == "?":
+ if idx[0] >= len(parameters):
+ raise ProgrammingError("Not enough parameters provided")
+ out = parameters[idx[0]]
+ idx[0] += 1
+ return out
+ if isinstance(val, dict):
+ return {k: replace(v) for k, v in val.items()}
+ if isinstance(val, list):
+ return [replace(v) for v in val]
+ return val
+
+ return replace(value)
+
+ if style == "named":
+ if not isinstance(parameters, dict):
+ raise ProgrammingError("Named parameters must be provided as a mapping")
+
+ def replace(val: Any) -> Any:
+ if isinstance(val, str) and val.startswith(":"):
+ key = val[1:]
+ if key not in parameters:
+ raise ProgrammingError(f"Missing named parameter: {key}")
+ return parameters[key]
+ if isinstance(val, dict):
+ return {k: replace(v) for k, v in val.items()}
+ if isinstance(val, list):
+ return [replace(v) for v in val]
+ return val
+
+ return replace(value)
+
+ return value
diff --git a/pymongosql/result_set.py b/pymongosql/result_set.py
index abb3583..3656a5c 100644
--- a/pymongosql/result_set.py
+++ b/pymongosql/result_set.py
@@ -8,7 +8,7 @@
from .common import CursorIterator
from .error import DatabaseError, ProgrammingError
-from .sql.builder import ExecutionPlan
+from .sql.query_builder import QueryExecutionPlan
_logger = logging.getLogger(__name__)
@@ -19,7 +19,7 @@ class ResultSet(CursorIterator):
def __init__(
self,
command_result: Optional[Dict[str, Any]] = None,
- execution_plan: ExecutionPlan = None,
+ execution_plan: QueryExecutionPlan = None,
arraysize: int = None,
database: Optional[Any] = None,
**kwargs,
@@ -198,7 +198,21 @@ def errors(self) -> List[Dict[str, str]]:
@property
def rowcount(self) -> int:
- """Return number of rows fetched so far (not total available)"""
+ """Return number of rows fetched/affected"""
+ # Check for write operation results (UPDATE, DELETE, INSERT)
+ if hasattr(self, "_insert_result") and self._insert_result:
+ # INSERT operation - return number of inserted documents
+ return self._insert_result.get("n", 0)
+
+ # Check command result for write operations
+ if self._command_result:
+ # For UPDATE/DELETE operations, check 'n' (modified count) or 'nModified'
+ if "n" in self._command_result:
+ return self._command_result.get("n", 0)
+ if "nModified" in self._command_result:
+ return self._command_result.get("nModified", 0)
+
+ # For SELECT/QUERY operations, return number of fetched rows
return self._total_fetched
@property
diff --git a/pymongosql/sql/ast.py b/pymongosql/sql/ast.py
index fa0d97d..f9f3ed8 100644
--- a/pymongosql/sql/ast.py
+++ b/pymongosql/sql/ast.py
@@ -1,13 +1,21 @@
# -*- coding: utf-8 -*-
import logging
-from typing import Any, Dict
+from typing import Any, Dict, Union
from ..error import SqlSyntaxError
-from .builder import BuilderFactory, ExecutionPlan
-from .handler import BaseHandler, HandlerFactory, ParseResult
+from .builder import BuilderFactory
+from .delete_builder import DeleteExecutionPlan
+from .delete_handler import DeleteParseResult
+from .handler import BaseHandler, HandlerFactory
+from .insert_builder import InsertExecutionPlan
+from .insert_handler import InsertParseResult
from .partiql.PartiQLLexer import PartiQLLexer
from .partiql.PartiQLParser import PartiQLParser
from .partiql.PartiQLParserVisitor import PartiQLParserVisitor
+from .query_builder import QueryExecutionPlan
+from .query_handler import QueryParseResult
+from .update_builder import UpdateExecutionPlan
+from .update_handler import UpdateParseResult
_logger = logging.getLogger(__name__)
@@ -29,7 +37,12 @@ class MongoSQLParserVisitor(PartiQLParserVisitor):
def __init__(self) -> None:
super().__init__()
- self._parse_result = ParseResult.for_visitor()
+ self._parse_result = QueryParseResult.for_visitor()
+ self._insert_parse_result = InsertParseResult.for_visitor()
+ self._delete_parse_result = DeleteParseResult.for_visitor()
+ self._update_parse_result = UpdateParseResult.for_visitor()
+ # Track current statement kind generically so UPDATE/DELETE can reuse this
+ self._current_operation: str = "select" # expected values: select | insert | update | delete
self._handlers = self._initialize_handlers()
def _initialize_handlers(self) -> Dict[str, BaseHandler]:
@@ -39,15 +52,31 @@ def _initialize_handlers(self) -> Dict[str, BaseHandler]:
"select": HandlerFactory.get_visitor_handler("select"),
"from": HandlerFactory.get_visitor_handler("from"),
"where": HandlerFactory.get_visitor_handler("where"),
+ "insert": HandlerFactory.get_visitor_handler("insert"),
+ "update": HandlerFactory.get_visitor_handler("update"),
+ "delete": HandlerFactory.get_visitor_handler("delete"),
}
@property
- def parse_result(self) -> ParseResult:
+ def parse_result(self) -> QueryParseResult:
"""Get the current parse result"""
return self._parse_result
- def parse_to_execution_plan(self) -> ExecutionPlan:
- """Convert the parse result to an ExecutionPlan using BuilderFactory"""
+ def parse_to_execution_plan(
+ self,
+ ) -> Union[QueryExecutionPlan, InsertExecutionPlan, DeleteExecutionPlan, UpdateExecutionPlan]:
+ """Convert the parse result to an execution plan using BuilderFactory."""
+ if self._current_operation == "insert":
+ return self._build_insert_plan()
+ elif self._current_operation == "delete":
+ return self._build_delete_plan()
+ elif self._current_operation == "update":
+ return self._build_update_plan()
+
+ return self._build_query_plan()
+
+ def _build_query_plan(self) -> QueryExecutionPlan:
+ """Build a query execution plan from SELECT parsing."""
builder = BuilderFactory.create_query_builder().collection(self._parse_result.collection)
builder.filter(self._parse_result.filter_conditions).project(self._parse_result.projection).column_aliases(
@@ -58,6 +87,54 @@ def parse_to_execution_plan(self) -> ExecutionPlan:
return builder.build()
+ def _build_insert_plan(self) -> InsertExecutionPlan:
+ """Build an INSERT execution plan from INSERT parsing."""
+ if self._insert_parse_result.has_errors:
+ raise SqlSyntaxError(self._insert_parse_result.error_message or "INSERT parsing failed")
+
+ builder = BuilderFactory.create_insert_builder().collection(self._insert_parse_result.collection)
+
+ documents = self._insert_parse_result.insert_documents or []
+ builder.insert_documents(documents)
+
+ if self._insert_parse_result.parameter_style:
+ builder.parameter_style(self._insert_parse_result.parameter_style)
+
+ if self._insert_parse_result.parameter_count > 0:
+ builder.parameter_count(self._insert_parse_result.parameter_count)
+
+ return builder.build()
+
+ def _build_delete_plan(self) -> DeleteExecutionPlan:
+ """Build a DELETE execution plan from DELETE parsing."""
+ _logger.debug(
+ f"Building DELETE plan with collection: {self._delete_parse_result.collection}, "
+ f"filters: {self._delete_parse_result.filter_conditions}"
+ )
+ builder = BuilderFactory.create_delete_builder().collection(self._delete_parse_result.collection)
+
+ if self._delete_parse_result.filter_conditions:
+ builder.filter_conditions(self._delete_parse_result.filter_conditions)
+
+ return builder.build()
+
+ def _build_update_plan(self) -> UpdateExecutionPlan:
+ """Build an UPDATE execution plan from UPDATE parsing."""
+ _logger.debug(
+ f"Building UPDATE plan with collection: {self._update_parse_result.collection}, "
+ f"update_fields: {self._update_parse_result.update_fields}, "
+ f"filters: {self._update_parse_result.filter_conditions}"
+ )
+ builder = BuilderFactory.create_update_builder().collection(self._update_parse_result.collection)
+
+ if self._update_parse_result.update_fields:
+ builder.update_fields(self._update_parse_result.update_fields)
+
+ if self._update_parse_result.filter_conditions:
+ builder.filter_conditions(self._update_parse_result.filter_conditions)
+
+ return builder.build()
+
def visitRoot(self, ctx: PartiQLParser.RootContext) -> Any:
"""Visit root node and process child nodes"""
_logger.debug("Starting to parse SQL query")
@@ -116,6 +193,81 @@ def visitWhereClauseSelect(self, ctx: PartiQLParser.WhereClauseSelectContext) ->
_logger.warning(f"Error processing WHERE clause: {e}")
return self.visitChildren(ctx)
+ def visitInsertStatement(self, ctx: PartiQLParser.InsertStatementContext) -> Any:
+ """Handle INSERT statements via the insert handler."""
+ _logger.debug("Processing INSERT statement")
+ self._current_operation = "insert"
+ handler = self._handlers.get("insert")
+ if handler:
+ return handler.handle_visitor(ctx, self._insert_parse_result)
+ return self.visitChildren(ctx)
+
+ def visitInsertStatementLegacy(self, ctx: PartiQLParser.InsertStatementLegacyContext) -> Any:
+ """Handle legacy INSERT statements."""
+ _logger.debug("Processing INSERT legacy statement")
+ self._current_operation = "insert"
+ handler = self._handlers.get("insert")
+ if handler:
+ return handler.handle_visitor(ctx, self._insert_parse_result)
+ return self.visitChildren(ctx)
+
+ def visitFromClauseSimpleExplicit(self, ctx: PartiQLParser.FromClauseSimpleExplicitContext) -> Any:
+ """Handle FROM clause (explicit form) in DELETE statements."""
+ if self._current_operation == "delete":
+ handler = self._handlers.get("delete")
+ if handler:
+ return handler.handle_from_clause_explicit(ctx, self._delete_parse_result)
+ return self.visitChildren(ctx)
+
+ def visitFromClauseSimpleImplicit(self, ctx: PartiQLParser.FromClauseSimpleImplicitContext) -> Any:
+ """Handle FROM clause (implicit form) in DELETE statements."""
+ if self._current_operation == "delete":
+ handler = self._handlers.get("delete")
+ if handler:
+ return handler.handle_from_clause_implicit(ctx, self._delete_parse_result)
+ return self.visitChildren(ctx)
+
+ def visitWhereClause(self, ctx: PartiQLParser.WhereClauseContext) -> Any:
+ """Handle WHERE clause (generic form used in DELETE, UPDATE)."""
+ _logger.debug("Processing WHERE clause (generic)")
+ try:
+ # For DELETE, use the delete handler
+ if self._current_operation == "delete":
+ handler = self._handlers.get("delete")
+ if handler:
+ return handler.handle_where_clause(ctx, self._delete_parse_result)
+ return {}
+ # For UPDATE, use the update handler
+ elif self._current_operation == "update":
+ handler = self._handlers.get("update")
+ if handler:
+ return handler.handle_where_clause(ctx, self._update_parse_result)
+ return {}
+ else:
+ # For other operations, use the where handler
+ handler = self._handlers["where"]
+ if handler:
+ result = handler.handle_visitor(ctx, self._parse_result)
+ _logger.debug(f"Extracted filter conditions: {result}")
+ return result
+ return {}
+ except Exception as e:
+ _logger.warning(f"Error processing WHERE clause: {e}")
+ return {}
+
+ def visitDeleteCommand(self, ctx: PartiQLParser.DeleteCommandContext) -> Any:
+ """Handle DELETE statements."""
+ _logger.debug("Processing DELETE statement")
+ self._current_operation = "delete"
+ # Reset delete parse result for this statement
+ self._delete_parse_result = DeleteParseResult.for_visitor()
+ # Use delete handler if available
+ handler = self._handlers.get("delete")
+ if handler:
+ handler.handle_visitor(ctx, self._delete_parse_result)
+ # Visit children to process FROM and WHERE clauses
+ return self.visitChildren(ctx)
+
def visitOrderByClause(self, ctx: PartiQLParser.OrderByClauseContext) -> Any:
"""Handle ORDER BY clause for sorting"""
_logger.debug("Processing ORDER BY clause")
@@ -172,3 +324,29 @@ def visitOffsetByClause(self, ctx: PartiQLParser.OffsetByClauseContext) -> Any:
except Exception as e:
_logger.warning(f"Error processing OFFSET clause: {e}")
return self.visitChildren(ctx)
+
+ def visitUpdateClause(self, ctx: PartiQLParser.UpdateClauseContext) -> Any:
+ """Handle UPDATE clause to extract collection/table name."""
+ _logger.debug("Processing UPDATE clause")
+ self._current_operation = "update"
+ # Reset update parse result for this statement
+ self._update_parse_result = UpdateParseResult.for_visitor()
+
+ handler = self._handlers.get("update")
+ if handler:
+ handler.handle_visitor(ctx, self._update_parse_result)
+
+ # Visit children to process SET and WHERE clauses
+ return self.visitChildren(ctx)
+
+ def visitSetCommand(self, ctx: PartiQLParser.SetCommandContext) -> Any:
+ """Handle SET command for UPDATE statements."""
+ _logger.debug("Processing SET command")
+
+ if self._current_operation == "update":
+ handler = self._handlers.get("update")
+ if handler:
+ handler.handle_set_command(ctx, self._update_parse_result)
+ return None
+
+ return self.visitChildren(ctx)
diff --git a/pymongosql/sql/builder.py b/pymongosql/sql/builder.py
index 839bc41..6f6771f 100644
--- a/pymongosql/sql/builder.py
+++ b/pymongosql/sql/builder.py
@@ -1,261 +1,72 @@
# -*- coding: utf-8 -*-
import logging
-from dataclasses import dataclass, field
-from typing import Any, Dict, List, Optional, Union
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
_logger = logging.getLogger(__name__)
@dataclass
class ExecutionPlan:
- """Unified representation for MongoDB operations - supports queries, DDL, and DML operations"""
+ """Base class for execution plans (query, insert, etc.).
+
+ Provides common attributes and shared validation helpers.
+ """
collection: Optional[str] = None
- filter_stage: Dict[str, Any] = field(default_factory=dict)
- projection_stage: Dict[str, Any] = field(default_factory=dict)
- column_aliases: Dict[str, str] = field(default_factory=dict) # Maps field_name -> alias
- sort_stage: List[Dict[str, int]] = field(default_factory=list)
- limit_stage: Optional[int] = None
- skip_stage: Optional[int] = None
def to_dict(self) -> Dict[str, Any]:
- """Convert query plan to dictionary representation"""
- return {
- "collection": self.collection,
- "filter": self.filter_stage,
- "projection": self.projection_stage,
- "sort": self.sort_stage,
- "limit": self.limit_stage,
- "skip": self.skip_stage,
- }
+ """Convert plan to a serializable dictionary. Must be implemented by subclasses."""
+ raise NotImplementedError()
- def validate(self) -> bool:
- """Validate the query plan"""
- errors = []
+ def validate_base(self) -> list[str]:
+ """Common validation checks for all plans.
+ Returns a list of error messages for the caller to aggregate and log.
+ """
+ errors: list[str] = []
if not self.collection:
errors.append("Collection name is required")
+ return errors
- if self.limit_stage is not None and (not isinstance(self.limit_stage, int) or self.limit_stage < 0):
- errors.append("Limit must be a non-negative integer")
-
- if self.skip_stage is not None and (not isinstance(self.skip_stage, int) or self.skip_stage < 0):
- errors.append("Skip must be a non-negative integer")
-
- if errors:
- _logger.error(f"Query validation errors: {errors}")
- return False
-
- return True
-
- def copy(self) -> "ExecutionPlan":
- """Create a copy of this execution plan"""
- return ExecutionPlan(
- collection=self.collection,
- filter_stage=self.filter_stage.copy(),
- projection_stage=self.projection_stage.copy(),
- column_aliases=self.column_aliases.copy(),
- sort_stage=self.sort_stage.copy(),
- limit_stage=self.limit_stage,
- skip_stage=self.skip_stage,
- )
-
-
-class MongoQueryBuilder:
- """Fluent builder for MongoDB queries with validation and readability"""
-
- def __init__(self):
- self._execution_plan = ExecutionPlan()
- self._validation_errors = []
-
- def collection(self, name: str) -> "MongoQueryBuilder":
- """Set the target collection"""
- if not name or not name.strip():
- self._add_error("Collection name cannot be empty")
- return self
-
- self._execution_plan.collection = name.strip()
- _logger.debug(f"Set collection to: {name}")
- return self
-
- def filter(self, conditions: Dict[str, Any]) -> "MongoQueryBuilder":
- """Add filter conditions"""
- if not isinstance(conditions, dict):
- self._add_error("Filter conditions must be a dictionary")
- return self
-
- self._execution_plan.filter_stage.update(conditions)
- _logger.debug(f"Added filter conditions: {conditions}")
- return self
-
- def project(self, fields: Union[Dict[str, int], List[str]]) -> "MongoQueryBuilder":
- """Set projection fields"""
- if isinstance(fields, list):
- # Convert list to projection dict
- projection = {field: 1 for field in fields}
- elif isinstance(fields, dict):
- projection = fields
- else:
- self._add_error("Projection must be a list of field names or a dictionary")
- return self
-
- self._execution_plan.projection_stage = projection
- _logger.debug(f"Set projection: {projection}")
- return self
-
- def sort(self, specs: List[Dict[str, int]]) -> "MongoQueryBuilder":
- """Add sort criteria.
-
- Only accepts a list of single-key dicts in the form:
- [{"field": 1}, {"other": -1}]
-
- This matches the output produced by the SQL parser (`sort_fields`).
- """
- if not isinstance(specs, list):
- self._add_error("Sort specifications must be a list of single-key dicts")
- return self
-
- for spec in specs:
- if not isinstance(spec, dict) or len(spec) != 1:
- self._add_error("Each sort specification must be a single-key dict, e.g. {'name': 1}")
- continue
-
- field, direction = next(iter(spec.items()))
-
- if not isinstance(field, str) or not field:
- self._add_error("Sort field must be a non-empty string")
- continue
-
- if direction not in [-1, 1]:
- self._add_error(f"Sort direction for field '{field}' must be 1 or -1")
- continue
-
- self._execution_plan.sort_stage.append({field: direction})
- _logger.debug(f"Added sort: {field} -> {direction}")
-
- return self
- def limit(self, count: int) -> "MongoQueryBuilder":
- """Set limit for results"""
- if not isinstance(count, int) or count < 0:
- self._add_error("Limit must be a non-negative integer")
- return self
-
- self._execution_plan.limit_stage = count
- _logger.debug(f"Set limit to: {count}")
- return self
-
- def skip(self, count: int) -> "MongoQueryBuilder":
- """Set skip count for pagination"""
- if not isinstance(count, int) or count < 0:
- self._add_error("Skip must be a non-negative integer")
- return self
-
- self._execution_plan.skip_stage = count
- _logger.debug(f"Set skip to: {count}")
- return self
-
- def column_aliases(self, aliases: Dict[str, str]) -> "MongoQueryBuilder":
- """Set column aliases mapping (field_name -> alias)"""
- if not isinstance(aliases, dict):
- self._add_error("Column aliases must be a dictionary")
- return self
-
- self._execution_plan.column_aliases = aliases
- _logger.debug(f"Set column aliases to: {aliases}")
- return self
-
- def where(self, field: str, operator: str, value: Any) -> "MongoQueryBuilder":
- """Add a where condition in a readable format"""
- condition = self._build_condition(field, operator, value)
- if condition:
- return self.filter(condition)
- return self
-
- def where_in(self, field: str, values: List[Any]) -> "MongoQueryBuilder":
- """Add a WHERE field IN (values) condition"""
- return self.filter({field: {"$in": values}})
-
- def where_between(self, field: str, min_val: Any, max_val: Any) -> "MongoQueryBuilder":
- """Add a WHERE field BETWEEN min AND max condition"""
- return self.filter({field: {"$gte": min_val, "$lte": max_val}})
-
- def where_like(self, field: str, pattern: str) -> "MongoQueryBuilder":
- """Add a WHERE field LIKE pattern condition"""
- # Convert SQL LIKE pattern to MongoDB regex
- regex_pattern = pattern.replace("%", ".*").replace("_", ".")
- return self.filter({field: {"$regex": regex_pattern, "$options": "i"}})
-
- def _build_condition(self, field: str, operator: str, value: Any) -> Optional[Dict[str, Any]]:
- """Build a MongoDB condition from field, operator, and value"""
- operator_map = {
- "=": "$eq",
- "!=": "$ne",
- "<": "$lt",
- "<=": "$lte",
- ">": "$gt",
- ">=": "$gte",
- "eq": "$eq",
- "ne": "$ne",
- "lt": "$lt",
- "lte": "$lte",
- "gt": "$gt",
- "gte": "$gte",
- }
-
- mongo_op = operator_map.get(operator.lower())
- if not mongo_op:
- self._add_error(f"Unsupported operator: {operator}")
- return None
-
- return {field: {mongo_op: value}}
-
- def _add_error(self, message: str) -> None:
- """Add validation error"""
- self._validation_errors.append(message)
- _logger.error(f"Query builder error: {message}")
-
- def validate(self) -> bool:
- """Validate the current query plan"""
- self._validation_errors.clear()
+class BuilderFactory:
+ """Factory for creating builders for different operations."""
- if not self._execution_plan.collection:
- self._add_error("Collection name is required")
+ @staticmethod
+ def create_query_builder():
+ """Create a builder for SELECT queries"""
+ # Local import to avoid circular dependency during module import
+ from .query_builder import MongoQueryBuilder
- # Add more validation rules as needed
- return len(self._validation_errors) == 0
+ return MongoQueryBuilder()
- def get_errors(self) -> List[str]:
- """Get validation errors"""
- return self._validation_errors.copy()
+ @staticmethod
+ def create_insert_builder():
+ """Create a builder for INSERT queries"""
+ # Local import to avoid circular dependency during module import
+ from .insert_builder import MongoInsertBuilder
- def build(self) -> ExecutionPlan:
- """Build and return the execution plan"""
- if not self.validate():
- error_summary = "; ".join(self._validation_errors)
- raise ValueError(f"Query validation failed: {error_summary}")
+ return MongoInsertBuilder()
- return self._execution_plan
+ @staticmethod
+ def create_delete_builder():
+ """Create a builder for DELETE queries"""
+ # Local import to avoid circular dependency during module import
+ from .delete_builder import MongoDeleteBuilder
- def reset(self) -> "MongoQueryBuilder":
- """Reset the builder to start a new query"""
- self._execution_plan = ExecutionPlan()
- self._validation_errors.clear()
- return self
+ return MongoDeleteBuilder()
- def __str__(self) -> str:
- """String representation for debugging"""
- return (
- f"MongoQueryBuilder(collection={self._execution_plan.collection}, "
- f"filter={self._execution_plan.filter_stage}, "
- f"projection={self._execution_plan.projection_stage})"
- )
+ @staticmethod
+ def create_update_builder():
+ """Create a builder for UPDATE queries"""
+ # Local import to avoid circular dependency during module import
+ from .update_builder import MongoUpdateBuilder
+ return MongoUpdateBuilder()
-class BuilderFactory:
- """Factory for creating query builders"""
- @staticmethod
- def create_query_builder() -> MongoQueryBuilder:
- """Create a builder for SELECT queries"""
- return MongoQueryBuilder()
+__all__ = [
+ "ExecutionPlan",
+ "BuilderFactory",
+]
diff --git a/pymongosql/sql/delete_builder.py b/pymongosql/sql/delete_builder.py
new file mode 100644
index 0000000..3c1dd13
--- /dev/null
+++ b/pymongosql/sql/delete_builder.py
@@ -0,0 +1,68 @@
+# -*- coding: utf-8 -*-
+import logging
+from dataclasses import dataclass, field
+from typing import Any, Dict
+
+from .builder import ExecutionPlan
+
+_logger = logging.getLogger(__name__)
+
+
+@dataclass
+class DeleteExecutionPlan(ExecutionPlan):
+ """Execution plan for DELETE operations against MongoDB."""
+
+ filter_conditions: Dict[str, Any] = field(default_factory=dict)
+ parameter_style: str = field(default="qmark") # Parameter placeholder style: qmark (?) or named (:name)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert delete plan to dictionary representation."""
+ return {
+ "collection": self.collection,
+ "filter": self.filter_conditions,
+ }
+
+ def validate(self) -> bool:
+ """Validate the delete plan."""
+ errors = self.validate_base()
+
+ # Note: filter_conditions can be empty for DELETE FROM (delete all)
+ # which is valid, so we don't enforce filter presence
+
+ if errors:
+ _logger.error(f"Delete plan validation errors: {errors}")
+ return False
+
+ return True
+
+ def copy(self) -> "DeleteExecutionPlan":
+ """Create a copy of this delete plan."""
+ return DeleteExecutionPlan(
+ collection=self.collection,
+ filter_conditions=self.filter_conditions.copy() if self.filter_conditions else {},
+ )
+
+
+class MongoDeleteBuilder:
+ """Builder for constructing DeleteExecutionPlan objects."""
+
+ def __init__(self) -> None:
+ """Initialize the delete builder."""
+ self._plan = DeleteExecutionPlan()
+
+ def collection(self, collection: str) -> "MongoDeleteBuilder":
+ """Set the collection name."""
+ self._plan.collection = collection
+ return self
+
+ def filter_conditions(self, conditions: Dict[str, Any]) -> "MongoDeleteBuilder":
+ """Set the filter conditions for the delete operation."""
+ if conditions:
+ self._plan.filter_conditions = conditions
+ return self
+
+ def build(self) -> DeleteExecutionPlan:
+ """Build and return the DeleteExecutionPlan."""
+ if not self._plan.validate():
+ raise ValueError("Invalid delete plan")
+ return self._plan
diff --git a/pymongosql/sql/delete_handler.py b/pymongosql/sql/delete_handler.py
new file mode 100644
index 0000000..c59643b
--- /dev/null
+++ b/pymongosql/sql/delete_handler.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+import logging
+from dataclasses import dataclass, field
+from typing import Any, Dict, Optional
+
+from .handler import BaseHandler
+from .partiql.PartiQLParser import PartiQLParser
+
+_logger = logging.getLogger(__name__)
+
+
+@dataclass
+class DeleteParseResult:
+ """Result of parsing a DELETE statement.
+
+ Stores the extracted information needed to build a DeleteExecutionPlan.
+ """
+
+ collection: Optional[str] = None
+ filter_conditions: Dict[str, Any] = field(default_factory=dict)
+ has_errors: bool = False
+ error_message: Optional[str] = None
+
+ @staticmethod
+ def for_visitor() -> "DeleteParseResult":
+ """Factory method to create a fresh DeleteParseResult for visitor pattern."""
+ return DeleteParseResult()
+
+ def validate(self) -> bool:
+ """Validate that required fields are populated."""
+ if not self.collection:
+ self.error_message = "Collection name is required"
+ self.has_errors = True
+ return False
+ return True
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert to dictionary representation for debugging."""
+ return {
+ "collection": self.collection,
+ "filter_conditions": self.filter_conditions,
+ "has_errors": self.has_errors,
+ "error_message": self.error_message,
+ }
+
+ def __repr__(self) -> str:
+ """String representation."""
+ return (
+ f"DeleteParseResult(collection={self.collection}, "
+ f"filter_conditions={self.filter_conditions}, "
+ f"has_errors={self.has_errors})"
+ )
+
+
+class DeleteHandler(BaseHandler):
+ """Handler for DELETE statement visitor parsing."""
+
+ def can_handle(self, ctx: Any) -> bool:
+ """Check if this handler can process the given context."""
+ return hasattr(ctx, "DELETE") or isinstance(ctx, PartiQLParser.DeleteCommandContext)
+
+ def handle_visitor(self, ctx: Any, parse_result: DeleteParseResult) -> DeleteParseResult:
+ """Handle DELETE statement parsing - entry point from visitDeleteCommand."""
+ try:
+ _logger.debug("DeleteHandler processing DELETE statement")
+ # Reset parse result for new statement
+ parse_result.collection = None
+ parse_result.filter_conditions = {}
+ parse_result.has_errors = False
+ parse_result.error_message = None
+ return parse_result
+ except Exception as exc:
+ _logger.error("Failed to handle DELETE", exc_info=True)
+ parse_result.has_errors = True
+ parse_result.error_message = str(exc)
+ return parse_result
+
+ def handle_from_clause_explicit(
+ self, ctx: PartiQLParser.FromClauseSimpleExplicitContext, parse_result: DeleteParseResult
+ ) -> Optional[str]:
+ """Extract collection name from FROM clause (explicit form)."""
+ _logger.debug("DeleteHandler processing FROM clause (simple, explicit)")
+ try:
+ if ctx.pathSimple():
+ collection_name = ctx.pathSimple().getText()
+ parse_result.collection = collection_name
+ _logger.debug(f"Extracted collection for DELETE (explicit): {collection_name}")
+ return collection_name
+ except Exception as e:
+ _logger.warning(f"Error processing FROM clause (explicit): {e}")
+ parse_result.has_errors = True
+ parse_result.error_message = str(e)
+ return None
+
+ def handle_from_clause_implicit(
+ self, ctx: PartiQLParser.FromClauseSimpleImplicitContext, parse_result: DeleteParseResult
+ ) -> Optional[str]:
+ """Extract collection name from FROM clause (implicit form)."""
+ _logger.debug("DeleteHandler processing FROM clause (simple, implicit)")
+ try:
+ if ctx.pathSimple():
+ collection_name = ctx.pathSimple().getText()
+ parse_result.collection = collection_name
+ _logger.debug(f"Extracted collection for DELETE (implicit): {collection_name}")
+ return collection_name
+ except Exception as e:
+ _logger.warning(f"Error processing FROM clause (implicit): {e}")
+ parse_result.has_errors = True
+ parse_result.error_message = str(e)
+ return None
+
+ def handle_where_clause(
+ self, ctx: PartiQLParser.WhereClauseContext, parse_result: DeleteParseResult
+ ) -> Dict[str, Any]:
+ """Handle WHERE clause for DELETE statements."""
+ _logger.debug("DeleteHandler processing WHERE clause")
+ try:
+ # Get the expression context - it could be ctx.arg or ctx.expr()
+ expression_ctx = None
+ if hasattr(ctx, "arg") and ctx.arg:
+ expression_ctx = ctx.arg
+ elif hasattr(ctx, "expr"):
+ expression_ctx = ctx.expr()
+
+ if expression_ctx:
+ # Debug: log the raw context text
+ raw_text = expression_ctx.getText() if hasattr(expression_ctx, "getText") else str(expression_ctx)
+ _logger.debug(f"[WHERE_CLAUSE_DEBUG] Raw expression text: {raw_text}")
+ _logger.debug(f"[WHERE_CLAUSE_DEBUG] Expression context type: {type(expression_ctx).__name__}")
+
+ from .handler import HandlerFactory
+
+ handler = HandlerFactory.get_expression_handler(expression_ctx)
+
+ if handler:
+ result = handler.handle_expression(expression_ctx)
+ if not result.has_errors:
+ parse_result.filter_conditions = result.filter_conditions
+ _logger.debug(f"Extracted filter conditions for DELETE: {result.filter_conditions}")
+ return result.filter_conditions
+ # If no handler or error, leave filter_conditions empty (delete all)
+ _logger.debug("Extracted filter conditions for DELETE: {}")
+ return {}
+ except Exception as e:
+ _logger.warning(f"Error processing WHERE clause: {e}")
+ parse_result.has_errors = True
+ parse_result.error_message = str(e)
+ return {}
diff --git a/pymongosql/sql/handler.py b/pymongosql/sql/handler.py
index 086dba8..f4f9622 100644
--- a/pymongosql/sql/handler.py
+++ b/pymongosql/sql/handler.py
@@ -5,8 +5,6 @@
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
-from .partiql.PartiQLParser import PartiQLParser
-
_logger = logging.getLogger(__name__)
@@ -28,8 +26,8 @@
@dataclass
-class ParseResult:
- """Unified result container for both expression parsing and visitor state management"""
+class QueryParseResult:
+ """Result container for query (SELECT) expression parsing and visitor state management"""
# Core parsing fields
filter_conditions: Dict[str, Any] = field(default_factory=dict) # Unified filter field for all MongoDB conditions
@@ -50,12 +48,12 @@ class ParseResult:
# Factory methods for different use cases
@classmethod
- def for_visitor(cls) -> "ParseResult":
- """Create ParseResult for visitor parsing"""
+ def for_visitor(cls) -> "QueryParseResult":
+ """Create QueryParseResult for visitor parsing"""
return cls()
- def merge_expression(self, other: "ParseResult") -> "ParseResult":
- """Merge expression results from another ParseResult"""
+ def merge_expression(self, other: "QueryParseResult") -> "QueryParseResult":
+ """Merge expression results from another QueryParseResult"""
if other.has_errors:
self.has_errors = True
self.error_message = other.error_message
@@ -221,7 +219,7 @@ def can_handle(self, ctx: Any) -> bool:
"""Check if this handler can process the given context"""
pass
- def handle(self, ctx: Any, parse_result: Optional["ParseResult"] = None) -> Any:
+ def handle(self, ctx: Any, parse_result: Optional["QueryParseResult"] = None) -> Any:
"""Handle the context and return appropriate result"""
# Default implementation for expression handlers
if parse_result is None:
@@ -229,11 +227,11 @@ def handle(self, ctx: Any, parse_result: Optional["ParseResult"] = None) -> Any:
else:
return self.handle_visitor(ctx, parse_result)
- def handle_expression(self, ctx: Any) -> ParseResult:
+ def handle_expression(self, ctx: Any) -> QueryParseResult:
"""Handle expression parsing (to be overridden by expression handlers)"""
raise NotImplementedError("Expression handlers must implement handle_expression")
- def handle_visitor(self, ctx: Any, parse_result: "ParseResult") -> Any:
+ def handle_visitor(self, ctx: Any, parse_result: "QueryParseResult") -> Any:
"""Handle visitor operations (to be overridden by visitor handlers)"""
raise NotImplementedError("Visitor handlers must implement handle_visitor")
@@ -262,7 +260,7 @@ def can_handle(self, ctx: Any) -> bool:
hasattr(ctx, "comparisonOperator") or self._is_comparison_context(ctx) or self._has_comparison_pattern(ctx)
)
- def handle_expression(self, ctx: Any) -> ParseResult:
+ def handle_expression(self, ctx: Any) -> QueryParseResult:
"""Convert comparison expression to MongoDB filter"""
operation_id = id(ctx)
self._log_operation_start("comparison_parsing", ctx, operation_id)
@@ -281,11 +279,11 @@ def handle_expression(self, ctx: Any) -> ParseResult:
operator=operator,
)
- return ParseResult(filter_conditions=mongo_filter)
+ return QueryParseResult(filter_conditions=mongo_filter)
except Exception as e:
self._log_operation_error("comparison_parsing", ctx, operation_id, e)
- return ParseResult(has_errors=True, error_message=str(e))
+ return QueryParseResult(has_errors=True, error_message=str(e))
def _build_mongo_filter(self, field_name: str, operator: str, value: Any) -> Dict[str, Any]:
"""Build MongoDB filter from field, operator and value"""
@@ -544,9 +542,13 @@ def _has_logical_operators(self, ctx: Any) -> bool:
"""Check if the expression text contains logical operators"""
try:
text = self.get_context_text(ctx).upper()
- comparison_count = sum(1 for op in COMPARISON_OPERATORS if op in text)
+
+ # Count comparison operator occurrences, not just distinct operator types
+ # so that "a = 1 OR b = 2" counts as 2 comparisons and is treated
+ # as a logical expression instead of a single comparison.
+ comparison_count = len(re.findall(r"(>=|<=|!=|<>|=|<|>)", text))
has_logical_ops = any(op in text for op in ["AND", "OR"])
- return has_logical_ops and comparison_count > 1
+ return has_logical_ops and comparison_count >= 2
except Exception:
return False
@@ -560,7 +562,7 @@ def _is_logical_context(self, ctx: Any) -> bool:
except Exception:
return False
- def handle_expression(self, ctx: Any) -> ParseResult:
+ def handle_expression(self, ctx: Any) -> QueryParseResult:
"""Convert logical expression to MongoDB filter"""
operation_id = id(ctx)
self._log_operation_start("logical_parsing", ctx, operation_id)
@@ -585,11 +587,11 @@ def handle_expression(self, ctx: Any) -> ParseResult:
processed_count=len(processed_operands),
)
- return ParseResult(filter_conditions=mongo_filter)
+ return QueryParseResult(filter_conditions=mongo_filter)
except Exception as e:
self._log_operation_error("logical_parsing", ctx, operation_id, e)
- return ParseResult(has_errors=True, error_message=str(e))
+ return QueryParseResult(has_errors=True, error_message=str(e))
def _process_operands(self, operands: List[Any]) -> List[Dict[str, Any]]:
"""Process operands and return processed filters"""
@@ -743,7 +745,7 @@ def can_handle(self, ctx: Any) -> bool:
"""Check if context represents a function call"""
return hasattr(ctx, "functionName") or self._is_function_context(ctx)
- def handle_expression(self, ctx: Any) -> ParseResult:
+ def handle_expression(self, ctx: Any) -> QueryParseResult:
"""Handle function expressions"""
operation_id = id(ctx)
self._log_operation_start("function_parsing", ctx, operation_id)
@@ -761,11 +763,11 @@ def handle_expression(self, ctx: Any) -> ParseResult:
function_name=function_name,
)
- return ParseResult(filter_conditions=mongo_filter)
+ return QueryParseResult(filter_conditions=mongo_filter)
except Exception as e:
self._log_operation_error("function_parsing", ctx, operation_id, e)
- return ParseResult(has_errors=True, error_message=str(e))
+ return QueryParseResult(has_errors=True, error_message=str(e))
def _is_function_context(self, ctx: Any) -> bool:
"""Check if context is a function call"""
@@ -804,10 +806,18 @@ def _initialize_expression_handlers(cls):
def _initialize_visitor_handlers(cls):
"""Lazy initialization of visitor handlers"""
if cls._visitor_handlers is None:
+ from .delete_handler import DeleteHandler
+ from .insert_handler import InsertHandler
+ from .query_handler import FromHandler, SelectHandler, WhereHandler
+ from .update_handler import UpdateHandler
+
cls._visitor_handlers = {
"select": SelectHandler(),
"from": FromHandler(),
"where": WhereHandler(),
+ "insert": InsertHandler(),
+ "delete": DeleteHandler(),
+ "update": UpdateHandler(),
}
return cls._visitor_handlers
@@ -843,137 +853,3 @@ def register_visitor_handler(cls, handler_type: str, handler: BaseHandler) -> No
def get_handler(cls, ctx: Any) -> Optional[BaseHandler]:
"""Backward compatibility method"""
return cls.get_expression_handler(ctx)
-
-
-class EnhancedWhereHandler(ContextUtilsMixin):
- """Enhanced WHERE clause handler using expression handlers"""
-
- def handle(self, ctx: PartiQLParser.WhereClauseSelectContext) -> Dict[str, Any]:
- """Handle WHERE clause with proper expression parsing"""
- if not hasattr(ctx, "exprSelect") or not ctx.exprSelect():
- _logger.debug("No expression found in WHERE clause")
- return {}
-
- expression_ctx = ctx.exprSelect()
- handler = HandlerFactory.get_expression_handler(expression_ctx)
-
- if handler:
- _logger.debug(
- f"Using {type(handler).__name__} for WHERE clause",
- extra={"context_text": self.get_context_text(expression_ctx)[:100]},
- )
- result = handler.handle_expression(expression_ctx)
- if result.has_errors:
- _logger.warning(
- "Expression parsing error, falling back to text search",
- extra={"error": result.error_message},
- )
- # Fallback to text-based filter
- return {"$text": {"$search": self.get_context_text(expression_ctx)}}
- return result.filter_conditions
- else:
- # Fallback to simple text-based search
- _logger.debug(
- "No suitable expression handler found, using text search",
- extra={"context_text": self.get_context_text(expression_ctx)[:100]},
- )
- return {"$text": {"$search": self.get_context_text(expression_ctx)}}
-
-
-# Visitor Handler Classes for AST Processing
-
-
-class SelectHandler(BaseHandler, ContextUtilsMixin):
- """Handles SELECT statement parsing"""
-
- def can_handle(self, ctx: Any) -> bool:
- """Check if this is a select context"""
- return hasattr(ctx, "projectionItems")
-
- def handle_visitor(self, ctx: PartiQLParser.SelectItemsContext, parse_result: "ParseResult") -> Any:
- projection = {}
- column_aliases = {}
-
- if hasattr(ctx, "projectionItems") and ctx.projectionItems():
- for item in ctx.projectionItems().projectionItem():
- field_name, alias = self._extract_field_and_alias(item)
- # Use MongoDB standard projection format: {field: 1} to include field
- projection[field_name] = 1
- # Store alias if present
- if alias:
- column_aliases[field_name] = alias
-
- parse_result.projection = projection
- parse_result.column_aliases = column_aliases
- return projection
-
- def _extract_field_and_alias(self, item) -> Tuple[str, Optional[str]]:
- """Extract field name and alias from projection item context with nested field support"""
- if not hasattr(item, "children") or not item.children:
- return str(item), None
-
- # According to grammar: projectionItem : expr ( AS? symbolPrimitive )? ;
- # children[0] is always the expression
- # If there's an alias, children[1] might be AS and children[2] symbolPrimitive
- # OR children[1] might be just symbolPrimitive (without AS)
-
- field_name = item.children[0].getText()
- # Normalize bracket notation (jmspath) to Mongo dot notation
- field_name = self.normalize_field_path(field_name)
-
- alias = None
-
- if len(item.children) >= 2:
- # Check if we have an alias
- if len(item.children) == 3:
- # Pattern: expr AS symbolPrimitive
- if hasattr(item.children[1], "getText") and item.children[1].getText().upper() == "AS":
- alias = item.children[2].getText()
- elif len(item.children) == 2:
- # Pattern: expr symbolPrimitive (without AS)
- alias = item.children[1].getText()
-
- return field_name, alias
-
-
-class FromHandler(BaseHandler):
- """Handles FROM clause parsing"""
-
- def can_handle(self, ctx: Any) -> bool:
- """Check if this is a from context"""
- return hasattr(ctx, "tableReference")
-
- def handle_visitor(self, ctx: PartiQLParser.FromClauseContext, parse_result: "ParseResult") -> Any:
- if hasattr(ctx, "tableReference") and ctx.tableReference():
- table_text = ctx.tableReference().getText()
- collection_name = table_text
- parse_result.collection = collection_name
- return collection_name
- return None
-
-
-class WhereHandler(BaseHandler):
- """Handles WHERE clause parsing"""
-
- def __init__(self):
- self._expression_handler = EnhancedWhereHandler()
-
- def can_handle(self, ctx: Any) -> bool:
- """Check if this is a where context"""
- return hasattr(ctx, "exprSelect")
-
- def handle_visitor(self, ctx: PartiQLParser.WhereClauseSelectContext, parse_result: "ParseResult") -> Any:
- if hasattr(ctx, "exprSelect") and ctx.exprSelect():
- try:
- # Use enhanced expression handler for better parsing
- filter_conditions = self._expression_handler.handle(ctx)
- parse_result.filter_conditions = filter_conditions
- return filter_conditions
- except Exception as e:
- _logger.warning(f"Failed to parse WHERE expression, falling back to text search: {e}")
- # Fallback to simple text search
- filter_text = ctx.exprSelect().getText()
- fallback_filter = {"$text": {"$search": filter_text}}
- parse_result.filter_conditions = fallback_filter
- return fallback_filter
- return {}
diff --git a/pymongosql/sql/insert_builder.py b/pymongosql/sql/insert_builder.py
new file mode 100644
index 0000000..fc3afd7
--- /dev/null
+++ b/pymongosql/sql/insert_builder.py
@@ -0,0 +1,141 @@
+# -*- coding: utf-8 -*-
+import logging
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional
+
+from .builder import ExecutionPlan
+
+_logger = logging.getLogger(__name__)
+
+
+@dataclass
+class InsertExecutionPlan(ExecutionPlan):
+ """Execution plan for INSERT operations against MongoDB."""
+
+ insert_documents: List[Dict[str, Any]] = field(default_factory=list)
+ parameter_style: Optional[str] = None # e.g., "qmark"
+ parameter_count: int = 0
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert insert plan to dictionary representation."""
+ return {
+ "collection": self.collection,
+ "documents": self.insert_documents,
+ "parameter_count": self.parameter_count,
+ }
+
+ def validate(self) -> bool:
+ """Validate the insert plan."""
+ errors = self.validate_base()
+
+ if not self.insert_documents:
+ errors.append("At least one document must be provided for insertion")
+
+ if errors:
+ _logger.error(f"Insert plan validation errors: {errors}")
+ return False
+
+ return True
+
+ def copy(self) -> "InsertExecutionPlan":
+ """Create a copy of this insert plan."""
+ return InsertExecutionPlan(
+ collection=self.collection,
+ insert_documents=[doc.copy() for doc in self.insert_documents],
+ parameter_style=self.parameter_style,
+ parameter_count=self.parameter_count,
+ )
+
+
+class MongoInsertBuilder:
+ """Fluent builder for INSERT execution plans."""
+
+ def __init__(self):
+ self._execution_plan = InsertExecutionPlan()
+ self._validation_errors: List[str] = []
+
+ def collection(self, name: str) -> "MongoInsertBuilder":
+ """Set the target collection."""
+ if not name or not name.strip():
+ self._add_error("Collection name cannot be empty")
+ return self
+
+ self._execution_plan.collection = name.strip()
+ _logger.debug(f"Set collection to: {name}")
+ return self
+
+ def insert_documents(self, documents: List[Dict[str, Any]]) -> "MongoInsertBuilder":
+ """Set documents to insert (normalized from any syntax)."""
+ if not isinstance(documents, list):
+ self._add_error("Documents must be a list")
+ return self
+
+ if not documents:
+ self._add_error("At least one document must be provided")
+ return self
+
+ self._execution_plan.insert_documents = documents
+ _logger.debug(f"Set insert documents: {len(documents)} document(s)")
+ return self
+
+ def parameter_style(self, style: Optional[str]) -> "MongoInsertBuilder":
+ """Set parameter binding style for tracking."""
+ if style and style not in ["qmark", "named"]:
+ self._add_error(f"Invalid parameter style: {style}")
+ return self
+
+ self._execution_plan.parameter_style = style
+ _logger.debug(f"Set parameter style to: {style}")
+ return self
+
+ def parameter_count(self, count: int) -> "MongoInsertBuilder":
+ """Set number of parameter placeholders to be bound."""
+ if not isinstance(count, int) or count < 0:
+ self._add_error("Parameter count must be a non-negative integer")
+ return self
+
+ self._execution_plan.parameter_count = count
+ _logger.debug(f"Set parameter count to: {count}")
+ return self
+
+ def _add_error(self, message: str) -> None:
+ """Add validation error."""
+ self._validation_errors.append(message)
+ _logger.error(f"Insert builder error: {message}")
+
+ def validate(self) -> bool:
+ """Validate the insert plan."""
+ self._validation_errors.clear()
+
+ if not self._execution_plan.collection:
+ self._add_error("Collection name is required")
+
+ if not self._execution_plan.insert_documents:
+ self._add_error("At least one document must be provided")
+
+ return len(self._validation_errors) == 0
+
+ def get_errors(self) -> List[str]:
+ """Get validation errors."""
+ return self._validation_errors.copy()
+
+ def build(self) -> InsertExecutionPlan:
+ """Build and return the insert execution plan."""
+ if not self.validate():
+ error_summary = "; ".join(self._validation_errors)
+ raise ValueError(f"Insert plan validation failed: {error_summary}")
+
+ return self._execution_plan
+
+ def reset(self) -> "MongoInsertBuilder":
+ """Reset the builder to start a new insert plan."""
+ self._execution_plan = InsertExecutionPlan()
+ self._validation_errors.clear()
+ return self
+
+ def __str__(self) -> str:
+ """String representation for debugging."""
+ return (
+ f"MongoInsertBuilder(collection={self._execution_plan.collection}, "
+ f"documents={len(self._execution_plan.insert_documents)})"
+ )
diff --git a/pymongosql/sql/insert_handler.py b/pymongosql/sql/insert_handler.py
new file mode 100644
index 0000000..09b6ae0
--- /dev/null
+++ b/pymongosql/sql/insert_handler.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+import ast
+import logging
+import re
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple
+
+from .handler import BaseHandler
+
+_logger = logging.getLogger(__name__)
+
+
+@dataclass
+class InsertParseResult:
+ """Result container for INSERT statement visitor parsing."""
+
+ collection: Optional[str] = None
+ insert_columns: Optional[List[str]] = None
+ insert_values: Optional[List[List[Any]]] = None
+ insert_documents: Optional[List[Dict[str, Any]]] = None
+ insert_type: Optional[str] = None # e.g., "values" | "bag"
+ parameter_style: Optional[str] = None # e.g., "qmark"
+ parameter_count: int = 0
+ has_errors: bool = False
+ error_message: Optional[str] = None
+
+ @classmethod
+ def for_visitor(cls) -> "InsertParseResult":
+ """Factory for a fresh insert parse result."""
+ return cls()
+
+
+class InsertHandler(BaseHandler):
+ """Visitor handler to convert INSERT parse trees into InsertParseResult."""
+
+ def can_handle(self, ctx: Any) -> bool:
+ return hasattr(ctx, "INSERT")
+
+ def handle_visitor(self, ctx: Any, parse_result: InsertParseResult) -> InsertParseResult:
+ try:
+ collection = self._extract_collection(ctx)
+ value_text = self._extract_value_text(ctx)
+
+ documents = self._parse_value_expr(value_text)
+ param_style, param_count = self._detect_parameter_style(documents)
+
+ parse_result.collection = collection
+ parse_result.insert_documents = documents
+ parse_result.insert_type = "bag" if value_text.strip().startswith("<<") else "value"
+ parse_result.parameter_style = param_style
+ parse_result.parameter_count = param_count
+ parse_result.has_errors = False
+ parse_result.error_message = None
+ return parse_result
+ except Exception as exc: # pragma: no cover - defensive logging
+ _logger.error("Failed to handle INSERT", exc_info=True)
+ parse_result.has_errors = True
+ parse_result.error_message = str(exc)
+ return parse_result
+
+ def _extract_collection(self, ctx: Any) -> str:
+ if hasattr(ctx, "symbolPrimitive") and ctx.symbolPrimitive():
+ return ctx.symbolPrimitive().getText()
+ if hasattr(ctx, "pathSimple") and ctx.pathSimple(): # legacy form
+ return ctx.pathSimple().getText()
+ raise ValueError("INSERT statement missing collection name")
+
+ def _extract_value_text(self, ctx: Any) -> str:
+ if hasattr(ctx, "value") and ctx.value:
+ return ctx.value.getText()
+ if hasattr(ctx, "value") and callable(ctx.value): # legacy form pathSimple VALUE expr
+ value_ctx = ctx.value()
+ if value_ctx:
+ return value_ctx.getText()
+ raise ValueError("INSERT statement missing value expression")
+
+ def _parse_value_expr(self, text: str) -> List[Dict[str, Any]]:
+ cleaned = text.strip()
+ cleaned = self._normalize_literals(cleaned)
+
+ if cleaned.startswith("<<") and cleaned.endswith(">>"):
+ literal_text = cleaned.replace("<<", "[").replace(">>", "]")
+ return self._parse_literal_list(literal_text)
+
+ if cleaned.startswith("{") and cleaned.endswith("}"):
+ doc = self._parse_literal_dict(cleaned)
+ return [doc]
+
+ raise ValueError("Unsupported INSERT value expression")
+
+ def _parse_literal_list(self, literal_text: str) -> List[Dict[str, Any]]:
+ try:
+ value = ast.literal_eval(literal_text)
+ except Exception as exc:
+ raise ValueError(f"Failed to parse INSERT bag literal: {exc}") from exc
+ if not isinstance(value, list) or not all(isinstance(item, dict) for item in value):
+ raise ValueError("INSERT bag must contain objects")
+ return value
+
+ def _parse_literal_dict(self, literal_text: str) -> Dict[str, Any]:
+ try:
+ value = ast.literal_eval(literal_text)
+ except Exception as exc:
+ raise ValueError(f"Failed to parse INSERT object literal: {exc}") from exc
+ if not isinstance(value, dict):
+ raise ValueError("INSERT value expression must be an object")
+ return value
+
+ def _normalize_literals(self, text: str) -> str:
+ # Replace PartiQL-style booleans/null with Python equivalents for literal_eval
+ replacements = {
+ r"\bnull\b": "None",
+ r"\bNULL\b": "None",
+ r"\btrue\b": "True",
+ r"\bTRUE\b": "True",
+ r"\bfalse\b": "False",
+ r"\bFALSE\b": "False",
+ }
+ normalized = text
+ for pattern, replacement in replacements.items():
+ normalized = re.sub(pattern, replacement, normalized)
+ return normalized
+
+ def _detect_parameter_style(self, documents: List[Dict[str, Any]]) -> Tuple[Optional[str], int]:
+ style = None
+ count = 0
+
+ def consider(value: Any):
+ nonlocal style, count
+ if value == "?":
+ new_style = "qmark"
+ elif isinstance(value, str) and value.startswith(":"):
+ new_style = "named"
+ else:
+ return
+
+ if style and style != new_style:
+ raise ValueError("Mixed parameter styles are not supported")
+ style = new_style
+ count += 1
+
+ for doc in documents:
+ for val in doc.values():
+ consider(val)
+
+ return style, count
diff --git a/pymongosql/sql/parser.py b/pymongosql/sql/parser.py
index c62556c..5dfe491 100644
--- a/pymongosql/sql/parser.py
+++ b/pymongosql/sql/parser.py
@@ -1,14 +1,17 @@
# -*- coding: utf-8 -*-
import logging
from abc import ABCMeta
-from typing import Any, Optional
+from typing import Any, Optional, Union
from antlr4 import CommonTokenStream, InputStream
from antlr4.error.ErrorListener import ErrorListener
from ..error import SqlSyntaxError
from .ast import MongoSQLLexer, MongoSQLParser, MongoSQLParserVisitor
-from .builder import ExecutionPlan
+from .delete_builder import DeleteExecutionPlan
+from .insert_builder import InsertExecutionPlan
+from .query_builder import QueryExecutionPlan
+from .update_builder import UpdateExecutionPlan
_logger = logging.getLogger(__name__)
@@ -126,27 +129,27 @@ def _validate_ast(self) -> None:
_logger.debug("AST validation successful")
- def get_execution_plan(self) -> ExecutionPlan:
- """Parse SQL and return ExecutionPlan directly"""
+ def get_execution_plan(
+ self,
+ ) -> Union[QueryExecutionPlan, InsertExecutionPlan, DeleteExecutionPlan, UpdateExecutionPlan]:
+ """Parse SQL and return an execution plan (SELECT, INSERT, DELETE, or UPDATE)."""
if self._ast is None:
raise SqlSyntaxError("No AST available - parsing may have failed")
try:
- # Create and use visitor to generate ExecutionPlan
self._visitor = MongoSQLParserVisitor()
self._visitor.visit(self._ast)
execution_plan = self._visitor.parse_to_execution_plan()
- # Validate execution plan
if not execution_plan.validate():
raise SqlSyntaxError("Generated execution plan is invalid")
- _logger.debug(f"Generated ExecutionPlan for collection: {execution_plan.collection}")
+ _logger.debug(f"Generated execution plan for collection: {execution_plan.collection}")
return execution_plan
except Exception as e:
- _logger.error(f"Failed to generate ExecutionPlan from AST: {e}")
- raise SqlSyntaxError(f"ExecutionPlan generation failed: {e}") from e
+ _logger.error(f"Failed to generate execution plan from AST: {e}")
+ raise SqlSyntaxError(f"Execution plan generation failed: {e}") from e
def get_parse_info(self) -> dict:
"""Get detailed parsing information for debugging"""
diff --git a/pymongosql/sql/query_builder.py b/pymongosql/sql/query_builder.py
new file mode 100644
index 0000000..1ff7f10
--- /dev/null
+++ b/pymongosql/sql/query_builder.py
@@ -0,0 +1,250 @@
+# -*- coding: utf-8 -*-
+import logging
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Union
+
+from .builder import ExecutionPlan
+
+_logger = logging.getLogger(__name__)
+
+
+@dataclass
+class QueryExecutionPlan(ExecutionPlan):
+ """Execution plan for MongoDB SELECT queries (query-only)."""
+
+ filter_stage: Dict[str, Any] = field(default_factory=dict)
+ projection_stage: Dict[str, Any] = field(default_factory=dict)
+ column_aliases: Dict[str, str] = field(default_factory=dict) # Maps field_name -> alias
+ sort_stage: List[Dict[str, int]] = field(default_factory=list)
+ limit_stage: Optional[int] = None
+ skip_stage: Optional[int] = None
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert query plan to dictionary representation"""
+ return {
+ "collection": self.collection,
+ "filter": self.filter_stage,
+ "projection": self.projection_stage,
+ "sort": self.sort_stage,
+ "limit": self.limit_stage,
+ "skip": self.skip_stage,
+ }
+
+ def validate(self) -> bool:
+ """Validate the query plan"""
+ errors = self.validate_base()
+
+ if self.limit_stage is not None and (not isinstance(self.limit_stage, int) or self.limit_stage < 0):
+ errors.append("Limit must be a non-negative integer")
+
+ if self.skip_stage is not None and (not isinstance(self.skip_stage, int) or self.skip_stage < 0):
+ errors.append("Skip must be a non-negative integer")
+
+ if errors:
+ _logger.error(f"Query validation errors: {errors}")
+ return False
+
+ return True
+
+ def copy(self) -> "QueryExecutionPlan":
+ """Create a copy of this execution plan"""
+ return QueryExecutionPlan(
+ collection=self.collection,
+ filter_stage=self.filter_stage.copy(),
+ projection_stage=self.projection_stage.copy(),
+ column_aliases=self.column_aliases.copy(),
+ sort_stage=self.sort_stage.copy(),
+ limit_stage=self.limit_stage,
+ skip_stage=self.skip_stage,
+ )
+
+
+class MongoQueryBuilder:
+ """Fluent builder for MongoDB queries with validation and readability"""
+
+ def __init__(self):
+ self._execution_plan = QueryExecutionPlan()
+ self._validation_errors = []
+
+ def collection(self, name: str) -> "MongoQueryBuilder":
+ """Set the target collection"""
+ if not name or not name.strip():
+ self._add_error("Collection name cannot be empty")
+ return self
+
+ self._execution_plan.collection = name.strip()
+ _logger.debug(f"Set collection to: {name}")
+ return self
+
+ def filter(self, conditions: Dict[str, Any]) -> "MongoQueryBuilder":
+ """Add filter conditions"""
+ if not isinstance(conditions, dict):
+ self._add_error("Filter conditions must be a dictionary")
+ return self
+
+ self._execution_plan.filter_stage.update(conditions)
+ _logger.debug(f"Added filter conditions: {conditions}")
+ return self
+
+ def project(self, fields: Union[Dict[str, int], List[str]]) -> "MongoQueryBuilder":
+ """Set projection fields"""
+ if isinstance(fields, list):
+ # Convert list to projection dict
+ projection = {field: 1 for field in fields}
+ elif isinstance(fields, dict):
+ projection = fields
+ else:
+ self._add_error("Projection must be a list of field names or a dictionary")
+ return self
+
+ self._execution_plan.projection_stage = projection
+ _logger.debug(f"Set projection: {projection}")
+ return self
+
+ def sort(self, specs: List[Dict[str, int]]) -> "MongoQueryBuilder":
+ """Add sort criteria.
+
+ Only accepts a list of single-key dicts in the form:
+ [{"field": 1}, {"other": -1}]
+
+ This matches the output produced by the SQL parser (`sort_fields`).
+ """
+ if not isinstance(specs, list):
+ self._add_error("Sort specifications must be a list of single-key dicts")
+ return self
+
+ for spec in specs:
+ if not isinstance(spec, dict) or len(spec) != 1:
+ self._add_error("Each sort specification must be a single-key dict, e.g. {'name': 1}")
+ continue
+
+ field, direction = next(iter(spec.items()))
+
+ if not isinstance(field, str) or not field:
+ self._add_error("Sort field must be a non-empty string")
+ continue
+
+ if direction not in [-1, 1]:
+ self._add_error(f"Sort direction for field '{field}' must be 1 or -1")
+ continue
+
+ self._execution_plan.sort_stage.append({field: direction})
+ _logger.debug(f"Added sort: {field} -> {direction}")
+
+ return self
+
+ def limit(self, count: int) -> "MongoQueryBuilder":
+ """Set limit for results"""
+ if not isinstance(count, int) or count < 0:
+ self._add_error("Limit must be a non-negative integer")
+ return self
+
+ self._execution_plan.limit_stage = count
+ _logger.debug(f"Set limit to: {count}")
+ return self
+
+ def skip(self, count: int) -> "MongoQueryBuilder":
+ """Set skip count for pagination"""
+ if not isinstance(count, int) or count < 0:
+ self._add_error("Skip must be a non-negative integer")
+ return self
+
+ self._execution_plan.skip_stage = count
+ _logger.debug(f"Set skip to: {count}")
+ return self
+
+ def column_aliases(self, aliases: Dict[str, str]) -> "MongoQueryBuilder":
+ """Set column aliases mapping (field_name -> alias)"""
+ if not isinstance(aliases, dict):
+ self._add_error("Column aliases must be a dictionary")
+ return self
+
+ self._execution_plan.column_aliases = aliases
+ _logger.debug(f"Set column aliases to: {aliases}")
+ return self
+
+ def where(self, field: str, operator: str, value: Any) -> "MongoQueryBuilder":
+ """Add a where condition in a readable format"""
+ condition = self._build_condition(field, operator, value)
+ if condition:
+ return self.filter(condition)
+ return self
+
+ def where_in(self, field: str, values: List[Any]) -> "MongoQueryBuilder":
+ """Add a WHERE field IN (values) condition"""
+ return self.filter({field: {"$in": values}})
+
+ def where_between(self, field: str, min_val: Any, max_val: Any) -> "MongoQueryBuilder":
+ """Add a WHERE field BETWEEN min AND max condition"""
+ return self.filter({field: {"$gte": min_val, "$lte": max_val}})
+
+ def where_like(self, field: str, pattern: str) -> "MongoQueryBuilder":
+ """Add a WHERE field LIKE pattern condition"""
+ # Convert SQL LIKE pattern to MongoDB regex
+ regex_pattern = pattern.replace("%", ".*").replace("_", ".")
+ return self.filter({field: {"$regex": regex_pattern, "$options": "i"}})
+
+ def _build_condition(self, field: str, operator: str, value: Any) -> Optional[Dict[str, Any]]:
+ """Build a MongoDB condition from field, operator, and value"""
+ operator_map = {
+ "=": "$eq",
+ "!=": "$ne",
+ "<": "$lt",
+ "<=": "$lte",
+ ">": "$gt",
+ ">=": "$gte",
+ "eq": "$eq",
+ "ne": "$ne",
+ "lt": "$lt",
+ "lte": "$lte",
+ "gt": "$gt",
+ "gte": "$gte",
+ }
+
+ mongo_op = operator_map.get(operator.lower())
+ if not mongo_op:
+ self._add_error(f"Unsupported operator: {operator}")
+ return None
+
+ return {field: {mongo_op: value}}
+
+ def _add_error(self, message: str) -> None:
+ """Add validation error"""
+ self._validation_errors.append(message)
+ _logger.error(f"Query builder error: {message}")
+
+ def validate(self) -> bool:
+ """Validate the current query plan"""
+ self._validation_errors.clear()
+
+ if not self._execution_plan.collection:
+ self._add_error("Collection name is required")
+
+ # Add more validation rules as needed
+ return len(self._validation_errors) == 0
+
+ def get_errors(self) -> List[str]:
+ """Get validation errors"""
+ return self._validation_errors.copy()
+
+ def build(self) -> QueryExecutionPlan:
+ """Build and return the execution plan"""
+ if not self.validate():
+ error_summary = "; ".join(self._validation_errors)
+ raise ValueError(f"Query validation failed: {error_summary}")
+
+ return self._execution_plan
+
+ def reset(self) -> "MongoQueryBuilder":
+ """Reset the builder to start a new query"""
+ self._execution_plan = QueryExecutionPlan()
+ self._validation_errors.clear()
+ return self
+
+ def __str__(self) -> str:
+ """String representation for debugging"""
+ return (
+ f"MongoQueryBuilder(collection={self._execution_plan.collection}, "
+ f"filter={self._execution_plan.filter_stage}, "
+ f"projection={self._execution_plan.projection_stage})"
+ )
diff --git a/pymongosql/sql/query_handler.py b/pymongosql/sql/query_handler.py
new file mode 100644
index 0000000..49c6e63
--- /dev/null
+++ b/pymongosql/sql/query_handler.py
@@ -0,0 +1,198 @@
+# -*- coding: utf-8 -*-
+import logging
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Tuple
+
+from .handler import BaseHandler, ContextUtilsMixin
+from .partiql.PartiQLParser import PartiQLParser
+
+_logger = logging.getLogger(__name__)
+
+
+@dataclass
+class QueryParseResult:
+ """Result container for query (SELECT) expression parsing and visitor state management"""
+
+ # Core parsing fields
+ filter_conditions: Dict[str, Any] = field(default_factory=dict) # Unified filter field for all MongoDB conditions
+ has_errors: bool = False
+ error_message: Optional[str] = None
+
+ # Visitor parsing state fields
+ collection: Optional[str] = None
+ projection: Dict[str, Any] = field(default_factory=dict)
+ column_aliases: Dict[str, str] = field(default_factory=dict) # Maps field_name -> alias
+ sort_fields: List[Dict[str, int]] = field(default_factory=list)
+ limit_value: Optional[int] = None
+ offset_value: Optional[int] = None
+
+ # Subquery info (for wrapped subqueries, e.g., Superset outering)
+ subquery_plan: Optional[Any] = None
+ subquery_alias: Optional[str] = None
+
+ # Factory methods for different use cases
+ @classmethod
+ def for_visitor(cls) -> "QueryParseResult":
+ """Create QueryParseResult for visitor parsing"""
+ return cls()
+
+ def merge_expression(self, other: "QueryParseResult") -> "QueryParseResult":
+ """Merge expression results from another QueryParseResult"""
+ if other.has_errors:
+ self.has_errors = True
+ self.error_message = other.error_message
+
+ # Merge filter conditions intelligently
+ if other.filter_conditions:
+ if not self.filter_conditions:
+ self.filter_conditions = other.filter_conditions
+ else:
+ # If both have filters, combine them with $and
+ self.filter_conditions = {"$and": [self.filter_conditions, other.filter_conditions]}
+
+ return self
+
+ # Backward compatibility properties
+ @property
+ def mongo_filter(self) -> Dict[str, Any]:
+ """Backward compatibility property for mongo_filter"""
+ return self.filter_conditions
+
+ @mongo_filter.setter
+ def mongo_filter(self, value: Dict[str, Any]):
+ """Backward compatibility setter for mongo_filter"""
+ self.filter_conditions = value
+
+
+class EnhancedWhereHandler(ContextUtilsMixin):
+ """Enhanced WHERE clause handler using expression handlers"""
+
+ def handle(self, ctx: PartiQLParser.WhereClauseSelectContext) -> Dict[str, Any]:
+ """Handle WHERE clause with proper expression parsing"""
+ if not hasattr(ctx, "exprSelect") or not ctx.exprSelect():
+ _logger.debug("No expression found in WHERE clause")
+ return {}
+
+ expression_ctx = ctx.exprSelect()
+ # Local import to avoid circular dependency between query_handler and handler
+ from .handler import HandlerFactory
+
+ handler = HandlerFactory.get_expression_handler(expression_ctx)
+
+ if handler:
+ _logger.debug(
+ f"Using {type(handler).__name__} for WHERE clause",
+ extra={"context_text": self.get_context_text(expression_ctx)[:100]},
+ )
+ result = handler.handle_expression(expression_ctx)
+ if result.has_errors:
+ _logger.warning(
+ "Expression parsing error, falling back to text search",
+ extra={"error": result.error_message},
+ )
+ # Fallback to text-based filter
+ return {"$text": {"$search": self.get_context_text(expression_ctx)}}
+ return result.filter_conditions
+ else:
+ # Fallback to simple text-based search
+ _logger.debug(
+ "No suitable expression handler found, using text search",
+ extra={"context_text": self.get_context_text(expression_ctx)[:100]},
+ )
+ return {"$text": {"$search": self.get_context_text(expression_ctx)}}
+
+
+class SelectHandler(BaseHandler, ContextUtilsMixin):
+ """Handles SELECT statement parsing"""
+
+ def can_handle(self, ctx: Any) -> bool:
+ """Check if this is a select context"""
+ return hasattr(ctx, "projectionItems")
+
+ def handle_visitor(self, ctx: PartiQLParser.SelectItemsContext, parse_result: "QueryParseResult") -> Any:
+ projection = {}
+ column_aliases = {}
+
+ if hasattr(ctx, "projectionItems") and ctx.projectionItems():
+ for item in ctx.projectionItems().projectionItem():
+ field_name, alias = self._extract_field_and_alias(item)
+ # Use MongoDB standard projection format: {field: 1} to include field
+ projection[field_name] = 1
+ # Store alias if present
+ if alias:
+ column_aliases[field_name] = alias
+
+ parse_result.projection = projection
+ parse_result.column_aliases = column_aliases
+ return projection
+
+ def _extract_field_and_alias(self, item) -> Tuple[str, Optional[str]]:
+ """Extract field name and alias from projection item context with nested field support"""
+ if not hasattr(item, "children") or not item.children:
+ return str(item), None
+
+ # According to grammar: projectionItem : expr ( AS? symbolPrimitive )? ;
+ # children[0] is always the expression
+ # If there's an alias, children[1] might be AS and children[2] symbolPrimitive
+ # OR children[1] might be just symbolPrimitive (without AS)
+
+ field_name = item.children[0].getText()
+ # Normalize bracket notation (jmspath) to Mongo dot notation
+ field_name = self.normalize_field_path(field_name)
+
+ alias = None
+
+ if len(item.children) >= 2:
+ # Check if we have an alias
+ if len(item.children) == 3:
+ # Pattern: expr AS symbolPrimitive
+ if hasattr(item.children[1], "getText") and item.children[1].getText().upper() == "AS":
+ alias = item.children[2].getText()
+ elif len(item.children) == 2:
+ # Pattern: expr symbolPrimitive (without AS)
+ alias = item.children[1].getText()
+
+ return field_name, alias
+
+
+class FromHandler(BaseHandler):
+ """Handles FROM clause parsing"""
+
+ def can_handle(self, ctx: Any) -> bool:
+ """Check if this is a from context"""
+ return hasattr(ctx, "tableReference")
+
+ def handle_visitor(self, ctx: PartiQLParser.FromClauseContext, parse_result: "QueryParseResult") -> Any:
+ if hasattr(ctx, "tableReference") and ctx.tableReference():
+ table_text = ctx.tableReference().getText()
+ collection_name = table_text
+ parse_result.collection = collection_name
+ return collection_name
+ return None
+
+
+class WhereHandler(BaseHandler):
+ """Handles WHERE clause parsing"""
+
+ def __init__(self):
+ self._expression_handler = EnhancedWhereHandler()
+
+ def can_handle(self, ctx: Any) -> bool:
+ """Check if this is a where context"""
+ return hasattr(ctx, "exprSelect")
+
+ def handle_visitor(self, ctx: PartiQLParser.WhereClauseSelectContext, parse_result: "QueryParseResult") -> Any:
+ if hasattr(ctx, "exprSelect") and ctx.exprSelect():
+ try:
+ # Use enhanced expression handler for better parsing
+ filter_conditions = self._expression_handler.handle(ctx)
+ parse_result.filter_conditions = filter_conditions
+ return filter_conditions
+ except Exception as e:
+ _logger.warning(f"Failed to parse WHERE expression, falling back to text search: {e}")
+ # Fallback to simple text search
+ filter_text = ctx.exprSelect().getText()
+ fallback_filter = {"$text": {"$search": filter_text}}
+ parse_result.filter_conditions = fallback_filter
+ return fallback_filter
+ return {}
diff --git a/pymongosql/sql/update_builder.py b/pymongosql/sql/update_builder.py
new file mode 100644
index 0000000..17fc0f2
--- /dev/null
+++ b/pymongosql/sql/update_builder.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+import logging
+from dataclasses import dataclass, field
+from typing import Any, Dict
+
+from .builder import ExecutionPlan
+
+_logger = logging.getLogger(__name__)
+
+
+@dataclass
+class UpdateExecutionPlan(ExecutionPlan):
+ """Execution plan for UPDATE operations against MongoDB."""
+
+ update_fields: Dict[str, Any] = field(default_factory=dict) # Fields to update
+ filter_conditions: Dict[str, Any] = field(default_factory=dict) # Filter for documents to update
+ parameter_style: str = field(default="qmark") # Parameter placeholder style: qmark (?) or named (:name)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert update plan to dictionary representation."""
+ return {
+ "collection": self.collection,
+ "filter": self.filter_conditions,
+ "update": {"$set": self.update_fields},
+ }
+
+ def validate(self) -> bool:
+ """Validate the update plan."""
+ errors = self.validate_base()
+
+ if not self.update_fields:
+ errors.append("Update fields are required")
+
+ # Note: filter_conditions can be empty for UPDATE SET ... (update all)
+ # which is valid, so we don't enforce filter presence
+
+ if errors:
+ _logger.error(f"Update plan validation errors: {errors}")
+ return False
+
+ return True
+
+ def copy(self) -> "UpdateExecutionPlan":
+ """Create a copy of this update plan."""
+ return UpdateExecutionPlan(
+ collection=self.collection,
+ update_fields=self.update_fields.copy() if self.update_fields else {},
+ filter_conditions=self.filter_conditions.copy() if self.filter_conditions else {},
+ )
+
+ def get_mongo_update_doc(self) -> Dict[str, Any]:
+ """Get MongoDB update document using $set operator."""
+ return {"$set": self.update_fields}
+
+
+class MongoUpdateBuilder:
+ """Builder for constructing UpdateExecutionPlan objects."""
+
+ def __init__(self) -> None:
+ """Initialize the update builder."""
+ self._plan = UpdateExecutionPlan()
+
+ def collection(self, collection: str) -> "MongoUpdateBuilder":
+ """Set the collection name."""
+ self._plan.collection = collection
+ return self
+
+ def update_fields(self, fields: Dict[str, Any]) -> "MongoUpdateBuilder":
+ """Set the fields to update."""
+ if fields:
+ self._plan.update_fields = fields
+ return self
+
+ def filter_conditions(self, conditions: Dict[str, Any]) -> "MongoUpdateBuilder":
+ """Set the filter conditions for the update operation."""
+ if conditions:
+ self._plan.filter_conditions = conditions
+ return self
+
+ def parameter_style(self, style: str) -> "MongoUpdateBuilder":
+ """Set the parameter placeholder style."""
+ self._plan.parameter_style = style
+ return self
+
+ def build(self) -> UpdateExecutionPlan:
+ """Build and return the UpdateExecutionPlan."""
+ if not self._plan.validate():
+ raise ValueError("Invalid update plan")
+ return self._plan
diff --git a/pymongosql/sql/update_handler.py b/pymongosql/sql/update_handler.py
new file mode 100644
index 0000000..6b08f57
--- /dev/null
+++ b/pymongosql/sql/update_handler.py
@@ -0,0 +1,210 @@
+# -*- coding: utf-8 -*-
+import logging
+from dataclasses import dataclass, field
+from typing import Any, Dict, Optional
+
+from .handler import BaseHandler
+from .partiql.PartiQLParser import PartiQLParser
+
+_logger = logging.getLogger(__name__)
+
+
+@dataclass
+class UpdateParseResult:
+ """Result of parsing an UPDATE statement.
+
+ Stores the extracted information needed to build an UpdateExecutionPlan.
+ """
+
+ collection: Optional[str] = None
+ update_fields: Dict[str, Any] = field(default_factory=dict) # Field -> new value mapping
+ filter_conditions: Dict[str, Any] = field(default_factory=dict)
+ has_errors: bool = False
+ error_message: Optional[str] = None
+
+ @staticmethod
+ def for_visitor() -> "UpdateParseResult":
+ """Factory method to create a fresh UpdateParseResult for visitor pattern."""
+ return UpdateParseResult()
+
+ def validate(self) -> bool:
+ """Validate that required fields are populated."""
+ if not self.collection:
+ self.error_message = "Collection name is required"
+ self.has_errors = True
+ return False
+ if not self.update_fields:
+ self.error_message = "At least one field to update is required"
+ self.has_errors = True
+ return False
+ return True
+
+ def to_dict(self) -> Dict[str, Any]:
+ """Convert to dictionary representation for debugging."""
+ return {
+ "collection": self.collection,
+ "update_fields": self.update_fields,
+ "filter_conditions": self.filter_conditions,
+ "has_errors": self.has_errors,
+ "error_message": self.error_message,
+ }
+
+ def __repr__(self) -> str:
+ """String representation."""
+ return (
+ f"UpdateParseResult(collection={self.collection}, "
+ f"update_fields={self.update_fields}, "
+ f"filter_conditions={self.filter_conditions}, "
+ f"has_errors={self.has_errors})"
+ )
+
+
+class UpdateHandler(BaseHandler):
+ """Handler for UPDATE statement visitor parsing."""
+
+ def can_handle(self, ctx: Any) -> bool:
+ """Check if this handler can process the given context."""
+ return hasattr(ctx, "UPDATE") or isinstance(ctx, PartiQLParser.UpdateClauseContext)
+
+ def handle_visitor(self, ctx: Any, parse_result: UpdateParseResult) -> UpdateParseResult:
+ """Handle UPDATE clause during visitor traversal."""
+ _logger.debug("UpdateHandler processing UPDATE clause")
+ try:
+ # Extract collection name from UPDATE clause
+ # updateClause: UPDATE tableBaseReference
+ if hasattr(ctx, "tableBaseReference") and ctx.tableBaseReference():
+ collection_name = self._extract_collection_from_table_ref(ctx.tableBaseReference())
+ parse_result.collection = collection_name
+ _logger.debug(f"Extracted collection for UPDATE: {collection_name}")
+ except Exception as e:
+ _logger.warning(f"Error processing UPDATE clause: {e}")
+ parse_result.has_errors = True
+ parse_result.error_message = str(e)
+
+ return parse_result
+
+ def _extract_collection_from_table_ref(self, ctx: Any) -> Optional[str]:
+ """Extract collection name from tableBaseReference context."""
+ try:
+ # tableBaseReference can have multiple forms:
+ # - source=exprSelect symbolPrimitive
+ # - source=exprSelect asIdent? atIdent? byIdent?
+ # - source=exprGraphMatchOne asIdent? atIdent? byIdent?
+
+ # For simple UPDATE statements, we expect exprSelect to be a simple identifier
+ if hasattr(ctx, "source") and ctx.source:
+ source_text = ctx.source.getText()
+ _logger.debug(f"Extracted collection from tableBaseReference: {source_text}")
+ return source_text
+
+ # Fallback: try to get text directly
+ return ctx.getText()
+ except Exception as e:
+ _logger.warning(f"Error extracting collection from tableBaseReference: {e}")
+ return None
+
+ def handle_set_command(self, ctx: Any, parse_result: UpdateParseResult) -> UpdateParseResult:
+ """Handle SET command during visitor traversal.
+
+ setCommand: SET setAssignment ( COMMA setAssignment )*
+ setAssignment: pathSimple EQ expr
+ """
+ _logger.debug("UpdateHandler processing SET command")
+ try:
+ if hasattr(ctx, "setAssignment") and ctx.setAssignment():
+ for assignment_ctx in ctx.setAssignment():
+ field_name, field_value = self._extract_set_assignment(assignment_ctx)
+ if field_name:
+ parse_result.update_fields[field_name] = field_value
+ _logger.debug(f"Extracted SET assignment: {field_name} = {field_value}")
+ except Exception as e:
+ _logger.warning(f"Error processing SET command: {e}")
+ parse_result.has_errors = True
+ parse_result.error_message = str(e)
+
+ return parse_result
+
+ def _extract_set_assignment(self, ctx: Any) -> tuple[Optional[str], Any]:
+ """Extract field name and value from setAssignment.
+
+ setAssignment: pathSimple EQ expr
+ """
+ try:
+ field_name = None
+ field_value = None
+
+ # Extract field name from pathSimple
+ if hasattr(ctx, "pathSimple") and ctx.pathSimple():
+ field_name = ctx.pathSimple().getText()
+
+ # Extract value from expr
+ if hasattr(ctx, "expr") and ctx.expr():
+ expr_text = ctx.expr().getText()
+ # Parse the expression to get the actual value
+ field_value = self._parse_value(expr_text)
+
+ return field_name, field_value
+ except Exception as e:
+ _logger.warning(f"Error extracting set assignment: {e}")
+ return None, None
+
+ def _parse_value(self, text: str) -> Any:
+ """Parse expression text to extract the actual value."""
+ # Remove surrounding quotes if present
+ text = text.strip()
+
+ if text.startswith("'") and text.endswith("'"):
+ return text[1:-1]
+ elif text.startswith('"') and text.endswith('"'):
+ return text[1:-1]
+ elif text.lower() == "null":
+ return None
+ elif text.lower() == "true":
+ return True
+ elif text.lower() == "false":
+ return False
+ elif text.startswith("?") or text.startswith(":"):
+ # Parameter placeholder
+ return text
+ else:
+ # Try to parse as number
+ try:
+ if "." in text:
+ return float(text)
+ else:
+ return int(text)
+ except ValueError:
+ # Return as string if not a number
+ return text
+
+ def handle_where_clause(self, ctx: Any, parse_result: UpdateParseResult) -> Dict[str, Any]:
+ """Handle WHERE clause for UPDATE statements."""
+ _logger.debug("UpdateHandler processing WHERE clause")
+ try:
+ # Get the expression context
+ expression_ctx = None
+ if hasattr(ctx, "arg") and ctx.arg:
+ expression_ctx = ctx.arg
+ elif hasattr(ctx, "expr"):
+ expression_ctx = ctx.expr()
+
+ if expression_ctx:
+ from .handler import HandlerFactory
+
+ handler = HandlerFactory.get_expression_handler(expression_ctx)
+
+ if handler:
+ result = handler.handle_expression(expression_ctx)
+ if not result.has_errors:
+ parse_result.filter_conditions = result.filter_conditions
+ _logger.debug(f"Extracted filter conditions for UPDATE: {result.filter_conditions}")
+ return result.filter_conditions
+
+ # No WHERE clause means update all documents
+ _logger.debug("No WHERE clause for UPDATE")
+ return {}
+ except Exception as e:
+ _logger.warning(f"Error processing WHERE clause: {e}")
+ parse_result.has_errors = True
+ parse_result.error_message = str(e)
+ return {}
diff --git a/pymongosql/superset_mongodb/executor.py b/pymongosql/superset_mongodb/executor.py
index 91ec8ab..9fb6c9b 100644
--- a/pymongosql/superset_mongodb/executor.py
+++ b/pymongosql/superset_mongodb/executor.py
@@ -2,16 +2,16 @@
import logging
from typing import Any, Dict, List, Optional
-from ..executor import ExecutionContext, StandardExecution
+from ..executor import ExecutionContext, StandardQueryExecution
from ..result_set import ResultSet
-from ..sql.builder import ExecutionPlan
+from ..sql.query_builder import QueryExecutionPlan
from .detector import SubqueryDetector
from .query_db_sqlite import QueryDBSQLite
_logger = logging.getLogger(__name__)
-class SupersetExecution(StandardExecution):
+class SupersetExecution(StandardQueryExecution):
"""Two-stage execution strategy for subquery-based queries using intermediate RDBMS.
Uses a QueryDatabase backend (SQLite3 by default) to handle complex
@@ -30,15 +30,16 @@ def __init__(self, query_db_factory: Optional[Any] = None) -> None:
Defaults to SQLiteBridge if not provided.
"""
self._query_db_factory = query_db_factory or QueryDBSQLite
- self._execution_plan: Optional[ExecutionPlan] = None
+ self._execution_plan: Optional[QueryExecutionPlan] = None
@property
- def execution_plan(self) -> ExecutionPlan:
+ def execution_plan(self) -> QueryExecutionPlan:
return self._execution_plan
def supports(self, context: ExecutionContext) -> bool:
- """Support queries with subqueries"""
- return context.execution_mode == "superset"
+ """Support queries with subqueries, only SELECT statments is supported in this mode."""
+ normalized = context.query.lstrip().upper()
+ return "superset" in context.execution_mode.lower() and normalized.startswith("SELECT")
def execute(
self,
@@ -129,7 +130,7 @@ def execute(
except Exception as e:
_logger.warning(f"Could not extract column names from empty result: {e}")
- self._execution_plan = ExecutionPlan(collection="query_db_result", projection_stage=projection_stage)
+ self._execution_plan = QueryExecutionPlan(collection="query_db_result", projection_stage=projection_stage)
return result_set
diff --git a/tests/test_cursor_delete.py b/tests/test_cursor_delete.py
new file mode 100644
index 0000000..d955218
--- /dev/null
+++ b/tests/test_cursor_delete.py
@@ -0,0 +1,182 @@
+# -*- coding: utf-8 -*-
+import pytest
+
+
+class TestCursorDelete:
+ """Test suite for DELETE operations using a dedicated test collection."""
+
+ TEST_COLLECTION = "Music"
+
+ @pytest.fixture(autouse=True)
+ def setup_teardown(self, conn):
+ """Setup: insert test data. Teardown: drop test collection."""
+ db = conn.database
+ if self.TEST_COLLECTION in db.list_collection_names():
+ db.drop_collection(self.TEST_COLLECTION)
+
+ # Insert test data for delete operations
+ db[self.TEST_COLLECTION].insert_many(
+ [
+ {"title": "Song A", "artist": "Alice", "year": 2021, "genre": "Pop"},
+ {"title": "Song B", "artist": "Bob", "year": 2020, "genre": "Rock"},
+ {"title": "Song C", "artist": "Charlie", "year": 2021, "genre": "Jazz"},
+ {"title": "Song D", "artist": "Diana", "year": 2019, "genre": "Pop"},
+ {"title": "Song E", "artist": "Eve", "year": 2022, "genre": "Electronic"},
+ ]
+ )
+
+ yield
+
+ # Teardown: drop the test collection after each test
+ if self.TEST_COLLECTION in db.list_collection_names():
+ db.drop_collection(self.TEST_COLLECTION)
+
+ def test_delete_all_documents(self, conn):
+ """Test deleting all documents from collection."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"DELETE FROM {self.TEST_COLLECTION}")
+
+ assert result == cursor # execute returns self
+
+ # Verify all documents were deleted
+ db = conn.database
+ remaining = list(db[self.TEST_COLLECTION].find())
+ assert len(remaining) == 0
+
+ def test_delete_with_where_equality(self, conn):
+ """Test DELETE with WHERE clause filtering by equality."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"DELETE FROM {self.TEST_COLLECTION} WHERE artist = 'Bob'")
+
+ assert result == cursor # execute returns self
+
+ # Verify only Bob's song was deleted
+ db = conn.database
+ remaining = list(db[self.TEST_COLLECTION].find())
+ assert len(remaining) == 4
+
+ artist_names = {doc["artist"] for doc in remaining}
+ assert "Bob" not in artist_names
+ assert "Alice" in artist_names
+
+ def test_delete_with_where_numeric_filter(self, conn):
+ """Test DELETE with WHERE clause filtering by numeric field."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"DELETE FROM {self.TEST_COLLECTION} WHERE year > 2020")
+
+ assert result == cursor
+
+ # Verify songs from 2021 and 2022 were deleted
+ db = conn.database
+ remaining = list(db[self.TEST_COLLECTION].find())
+ assert len(remaining) == 2 # Only 2019 and 2020 remain
+
+ def test_delete_with_and_condition(self, conn):
+ """Test DELETE with WHERE clause using AND condition."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"DELETE FROM {self.TEST_COLLECTION} WHERE genre = 'Pop' AND year = 2021")
+
+ assert result == cursor
+
+ # Only Song A (Pop, 2021) should be deleted
+ db = conn.database
+ remaining = list(db[self.TEST_COLLECTION].find())
+ assert len(remaining) == 4
+
+ titles = {doc["title"] for doc in remaining}
+ assert "Song A" not in titles
+
+ def test_delete_with_qmark_parameters(self, conn):
+ """Test DELETE with qmark (?) placeholder parameters."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"DELETE FROM {self.TEST_COLLECTION} WHERE artist = '?'", ["Charlie"])
+
+ assert result == cursor
+
+ # Verify Charlie's song was deleted
+ db = conn.database
+ remaining = list(db[self.TEST_COLLECTION].find())
+ assert len(remaining) == 4
+
+ artists = {doc["artist"] for doc in remaining}
+ assert "Charlie" not in artists
+
+ def test_delete_with_multiple_parameters(self, conn):
+ """Test DELETE with multiple qmark parameters."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"DELETE FROM {self.TEST_COLLECTION} WHERE genre = '?' AND year = '?'", ["Pop", 2019])
+
+ assert result == cursor
+
+ # Only Song D (Pop, 2019) should be deleted
+ db = conn.database
+ remaining = list(db[self.TEST_COLLECTION].find())
+ assert len(remaining) == 4
+
+ titles = {doc["title"] for doc in remaining}
+ assert "Song D" not in titles
+
+ def test_delete_no_match_returns_success(self, conn):
+ """Test that DELETE with no matching records still succeeds."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"DELETE FROM {self.TEST_COLLECTION} WHERE artist = 'Nonexistent'")
+
+ assert result == cursor
+
+ # Verify no documents were deleted
+ db = conn.database
+ remaining = list(db[self.TEST_COLLECTION].find())
+ assert len(remaining) == 5
+
+ def test_delete_invalid_sql_raises_error(self, conn):
+ """Test that invalid DELETE SQL raises SqlSyntaxError."""
+ _ = conn.cursor()
+
+ # Note: The parser is quite forgiving. This test is skipped for now
+ # as the PartiQL grammar may accept various forms of DELETE syntax.
+ # A truly invalid statement would be one with syntax errors at the
+ # lexer/parser level, like unmatched parentheses.
+ pass
+
+ def test_delete_missing_collection_raises_error(self, conn):
+ """Test that DELETE on non-existent collection is handled."""
+ cursor = conn.cursor()
+
+ # DELETE on non-existent collection should succeed but delete nothing
+ result = cursor.execute("DELETE FROM NonexistentCollection WHERE title = 'Test'")
+ assert result == cursor
+
+ def test_delete_then_select_verify_persistence(self, conn):
+ """Test DELETE followed by SELECT to verify deletion was persisted."""
+ # Delete documents by year
+ delete_cursor = conn.cursor()
+ delete_cursor.execute(f"DELETE FROM {self.TEST_COLLECTION} WHERE year < 2021")
+
+ # Select remaining documents
+ select_cursor = conn.cursor()
+ select_cursor.execute(f"SELECT title, year FROM {self.TEST_COLLECTION} ORDER BY year")
+
+ rows = select_cursor.fetchall()
+
+ # Should have Song A (2021), Song C (2021), and Song E (2022)
+ assert len(rows) == 3
+ years = [row[1] for row in rows]
+ assert all(year >= 2021 for year in years)
+
+ def test_delete_followed_by_insert(self, conn):
+ """Test DELETE followed by INSERT to verify both operations work."""
+ # Delete all
+ delete_cursor = conn.cursor()
+ delete_cursor.execute(f"DELETE FROM {self.TEST_COLLECTION}")
+
+ db = conn.database
+ assert len(list(db[self.TEST_COLLECTION].find())) == 0
+
+ # Insert new document
+ insert_cursor = conn.cursor()
+ insert_cursor.execute(f"INSERT INTO {self.TEST_COLLECTION} {{'title': 'New Song', 'artist': 'Frank'}}")
+
+ # Verify insertion
+ assert len(list(db[self.TEST_COLLECTION].find())) == 1
+ doc = list(db[self.TEST_COLLECTION].find())[0]
+ assert doc["title"] == "New Song"
diff --git a/tests/test_cursor_insert.py b/tests/test_cursor_insert.py
new file mode 100644
index 0000000..13bc2eb
--- /dev/null
+++ b/tests/test_cursor_insert.py
@@ -0,0 +1,198 @@
+# -*- coding: utf-8 -*-
+"""Test suite for INSERT statement execution via Cursor."""
+
+import pytest
+
+from pymongosql.error import ProgrammingError, SqlSyntaxError
+from pymongosql.result_set import ResultSet
+
+
+class TestCursorInsert:
+ """Test suite for INSERT operations using a dedicated test collection."""
+
+ TEST_COLLECTION = "musicians"
+
+ @pytest.fixture(autouse=True)
+ def setup_teardown(self, conn):
+ """Setup: drop test collection before each test. Teardown: drop after each test."""
+ db = conn.database
+ if self.TEST_COLLECTION in db.list_collection_names():
+ db.drop_collection(self.TEST_COLLECTION)
+ yield
+ # Teardown: drop the test collection after each test
+ if self.TEST_COLLECTION in db.list_collection_names():
+ db.drop_collection(self.TEST_COLLECTION)
+
+ def test_insert_single_document(self, conn):
+ """Test inserting a single document into the collection."""
+ sql = f"INSERT INTO {self.TEST_COLLECTION} {{'name': 'Alice', 'age': 30, 'city': 'New York'}}"
+ cursor = conn.cursor()
+ result = cursor.execute(sql)
+
+ assert result == cursor # execute returns self
+
+ # Verify the document was inserted
+ db = conn.database
+ docs = list(db[self.TEST_COLLECTION].find())
+ assert len(docs) == 1
+ assert docs[0]["name"] == "Alice"
+ assert docs[0]["age"] == 30
+ assert docs[0]["city"] == "New York"
+
+ def test_insert_multiple_documents_via_bag(self, conn):
+ """Test inserting multiple documents using bag syntax."""
+ sql = (
+ f"INSERT INTO {self.TEST_COLLECTION} << "
+ "{'name': 'Bob', 'age': 25, 'city': 'Boston'}, "
+ "{'name': 'Charlie', 'age': 35, 'city': 'Chicago'} >>"
+ )
+ cursor = conn.cursor()
+ result = cursor.execute("".join(sql))
+
+ assert result == cursor # execute returns self
+
+ # Verify both documents were inserted
+ db = conn.database
+ docs = list(db[self.TEST_COLLECTION].find({}))
+ assert len(docs) == 2
+
+ names = {doc["name"] for doc in docs}
+ assert "Bob" in names
+ assert "Charlie" in names
+
+ def test_insert_with_null_values(self, conn):
+ """Test inserting document with null values."""
+ sql = f"INSERT INTO {self.TEST_COLLECTION} {{'name': 'Diana', 'age': null, 'city': 'Denver'}}"
+ cursor = conn.cursor()
+ result = cursor.execute(sql)
+
+ assert result == cursor # execute returns self
+
+ # Verify document with null was inserted
+ db = conn.database
+ docs = list(db[self.TEST_COLLECTION].find())
+ assert len(docs) == 1
+ assert docs[0]["name"] == "Diana"
+ assert docs[0]["age"] is None
+ assert docs[0]["city"] == "Denver"
+
+ def test_insert_with_boolean_and_mixed_types(self, conn):
+ """Test inserting document with booleans and various data types."""
+ sql = f"INSERT INTO {self.TEST_COLLECTION} {{'name': 'Eve', 'active': true, 'score': 95.5, 'level': 5}}"
+ cursor = conn.cursor()
+ result = cursor.execute(sql)
+
+ assert result == cursor # execute returns self
+
+ # Verify document with mixed types was inserted
+ db = conn.database
+ docs = list(db[self.TEST_COLLECTION].find())
+ assert len(docs) == 1
+ assert docs[0]["name"] == "Eve"
+ assert docs[0]["active"] is True
+ assert docs[0]["score"] == 95.5
+ assert docs[0]["level"] == 5
+
+ def test_insert_with_qmark_parameters(self, conn):
+ """Test INSERT with qmark (?) placeholder parameters."""
+ sql = f"INSERT INTO {self.TEST_COLLECTION} {{'name': '?', 'age': '?', 'city': '?'}}"
+ cursor = conn.cursor()
+
+ # Execute with positional parameters
+ result = cursor.execute(sql, ["Frank", 28, "Fresno"])
+
+ assert result == cursor # execute returns self
+
+ # Verify document was inserted with parameter values
+ db = conn.database
+ docs = list(db[self.TEST_COLLECTION].find())
+ assert len(docs) == 1
+ assert docs[0]["name"] == "Frank"
+ assert docs[0]["age"] == 28
+ assert docs[0]["city"] == "Fresno"
+
+ def test_insert_with_named_parameters(self, conn):
+ """Test INSERT with qmark (?) placeholder parameters."""
+ sql = f"INSERT INTO {self.TEST_COLLECTION} {{'name': '?', 'age': '?', 'city': '?'}}"
+ cursor = conn.cursor()
+
+ # Execute with positional parameters (qmark style)
+ result = cursor.execute(sql, ["Grace", 32, "Greensboro"])
+
+ assert result == cursor # execute returns self
+
+ # Verify document was inserted with parameter values
+ db = conn.database
+ docs = list(db[self.TEST_COLLECTION].find())
+ assert len(docs) == 1
+ assert docs[0]["name"] == "Grace"
+ assert docs[0]["age"] == 32
+ assert docs[0]["city"] == "Greensboro"
+
+ def test_insert_multiple_documents_with_parameters(self, conn):
+ """Test inserting multiple documents with qmark (?) parameters via bag syntax."""
+ sql = f"INSERT INTO {self.TEST_COLLECTION} << {{'name': '?', 'age': '?'}}, {{'name': '?', 'age': '?'}} >>"
+ cursor = conn.cursor()
+
+ # Execute with positional parameters for multiple documents
+ result = cursor.execute(sql, ["Henry", 40, "Iris", 29])
+
+ assert result == cursor # execute returns self
+
+ # Verify both documents were inserted with parameter values
+ db = conn.database
+ docs = list(db[self.TEST_COLLECTION].find({}))
+ assert len(docs) == 2
+
+ doc_by_name = {doc["name"]: doc for doc in docs}
+ assert "Henry" in doc_by_name
+ assert doc_by_name["Henry"]["age"] == 40
+ assert "Iris" in doc_by_name
+ assert doc_by_name["Iris"]["age"] == 29
+
+ def test_insert_insufficient_parameters_raises_error(self, conn):
+ """Test that insufficient parameters raises ProgrammingError."""
+ sql = f"INSERT INTO {self.TEST_COLLECTION} {{'name': '?', 'age': '?'}}"
+ cursor = conn.cursor()
+
+ # Execute with fewer parameters than placeholders
+ with pytest.raises(ProgrammingError):
+ cursor.execute(sql, ["Jack"]) # Missing second parameter
+
+ def test_insert_missing_named_parameter_raises_error(self, conn):
+ """Test that missing named parameter raises ProgrammingError."""
+ sql = f"INSERT INTO {self.TEST_COLLECTION} {{'name': ':name', 'age': ':age'}}"
+ cursor = conn.cursor()
+
+ # Execute with incomplete named parameters
+ with pytest.raises(ProgrammingError):
+ cursor.execute(sql, {"name": "Kate"}) # Missing :age parameter
+
+ def test_insert_invalid_sql_raises_error(self, conn):
+ """Test that invalid INSERT SQL raises SqlSyntaxError."""
+ sql = f"INSERT INTO {self.TEST_COLLECTION} invalid_syntax"
+ cursor = conn.cursor()
+
+ with pytest.raises(SqlSyntaxError):
+ cursor.execute(sql)
+
+ def test_insert_followed_by_select(self, conn):
+ """Test INSERT followed by SELECT to verify data was persisted."""
+ # Insert a document
+ insert_sql = f"INSERT INTO {self.TEST_COLLECTION} {{'name': 'Liam', 'score': 88}}"
+ cursor = conn.cursor()
+ cursor.execute(insert_sql)
+
+ # Select the document back
+ select_sql = f"SELECT name, score FROM {self.TEST_COLLECTION} WHERE score > 80"
+ result = cursor.execute(select_sql)
+
+ assert result == cursor # execute returns self
+ assert isinstance(cursor.result_set, ResultSet)
+ rows = cursor.result_set.fetchall()
+
+ assert len(rows) == 1
+ if cursor.result_set.description:
+ col_names = [desc[0] for desc in cursor.result_set.description]
+ assert "name" in col_names
+ assert "score" in col_names
diff --git a/tests/test_cursor_parameters.py b/tests/test_cursor_parameters.py
index 3234ee9..1a28c0c 100644
--- a/tests/test_cursor_parameters.py
+++ b/tests/test_cursor_parameters.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
import pytest
-from pymongosql.executor import StandardExecution
+from pymongosql.executor import StandardQueryExecution
class TestPositionalParameters:
@@ -9,7 +9,7 @@ class TestPositionalParameters:
def test_simple_positional_replacement(self):
"""Test basic positional parameter replacement in filter"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"age": "?", "status": "?"}
params = [25, "active"]
@@ -19,7 +19,7 @@ def test_simple_positional_replacement(self):
def test_nested_positional_replacement(self):
"""Test positional parameter replacement in nested filter"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"profile": {"age": "?"}, "status": "?"}
params = [30, "inactive"]
@@ -29,7 +29,7 @@ def test_nested_positional_replacement(self):
def test_list_positional_replacement(self):
"""Test positional parameter replacement in list"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"items": ["?", "?"], "name": "?"}
params = [1, 2, "test"]
@@ -39,7 +39,7 @@ def test_list_positional_replacement(self):
def test_mixed_positional_replacement(self):
"""Test positional parameter replacement with mixed data types"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"$gt": "?", "$lt": "?", "status": "?"}
params = [18, 65, "active"]
@@ -51,7 +51,7 @@ def test_insufficient_positional_parameters(self):
"""Test error when not enough positional parameters provided"""
from pymongosql.error import ProgrammingError
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"age": "?", "status": "?"}
params = [25] # Only one parameter provided
@@ -63,7 +63,7 @@ def test_insufficient_positional_parameters(self):
def test_complex_nested_positional_replacement(self):
"""Test positional parameters in complex nested structures"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"$and": [{"age": {"$gt": "?"}}, {"profile": {"status": "?"}}, {"items": ["?", "?"]}]}
params = [25, "active", 1, 2]
@@ -77,7 +77,7 @@ class TestParameterTypes:
def test_positional_with_numeric_types(self):
"""Test positional parameters with int and float"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"age": "?", "salary": "?"}
params = [25, 50000.50]
@@ -87,7 +87,7 @@ def test_positional_with_numeric_types(self):
def test_positional_with_boolean(self):
"""Test positional parameters with boolean values"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"active": "?", "verified": "?"}
params = [True, False]
@@ -97,7 +97,7 @@ def test_positional_with_boolean(self):
def test_positional_with_null(self):
"""Test positional parameters with None value"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"deleted_at": "?"}
params = [None]
@@ -107,7 +107,7 @@ def test_positional_with_null(self):
def test_positional_with_list_value(self):
"""Test positional parameter with list as value"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"tags": "?"}
params = [["python", "mongodb"]]
@@ -117,7 +117,7 @@ def test_positional_with_list_value(self):
def test_positional_with_dict_value(self):
"""Test positional parameter with dict as value"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"metadata": "?"}
params = [{"key": "value"}]
@@ -131,14 +131,14 @@ class TestEdgeCases:
def test_empty_filter_with_parameters(self):
"""Test parameters with empty filter"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
result = execution._replace_placeholders({}, [])
assert result == {}
def test_non_placeholder_strings_untouched(self):
"""Test that non-placeholder strings are not modified"""
- execution = StandardExecution()
+ execution = StandardQueryExecution()
test_filter = {"status": "active", "query": "search"}
params = [25, "test"]
diff --git a/tests/test_cursor_update.py b/tests/test_cursor_update.py
new file mode 100644
index 0000000..b49e728
--- /dev/null
+++ b/tests/test_cursor_update.py
@@ -0,0 +1,212 @@
+# -*- coding: utf-8 -*-
+import pytest
+
+
+class TestCursorUpdate:
+ """Test suite for UPDATE operations using a dedicated test collection."""
+
+ TEST_COLLECTION = "Books"
+
+ @pytest.fixture(autouse=True)
+ def setup_teardown(self, conn):
+ """Setup: insert test data. Teardown: drop test collection."""
+ db = conn.database
+ if self.TEST_COLLECTION in db.list_collection_names():
+ db.drop_collection(self.TEST_COLLECTION)
+
+ # Insert test data for update operations
+ db[self.TEST_COLLECTION].insert_many(
+ [
+ {"title": "Book A", "author": "Alice", "year": 2020, "price": 29.99, "stock": 10, "available": True},
+ {"title": "Book B", "author": "Bob", "year": 2021, "price": 39.99, "stock": 5, "available": True},
+ {"title": "Book C", "author": "Charlie", "year": 2019, "price": 19.99, "stock": 0, "available": False},
+ {"title": "Book D", "author": "Diana", "year": 2022, "price": 49.99, "stock": 15, "available": True},
+ {"title": "Book E", "author": "Eve", "year": 2020, "price": 24.99, "stock": 8, "available": True},
+ ]
+ )
+
+ yield
+
+ # Teardown: drop the test collection after each test
+ if self.TEST_COLLECTION in db.list_collection_names():
+ db.drop_collection(self.TEST_COLLECTION)
+
+ def test_update_single_field_all_documents(self, conn):
+ """Test updating a single field in all documents."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET available = false")
+
+ assert result == cursor # execute returns self
+
+ # Verify all documents were updated
+ db = conn.database
+ updated_docs = list(db[self.TEST_COLLECTION].find())
+ assert len(updated_docs) == 5
+ assert all(doc["available"] is False for doc in updated_docs)
+
+ def test_update_with_where_equality(self, conn):
+ """Test UPDATE with WHERE clause filtering by equality."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET price = 34.99 WHERE author = 'Bob'")
+
+ assert result == cursor
+
+ # Verify only Bob's book was updated
+ db = conn.database
+ bob_book = db[self.TEST_COLLECTION].find_one({"author": "Bob"})
+ assert bob_book is not None
+ assert bob_book["price"] == 34.99
+
+ # Verify other books remain unchanged
+ alice_book = db[self.TEST_COLLECTION].find_one({"author": "Alice"})
+ assert alice_book["price"] == 29.99
+
+ def test_update_multiple_fields(self, conn):
+ """Test updating multiple fields in one statement."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET price = 14.99, stock = 20 WHERE title = 'Book C'")
+
+ assert result == cursor
+
+ # Verify multiple fields were updated
+ db = conn.database
+ book_c = db[self.TEST_COLLECTION].find_one({"title": "Book C"})
+ assert book_c is not None
+ assert book_c["price"] == 14.99
+ assert book_c["stock"] == 20
+
+ def test_update_with_numeric_comparison(self, conn):
+ """Test UPDATE with WHERE clause using numeric comparison."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET available = false WHERE stock < 5")
+
+ assert result == cursor
+
+ # Books with stock < 5 should be unavailable (Book C with 0 stock)
+ db = conn.database
+ unavailable_books = list(db[self.TEST_COLLECTION].find({"available": False}))
+ assert len(unavailable_books) >= 1
+ assert all(doc["stock"] < 5 for doc in unavailable_books)
+
+ def test_update_with_and_condition(self, conn):
+ """Test UPDATE with WHERE clause using AND condition."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET price = 22.99 WHERE year = 2020 AND stock > 5")
+
+ assert result == cursor
+
+ # Only Book E (year=2020, stock=8) should be updated
+ db = conn.database
+ book_e = db[self.TEST_COLLECTION].find_one({"title": "Book E"})
+ assert book_e is not None
+ assert book_e["price"] == 22.99
+
+ # Book A (year=2020, stock=10) should also be updated
+ book_a = db[self.TEST_COLLECTION].find_one({"title": "Book A"})
+ assert book_a is not None
+ assert book_a["price"] == 22.99
+
+ def test_update_with_qmark_parameters(self, conn):
+ """Test UPDATE with qmark (?) placeholder parameters."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET stock = ? WHERE author = ?", [25, "Alice"])
+
+ assert result == cursor
+
+ # Verify Alice's book stock was updated
+ db = conn.database
+ alice_book = db[self.TEST_COLLECTION].find_one({"author": "Alice"})
+ assert alice_book is not None
+ assert alice_book["stock"] == 25
+
+ def test_update_boolean_field(self, conn):
+ """Test updating boolean field."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET available = true WHERE stock = 0")
+
+ assert result == cursor
+
+ # Verify Book C (stock=0) is now available
+ db = conn.database
+ book_c = db[self.TEST_COLLECTION].find_one({"title": "Book C"})
+ assert book_c is not None
+ assert book_c["available"] is True
+
+ def test_update_with_greater_than(self, conn):
+ """Test UPDATE with > operator in WHERE clause."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET price = 59.99 WHERE price > 40")
+
+ assert result == cursor
+
+ # Only Book D (price=49.99) should be updated
+ db = conn.database
+ book_d = db[self.TEST_COLLECTION].find_one({"title": "Book D"})
+ assert book_d is not None
+ assert book_d["price"] == 59.99
+
+ def test_update_numeric_to_string(self, conn):
+ """Test updating numeric value with string."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET author = 'Anonymous' WHERE year = 2019")
+
+ assert result == cursor
+
+ # Verify Book C author was updated
+ db = conn.database
+ book_c = db[self.TEST_COLLECTION].find_one({"year": 2019})
+ assert book_c is not None
+ assert book_c["author"] == "Anonymous"
+
+ def test_update_rowcount(self, conn):
+ """Test that rowcount reflects number of updated documents."""
+ cursor = conn.cursor()
+ cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET available = false WHERE year = 2020")
+
+ # Two books from 2020 (Book A and Book E)
+ assert cursor.rowcount == 2
+
+ def test_update_no_matches(self, conn):
+ """Test UPDATE with WHERE clause that matches no documents."""
+ cursor = conn.cursor()
+ cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET price = 99.99 WHERE year = 1999")
+
+ # No documents should be updated
+ assert cursor.rowcount == 0
+
+ # Verify all books retain original prices
+ db = conn.database
+ books = list(db[self.TEST_COLLECTION].find())
+ assert all(doc["price"] < 60 for doc in books)
+
+ def test_update_nested_field(self, conn):
+ """Test updating nested field using dot notation."""
+ # First insert a document with nested structure
+ db = conn.database
+ db[self.TEST_COLLECTION].insert_one(
+ {"title": "Book F", "author": "Frank", "details": {"pages": 300, "publisher": "ABC"}, "year": 2023}
+ )
+
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET details.pages = 350 WHERE title = 'Book F'")
+
+ assert result == cursor
+
+ # Verify nested field was updated
+ book_f = db[self.TEST_COLLECTION].find_one({"title": "Book F"})
+ assert book_f is not None
+ assert book_f["details"]["pages"] == 350
+ assert book_f["details"]["publisher"] == "ABC" # Other nested field unchanged
+
+ def test_update_set_null(self, conn):
+ """Test setting a field to NULL."""
+ cursor = conn.cursor()
+ result = cursor.execute(f"UPDATE {self.TEST_COLLECTION} SET stock = null WHERE title = 'Book B'")
+
+ assert result == cursor
+
+ # Verify stock was set to None
+ db = conn.database
+ book_b = db[self.TEST_COLLECTION].find_one({"title": "Book B"})
+ assert book_b is not None
+ assert book_b["stock"] is None
diff --git a/tests/test_sql_parser_delete.py b/tests/test_sql_parser_delete.py
new file mode 100644
index 0000000..395f37b
--- /dev/null
+++ b/tests/test_sql_parser_delete.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+from pymongosql.sql.delete_builder import DeleteExecutionPlan
+from pymongosql.sql.parser import SQLParser
+
+
+class TestSQLParserDelete:
+ """Tests for DELETE parsing via AST visitor (PartiQL-style)."""
+
+ def test_delete_all_documents(self):
+ """Test DELETE without WHERE clause."""
+ sql = "DELETE FROM users"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "users"
+ assert plan.filter_conditions == {}
+
+ def test_delete_with_simple_where(self):
+ """Test DELETE with simple equality WHERE clause."""
+ sql = "DELETE FROM orders WHERE status = 'cancelled'"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "orders"
+ assert plan.filter_conditions == {"status": "cancelled"}
+
+ def test_delete_with_numeric_filter(self):
+ """Test DELETE with numeric comparison."""
+ sql = "DELETE FROM products WHERE price > 100"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "products"
+ assert plan.filter_conditions == {"price": {"$gt": 100}}
+
+ def test_delete_with_less_than(self):
+ """Test DELETE with less than operator."""
+ sql = "DELETE FROM sessions WHERE created_at < 1609459200"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "sessions"
+ assert plan.filter_conditions == {"created_at": {"$lt": 1609459200}}
+
+ def test_delete_with_greater_equal(self):
+ """Test DELETE with >= operator."""
+ sql = "DELETE FROM inventory WHERE quantity >= 1000"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "inventory"
+ assert plan.filter_conditions == {"quantity": {"$gte": 1000}}
+
+ def test_delete_with_less_equal(self):
+ """Test DELETE with <= operator."""
+ sql = "DELETE FROM logs WHERE severity <= 2"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "logs"
+ assert plan.filter_conditions == {"severity": {"$lte": 2}}
+
+ def test_delete_with_not_equal(self):
+ """Test DELETE with != operator."""
+ sql = "DELETE FROM temp WHERE valid != true"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "temp"
+ assert plan.filter_conditions == {"valid": {"$ne": True}}
+
+ def test_delete_with_qmark_parameter(self):
+ """Test DELETE with qmark placeholder."""
+ sql = "DELETE FROM users WHERE name = '?'"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "users"
+ # Parameters should be in the filter as placeholders
+ assert plan.filter_conditions == {"name": "?"}
+
+ def test_delete_with_named_parameter(self):
+ """Test DELETE with named parameter placeholder."""
+ sql = "DELETE FROM orders WHERE order_id = ':orderId'"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "orders"
+ assert plan.filter_conditions == {"order_id": ":orderId"}
+
+ def test_delete_with_null_comparison(self):
+ """Test DELETE with NULL value."""
+ sql = "DELETE FROM cache WHERE expires = null"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "cache"
+ assert plan.filter_conditions == {"expires": None}
+
+ def test_delete_with_boolean_true(self):
+ """Test DELETE with boolean TRUE value."""
+ sql = "DELETE FROM flags WHERE active = true"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "flags"
+ assert plan.filter_conditions == {"active": True}
+
+ def test_delete_with_boolean_false(self):
+ """Test DELETE with boolean FALSE value."""
+ sql = "DELETE FROM flags WHERE active = false"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "flags"
+ assert plan.filter_conditions == {"active": False}
+
+ def test_delete_with_string_value(self):
+ """Test DELETE with string literal."""
+ sql = "DELETE FROM users WHERE username = 'john_doe'"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "users"
+ assert plan.filter_conditions == {"username": "john_doe"}
+
+ def test_delete_with_negative_number(self):
+ """Test DELETE with negative number."""
+ sql = "DELETE FROM transactions WHERE amount = -50"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "transactions"
+ assert plan.filter_conditions == {"amount": -50}
+
+ def test_delete_with_float_value(self):
+ """Test DELETE with floating point number."""
+ sql = "DELETE FROM measurements WHERE temperature > 36.5"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "measurements"
+ assert plan.filter_conditions == {"temperature": {"$gt": 36.5}}
+
+ def test_delete_with_and_condition(self):
+ """Test DELETE with AND condition."""
+ sql = "DELETE FROM items WHERE category = 'electronics' AND price > 500"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "items"
+ # AND condition creates a $and array with both conditions
+ assert "$and" in plan.filter_conditions
+ assert len(plan.filter_conditions["$and"]) == 2
+ assert {"category": "electronics"} in plan.filter_conditions["$and"]
+ assert {"price": {"$gt": 500}} in plan.filter_conditions["$and"]
+
+ def test_delete_with_or_condition(self):
+ """Test DELETE with OR condition."""
+ sql = "DELETE FROM logs WHERE severity = 'ERROR' OR severity = 'CRITICAL'"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "logs"
+ assert "$or" in plan.filter_conditions
+ assert len(plan.filter_conditions["$or"]) == 2
+ assert {"severity": "ERROR"} in plan.filter_conditions["$or"]
+ assert {"severity": "CRITICAL"} in plan.filter_conditions["$or"]
+
+ def test_delete_collection_name_case_sensitive(self):
+ """Test that collection names are case-sensitive."""
+ sql = "DELETE FROM MyCollection WHERE id = 1"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "MyCollection"
+ assert plan.filter_conditions == {"id": 1}
+
+ def test_delete_field_name_case_sensitive(self):
+ """Test that field names are case-sensitive."""
+ sql = "DELETE FROM users WHERE UserID = 123"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.collection == "users"
+ assert plan.filter_conditions == {"UserID": 123}
+
+ def test_delete_validates_execution_plan(self):
+ """Test that validation is called on the execution plan."""
+ sql = "DELETE FROM products WHERE category = 'obsolete'"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, DeleteExecutionPlan)
+ assert plan.validate() is True
+ assert plan.collection == "products"
diff --git a/tests/test_sql_parser_insert.py b/tests/test_sql_parser_insert.py
new file mode 100644
index 0000000..767e1cd
--- /dev/null
+++ b/tests/test_sql_parser_insert.py
@@ -0,0 +1,132 @@
+# -*- coding: utf-8 -*-
+import pytest
+
+from pymongosql.error import SqlSyntaxError
+from pymongosql.sql.insert_builder import InsertExecutionPlan
+from pymongosql.sql.parser import SQLParser
+
+
+class TestSQLParserInsert:
+ """Tests for INSERT parsing via AST visitor (PartiQL-style)."""
+
+ def test_insert_single_object_literal(self):
+ sql = "INSERT INTO users {'id': 1, 'name': 'Jane', 'age': 30}"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, InsertExecutionPlan)
+ assert plan.collection == "users"
+ assert plan.insert_documents == [{"id": 1, "name": "Jane", "age": 30}]
+
+ def test_insert_bag_documents(self):
+ sql = "INSERT INTO items << {'a': 1}, {'a': 2, 'b': 'x'} >>"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert plan.collection == "items"
+ assert plan.insert_documents == [{"a": 1}, {"a": 2, "b": "x"}]
+
+ def test_insert_literals_lowercase(self):
+ sql = "INSERT INTO flags {'is_on': null, 'is_new': true, 'note': 'ok'}"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert plan.collection == "flags"
+ assert plan.insert_documents == [{"is_on": None, "is_new": True, "note": "ok"}]
+
+ def test_insert_object_qmark_parameters(self):
+ sql = "INSERT INTO orders {'id': '?', 'total': '?'}"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert plan.collection == "orders"
+ assert plan.insert_documents == [{"id": "?", "total": "?"}]
+ assert plan.parameter_style == "qmark"
+ assert plan.parameter_count == 2
+
+ def test_insert_object_named_parameters(self):
+ sql = "INSERT INTO orders {'id': ':id', 'total': ':total'}"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert plan.collection == "orders"
+ assert plan.insert_documents == [{"id": ":id", "total": ":total"}]
+ assert plan.parameter_style == "named"
+ assert plan.parameter_count == 2
+
+ def test_insert_bag_named_parameters(self):
+ sql = "INSERT INTO items << {'a': ':one'}, {'a': ':two'} >>"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert plan.collection == "items"
+ assert plan.insert_documents == [{"a": ":one"}, {"a": ":two"}]
+ assert plan.parameter_style == "named"
+ assert plan.parameter_count == 2
+
+ def test_insert_mixed_parameter_styles_fails(self):
+ sql = "INSERT INTO items << {'a': '?'}, {'a': ':b'} >>"
+ with pytest.raises(SqlSyntaxError):
+ SQLParser(sql).get_execution_plan()
+
+ def test_insert_single_tuple(self):
+ sql = "INSERT INTO Films {'code': 'B6717', 'did': 110, 'date_prod': '1985-02-10', 'kind': 'Comedy'}"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, InsertExecutionPlan)
+ assert plan.collection == "Films"
+ assert plan.insert_documents == [
+ {
+ "code": "B6717",
+ "did": 110,
+ "date_prod": "1985-02-10",
+ "kind": "Comedy",
+ }
+ ]
+
+ # Not supported in PartiQL Grammar yet
+ #
+ # def test_insert_values_single_row(self):
+ # sql = (
+ # "INSERT INTO Films (code, title, did, date_prod, kind) "
+ # "VALUES ('B6717', 'Tampopo', 110, '1985-02-10', 'Comedy')"
+ # )
+ # plan = SQLParser("".join(sql)).get_execution_plan()
+
+ # assert isinstance(plan, InsertExecutionPlan)
+ # assert plan.collection == "Films"
+ # assert plan.insert_documents == [
+ # {
+ # "code": "B6717",
+ # "title": "Tampopo",
+ # "did": 110,
+ # "date_prod": "1985-02-10",
+ # "kind": "Comedy",
+ # }
+ # ]
+
+ # def test_insert_values_multiple_rows(self):
+ # sql = (
+ # "INSERT INTO Films (code, title, did, date_prod, kind) "
+ # "VALUES ('B6717', 'Tampopo', 110, '1985-02-10', 'Comedy'),"
+ # " ('HG120', 'The Dinner Game', 140, DEFAULT, 'Comedy')"
+ # )
+ # plan = SQLParser("".join(sql)).get_execution_plan()
+
+ # assert isinstance(plan, InsertExecutionPlan)
+ # assert plan.collection == "Films"
+ # assert plan.insert_documents == [
+ # {
+ # "code": "B6717",
+ # "title": "Tampopo",
+ # "did": 110,
+ # "date_prod": "1985-02-10",
+ # "kind": "Comedy",
+ # },
+ # {
+ # "code": "HG120",
+ # "title": "The Dinner Game",
+ # "did": 140,
+ # "date_prod": None,
+ # "kind": "Comedy",
+ # },
+ # ]
+
+ def test_insert_invalid_expression_raises(self):
+ sql = "INSERT INTO users 123"
+ with pytest.raises(SqlSyntaxError):
+ SQLParser(sql).get_execution_plan()
diff --git a/tests/test_sql_parser_update.py b/tests/test_sql_parser_update.py
new file mode 100644
index 0000000..265af3e
--- /dev/null
+++ b/tests/test_sql_parser_update.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+from pymongosql.sql.parser import SQLParser
+from pymongosql.sql.update_builder import UpdateExecutionPlan
+
+
+class TestSQLParserUpdate:
+ """Tests for UPDATE parsing via AST visitor (PartiQL-style)."""
+
+ def test_update_simple_field(self):
+ """Test UPDATE with single field update."""
+ sql = "UPDATE users SET name = 'John'"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.collection == "users"
+ assert plan.update_fields == {"name": "John"}
+ assert plan.filter_conditions == {}
+
+ def test_update_multiple_fields(self):
+ """Test UPDATE with multiple field updates."""
+ sql = "UPDATE products SET price = 100, stock = 50"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.collection == "products"
+ assert plan.update_fields == {"price": 100, "stock": 50}
+ assert plan.filter_conditions == {}
+
+ def test_update_with_where_clause(self):
+ """Test UPDATE with WHERE clause."""
+ sql = "UPDATE orders SET status = 'shipped' WHERE id = 123"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.collection == "orders"
+ assert plan.update_fields == {"status": "shipped"}
+ assert plan.filter_conditions == {"id": 123}
+
+ def test_update_numeric_value(self):
+ """Test UPDATE with numeric value."""
+ sql = "UPDATE inventory SET quantity = 100 WHERE product_id = 5"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.collection == "inventory"
+ assert plan.update_fields == {"quantity": 100}
+ assert plan.filter_conditions == {"product_id": 5}
+
+ def test_update_boolean_value(self):
+ """Test UPDATE with boolean value."""
+ sql = "UPDATE settings SET enabled = true WHERE setting_key = 'feature_flag'"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.collection == "settings"
+ assert plan.update_fields == {"enabled": True}
+ assert plan.filter_conditions == {"setting_key": "feature_flag"}
+
+ def test_update_null_value(self):
+ """Test UPDATE with NULL value."""
+ sql = "UPDATE cache SET expires = null WHERE session_id = 'abc123'"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.collection == "cache"
+ assert plan.update_fields == {"expires": None}
+ assert plan.filter_conditions == {"session_id": "abc123"}
+
+ def test_update_with_comparison_operators(self):
+ """Test UPDATE with various comparison operators in WHERE."""
+ sql = "UPDATE products SET discount = 0.1 WHERE price > 100"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.collection == "products"
+ assert plan.update_fields == {"discount": 0.1}
+ assert plan.filter_conditions == {"price": {"$gt": 100}}
+
+ def test_update_with_and_condition(self):
+ """Test UPDATE with AND condition in WHERE."""
+ sql = "UPDATE items SET status = 'archived' WHERE category = 'old' AND year < 2020"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.collection == "items"
+ assert plan.update_fields == {"status": "archived"}
+ assert "$and" in plan.filter_conditions
+ assert len(plan.filter_conditions["$and"]) == 2
+
+ def test_update_validates_execution_plan(self):
+ """Test that validation is called on the execution plan."""
+ sql = "UPDATE users SET email = 'test@example.com' WHERE id = 1"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.validate() is True
+ assert plan.collection == "users"
+
+ def test_update_nested_field(self):
+ """Test UPDATE with nested field path."""
+ sql = "UPDATE users SET address.city = 'NYC' WHERE id = 1"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.collection == "users"
+ assert "address.city" in plan.update_fields
+ assert plan.update_fields["address.city"] == "NYC"
+
+ def test_update_with_parameter_placeholder(self):
+ """Test UPDATE with parameter placeholder."""
+ sql = "UPDATE users SET name = '?' WHERE id = '?'"
+ plan = SQLParser(sql).get_execution_plan()
+
+ assert isinstance(plan, UpdateExecutionPlan)
+ assert plan.collection == "users"
+ assert plan.update_fields == {"name": "?"}
+ assert plan.filter_conditions == {"id": "?"}
diff --git a/tests/test_superset_connection.py b/tests/test_superset_connection.py
index 3bb9907..159f265 100644
--- a/tests/test_superset_connection.py
+++ b/tests/test_superset_connection.py
@@ -86,13 +86,13 @@ def test_subquery_execution_supports_subqueries(self):
assert superset_strategy.supports(context) is True
def test_standard_execution_rejects_subqueries(self):
- """Test that StandardExecution doesn't support subqueries"""
- from pymongosql.executor import StandardExecution
+ """Test that StandardQueryExecution doesn't support subqueries"""
+ from pymongosql.executor import StandardQueryExecution
subquery_sql = "SELECT * FROM (SELECT id, name FROM users) AS u WHERE u.id > 10"
context = ExecutionContext(subquery_sql, "superset")
- standard_strategy = StandardExecution()
+ standard_strategy = StandardQueryExecution()
assert standard_strategy.supports(context) is False
def test_get_strategy_selects_subquery_execution(self):
@@ -104,14 +104,14 @@ def test_get_strategy_selects_subquery_execution(self):
assert isinstance(strategy, SupersetExecution)
def test_get_strategy_selects_standard_execution(self):
- """Test that get_strategy returns StandardExecution for simple queries"""
- from pymongosql.executor import StandardExecution
+ """Test that get_strategy returns StandardQueryExecution for simple queries"""
+ from pymongosql.executor import StandardQueryExecution
simple_sql = "SELECT id, name FROM users WHERE id > 10"
context = ExecutionContext(simple_sql)
strategy = ExecutionPlanFactory.get_strategy(context)
- assert isinstance(strategy, StandardExecution)
+ assert isinstance(strategy, StandardQueryExecution)
class TestConnectionModeDetection: