diff --git a/pymongosql/__init__.py b/pymongosql/__init__.py index 09bcd31..0732d1a 100644 --- a/pymongosql/__init__.py +++ b/pymongosql/__init__.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: from .connection import Connection -__version__: str = "0.2.2" +__version__: str = "0.2.3" # Globals https://www.python.org/dev/peps/pep-0249/#globals apilevel: str = "2.0" diff --git a/pymongosql/connection.py b/pymongosql/connection.py index d31f34d..f547e96 100644 --- a/pymongosql/connection.py +++ b/pymongosql/connection.py @@ -45,14 +45,18 @@ def __init__( """ # Check if connection string specifies mode connection_string = host if isinstance(host, str) else None - self._mode, host = ConnectionHelper.parse_connection_string(connection_string) + mode, host = ConnectionHelper.parse_connection_string(connection_string) + + self._mode = kwargs.pop("mode", None) + if not self._mode and mode: + self._mode = mode # Extract commonly used parameters for backward compatibility self._host = host or "localhost" self._port = port or 27017 # Handle database parameter separately (not a MongoClient parameter) - self._database_name = kwargs.pop("database", None) # Remove from kwargs + self._database_name = kwargs.pop("database", None) # Store all PyMongo parameters to pass through directly self._pymongo_params = kwargs.copy() diff --git a/pymongosql/result_set.py b/pymongosql/result_set.py index 1a78db2..20597b7 100644 --- a/pymongosql/result_set.py +++ b/pymongosql/result_set.py @@ -65,10 +65,16 @@ def _process_and_cache_batch(self, batch: List[Dict[str, Any]]) -> None: self._total_fetched += len(batch) def _build_description(self) -> None: - """Build column description from execution plan projection""" + """Build column description from execution plan projection or established column names""" if not self._execution_plan.projection_stage: - # No projection specified, description will be built dynamically - self._description = None + # No projection specified, build description from column names if available + if self._column_names: + self._description = [ + (col_name, "VARCHAR", None, None, None, None, None) for col_name in self._column_names + ] + else: + # Will be built dynamically when columns are established + self._description = None return # Build description from projection (now in MongoDB format {field: 1}) @@ -198,10 +204,13 @@ def description( self, ) -> Optional[List[Tuple[str, str, None, None, None, None, None]]]: """Return column description""" - if self._description is None and not self._cache_exhausted: - # Try to fetch one result to build description dynamically + if self._description is None: + # Try to build description from established column names try: - self._ensure_results_available(1) + if not self._cache_exhausted: + # Fetch one result to establish column names if needed + self._ensure_results_available(1) + if self._column_names: # Build description from established column names self._description = [ diff --git a/pymongosql/sqlalchemy_mongodb/__init__.py b/pymongosql/sqlalchemy_mongodb/__init__.py index 9bf4cc2..94c3078 100644 --- a/pymongosql/sqlalchemy_mongodb/__init__.py +++ b/pymongosql/sqlalchemy_mongodb/__init__.py @@ -29,22 +29,31 @@ __supports_sqlalchemy_2x__ = False -def create_engine_url(host: str = "localhost", port: int = 27017, database: str = "test", **kwargs) -> str: +def create_engine_url( + host: str = "localhost", port: int = 27017, database: str = "test", mode: str = "standard", **kwargs +) -> str: """Create a SQLAlchemy engine URL for PyMongoSQL. Args: host: MongoDB host port: MongoDB port database: Database name + mode: Connection mode - "standard" (default) or "superset" (with subquery support) **kwargs: Additional connection parameters Returns: - SQLAlchemy URL string (uses mongodb:// format) + SQLAlchemy URL string Example: + >>> # Standard mode >>> url = create_engine_url("localhost", 27017, "mydb") >>> engine = sqlalchemy.create_engine(url) + >>> # Superset mode with subquery support + >>> url = create_engine_url("localhost", 27017, "mydb", mode="superset") + >>> engine = sqlalchemy.create_engine(url) """ + scheme = "mongodb+superset" if mode == "superset" else "mongodb" + params = [] for key, value in kwargs.items(): params.append(f"{key}={value}") @@ -53,7 +62,7 @@ def create_engine_url(host: str = "localhost", port: int = 27017, database: str if param_str: param_str = "?" + param_str - return f"mongodb://{host}:{port}/{database}{param_str}" + return f"{scheme}://{host}:{port}/{database}{param_str}" def create_mongodb_url(mongodb_uri: str) -> str: @@ -77,11 +86,11 @@ def create_mongodb_url(mongodb_uri: str) -> str: def create_engine_from_mongodb_uri(mongodb_uri: str, **engine_kwargs): """Create a SQLAlchemy engine from any MongoDB connection string. - This function handles both mongodb:// and mongodb+srv:// URIs properly. - Use this instead of create_engine() directly for mongodb+srv URIs. + This function handles mongodb://, mongodb+srv://, and mongodb+superset:// URIs properly. + Use this instead of create_engine() directly for special URI schemes. Args: - mongodb_uri: Standard MongoDB connection string + mongodb_uri: MongoDB connection string (supports standard, SRV, and superset modes) **engine_kwargs: Additional arguments passed to create_engine Returns: @@ -92,6 +101,8 @@ def create_engine_from_mongodb_uri(mongodb_uri: str, **engine_kwargs): >>> engine = create_engine_from_mongodb_uri("mongodb+srv://user:pass@cluster.net/db") >>> # For standard MongoDB >>> engine = create_engine_from_mongodb_uri("mongodb://localhost:27017/mydb") + >>> # For superset mode (with subquery support) + >>> engine = create_engine_from_mongodb_uri("mongodb+superset://localhost:27017/mydb") """ try: from sqlalchemy import create_engine @@ -109,6 +120,22 @@ def custom_create_connect_args(url): opts = {"host": mongodb_uri} return [], opts + engine.dialect.create_connect_args = custom_create_connect_args + return engine + elif mongodb_uri.startswith("mongodb+superset://"): + # For MongoDB+Superset, convert to standard mongodb:// for SQLAlchemy compatibility + # but preserve the superset mode by passing it through connection options + converted_uri = mongodb_uri.replace("mongodb+superset://", "mongodb://") + + # Create engine with converted URI + engine = create_engine(converted_uri, **engine_kwargs) + + def custom_create_connect_args(url): + # Use original superset URI for actual MongoDB connection + # This preserves the superset mode for subquery support + opts = {"host": mongodb_uri} + return [], opts + engine.dialect.create_connect_args = custom_create_connect_args return engine else: @@ -123,7 +150,7 @@ def register_dialect(): """Register the PyMongoSQL dialect with SQLAlchemy. This function handles registration for both SQLAlchemy 1.x and 2.x. - Registers support for standard MongoDB connection strings only. + Registers support for standard, SRV, and superset MongoDB connection strings. """ try: from sqlalchemy.dialects import registry @@ -131,10 +158,10 @@ def register_dialect(): # Register for standard MongoDB URLs registry.register("mongodb", "pymongosql.sqlalchemy_mongodb.sqlalchemy_dialect", "PyMongoSQLDialect") - # Try to register both SRV forms so SQLAlchemy can resolve SRV-style URLs - # (either 'mongodb+srv' or the dotted 'mongodb.srv' plugin name). - # Some SQLAlchemy versions accept '+' in scheme names; others import - # the dotted plugin name. Attempt both registrations in one block. + # Try to register SRV and Superset forms so SQLAlchemy can resolve these URL patterns + # (either with '+' or dotted notation for compatibility with different SQLAlchemy versions). + # Some SQLAlchemy versions accept '+' in scheme names; others import the dotted plugin name. + # Attempt all registrations but don't fail if some are not supported. try: registry.register("mongodb+srv", "pymongosql.sqlalchemy_mongodb.sqlalchemy_dialect", "PyMongoSQLDialect") registry.register("mongodb.srv", "pymongosql.sqlalchemy_mongodb.sqlalchemy_dialect", "PyMongoSQLDialect") @@ -143,6 +170,18 @@ def register_dialect(): # create_engine_from_mongodb_uri by converting 'mongodb+srv' to 'mongodb'. pass + try: + registry.register( + "mongodb+superset", "pymongosql.sqlalchemy_mongodb.sqlalchemy_dialect", "PyMongoSQLDialect" + ) + registry.register( + "mongodb.superset", "pymongosql.sqlalchemy_mongodb.sqlalchemy_dialect", "PyMongoSQLDialect" + ) + except Exception: + # If registration fails we fall back to handling Superset URIs in + # create_engine_from_mongodb_uri by converting 'mongodb+superset' to 'mongodb'. + pass + return True except ImportError: # Fallback for versions without registry diff --git a/pymongosql/superset_mongodb/detector.py b/pymongosql/superset_mongodb/detector.py index 01de7f0..9a6eb6b 100644 --- a/pymongosql/superset_mongodb/detector.py +++ b/pymongosql/superset_mongodb/detector.py @@ -89,20 +89,58 @@ def extract_outer_query(cls, query: str) -> Optional[Tuple[str, str]]: """ Extract outer query with subquery placeholder. + Preserves the complete outer query structure while replacing the subquery + with a reference to the temporary table. + Returns: - Tuple of (outer_query, subquery_alias) or None + Tuple of (outer_query, subquery_alias) or None if not a wrapped subquery """ info = cls.detect(query) if not info.is_wrapped: return None - # Replace subquery with temporary table reference - outer = cls.WRAPPED_SUBQUERY_PATTERN.sub( - f"SELECT * FROM {info.subquery_alias}", - query, + # Pattern to capture: SELECT FROM ( ) AS + # Matches both SELECT col1, col2 and SELECT col1 AS alias1, col2 AS alias2 formats + pattern = re.compile( + r"(SELECT\s+.+?)\s+FROM\s*\(\s*(?:select|SELECT)\s+.+?\s*\)\s+(?:AS\s+)?(\w+)(.*)", + re.IGNORECASE | re.DOTALL, ) - return outer, info.subquery_alias + match = pattern.search(query) + if match: + select_clause = match.group(1).strip() + table_alias = match.group(2) + rest_of_query = match.group(3).strip() + + if rest_of_query: + outer = f"{select_clause} FROM {table_alias} {rest_of_query}" + else: + outer = f"{select_clause} FROM {table_alias}" + + return outer, table_alias + + # If pattern doesn't match exactly, fall back to preserving SELECT clause + # Extract from SELECT to FROM keyword + select_match = re.search(r"(SELECT\s+.+?)\s+FROM", query, re.IGNORECASE | re.DOTALL) + if not select_match: + return None + + select_clause = select_match.group(1).strip() + + # Extract table alias and rest of query after the closing paren + rest_match = re.search(r"\)\s+(?:AS\s+)?(\w+)(.*)", query, re.IGNORECASE | re.DOTALL) + if rest_match: + table_alias = rest_match.group(1) + rest_of_query = rest_match.group(2).strip() + + if rest_of_query: + outer = f"{select_clause} FROM {table_alias} {rest_of_query}" + else: + outer = f"{select_clause} FROM {table_alias}" + + return outer, table_alias + + return None @classmethod def is_simple_select(cls, query: str) -> bool: diff --git a/pymongosql/superset_mongodb/executor.py b/pymongosql/superset_mongodb/executor.py index 9cecd47..920c3fb 100644 --- a/pymongosql/superset_mongodb/executor.py +++ b/pymongosql/superset_mongodb/executor.py @@ -105,10 +105,15 @@ def execute( try: # Create temporary table with MongoDB results querydb_query, table_name = SubqueryDetector.extract_outer_query(context.query) + if querydb_query is None or table_name is None: + # Fallback to original query if extraction fails + querydb_query = context.query + table_name = "virtual_table" + query_db.insert_records(table_name, mongo_dicts) # Execute outer query against intermediate DB - _logger.debug(f"Stage 2: Executing {db_name} query: {querydb_query}") + _logger.debug(f"Stage 2: Executing QueryDBSQLite query: {querydb_query}") querydb_rows = query_db.execute_query(querydb_query) _logger.debug(f"Stage 2 complete: Got {len(querydb_rows)} rows from {db_name}") @@ -116,23 +121,41 @@ def execute( # Create a ResultSet-like object from intermediate DB results result_set = self._create_result_set_from_db(querydb_rows, querydb_query) - self._execution_plan = ExecutionPlan(collection="query_db_result", projection_stage={}) + # Build projection_stage from query database result columns + projection_stage = {} + if querydb_rows and isinstance(querydb_rows[0], dict): + # Extract column names from first result row + for col_name in querydb_rows[0].keys(): + projection_stage[col_name] = 1 # 1 means included in projection + else: + # If no rows, get column names from the SQLite query directly + try: + cursor = query_db.execute_query_cursor(querydb_query) + if cursor.description: + # Extract column names from cursor description + for col_desc in cursor.description: + col_name = col_desc[0] + projection_stage[col_name] = 1 + except Exception as e: + _logger.warning(f"Could not extract column names from empty result: {e}") + + self._execution_plan = ExecutionPlan(collection="query_db_result", projection_stage=projection_stage) return result_set finally: query_db.close() - def _create_result_set_from_db(self, rows: List[Dict[str, Any]], query: str) -> ResultSet: + def _create_result_set_from_db(self, rows: List[Dict[str, Any]], query: str) -> Dict[str, Any]: """ - Create a ResultSet from query database results. + Create a command result from query database results. Args: rows: List of dictionaries from query database query: Original SQL query Returns: - ResultSet with query database results + Dictionary with command result format """ # Create a mock command result structure compatible with ResultSet command_result = { diff --git a/tests/test_superset_connection.py b/tests/test_superset_connection.py index fecd009..3e89a05 100644 --- a/tests/test_superset_connection.py +++ b/tests/test_superset_connection.py @@ -154,24 +154,39 @@ def test_core_connection_with_standard_queries(self, conn): assert "age" in col_names def test_subquery_simple_wrapping(self, superset_conn): - """Test simple subquery wrapping on users""" + """Test simple subquery wrapping on users (Superset-style SQL)""" assert superset_conn.mode == "superset" cursor = superset_conn.cursor() - # Simple subquery: wrap a MongoDB query result - subquery_sql = "SELECT * FROM (SELECT _id, name, age FROM users) AS u LIMIT 5" + # Superset-style query with column aliases + subquery_sql = """ + SELECT _id AS _id, name AS name, age AS age + FROM (SELECT _id, name, age FROM users) AS virtual_table + LIMIT 5 + """ cursor.execute(subquery_sql) rows = cursor.fetchall() assert len(rows) == 5 + # Verify column names + description = cursor.description + col_names = [desc[0] for desc in description] if description else [] + assert "_id" in col_names + assert "name" in col_names + assert "age" in col_names + def test_subquery_with_where_condition(self, superset_conn): - """Test subquery with WHERE on wrapper""" + """Test subquery with WHERE on wrapper (Superset-style SQL)""" cursor = superset_conn.cursor() - # Subquery: select from users, then filter in wrapper - subquery_sql = "SELECT * FROM (SELECT _id, name, age FROM users) AS u WHERE age > 30" + # Superset-style query with column aliases and WHERE clause + subquery_sql = """ + SELECT _id AS _id, name AS name, age AS age + FROM (SELECT _id, name, age FROM users) AS virtual_table + WHERE age > 30 + """ cursor.execute(subquery_sql) rows = cursor.fetchall() @@ -179,13 +194,16 @@ def test_subquery_with_where_condition(self, superset_conn): assert len(rows) == 11 def test_subquery_products_by_price_range(self, superset_conn): - """Test subquery filtering products by price range""" + """Test subquery filtering products by price range (Superset-style SQL)""" cursor = superset_conn.cursor() - # Subquery: get products, filter by price range in wrapper + # Superset-style query with column aliases and GROUP BY subquery_sql = """ - SELECT * FROM (SELECT _id, name, price, category FROM products WHERE price > 100) - AS p WHERE price < 2000 LIMIT 10 + SELECT _id AS _id, name AS name, price AS price, category AS category + FROM (SELECT _id, name, price, category FROM products WHERE price > 100) AS virtual_table + WHERE price < 2000 + GROUP BY _id, name, price, category + LIMIT 10 """ cursor.execute(subquery_sql) @@ -193,13 +211,15 @@ def test_subquery_products_by_price_range(self, superset_conn): assert len(rows) == 10 def test_subquery_orders_aggregation(self, superset_conn): - """Test subquery on orders with multiple conditions""" + """Test subquery on orders with multiple conditions (Superset-style SQL)""" cursor = superset_conn.cursor() - # Subquery: get orders, then filter for high-value completed orders + # Superset-style query with column aliases and GROUP BY aggregation subquery_sql = """ - SELECT * FROM (SELECT _id, user_id, total_amount, status FROM orders) - AS o WHERE status = 'completed' LIMIT 18 + SELECT order_date AS order_date, status AS status, total_amount AS total_amount, currency AS currency + FROM (SELECT order_date, status, total_amount, currency FROM orders) AS virtual_table + GROUP BY order_date, status, total_amount, currency + LIMIT 18 """ cursor.execute(subquery_sql) @@ -224,3 +244,292 @@ def test_multiple_queries_in_session(self, superset_conn): cursor.execute("SELECT _id, name, price FROM products LIMIT 3") products = cursor.fetchall() assert len(products) == 3 + + def test_description_matches_data_length(self, superset_conn): + """Test that cursor.description column count matches actual data tuple length""" + cursor = superset_conn.cursor() + + # Superset-style query with column aliases + subquery_sql = """ + SELECT _id AS _id, name AS name, age AS age + FROM (SELECT _id, name, age FROM users) AS virtual_table + LIMIT 5 + """ + cursor.execute(subquery_sql) + rows = cursor.fetchall() + description = cursor.description + + # Verify description exists + assert description is not None + assert len(description) > 0 + + # Verify each row tuple has same length as description + for row in rows: + assert len(row) == len( + description + ), f"Row tuple length {len(row)} doesn't match description length {len(description)}" + + def test_description_column_names_match_data(self, superset_conn): + """Test that description column names match the actual data fields""" + cursor = superset_conn.cursor() + + # Superset-style query with column aliases and GROUP BY + subquery_sql = """ + SELECT _id AS _id, name AS name, age AS age + FROM (SELECT _id, name, age FROM users) AS virtual_table + GROUP BY _id, name, age + LIMIT 3 + """ + cursor.execute(subquery_sql) + _ = cursor.fetchall() + description = cursor.description + + # Extract column names from description + col_names = [desc[0] for desc in description] + + # Verify expected columns are present + assert "_id" in col_names + assert "name" in col_names + assert "age" in col_names + + # Verify description has correct structure (7-tuple per DB API 2.0) + for desc in description: + assert len(desc) == 7 # name, type_code, display_size, internal_size, precision, scale, null_ok + + def test_data_values_integrity_through_stages(self, superset_conn): + """Test that data values are preserved correctly through MongoDB->SQLite->ResultSet stages""" + cursor = superset_conn.cursor() + + # First, get a known user name from the database + cursor.execute("SELECT _id AS _id, name AS name FROM (SELECT _id, name FROM users) AS virtual_table LIMIT 1") + sample_rows = cursor.fetchall() + assert len(sample_rows) >= 1 + known_name = sample_rows[0][1] + + # Now query for that specific user with Superset-style SQL + cursor.execute( + f"SELECT _id AS _id, name AS name FROM (SELECT _id, name FROM users) AS virtual_table " + f"WHERE name = '{known_name}' GROUP BY _id, name LIMIT 1" + ) + rows = cursor.fetchall() + + # Verify data was retrieved + assert len(rows) >= 1 + + # Verify data structure + row = rows[0] + assert len(row) == 2 # _id and name + assert row[1] == known_name # Name should match the known value + + def test_numeric_data_preserved_through_stages(self, superset_conn): + """Test that numeric data is correctly preserved through the two-stage execution""" + cursor = superset_conn.cursor() + + # Superset-style query with column aliases and numeric filtering + subquery_sql = """ + SELECT _id AS _id, name AS name, age AS age + FROM (SELECT _id, name, age FROM users) AS virtual_table + WHERE age > 25 + GROUP BY _id, name, age + LIMIT 5 + """ + cursor.execute(subquery_sql) + rows = cursor.fetchall() + description = cursor.description + + assert len(rows) > 0 + assert description is not None + col_names = [desc[0] for desc in description] + + # Find age column index + age_idx = col_names.index("age") + + # Verify numeric values are preserved and filtered correctly + for row in rows: + age_value = row[age_idx] + assert isinstance(age_value, (int, float)), f"Age should be numeric, got {type(age_value)}" + assert age_value > 25, f"Age {age_value} should be > 25" + + def test_description_consistency_across_fetches(self, superset_conn): + """Test that cursor.description remains consistent across multiple fetches""" + cursor = superset_conn.cursor() + + # Superset-style query with column aliases and GROUP BY + subquery_sql = """ + SELECT _id AS _id, name AS name, price AS price + FROM (SELECT _id, name, price FROM products) AS virtual_table + GROUP BY _id, name, price + LIMIT 10 + """ + cursor.execute(subquery_sql) + + # Get description before fetch + description_before = cursor.description + assert description_before is not None + + # Fetch all and get description again + _ = cursor.fetchall() + description_after = cursor.description + + # Descriptions should be identical + assert description_before == description_after + assert len(description_before) == 3 # _id, name, price + + def test_all_columns_in_description_match_data(self, superset_conn): + """Test that all columns in description are present in actual data""" + cursor = superset_conn.cursor() + + # Superset-style query with column aliases and GROUP BY aggregation + subquery_sql = """ + SELECT _id AS _id, order_date AS order_date, status AS status, total_amount AS total_amount + FROM (SELECT _id, order_date, status, total_amount FROM orders) AS virtual_table + GROUP BY _id, order_date, status, total_amount + LIMIT 5 + """ + cursor.execute(subquery_sql) + rows = cursor.fetchall() + description = cursor.description + + assert len(rows) > 0 + assert description is not None + + # Verify description has 4 columns + assert len(description) == 4 + + # Extract column names from description + desc_col_names = [desc[0] for desc in description] + + # Verify expected columns + expected_cols = ["_id", "order_date", "status", "total_amount"] + for expected_col in expected_cols: + assert expected_col in desc_col_names, f"Expected column {expected_col} not in description" + + # Verify every row has 4 values (matching description) + for row in rows: + assert len(row) == 4, f"Row has {len(row)} values but description has {len(description)} columns" + + def test_empty_result_with_valid_description(self, superset_conn): + """Test that description is available for result sets, even if empty after filtering""" + cursor = superset_conn.cursor() + + # Superset-style query that filters to empty results using a numeric condition + # Use a very large age value that unlikely to exist + subquery_sql = """ + SELECT _id AS _id, name AS name, age AS age + FROM (SELECT _id, name, age FROM users) AS virtual_table + WHERE age > 999 + GROUP BY _id, name, age + """ + cursor.execute(subquery_sql) + _ = cursor.fetchall() + description = cursor.description + + # Description should be available based on projection_stage + # (even if actual data is empty, the schema is known) + assert description is not None + assert len(description) == 3 + col_names = [desc[0] for desc in description] + assert "_id" in col_names + assert "name" in col_names + assert "age" in col_names + + +class TestSubqueryDetector: + """Test subquery detection and outer query extraction""" + + def test_detect_wrapped_subquery(self): + """Test detection of wrapped subquery pattern""" + from pymongosql.superset_mongodb.detector import SubqueryDetector + + query = "SELECT col1, col2 FROM (SELECT col1, col2 FROM table1) AS t1 WHERE col1 > 5" + info = SubqueryDetector.detect(query) + + assert info.has_subquery is True + assert info.is_wrapped is True + assert info.subquery_alias == "t1" + + def test_extract_outer_query_preserves_select_clause(self): + """Test that extract_outer_query preserves SELECT clause with column aliases""" + from pymongosql.superset_mongodb.detector import SubqueryDetector + + # Exact pattern from Superset + query = """ + SELECT order_date AS order_date, status AS status, total_amount AS total_amount, currency AS currency + FROM (SELECT order_date, status, total_amount, currency FROM orders) AS virtual_table + GROUP BY order_date, status, total_amount, currency + LIMIT 1000 + """ + + result = SubqueryDetector.extract_outer_query(query) + assert result is not None + + outer_query, table_alias = result + + # Verify table alias is correct + assert table_alias == "virtual_table" + + # Verify SELECT clause is preserved + assert "order_date AS order_date" in outer_query + assert "status AS status" in outer_query + assert "total_amount AS total_amount" in outer_query + assert "currency AS currency" in outer_query + + # Verify it has the table reference + assert "virtual_table" in outer_query + + # Verify GROUP BY is preserved + assert "GROUP BY" in outer_query + assert "order_date" in outer_query + assert "status" in outer_query + + # Verify LIMIT is preserved + assert "LIMIT 1000" in outer_query + + def test_extract_outer_query_with_where_clause(self): + """Test outer query extraction with WHERE clause""" + from pymongosql.superset_mongodb.detector import SubqueryDetector + + query = """ + SELECT _id AS _id, name AS name + FROM (SELECT _id, name FROM users) AS virtual_table + WHERE name = 'test' + GROUP BY _id, name + """ + + result = SubqueryDetector.extract_outer_query(query) + assert result is not None + + outer_query, table_alias = result + + # Verify WHERE clause is preserved + assert "WHERE name = 'test'" in outer_query + assert "_id AS _id" in outer_query + assert "virtual_table" in outer_query + + def test_extract_outer_query_complex_pattern(self): + """Test extraction with complex column aliases and multiple conditions""" + from pymongosql.superset_mongodb.detector import SubqueryDetector + + query = """ + SELECT col1 AS column_one, col2 AS column_two, col3 AS column_three + FROM (SELECT col1, col2, col3 FROM data_table WHERE col1 > 0) AS virtual_table + WHERE col2 IS NOT NULL + GROUP BY col1, col2, col3 + ORDER BY col1 DESC + LIMIT 100 + """ + + result = SubqueryDetector.extract_outer_query(query) + assert result is not None + + outer_query, table_alias = result + + # Verify all important elements are preserved + assert "col1 AS column_one" in outer_query + assert "col2 AS column_two" in outer_query + assert "col3 AS column_three" in outer_query + assert "WHERE col2 IS NOT NULL" in outer_query + assert "GROUP BY col1, col2, col3" in outer_query + assert "ORDER BY col1 DESC" in outer_query + assert "LIMIT 100" in outer_query + assert "virtual_table" in outer_query