Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pymongosql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
from .connection import Connection

__version__: str = "0.4.2"
__version__: str = "0.4.3"

# Globals https://www.python.org/dev/peps/pep-0249/#globals
apilevel: str = "2.0"
Expand Down
20 changes: 18 additions & 2 deletions pymongosql/sql/query_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ def can_handle(self, ctx: Any) -> bool:
"""Check if this is a from context"""
return hasattr(ctx, "tableReference")

@staticmethod
def _strip_collection_quotes(name: str) -> str:
"""Strip surrounding double quotes from collection name if present.

Args:
name: Collection name, potentially quoted

Returns:
Collection name with quotes removed
"""
return re.sub(r'^"([^"]+)"$', r"\1", name)

def _parse_function_call(self, ctx: Any) -> Optional[Dict[str, Any]]:
"""
Detect and parse aggregate() function calls in FROM clause.
Expand Down Expand Up @@ -196,13 +208,17 @@ def _parse_function_call(self, ctx: Any) -> Optional[Dict[str, Any]]:

# Pattern: [qualifier.]functionName(arg1, arg2)
# We need to match: (optional_collection.)aggregate('...', '...')
pattern = r"^(?:(\w+)\.)?aggregate\s*\(\s*'([^']*)'\s*,\s*'([^']*)'\s*\)$"
# Support collection names with double quotes for special characters like hyphens
pattern = r"^(?:(\"[^\"]+\"|\w+)\.)?aggregate\s*\(\s*'([^']*)'\s*,\s*'([^']*)'\s*\)$"
match = re.match(pattern, text, re.IGNORECASE | re.DOTALL)

if not match:
return None

collection = match.group(1) # Can be None for unqualified aggregate()
# Strip quotes from collection name if present
if collection:
collection = self._strip_collection_quotes(collection)
pipeline = match.group(2)
options = match.group(3)

Expand Down Expand Up @@ -245,7 +261,7 @@ def handle_visitor(self, ctx: PartiQLParser.FromClauseContext, parse_result: "Qu
# Regular collection reference
table_text = ctx.tableReference().getText()
# Strip surrounding quotes from collection name (e.g., "user.accounts" -> user.accounts)
collection_name = re.sub(r'^"([^"]+)"$', r"\1", table_text)
collection_name = self._strip_collection_quotes(table_text)
parse_result.collection = collection_name
_logger.debug(f"Parsed regular collection: {collection_name}")
return collection_name
Expand Down
27 changes: 27 additions & 0 deletions tests/test_cursor_aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,3 +327,30 @@ def test_aggregate_multiple_stages(self, conn):
total_users_idx = col_names.index("total_users")
assert row[avg_age_idx] is not None and isinstance(row[avg_age_idx], (int, float))
assert row[total_users_idx] is not None and isinstance(row[total_users_idx], (int, float))

def test_aggregate_collection_name_with_hyphen(self, conn):
"""Test aggregate function with collection name containing hyphen (user-orders)"""
pipeline = json.dumps([{"$match": {"customer_type": "premium"}}])

# Test collection name with hyphen
sql = f"""
SELECT *
FROM "user-orders".aggregate('{pipeline}', '{{}}')
"""

cursor = conn.cursor()
result = cursor.execute(sql)

assert result == cursor
assert isinstance(cursor.result_set, ResultSet)

rows = cursor.result_set.fetchall()
assert len(rows) > 0, "Should have results from user-orders collection"

# Verify all returned rows are premium customers
col_names = [desc[0] for desc in cursor.result_set.description]
assert "customer_type" in col_names, "customer_type should be in result columns"

customer_type_idx = col_names.index("customer_type")
for row in rows:
assert row[customer_type_idx] == "premium", "All rows should have customer_type='premium'"