diff --git a/apis/routers/guardrail.py b/apis/routers/guardrail.py index 5f9088c5f2..a956658575 100644 --- a/apis/routers/guardrail.py +++ b/apis/routers/guardrail.py @@ -135,7 +135,7 @@ async def guardstats( ) item = "condenast" - query, comment = strip_comment(query, item) + query, comment = strip_comment(query) # Extract functions from the query all_functions = extract_functions_from_query( diff --git a/apis/routers/statistics.py b/apis/routers/statistics.py index efcc2a3053..92129b9e6f 100644 --- a/apis/routers/statistics.py +++ b/apis/routers/statistics.py @@ -68,7 +68,7 @@ async def stats_api( ) item = "condenast" - query, comment = strip_comment(query, item) + query, comment = strip_comment(query) # Extract functions from the query all_functions = extract_functions_from_query( diff --git a/apis/utils/helpers.py b/apis/utils/helpers.py index dec20474c6..fcbc511c73 100644 --- a/apis/utils/helpers.py +++ b/apis/utils/helpers.py @@ -16,10 +16,12 @@ logger = logging.getLogger(__name__) -def transpile_query(query: str, from_sql: str, to_sql: str) -> str: +def transpile_query(query: str, from_sql: str, to_sql: Optional[str] = "E6") -> str: """ Transpile a SQL query from one dialect to another. """ + if to_sql is None: + to_sql = "E6" try: # original_ast = parse_one(query, read=from_sql) # values_ensured_ast = ensure_select_from_values(original_ast) diff --git a/converter_api.py b/converter_api.py index 699cb5ad98..be253b3bcc 100644 --- a/converter_api.py +++ b/converter_api.py @@ -34,6 +34,7 @@ transform_catalog_schema_only, set_cte_names_case_sensitively, ) +from formatting_utils import preserve_formatting if t.TYPE_CHECKING: from sqlglot._typing import E @@ -147,6 +148,12 @@ async def convert_query( double_quotes_added_query = replace_struct_in_query(double_quotes_added_query) + # Preserve original formatting if enabled via feature flag + if flags_dict.get("PRESERVE_FORMATTING", False): + double_quotes_added_query = preserve_formatting( + query, double_quotes_added_query, from_sql, to_sql + ) + # double_quotes_added_query = add_comment_to_query(double_quotes_added_query, comment) logger.info( @@ -381,6 +388,12 @@ async def stats_api( double_quotes_added_query = replace_struct_in_query(double_quotes_added_query) + # Preserve original formatting if enabled via feature flag + if flags_dict.get("PRESERVE_FORMATTING", False): + double_quotes_added_query = preserve_formatting( + query, double_quotes_added_query, from_sql, to_sql + ) + double_quotes_added_query = add_comment_to_query(double_quotes_added_query, comment) logger.info("Got the converted query!!!!") @@ -562,6 +575,9 @@ async def guardstats( double_quotes_added_query = replace_struct_in_query(double_quotes_added_query) + # Note: PRESERVE_FORMATTING not available here as no flags_dict + # Can be added if needed by adding feature_flags parameter to this endpoint + double_quotes_added_query = add_comment_to_query(double_quotes_added_query, comment) all_functions_converted_query = extract_functions_from_query( diff --git a/formatting_utils.py b/formatting_utils.py new file mode 100644 index 0000000000..f9238c2463 --- /dev/null +++ b/formatting_utils.py @@ -0,0 +1,242 @@ +""" +Formatting Utilities for SQL Transpilation + +This module provides a simple wrapper to preserve original query formatting +after transpilation. It can be integrated into existing code with minimal changes. + +Usage: + from formatting_utils import preserve_formatting + + # Your existing transpilation code + transpiled_query = tree.sql(dialect=to_sql, from_dialect=from_sql) + + # Add this one line to preserve formatting + formatted_query = preserve_formatting(original_query, transpiled_query, from_sql, to_sql) +""" + +from typing import Dict, List, Optional +from sqlglot import tokenize +from sqlglot.tokens import Token + + +def preserve_formatting( + original_sql: str, + transpiled_sql: str, + source_dialect: Optional[str] = None, + target_dialect: Optional[str] = None, +) -> str: + """ + Preserve the original SQL formatting in the transpiled output. + + This function takes the original SQL (with its formatting) and the transpiled SQL, + then reconstructs the transpiled SQL using the whitespace from the original. + + Args: + original_sql: The original SQL query with formatting to preserve + transpiled_sql: The transpiled SQL query (typically loses formatting) + source_dialect: Source dialect (optional, for future use) + target_dialect: Target dialect (optional, for future use) + + Returns: + The transpiled SQL with original formatting preserved + + Example: + >>> original = '''SELECT + ... col1, + ... col2 + ... FROM table1''' + >>> transpiled = "SELECT col1, col2 FROM table1" + >>> preserve_formatting(original, transpiled) + 'SELECT\\n col1,\\n col2\\nFROM table1' + """ + if not original_sql or not transpiled_sql: + return transpiled_sql + + # Get original tokens with whitespace + try: + original_tokens = list(tokenize(original_sql)) + transpiled_tokens = list(tokenize(transpiled_sql)) + except Exception: + # If tokenization fails, return transpiled as-is + return transpiled_sql + + # Check if whitespace_before is available (not available in Rust tokenizer) + if original_tokens and not hasattr(original_tokens[0], "whitespace_before"): + # Rust tokenizer doesn't support whitespace_before, return transpiled as-is + return transpiled_sql + + # Build alignment between transpiled and original tokens + alignment = _align_tokens(original_tokens, transpiled_tokens) + + # Reconstruct with original whitespace + result_parts = [] + for i, token in enumerate(transpiled_tokens): + orig_idx = alignment.get(i) + + if orig_idx is not None: + # Use original whitespace + ws = original_tokens[orig_idx].whitespace_before + # For first token, strip leading whitespace to avoid extra indentation + if i == 0: + ws = ws.lstrip() + result_parts.append(ws) + else: + # New token - use minimal spacing + if i > 0: + # Don't add space before commas or closing parens + if token.token_type.name in ("COMMA", "R_PAREN", "R_BRACKET"): + pass # No space + # Don't add space after opening parens + elif result_parts and result_parts[-1].endswith("("): + pass # No space + else: + result_parts.append(" ") + + # Handle token text (preserve quotes for strings) + if token.token_type.name == "STRING": + result_parts.append(f"'{token.text}'") + else: + result_parts.append(token.text) + + return "".join(result_parts) + + +def _align_tokens( + original_tokens: List[Token], + transpiled_tokens: List[Token], +) -> Dict[int, int]: + """ + Align transpiled tokens to original tokens using position-aware matching. + Returns dict mapping transpiled index -> original index. + """ + alignment: Dict[int, int] = {} + used_original: set = set() + + n_orig = len(original_tokens) + n_trans = len(transpiled_tokens) + + # For each transpiled token, find all matching original tokens + # Then pick the one closest in relative position + for trans_idx, trans_token in enumerate(transpiled_tokens): + trans_ratio = trans_idx / max(n_trans, 1) + + best_orig = None + best_score = float("inf") + + for orig_idx, orig_token in enumerate(original_tokens): + if orig_idx in used_original: + continue + + # Check if tokens match (exact type and text) + if ( + orig_token.token_type == trans_token.token_type + and orig_token.text.upper() == trans_token.text.upper() + ): + # Score based on relative position difference + orig_ratio = orig_idx / max(n_orig, 1) + score = abs(trans_ratio - orig_ratio) + + if score < best_score: + best_score = score + best_orig = orig_idx + + if best_orig is not None: + alignment[trans_idx] = best_orig + used_original.add(best_orig) + + # Second pass: handle function renames (struct -> NAMED_STRUCT, etc.) + for trans_idx, trans_token in enumerate(transpiled_tokens): + if trans_idx in alignment: + continue + + # Only handle function calls (VAR followed by L_PAREN) + if trans_token.token_type.name not in ("VAR", "STRUCT"): + continue + if trans_idx + 1 >= n_trans: + continue + if transpiled_tokens[trans_idx + 1].token_type.name != "L_PAREN": + continue + + trans_ratio = trans_idx / max(n_trans, 1) + best_orig = None + best_score = float("inf") + + for orig_idx, orig_token in enumerate(original_tokens): + if orig_idx in used_original: + continue + if orig_token.token_type.name not in ("VAR", "STRUCT"): + continue + if orig_idx + 1 >= n_orig: + continue + if original_tokens[orig_idx + 1].token_type.name != "L_PAREN": + continue + + orig_ratio = orig_idx / max(n_orig, 1) + score = abs(trans_ratio - orig_ratio) + + if score < best_score: + best_score = score + best_orig = orig_idx + + if best_orig is not None and best_score < 0.2: + alignment[trans_idx] = best_orig + used_original.add(best_orig) + + # Third pass: structural tokens (parens, commas) by closest position + structural_types = {"L_PAREN", "R_PAREN", "COMMA"} + + for trans_idx, trans_token in enumerate(transpiled_tokens): + if trans_idx in alignment: + continue + if trans_token.token_type.name not in structural_types: + continue + + trans_ratio = trans_idx / max(n_trans, 1) + best_orig = None + best_score = float("inf") + + for orig_idx, orig_token in enumerate(original_tokens): + if orig_idx in used_original: + continue + if orig_token.token_type != trans_token.token_type: + continue + + orig_ratio = orig_idx / max(n_orig, 1) + score = abs(trans_ratio - orig_ratio) + + if score < best_score: + best_score = score + best_orig = orig_idx + + if best_orig is not None and best_score < 0.15: + alignment[trans_idx] = best_orig + used_original.add(best_orig) + + return alignment + + +# Convenience function with feature flag support +def transpile_with_formatting( + original_sql: str, + transpiled_sql: str, + preserve_format: bool = True, + source_dialect: Optional[str] = None, + target_dialect: Optional[str] = None, +) -> str: + """ + Wrapper that conditionally preserves formatting based on a flag. + + Args: + original_sql: The original SQL query + transpiled_sql: The transpiled SQL query + preserve_format: If True, preserve original formatting. If False, return transpiled as-is. + source_dialect: Source dialect (optional) + target_dialect: Target dialect (optional) + + Returns: + Formatted or unformatted transpiled SQL based on preserve_format flag + """ + if not preserve_format: + return transpiled_sql + + return preserve_formatting(original_sql, transpiled_sql, source_dialect, target_dialect) diff --git a/setup.cfg b/setup.cfg index 9b492ee3b7..4f6e3daf2c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,3 +13,6 @@ ignore_errors = True [mypy-tests.dataframe.*] ignore_errors = False + +[mypy-guardrail.*] +ignore_errors = True diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index de986f3b14..f7d2a9b11f 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -65,7 +65,7 @@ def __new__(cls, clsname, bases, attrs): SQLGLOT_ANONYMOUS = "sqlglot.anonymous" TABLE_PARTS = ("this", "db", "catalog") COLUMN_PARTS = ("this", "table", "db", "catalog") -POSITION_META_KEYS = ("line", "col", "start", "end") +POSITION_META_KEYS = ("line", "col", "start", "end", "whitespace_before") class Expression(metaclass=_Expression): @@ -878,6 +878,7 @@ def update_positions( "col": other.col, "start": other.start, "end": other.end, + "whitespace_before": getattr(other, "whitespace_before", ""), } ) self.meta.update({k: v for k, v in kwargs.items() if k in POSITION_META_KEYS}) diff --git a/sqlglot/tokens.py b/sqlglot/tokens.py index 5b7352bd5a..c2fe87d6e2 100644 --- a/sqlglot/tokens.py +++ b/sqlglot/tokens.py @@ -440,7 +440,16 @@ class TokenType(AutoName): class Token: - __slots__ = ("token_type", "text", "line", "col", "start", "end", "comments") + __slots__ = ( + "token_type", + "text", + "line", + "col", + "start", + "end", + "comments", + "whitespace_before", + ) @classmethod def number(cls, number: int) -> Token: @@ -471,6 +480,7 @@ def __init__( start: int = 0, end: int = 0, comments: t.Optional[t.List[str]] = None, + whitespace_before: str = "", ) -> None: """Token initializer. @@ -482,6 +492,7 @@ def __init__( start: The start index of the token. end: The ending index of the token. comments: The comments to attach to the token. + whitespace_before: The whitespace preceding this token in the original SQL. """ self.token_type = token_type self.text = text @@ -490,6 +501,7 @@ def __init__( self.start = start self.end = end self.comments = [] if comments is None else comments + self.whitespace_before = whitespace_before def __repr__(self) -> str: attributes = ", ".join(f"{k}: {getattr(self, k)}" for k in self.__slots__) @@ -1022,6 +1034,7 @@ class Tokenizer(metaclass=_Tokenizer): "_end", "_peek", "_prev_token_line", + "_prev_token_end", "_rs_dialect_settings", ) @@ -1063,6 +1076,7 @@ def reset(self) -> None: self._end = False self._peek = "" self._prev_token_line = -1 + self._prev_token_end = -1 # Track end of previous token for whitespace calculation def tokenize(self, sql: str) -> t.List[Token]: """Returns a list of tokens corresponding to the SQL string `sql`.""" @@ -1168,6 +1182,14 @@ def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: self.tokens[-1].comments.extend(self._comments) self._comments = [] + # Calculate whitespace before this token + if self._prev_token_end == -1: + # First token - whitespace from start of SQL + whitespace_before = self.sql[0 : self._start] + else: + # Whitespace between previous token end and this token start + whitespace_before = self.sql[self._prev_token_end + 1 : self._start] + self.tokens.append( Token( token_type, @@ -1177,9 +1199,11 @@ def _add(self, token_type: TokenType, text: t.Optional[str] = None) -> None: start=self._start, end=self._current - 1, comments=self._comments, + whitespace_before=whitespace_before, ) ) self._comments = [] + self._prev_token_end = self._current - 1 # Update for next token # If we have either a semicolon or a begin token before the command's token, we'll parse # whatever follows the command's token as a string diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 86d0946d14..90984e75b3 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -2570,38 +2570,58 @@ def test_with_offset(self): ) def test_identifier_meta(self): + from sqlglot import tokenize + + # Check if tokenizer provides accurate whitespace_before (Python tokenizer only) + # Note: Expression meta always includes whitespace_before (with "" fallback for Rust tokenizer) + tokens = list(tokenize("SELECT 1")) + has_accurate_whitespace = hasattr(tokens[0], "whitespace_before") if tokens else False + ast = parse_one( "SELECT a, b FROM test_schema.test_table_a UNION ALL SELECT c, d FROM test_catalog.test_schema.test_table_b", dialect="bigquery", ) + + # Meta always includes whitespace_before (via getattr fallback in expressions.py) + expected_meta_keys = {"line", "col", "start", "end", "whitespace_before"} + for identifier in ast.find_all(exp.Identifier): - self.assertEqual(set(identifier.meta), {"line", "col", "start", "end"}) + self.assertEqual(set(identifier.meta), expected_meta_keys) + + # With Rust tokenizer, whitespace_before is always "" (fallback) + # With Python tokenizer, whitespace_before is accurate + ws_empty = "" + ws_space = " " if has_accurate_whitespace else "" self.assertEqual( ast.this.args["from"].this.args["this"].meta, - {"line": 1, "col": 41, "start": 29, "end": 40}, + {"line": 1, "col": 41, "start": 29, "end": 40, "whitespace_before": ws_empty}, ) self.assertEqual( ast.this.args["from"].this.args["db"].meta, - {"line": 1, "col": 28, "start": 17, "end": 27}, + {"line": 1, "col": 28, "start": 17, "end": 27, "whitespace_before": ws_space}, ) self.assertEqual( ast.expression.args["from"].this.args["this"].meta, - {"line": 1, "col": 106, "start": 94, "end": 105}, + {"line": 1, "col": 106, "start": 94, "end": 105, "whitespace_before": ws_empty}, ) self.assertEqual( ast.expression.args["from"].this.args["db"].meta, - {"line": 1, "col": 93, "start": 82, "end": 92}, + {"line": 1, "col": 93, "start": 82, "end": 92, "whitespace_before": ws_empty}, ) self.assertEqual( ast.expression.args["from"].this.args["catalog"].meta, - {"line": 1, "col": 81, "start": 69, "end": 80}, + {"line": 1, "col": 81, "start": 69, "end": 80, "whitespace_before": ws_space}, ) information_schema_sql = "SELECT a, b FROM region.INFORMATION_SCHEMA.COLUMNS" ast = parse_one(information_schema_sql, dialect="bigquery") meta = ast.args["from"].this.this.meta - self.assertEqual(meta, {"line": 1, "col": 50, "start": 24, "end": 49}) + # Note: whitespace_before is always in meta, but value depends on tokenizer + self.assertEqual(meta["line"], 1) + self.assertEqual(meta["col"], 50) + self.assertEqual(meta["start"], 24) + self.assertEqual(meta["end"], 49) assert ( information_schema_sql[meta["start"] : meta["end"] + 1] == "INFORMATION_SCHEMA.COLUMNS" ) diff --git a/tests/dialects/test_e6.py b/tests/dialects/test_e6.py index ba561d80fa..5f4eaffe11 100644 --- a/tests/dialects/test_e6.py +++ b/tests/dialects/test_e6.py @@ -2994,3 +2994,102 @@ def test_cast_precision_preserved(self): "databricks": "SELECT CAST(a AS DECIMAL(10, 2)), CAST(b AS DECIMAL(5, 0)), CAST(c AS INTEGER)", }, ) + + def test_formatting_preservation(self): + """Test that formatting preservation works correctly during transpilation.""" + from formatting_utils import preserve_formatting + from sqlglot import tokenize + + # Check if whitespace_before is available (not available in Rust tokenizer) + tokens = list(tokenize("SELECT 1")) + has_whitespace_support = hasattr(tokens[0], "whitespace_before") if tokens else False + + # Test 1: Basic formatting with newlines and indentation + original = """SELECT + col1, + col2, + col3 +FROM table1 +WHERE status = 'active'""" + transpiled = "SELECT col1, col2, col3 FROM table1 WHERE status = 'active'" + result = preserve_formatting(original, transpiled, "databricks", "e6") + + # Basic assertions that work with both tokenizers + self.assertIn("col1", result) + self.assertIn("col2", result) + + # Formatting assertions only when whitespace_before is available + if has_whitespace_support: + self.assertIn("\n", result) + + # Test 2: Function rename (IFF -> IF) preserves formatting + original_iff = """SELECT + IFF(col1 > 10, 'high', 'low') AS category, + col2 +FROM table1""" + transpiled_if = "SELECT IF(col1 > 10, 'high', 'low') AS category, col2 FROM table1" + result_if = preserve_formatting(original_iff, transpiled_if, "snowflake", "e6") + + # Verify the function was renamed (works with both tokenizers) + self.assertIn("IF(", result_if) + self.assertNotIn("IFF(", result_if) + if has_whitespace_support: + self.assertIn("\n", result_if) + + # Test 3: Complex query with CTEs + original_cte = """WITH raw AS ( + SELECT + id, + value + FROM source_table +) +SELECT + id, + value +FROM raw""" + transpiled_cte = ( + "WITH raw AS (SELECT id, value FROM source_table) SELECT id, value FROM raw" + ) + result_cte = preserve_formatting(original_cte, transpiled_cte, "databricks", "e6") + + # Verify CTE structure (works with both tokenizers) + self.assertIn("WITH", result_cte) + self.assertIn("raw", result_cte) + if has_whitespace_support: + self.assertIn("\n", result_cte) + + # Test 4: Preserved string quotes + original_strings = """SELECT + 'hello' AS greeting, + 'world' AS target +FROM dual""" + transpiled_strings = "SELECT 'hello' AS greeting, 'world' AS target FROM dual" + result_strings = preserve_formatting( + original_strings, transpiled_strings, "databricks", "e6" + ) + + # Verify strings are preserved with quotes + self.assertIn("'hello'", result_strings) + self.assertIn("'world'", result_strings) + + # Test 5: Empty/None inputs handled gracefully + self.assertEqual(preserve_formatting("", "SELECT 1"), "SELECT 1") + self.assertEqual(preserve_formatting("SELECT 1", ""), "") + self.assertEqual(preserve_formatting(None, "SELECT 1"), "SELECT 1") + + # Test 6: Tab indentation preserved (only with whitespace support) + original_tabs = "SELECT\n\tcol1,\n\tcol2\nFROM table1" + transpiled_tabs = "SELECT col1, col2 FROM table1" + result_tabs = preserve_formatting(original_tabs, transpiled_tabs, "databricks", "e6") + + if has_whitespace_support: + self.assertIn("\t", result_tabs) + + # Test 7: Multiple spaces between tokens preserved + original_spaces = "SELECT col1, col2 FROM table1" + transpiled_spaces = "SELECT col1, col2 FROM table1" + result_spaces = preserve_formatting(original_spaces, transpiled_spaces, "databricks", "e6") + + # Result should have the columns (works with both tokenizers) + self.assertIn("col1", result_spaces) + self.assertIn("col2", result_spaces) diff --git a/tests/test_parser.py b/tests/test_parser.py index 38a99d9313..a556cfe436 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -961,41 +961,66 @@ def test_udf_meta(self): self.assertIsInstance(ast, exp.Year) def test_token_position_meta(self): + from sqlglot import tokenize + + # Check if tokenizer provides accurate whitespace_before (Python tokenizer only) + # Note: Expression meta always includes whitespace_before (with "" fallback for Rust tokenizer) + tokens = list(tokenize("SELECT 1")) + has_accurate_whitespace = hasattr(tokens[0], "whitespace_before") if tokens else False + ast = parse_one( "SELECT a, b FROM test_schema.test_table_a UNION ALL SELECT c, d FROM test_catalog.test_schema.test_table_b" ) + + # Meta always includes whitespace_before (via getattr fallback in expressions.py) + expected_meta_keys = {"line", "col", "start", "end", "whitespace_before"} + for identifier in ast.find_all(exp.Identifier): - self.assertEqual(set(identifier.meta), {"line", "col", "start", "end"}) + self.assertEqual(set(identifier.meta), expected_meta_keys) + + # With Rust tokenizer, whitespace_before is always "" (fallback) + # With Python tokenizer, whitespace_before is accurate + ws_empty = "" + ws_space = " " if has_accurate_whitespace else "" self.assertEqual( ast.this.args["from"].this.args["this"].meta, - {"line": 1, "col": 41, "start": 29, "end": 40}, + {"line": 1, "col": 41, "start": 29, "end": 40, "whitespace_before": ws_empty}, ) self.assertEqual( ast.this.args["from"].this.args["db"].meta, - {"line": 1, "col": 28, "start": 17, "end": 27}, + {"line": 1, "col": 28, "start": 17, "end": 27, "whitespace_before": ws_space}, ) self.assertEqual( ast.expression.args["from"].this.args["this"].meta, - {"line": 1, "col": 106, "start": 94, "end": 105}, + {"line": 1, "col": 106, "start": 94, "end": 105, "whitespace_before": ws_empty}, ) self.assertEqual( ast.expression.args["from"].this.args["db"].meta, - {"line": 1, "col": 93, "start": 82, "end": 92}, + {"line": 1, "col": 93, "start": 82, "end": 92, "whitespace_before": ws_empty}, ) self.assertEqual( ast.expression.args["from"].this.args["catalog"].meta, - {"line": 1, "col": 81, "start": 69, "end": 80}, + {"line": 1, "col": 81, "start": 69, "end": 80, "whitespace_before": ws_space}, ) ast = parse_one("SELECT FOO()") - self.assertEqual(ast.find(exp.Anonymous).meta, {"line": 1, "col": 10, "start": 7, "end": 9}) + self.assertEqual( + ast.find(exp.Anonymous).meta, + {"line": 1, "col": 10, "start": 7, "end": 9, "whitespace_before": ws_space}, + ) ast = parse_one("SELECT * FROM t") - self.assertEqual(ast.find(exp.Star).meta, {"line": 1, "col": 8, "start": 7, "end": 7}) + self.assertEqual( + ast.find(exp.Star).meta, + {"line": 1, "col": 8, "start": 7, "end": 7, "whitespace_before": ws_space}, + ) ast = parse_one("SELECT t.* FROM t") - self.assertEqual(ast.find(exp.Star).meta, {"line": 1, "col": 10, "start": 9, "end": 9}) + self.assertEqual( + ast.find(exp.Star).meta, + {"line": 1, "col": 10, "start": 9, "end": 9, "whitespace_before": ws_empty}, + ) def test_quoted_identifier_meta(self): sql = 'SELECT "a" FROM "test_schema"."test_table_a"' diff --git a/tests/test_tokens.py b/tests/test_tokens.py index 8b78fdf4d4..3293daddc1 100644 --- a/tests/test_tokens.py +++ b/tests/test_tokens.py @@ -204,7 +204,19 @@ def test_partial_token_list(self): def test_token_repr(self): # Ensures both the Python and the Rust tokenizer produce a human-friendly representation - self.assertEqual( - repr(Tokenizer().tokenize("foo")), - "[]", - ) + tokens = Tokenizer().tokenize("foo") + token_repr = repr(tokens) + + # Check if whitespace_before is available (Python tokenizer only) + has_whitespace_support = hasattr(tokens[0], "whitespace_before") if tokens else False + + if has_whitespace_support: + self.assertEqual( + token_repr, + "[]", + ) + else: + self.assertEqual( + token_repr, + "[]", + )