Skip to content

Commit ccd261c

Browse files
committed
Refactor builder class
1 parent c0f2859 commit ccd261c

File tree

4 files changed

+129
-54
lines changed

4 files changed

+129
-54
lines changed

pymongosql/sql/ast.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Dict
44

55
from ..error import SqlSyntaxError
6-
from .builder import ExecutionPlan
6+
from .builder import BuilderFactory, ExecutionPlan
77
from .handler import BaseHandler, HandlerFactory, ParseResult
88
from .partiql.PartiQLLexer import PartiQLLexer
99
from .partiql.PartiQLParser import PartiQLParser
@@ -47,15 +47,14 @@ def parse_result(self) -> ParseResult:
4747
return self._parse_result
4848

4949
def parse_to_execution_plan(self) -> ExecutionPlan:
50-
"""Convert the parse result to an ExecutionPlan"""
51-
return ExecutionPlan(
52-
collection=self._parse_result.collection,
53-
filter_stage=self._parse_result.filter_conditions,
54-
projection_stage=self._parse_result.projection,
55-
sort_stage=self._parse_result.sort_fields,
56-
limit_stage=self._parse_result.limit_value,
57-
skip_stage=self._parse_result.offset_value,
58-
)
50+
"""Convert the parse result to an ExecutionPlan using BuilderFactory"""
51+
builder = BuilderFactory.create_query_builder().collection(self._parse_result.collection)
52+
53+
builder.filter(self._parse_result.filter_conditions).project(self._parse_result.projection).sort(
54+
self._parse_result.sort_fields
55+
).limit(self._parse_result.limit_value).skip(self._parse_result.offset_value)
56+
57+
return builder.build()
5958

6059
def visitRoot(self, ctx: PartiQLParser.RootContext) -> Any:
6160
"""Visit root node and process child nodes"""

pymongosql/sql/builder.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,36 @@ def project(self, fields: Union[Dict[str, int], List[str]]) -> "MongoQueryBuilde
104104
_logger.debug(f"Set projection: {projection}")
105105
return self
106106

107-
def sort(self, field: str, direction: int = 1) -> "MongoQueryBuilder":
108-
"""Add sort criteria"""
109-
if not field or not isinstance(field, str):
110-
self._add_error("Sort field must be a non-empty string")
111-
return self
107+
def sort(self, specs: List[Dict[str, int]]) -> "MongoQueryBuilder":
108+
"""Add sort criteria.
109+
110+
Only accepts a list of single-key dicts in the form:
111+
[{"field": 1}, {"other": -1}]
112112
113-
if direction not in [-1, 1]:
114-
self._add_error("Sort direction must be 1 (ascending) or -1 (descending)")
113+
This matches the output produced by the SQL parser (`sort_fields`).
114+
"""
115+
if not isinstance(specs, list):
116+
self._add_error("Sort specifications must be a list of single-key dicts")
115117
return self
116118

117-
self._execution_plan.sort_stage.append({field: direction})
118-
_logger.debug(f"Added sort: {field} -> {direction}")
119+
for spec in specs:
120+
if not isinstance(spec, dict) or len(spec) != 1:
121+
self._add_error("Each sort specification must be a single-key dict, e.g. {'name': 1}")
122+
continue
123+
124+
field, direction = next(iter(spec.items()))
125+
126+
if not isinstance(field, str) or not field:
127+
self._add_error("Sort field must be a non-empty string")
128+
continue
129+
130+
if direction not in [-1, 1]:
131+
self._add_error(f"Sort direction for field '{field}' must be 1 or -1")
132+
continue
133+
134+
self._execution_plan.sort_stage.append({field: direction})
135+
_logger.debug(f"Added sort: {field} -> {direction}")
136+
119137
return self
120138

121139
def limit(self, count: int) -> "MongoQueryBuilder":

tests/test_result_set.py

Lines changed: 68 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from pymongosql.error import ProgrammingError
55
from pymongosql.result_set import ResultSet
6-
from pymongosql.sql.builder import ExecutionPlan
6+
from pymongosql.sql.builder import BuilderFactory
77

88

99
class TestResultSet:
@@ -19,7 +19,9 @@ def test_result_set_init(self, conn):
1919
# Execute a real command to get results
2020
command_result = db.command({"find": "users", "filter": {"age": {"$gt": 25}}, "limit": 1})
2121

22-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS)
22+
execution_plan = (
23+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build()
24+
)
2325
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
2426
assert result_set._command_result == command_result
2527
assert result_set._execution_plan == execution_plan
@@ -30,7 +32,9 @@ def test_result_set_init_empty_projection(self, conn):
3032
db = conn.database
3133
command_result = db.command({"find": "users", "limit": 1})
3234

33-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
35+
execution_plan = (
36+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
37+
)
3438
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
3539
assert result_set._execution_plan.projection_stage == {}
3640

@@ -40,7 +44,9 @@ def test_fetchone_with_data(self, conn):
4044
# Get real user data with projection mapping
4145
command_result = db.command({"find": "users", "projection": {"name": 1, "email": 1}, "limit": 1})
4246

43-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS)
47+
execution_plan = (
48+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build()
49+
)
4450
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
4551
row = result_set.fetchone()
4652

@@ -66,7 +72,9 @@ def test_fetchone_no_data(self, conn):
6672
{"find": "users", "filter": {"age": {"$gt": 999}}, "limit": 1} # No users over 999 years old
6773
)
6874

69-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS)
75+
execution_plan = (
76+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build()
77+
)
7078
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
7179
row = result_set.fetchone()
7280

@@ -77,7 +85,9 @@ def test_fetchone_empty_projection(self, conn):
7785
db = conn.database
7886
command_result = db.command({"find": "users", "limit": 1, "sort": {"_id": 1}})
7987

80-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
88+
execution_plan = (
89+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
90+
)
8191
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
8292
row = result_set.fetchone()
8393

@@ -102,7 +112,9 @@ def test_fetchone_closed_cursor(self, conn):
102112
db = conn.database
103113
command_result = db.command({"find": "users", "limit": 1})
104114

105-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS)
115+
execution_plan = (
116+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build()
117+
)
106118
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
107119
result_set.close()
108120

@@ -115,7 +127,9 @@ def test_fetchmany_with_data(self, conn):
115127
# Get multiple users with projection
116128
command_result = db.command({"find": "users", "projection": {"name": 1, "email": 1}, "limit": 5})
117129

118-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS)
130+
execution_plan = (
131+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build()
132+
)
119133
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
120134
rows = result_set.fetchmany(2)
121135

@@ -141,7 +155,9 @@ def test_fetchmany_default_size(self, conn):
141155
# Get all users (22 total in test dataset)
142156
command_result = db.command({"find": "users"})
143157

144-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
158+
execution_plan = (
159+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
160+
)
145161
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
146162
rows = result_set.fetchmany() # Should use default arraysize (1000)
147163

@@ -153,7 +169,9 @@ def test_fetchmany_less_data_available(self, conn):
153169
# Get only 2 users but request 5
154170
command_result = db.command({"find": "users", "limit": 2})
155171

156-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
172+
execution_plan = (
173+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
174+
)
157175
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
158176
rows = result_set.fetchmany(5) # Request 5 but only 2 available
159177

@@ -165,7 +183,9 @@ def test_fetchmany_no_data(self, conn):
165183
# Query for non-existent data
166184
command_result = db.command({"find": "users", "filter": {"age": {"$gt": 999}}}) # No users over 999 years old
167185

168-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
186+
execution_plan = (
187+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
188+
)
169189
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
170190
rows = result_set.fetchmany(3)
171191

@@ -179,7 +199,9 @@ def test_fetchall_with_data(self, conn):
179199
{"find": "users", "filter": {"age": {"$gt": 25}}, "projection": {"name": 1, "email": 1}}
180200
)
181201

182-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS)
202+
execution_plan = (
203+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build()
204+
)
183205
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
184206
rows = result_set.fetchall()
185207

@@ -201,7 +223,9 @@ def test_fetchall_no_data(self, conn):
201223
db = conn.database
202224
command_result = db.command({"find": "users", "filter": {"age": {"$gt": 999}}}) # No users over 999 years old
203225

204-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
226+
execution_plan = (
227+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
228+
)
205229
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
206230
rows = result_set.fetchall()
207231

@@ -212,7 +236,9 @@ def test_fetchall_closed_cursor(self, conn):
212236
db = conn.database
213237
command_result = db.command({"find": "users", "limit": 1})
214238

215-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
239+
execution_plan = (
240+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
241+
)
216242
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
217243
result_set.close()
218244

@@ -222,7 +248,7 @@ def test_fetchall_closed_cursor(self, conn):
222248
def test_apply_projection_mapping(self):
223249
"""Test _process_document method"""
224250
projection = {"name": 1, "age": 1, "email": 1}
225-
execution_plan = ExecutionPlan(collection="users", projection_stage=projection)
251+
execution_plan = BuilderFactory.create_query_builder().collection("users").project(projection).build()
226252

227253
# Create empty command result for testing _process_document method
228254
command_result = {"cursor": {"firstBatch": []}}
@@ -248,7 +274,7 @@ def test_apply_projection_mapping_missing_fields(self):
248274
"age": 1,
249275
"missing": 1,
250276
}
251-
execution_plan = ExecutionPlan(collection="users", projection_stage=projection)
277+
execution_plan = BuilderFactory.create_query_builder().collection("users").project(projection).build()
252278

253279
command_result = {"cursor": {"firstBatch": []}}
254280
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
@@ -264,7 +290,7 @@ def test_apply_projection_mapping_missing_fields(self):
264290
def test_apply_projection_mapping_identity_mapping(self):
265291
"""Test projection with MongoDB standard format"""
266292
projection = {"name": 1, "age": 1}
267-
execution_plan = ExecutionPlan(collection="users", projection_stage=projection)
293+
execution_plan = BuilderFactory.create_query_builder().collection("users").project(projection).build()
268294

269295
command_result = {"cursor": {"firstBatch": []}}
270296
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
@@ -279,7 +305,7 @@ def test_apply_projection_mapping_identity_mapping(self):
279305
def test_array_projection_mapping(self):
280306
"""Test projection mapping with array bracket/number conversion"""
281307
projection = {"items.0": 1, "items.1.name": 1}
282-
execution_plan = ExecutionPlan(collection="orders", projection_stage=projection)
308+
execution_plan = BuilderFactory.create_query_builder().collection("orders").project(projection).build()
283309

284310
command_result = {"cursor": {"firstBatch": []}}
285311
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
@@ -294,7 +320,9 @@ def test_array_projection_mapping(self):
294320
def test_close(self):
295321
"""Test close method"""
296322
command_result = {"cursor": {"firstBatch": []}}
297-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
323+
execution_plan = (
324+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
325+
)
298326
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
299327

300328
# Should not be closed initially
@@ -308,7 +336,9 @@ def test_close(self):
308336
def test_context_manager(self):
309337
"""Test ResultSet as context manager"""
310338
command_result = {"cursor": {"firstBatch": []}}
311-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
339+
execution_plan = (
340+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
341+
)
312342
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
313343

314344
with result_set as rs:
@@ -321,7 +351,9 @@ def test_context_manager(self):
321351
def test_context_manager_with_exception(self):
322352
"""Test context manager with exception"""
323353
command_result = {"cursor": {"firstBatch": []}}
324-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
354+
execution_plan = (
355+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
356+
)
325357
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
326358

327359
try:
@@ -340,7 +372,9 @@ def test_iterator_protocol(self, conn):
340372
# Get 2 users from database
341373
command_result = db.command({"find": "users", "limit": 2})
342374

343-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
375+
execution_plan = (
376+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
377+
)
344378
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
345379

346380
# Test iterator protocol
@@ -363,7 +397,9 @@ def test_iterator_with_projection(self, conn):
363397
db = conn.database
364398
command_result = db.command({"find": "users", "projection": {"name": 1, "email": 1}, "limit": 2})
365399

366-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_WITH_FIELDS)
400+
execution_plan = (
401+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_WITH_FIELDS).build()
402+
)
367403
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
368404

369405
rows = list(result_set)
@@ -378,7 +414,9 @@ def test_iterator_with_projection(self, conn):
378414
def test_iterator_closed_cursor(self):
379415
"""Test iteration on closed cursor"""
380416
command_result = {"cursor": {"firstBatch": []}}
381-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
417+
execution_plan = (
418+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
419+
)
382420
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
383421
result_set.close()
384422

@@ -388,7 +426,9 @@ def test_iterator_closed_cursor(self):
388426
def test_arraysize_property(self):
389427
"""Test arraysize property"""
390428
command_result = {"cursor": {"firstBatch": []}}
391-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
429+
execution_plan = (
430+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
431+
)
392432
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
393433

394434
# Default arraysize should be 1000
@@ -401,7 +441,9 @@ def test_arraysize_property(self):
401441
def test_arraysize_validation(self):
402442
"""Test arraysize validation"""
403443
command_result = {"cursor": {"firstBatch": []}}
404-
execution_plan = ExecutionPlan(collection="users", projection_stage=self.PROJECTION_EMPTY)
444+
execution_plan = (
445+
BuilderFactory.create_query_builder().collection("users").project(self.PROJECTION_EMPTY).build()
446+
)
405447
result_set = ResultSet(command_result=command_result, execution_plan=execution_plan)
406448

407449
# Should reject invalid values

0 commit comments

Comments
 (0)