From a9ada8be34086f5850a975f00da84b80e2cb6b20 Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Wed, 29 Apr 2026 11:10:22 -0600 Subject: [PATCH 01/13] Add AST node dataclasses for hrw4u linter Frozen dataclasses representing the semantic AST that a visitor produces from the ANTLR parse tree. Includes Target decomposition (namespace/field/modifier), all statement nodes (Assignment, FunctionCall, BreakStatement, StandaloneOperator), condition expression nodes (Comparison, LogicalOp, Negation), control flow (IfBlock, ElifBranch), and top-level constructs (VarDecl, UseDecl, ProcedureDecl, Section). Type aliases ConditionExpr, BodyNode, and TopLevelNode provide convenience unions. Tests cover Target.from_dotted parsing, node construction, and immutability. --- tools/hrw4u/src/ast_nodes.py | 152 ++++++++++++++++++++++++++++ tools/hrw4u/tests/test_ast_nodes.py | 40 ++++++++ 2 files changed, 192 insertions(+) create mode 100644 tools/hrw4u/src/ast_nodes.py create mode 100644 tools/hrw4u/tests/test_ast_nodes.py diff --git a/tools/hrw4u/src/ast_nodes.py b/tools/hrw4u/src/ast_nodes.py new file mode 100644 index 00000000000..aee65c02a71 --- /dev/null +++ b/tools/hrw4u/src/ast_nodes.py @@ -0,0 +1,152 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Union + + +@dataclass(frozen=True, kw_only=True) +class Node: + line: int + + +@dataclass(frozen=True) +class Target: + namespace: str | None + field: str + + @staticmethod + def from_dotted(name: str) -> Target: + # TODO: the grammar lexes dotted paths as a single IDENT token; + # ideally the grammar would split namespace/field so this + # heuristic isn't needed. + dot = name.rfind(".") + if dot == -1: + return Target(namespace=None, field=name) + return Target(namespace=name[:dot], field=name[dot + 1:]) + + +@dataclass(frozen=True, kw_only=True) +class Assignment(Node): + target: Target + operator: str # "=" or "+=" + value: str | int | bool | tuple + + +@dataclass(frozen=True, kw_only=True) +class FunctionCall(Node): + name: str + args: tuple[str | int | bool, ...] + + +@dataclass(frozen=True, kw_only=True) +class Break(Node): + pass + + +@dataclass(frozen=True, kw_only=True) +class Comparison(Node): + left: str | FunctionCall + operator: str # "==", "!=", ">", "<", "~", "!~", "in", "!in" + right: str | int | bool | tuple + modifiers: tuple[str, ...] + + +@dataclass(frozen=True, kw_only=True) +class LogicalOp(Node): + operator: str # "&&" or "||" + left: ConditionExpr + right: ConditionExpr + + +@dataclass(frozen=True, kw_only=True) +class NotOp(Node): + operand: ConditionExpr + + +@dataclass(frozen=True, kw_only=True) +class BoolLiteral(Node): + value: bool + + +@dataclass(frozen=True, kw_only=True) +class IdentCondition(Node): + name: str + + +@dataclass(frozen=True, kw_only=True) +class ElifBranch(Node): + condition: ConditionExpr + body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class IfBlock(Node): + condition: ConditionExpr + body: tuple[BodyNode, ...] + elif_branches: tuple[ElifBranch, ...] + else_body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class Section(Node): + type: str + body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class ProcParam(Node): + name: str + default: str | int | bool | None + + +@dataclass(frozen=True, kw_only=True) +class VarDecl(Node): + name: str + type_name: str + slot: int | None + + +@dataclass(frozen=True, kw_only=True) +class VarSection(Node): + scope: str + declarations: tuple[VarDecl, ...] + + +@dataclass(frozen=True, kw_only=True) +class UseDirective(Node): + spec: str + + +@dataclass(frozen=True, kw_only=True) +class ProcedureDecl(Node): + name: str + params: tuple[ProcParam, ...] + body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class HRW4UAST: + body: tuple[TopLevelNode, ...] + + +# Type aliases: must follow all class definitions (evaluated at runtime). +ConditionExpr = Union[Comparison, LogicalOp, NotOp, BoolLiteral, IdentCondition, FunctionCall] +BodyNode = Union[Assignment, FunctionCall, IfBlock, Break] +TopLevelNode = Union[UseDirective, VarSection, ProcedureDecl, Section] diff --git a/tools/hrw4u/tests/test_ast_nodes.py b/tools/hrw4u/tests/test_ast_nodes.py new file mode 100644 index 00000000000..1c2320c4e53 --- /dev/null +++ b/tools/hrw4u/tests/test_ast_nodes.py @@ -0,0 +1,40 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hrw4u.ast_nodes import Target + + +class TestTarget: + def test_dotted_path(self): + t = Target.from_dotted("inbound.req.X-Foo") + assert t.namespace == "inbound.req" + assert t.field == "X-Foo" + + def test_two_segments(self): + t = Target.from_dotted("inbound.ip") + assert t.namespace == "inbound" + assert t.field == "ip" + + def test_no_dots(self): + t = Target.from_dotted("bool_0") + assert t.namespace is None + assert t.field == "bool_0" + + def test_deep_namespace(self): + t = Target.from_dotted("http.cntl.TXN_DEBUG") + assert t.namespace == "http.cntl" + assert t.field == "TXN_DEBUG" From 8cf91ac75c11eeb427d10e0dad10e0c056d56ecd Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Wed, 29 Apr 2026 11:32:37 -0600 Subject: [PATCH 02/13] Add AST visitor to build AST from ANTLR parse tree ASTVisitor walks the ANTLR parse tree and produces HRW4UAST. Handles named sections, assignments (= and +=), function calls, break statements, standalone operators, condition expressions (comparisons, logical operators, negation, set membership, IP ranges, WITH modifiers), if/elif/else blocks with arbitrary nesting, and top-level var/use/procedure declarations. Only visitProgram is overridden from the ANTLR visitor base class; all other dispatch is internal, keeping the public API surface minimal. Raises ValueError for unhandled grammar alternatives to surface visitor-grammar drift early. Makefile updated to include ast_nodes.py and ast_visitor.py in the build copy step. --- tools/hrw4u/Makefile | 4 +- tools/hrw4u/src/ast_visitor.py | 293 +++++++++++++++++++++++++++++++++ 2 files changed, 296 insertions(+), 1 deletion(-) create mode 100644 tools/hrw4u/src/ast_visitor.py diff --git a/tools/hrw4u/Makefile b/tools/hrw4u/Makefile index 7b929356af3..ba25de02d98 100644 --- a/tools/hrw4u/Makefile +++ b/tools/hrw4u/Makefile @@ -53,7 +53,9 @@ SRC_FILES_HRW4U=src/visitor.py \ src/suggestions.py \ src/procedures.py \ src/sandbox.py \ - src/kg_visitor.py + src/kg_visitor.py \ + src/ast_nodes.py \ + src/ast_visitor.py ALL_HRW4U_FILES=$(SHARED_FILES) $(UTILS_FILES) $(SRC_FILES_HRW4U) diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py new file mode 100644 index 00000000000..8b5c63aab77 --- /dev/null +++ b/tools/hrw4u/src/ast_visitor.py @@ -0,0 +1,293 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from hrw4u.hrw4uVisitor import hrw4uVisitor +from hrw4u.ast_nodes import ( + HRW4UAST, + Section, + Assignment, + FunctionCall, + Break, + Target, + IfBlock, + ElifBranch, + BoolLiteral, + Comparison, + LogicalOp, + NotOp, + IdentCondition, + ProcParam, + VarDecl, + VarSection, + UseDirective, + ProcedureDecl, +) + + +class ASTVisitor(hrw4uVisitor): + """ANTLR visitor that walks an HRW4U parse tree and produces an AST for HRW4U.""" + + # Only visitProgram is overridden from the ANTLR visitor interface; + # all other traversal uses private _visit_* helpers so that each + # method has an explicit return type and full control over how + # child results are assembled into parent AST nodes. + + def visitProgram(self, ctx): + items = [] + for item in ctx.programItem(): + if item.useDirective() is not None: + items.append(self._visit_use_directive(item.useDirective())) + elif item.procedureDecl() is not None: + items.append(self._visit_procedure_decl(item.procedureDecl())) + elif item.section() is not None: + items.append(self._visit_section(item.section())) + return HRW4UAST(body=tuple(items)) + + def _visit_use_directive(self, ctx): + return UseDirective( + spec=ctx.QUALIFIED_IDENT().getText(), + line=ctx.start.line, + ) + + def _visit_procedure_decl(self, ctx): + name = ctx.QUALIFIED_IDENT().getText() + params = () + if ctx.paramList(): + params = tuple( + self._visit_proc_param(p) for p in ctx.paramList().param() + ) + body = tuple(self._visit_body(ctx.block().blockItem())) + return ProcedureDecl(name=name, params=params, body=body, line=ctx.start.line) + + def _visit_proc_param(self, ctx): + name = ctx.IDENT().getText() + default = self._extract_value(ctx.value()) if ctx.value() else None + return ProcParam(name=name, default=default, line=ctx.start.line) + + def _visit_section(self, ctx): + if ctx.varSection() is not None: + return self._visit_var_section(ctx.varSection(), "txn") + if ctx.sessionVarSection() is not None: + return self._visit_var_section(ctx.sessionVarSection(), "session") + name = ctx.name.text + body = self._visit_body(ctx.sectionBody()) + return Section(type=name, body=tuple(body), line=ctx.start.line) + + def _visit_var_section(self, ctx, scope): + decls = [] + for var_item in ctx.variables().variablesItem(): + if var_item.variableDecl() is not None: + decls.append(self._visit_var_decl(var_item.variableDecl())) + return VarSection(scope=scope, declarations=tuple(decls), line=ctx.start.line) + + def _visit_var_decl(self, ctx): + return VarDecl( + name=ctx.name.text, + type_name=ctx.typeName.text, + slot=int(ctx.slot.text) if ctx.slot else None, + line=ctx.start.line, + ) + + def _visit_body(self, items): + """Shared helper for sectionBody and blockItem lists.""" + result = [] + for item in items: + if item.statement() is not None: + result.append(self._visit_statement(item.statement())) + elif item.conditional() is not None: + result.append(self._visit_conditional(item.conditional())) + return result + + def _visit_statement(self, ctx): + line = ctx.start.line + if ctx.BREAK(): + return Break(line=line) + if ctx.functionCall(): + return self._visit_function_call(ctx.functionCall()) + if ctx.EQUAL(): + target = Target.from_dotted(ctx.lhs.text) + value = self._extract_value(ctx.value()) + return Assignment(target=target, operator="=", value=value, line=line) + if ctx.PLUSEQUAL(): + target = Target.from_dotted(ctx.lhs.text) + value = self._extract_value(ctx.value()) + return Assignment(target=target, operator="+=", value=value, line=line) + if ctx.op: + return FunctionCall(name=ctx.op.text, args=(), line=line) + raise ValueError(f"Unhandled statement alternative at line {line}") + + def _visit_function_call(self, ctx): + name = ctx.funcName.text + args = () + if ctx.argumentList(): + args = tuple( + self._extract_value(v) for v in ctx.argumentList().value() + ) + return FunctionCall(name=name, args=args, line=ctx.start.line) + + def _extract_value(self, ctx): + if ctx.number is not None: + return int(ctx.number.text) + if ctx.str_ is not None: + return ctx.str_.text[1:-1] + if ctx.TRUE(): + return True + if ctx.FALSE(): + return False + if ctx.ident is not None: + return ctx.ident.text + if ctx.ip(): + return ctx.ip().getText() + if ctx.iprange(): + return tuple(ip.getText() for ip in ctx.iprange().ip()) + if ctx.paramRef(): + return ctx.paramRef().getText() + return ctx.getText() + + def _visit_conditional(self, ctx): + if_stmt = ctx.ifStatement() + condition = self._visit_condition(if_stmt.condition()) + block = if_stmt.block() + body = tuple(self._visit_body(block.blockItem())) if block else () + + elif_branches = [] + for elif_ctx in ctx.elifClause(): + elif_cond = self._visit_condition(elif_ctx.condition()) + elif_block = elif_ctx.block() + elif_body = tuple(self._visit_body(elif_block.blockItem())) if elif_block else () + elif_branches.append(ElifBranch( + condition=elif_cond, + body=elif_body, + line=elif_ctx.start.line, + )) + + else_body = () + if ctx.elseClause(): + else_block = ctx.elseClause().block() + if else_block: + else_body = tuple(self._visit_body(else_block.blockItem())) + + return IfBlock( + condition=condition, + body=body, + elif_branches=tuple(elif_branches), + else_body=else_body, + line=ctx.start.line, + ) + + def _visit_condition(self, ctx): + return self._visit_expression(ctx.expression()) + + def _visit_expression(self, ctx): + if ctx.OR(): + left = self._visit_expression(ctx.expression()) + right = self._visit_term(ctx.term()) + return LogicalOp( + operator="||", left=left, right=right, line=ctx.start.line, + ) + return self._visit_term(ctx.term()) + + def _visit_term(self, ctx): + if ctx.AND(): + left = self._visit_term(ctx.term()) + right = self._visit_factor(ctx.factor()) + return LogicalOp( + operator="&&", left=left, right=right, line=ctx.start.line, + ) + return self._visit_factor(ctx.factor()) + + def _visit_factor(self, ctx): + if ctx.getChildCount() == 2 and ctx.getChild(0).getText() == "!": + return NotOp( + operand=self._visit_factor(ctx.factor()), + line=ctx.start.line, + ) + if ctx.LPAREN(): + return self._visit_expression(ctx.expression()) + if ctx.functionCall(): + return self._visit_function_call(ctx.functionCall()) + if ctx.comparison(): + return self._visit_comparison(ctx.comparison()) + if ctx.ident is not None: + return IdentCondition(name=ctx.ident.text, line=ctx.start.line) + if ctx.TRUE(): + return BoolLiteral(value=True, line=ctx.start.line) + if ctx.FALSE(): + return BoolLiteral(value=False, line=ctx.start.line) + raise ValueError(f"Unhandled factor alternative at line {ctx.start.line}") + + def _visit_comparison(self, ctx): + line = ctx.start.line + comp = ctx.comparable() + if comp.ident is not None: + left = comp.ident.text + else: + left = self._visit_function_call(comp.functionCall()) + + operator = self._detect_comparison_operator(ctx) + right = self._extract_comparison_rhs(ctx, operator) + modifiers = self._extract_modifiers(ctx) + + return Comparison( + left=left, operator=operator, right=right, + modifiers=modifiers, line=line, + ) + + def _detect_comparison_operator(self, ctx): + if ctx.EQUALS(): + return "==" + if ctx.NEQ(): + return "!=" + if ctx.GT(): + return ">" + if ctx.LT(): + return "<" + if ctx.TILDE(): + return "~" + if ctx.NOT_TILDE(): + return "!~" + if ctx.IN(): + for child in ctx.children: + if hasattr(child, "getText") and child.getText() == "!": + return "!in" + return "in" + raise ValueError(f"Unhandled comparison operator at line {ctx.start.line}") + + def _extract_comparison_rhs(self, ctx, operator): + if operator in ("~", "!~"): + return ctx.regex().getText()[1:-1] + if operator in ("in", "!in"): + if ctx.set_(): + return tuple( + self._extract_value(v) for v in ctx.set_().value() + ) + if ctx.iprange(): + return tuple( + ip.getText() for ip in ctx.iprange().ip() + ) + if ctx.value(): + return self._extract_value(ctx.value()) + raise ValueError(f"Unhandled comparison RHS at line {ctx.start.line}") + + def _extract_modifiers(self, ctx): + if ctx.modifier(): + return tuple( + tok.text for tok in ctx.modifier().modifierList().mods + ) + return () From d36900166c8f4dd4b889ea360936bde09154b3a6 Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Wed, 29 Apr 2026 11:36:37 -0600 Subject: [PATCH 03/13] Add AST visitor tests Integration tests covering the full visitor pipeline from source text to AST nodes. Tests are organized by concern: sections and simple statements, condition expressions (all operators, logical combinators, negation, parenthesized grouping), if/elif/else blocks with nesting, real config patterns (nested conditionals, boolean coercion, IP ranges, set membership with modifiers, exact match patterns), line number tracking across all 17 node types, and error handling for unhandled grammar alternatives. --- tools/hrw4u/tests/test_ast_visitor.py | 755 ++++++++++++++++++++++++++ 1 file changed, 755 insertions(+) create mode 100644 tools/hrw4u/tests/test_ast_visitor.py diff --git a/tools/hrw4u/tests/test_ast_visitor.py b/tools/hrw4u/tests/test_ast_visitor.py new file mode 100644 index 00000000000..77fedc874ec --- /dev/null +++ b/tools/hrw4u/tests/test_ast_visitor.py @@ -0,0 +1,755 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hrw4u.ast_nodes import ( + Target, Assignment, FunctionCall, Break, Section, HRW4UAST, + Comparison, IfBlock, ElifBranch, BoolLiteral, NotOp, LogicalOp, IdentCondition, + VarSection, VarDecl, UseDirective, ProcedureDecl, ProcParam, +) +from utils import parse_input_text +from hrw4u.ast_visitor import ASTVisitor + + +def _build(source: str) -> HRW4UAST: + _, tree = parse_input_text(source) + return ASTVisitor().visit(tree) + + +class TestAssignments: + def test_simple_assignment(self): + ast = _build('REMAP {\n inbound.req.X-Foo = "test";\n}') + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.target == Target.from_dotted("inbound.req.X-Foo") + assert a.operator == "=" + assert a.value == "test" + + def test_bool_value(self): + ast = _build('SEND_RESPONSE {\n http.cntl.TXN_DEBUG = true;\n}') + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.value is True + + def test_int_value(self): + ast = _build('REMAP {\n http.cntl.INTERCEPT_RETRY = 1;\n}') + a = ast.body[0].body[0] + assert a.value == 1 + + def test_plus_equals(self): + ast = _build('REMAP {\n inbound.req.X-Foo += "extra";\n}') + a = ast.body[0].body[0] + assert a.operator == "+=" + + def test_ip_value(self): + ast = _build('REMAP {\n inbound.req.X-IP = 10.0.0.1;\n}') + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.value == "10.0.0.1" + + def test_param_ref_value(self): + src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = $tag;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.value == "$tag" + + +class TestFunctionCalls: + def test_no_args(self): + ast = _build('REMAP {\n set-debug();\n}') + fc = ast.body[0].body[0] + assert isinstance(fc, FunctionCall) + assert fc.name == "set-debug" + assert fc.args == () + + def test_with_args(self): + ast = _build('REMAP {\n set-header("X-Foo", "bar");\n}') + fc = ast.body[0].body[0] + assert fc.name == "set-header" + assert fc.args == ("X-Foo", "bar") + + def test_standalone_operator(self): + ast = _build('REMAP {\n skip-remap;\n}') + fc = ast.body[0].body[0] + assert isinstance(fc, FunctionCall) + assert fc.name == "skip-remap" + assert fc.args == () + + def test_break(self): + ast = _build('REMAP {\n if true {\n break;\n }\n}') + body = ast.body[0].body[0].body + assert isinstance(body[0], Break) + + +class TestSections: + def test_section_type(self): + ast = _build('REMAP {\n set-debug();\n}') + s = ast.body[0] + assert isinstance(s, Section) + assert s.type == "REMAP" + + def test_multiple_sections(self): + src = 'REMAP {\n set-debug();\n}\nSEND_RESPONSE {\n set-debug();\n}' + ast = _build(src) + sections = [i for i in ast.body if isinstance(i, Section)] + assert len(sections) == 2 + assert sections[0].type == "REMAP" + assert sections[1].type == "SEND_RESPONSE" + + def test_use_directive(self): + src = 'use test::add-debug-header\nREMAP {\n test::add-debug-header("tag");\n}' + ast = _build(src) + assert len(ast.body) == 2 + u = ast.body[0] + assert isinstance(u, UseDirective) + assert u.spec == "test::add-debug-header" + + def test_item_ordering(self): + src = 'VARS {\n x: bool;\n}\nREMAP {\n set-debug();\n}\nSEND_RESPONSE {\n set-debug();\n}' + ast = _build(src) + assert len(ast.body) == 3 + assert isinstance(ast.body[0], VarSection) + assert isinstance(ast.body[1], Section) + assert isinstance(ast.body[2], Section) + + +class TestVarSections: + def test_txn_scope(self): + src = 'VARS {\n flag: bool;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert vs.scope == "txn" + assert len(vs.declarations) == 1 + assert vs.declarations[0].name == "flag" + assert vs.declarations[0].type_name == "bool" + assert vs.declarations[0].slot is None + + def test_session_scope(self): + src = 'SESSION_VARS {\n counter: int;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert vs.scope == "session" + assert vs.declarations[0].name == "counter" + + def test_slot(self): + src = 'VARS {\n x: int @3;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert vs.declarations[0].slot == 3 + + def test_multiple_declarations(self): + src = 'VARS {\n a: bool;\n b: int;\n c: string;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert len(vs.declarations) == 3 + assert vs.declarations[0].name == "a" + assert vs.declarations[1].name == "b" + assert vs.declarations[2].name == "c" + + +class TestProcedures: + def test_basic_decl(self): + src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = "$tag";\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + pd = ast.body[0] + assert isinstance(pd, ProcedureDecl) + assert pd.name == "local::stamp" + assert len(pd.params) == 1 + assert pd.params[0].name == "tag" + assert pd.params[0].default is None + + def test_default_param(self): + src = 'procedure local::cache($ttl=300) {\n set-debug();\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + pd = ast.body[0] + assert isinstance(pd, ProcedureDecl) + assert pd.params[0].name == "ttl" + assert pd.params[0].default == 300 + + def test_body(self): + src = ('procedure local::multi() {\n inbound.req.X = "a";\n' + ' set-debug();\n}\nREMAP {\n set-debug();\n}') + ast = _build(src) + pd = ast.body[0] + assert isinstance(pd, ProcedureDecl) + assert len(pd.body) == 2 + assert isinstance(pd.body[0], Assignment) + assert isinstance(pd.body[1], FunctionCall) + + +class TestConditionExpressions: + def _first_condition(self, source: str): + ast = _build(source) + return ast.body[0].body[0].condition + + def test_equality_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-Foo == "bar" {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.left == "inbound.req.X-Foo" + assert cond.operator == "==" + assert cond.right == "bar" + assert cond.modifiers == () + + def test_regex_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.url.path ~ /\\.php$/ {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "~" + assert isinstance(cond.right, str) + + def test_in_set(self): + cond = self._first_condition( + 'REMAP {\n if inbound.url.path in ["a", "b"] {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert cond.right == ("a", "b") + + def test_not_in_set(self): + cond = self._first_condition( + 'REMAP {\n if inbound.url.path !in ["a"] {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "!in" + + def test_in_iprange(self): + cond = self._first_condition( + 'REMAP {\n if inbound.ip in {10.0.0.0/8} {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert cond.right == ("10.0.0.0/8",) + + def test_modifiers(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-Foo == "bar" with NOCASE {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.modifiers == ("NOCASE",) + + def test_function_call_comparable(self): + cond = self._first_condition( + 'REMAP {\n if url(true) ~ /pat/ {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert isinstance(cond.left, FunctionCall) + assert cond.left.name == "url" + assert cond.left.args == (True,) + + def test_bool_literal_true(self): + cond = self._first_condition( + 'REMAP {\n if true {\n set-debug();\n }\n}' + ) + assert isinstance(cond, BoolLiteral) + assert cond.value is True + + def test_ident_condition(self): + cond = self._first_condition( + 'REMAP {\n if inbound.resp.All-Cache {\n set-debug();\n }\n}' + ) + assert isinstance(cond, IdentCondition) + assert cond.name == "inbound.resp.All-Cache" + + def test_not_condition(self): + cond = self._first_condition( + 'REMAP {\n if !inbound.resp.All-Cache {\n set-debug();\n }\n}' + ) + assert isinstance(cond, NotOp) + assert isinstance(cond.operand, IdentCondition) + + def test_and_condition(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-A == "a" && inbound.req.X-B == "b" {\n set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, Comparison) + assert isinstance(cond.right, Comparison) + + def test_or_condition(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-A == "a" || inbound.req.X-B == "b" {\n set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "||" + + def test_function_call_in_condition(self): + cond = self._first_condition( + 'REMAP {\n if access("/tmp/bar") {\n set-debug();\n }\n}' + ) + assert isinstance(cond, FunctionCall) + assert cond.name == "access" + assert cond.args == ("/tmp/bar",) + + def test_not_tilde_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.url.path !~ /\\.jpg$/ {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "!~" + assert isinstance(cond.right, str) + + def test_greater_than_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.Content-Length > 1000 {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == ">" + assert cond.right == 1000 + + def test_less_than_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.Content-Length < 500 {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "<" + assert cond.right == 500 + + def test_neq_comparison(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-Foo != "bar" {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "!=" + assert cond.right == "bar" + + def test_parenthesized_condition(self): + cond = self._first_condition( + 'REMAP {\n if (inbound.req.X-Foo == "bar") {\n set-debug();\n }\n}' + ) + assert isinstance(cond, Comparison) + assert cond.operator == "==" + assert cond.right == "bar" + + def test_and_binds_tighter_than_or(self): + # a || b && c should parse as a || (b && c) + cond = self._first_condition( + 'REMAP {\n' + ' if inbound.req.X-A == "a" || inbound.req.X-B == "b" && inbound.req.X-C == "c" {\n' + ' set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "||" + assert isinstance(cond.left, Comparison) + assert cond.left.left == "inbound.req.X-A" + assert isinstance(cond.right, LogicalOp) + assert cond.right.operator == "&&" + assert cond.right.left.left == "inbound.req.X-B" + assert cond.right.right.left == "inbound.req.X-C" + + def test_not_with_and(self): + # !ident && comparison should parse as (!ident) && comparison + cond = self._first_condition( + 'REMAP {\n' + ' if !inbound.resp.All-Cache && inbound.req.X-B == "b" {\n' + ' set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, NotOp) + assert isinstance(cond.left.operand, IdentCondition) + assert cond.left.operand.name == "inbound.resp.All-Cache" + assert isinstance(cond.right, Comparison) + assert cond.right.left == "inbound.req.X-B" + + def test_not_comparison_with_or(self): + # !(a == "x") || b == "y" should parse as (!(a == "x")) || (b == "y") + cond = self._first_condition( + 'REMAP {\n' + ' if !(inbound.req.X-A == "x") || inbound.req.X-B == "y" {\n' + ' set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "||" + assert isinstance(cond.left, NotOp) + assert isinstance(cond.left.operand, Comparison) + assert cond.left.operand.left == "inbound.req.X-A" + assert cond.left.operand.right == "x" + assert isinstance(cond.right, Comparison) + assert cond.right.left == "inbound.req.X-B" + + def test_double_negation(self): + cond = self._first_condition( + 'REMAP {\n if !!inbound.resp.All-Cache {\n set-debug();\n }\n}' + ) + assert isinstance(cond, NotOp) + assert isinstance(cond.operand, NotOp) + assert isinstance(cond.operand.operand, IdentCondition) + assert cond.operand.operand.name == "inbound.resp.All-Cache" + + def test_not_bool_literal(self): + cond = self._first_condition( + 'REMAP {\n if !false {\n set-debug();\n }\n}' + ) + assert isinstance(cond, NotOp) + assert isinstance(cond.operand, BoolLiteral) + assert cond.operand.value is False + + def test_parens_override_precedence(self): + # (a || b) && c — parens force || to bind first + cond = self._first_condition( + 'REMAP {\n' + ' if (inbound.req.X-A == "a" || inbound.req.X-B == "b") && inbound.req.X-C == "c" {\n' + ' set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, LogicalOp) + assert cond.left.operator == "||" + assert cond.left.left.left == "inbound.req.X-A" + assert cond.left.right.left == "inbound.req.X-B" + assert isinstance(cond.right, Comparison) + assert cond.right.left == "inbound.req.X-C" + + def test_nested_parens_with_not(self): + # !(a == "x" || b == "y") && c == "z" + cond = self._first_condition( + 'REMAP {\n' + ' if !(inbound.req.X-A == "x" || inbound.req.X-B == "y") && inbound.req.X-C == "z" {\n' + ' set-debug();\n }\n}' + ) + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, NotOp) + assert isinstance(cond.left.operand, LogicalOp) + assert cond.left.operand.operator == "||" + assert isinstance(cond.right, Comparison) + assert cond.right.left == "inbound.req.X-C" + + +class TestIfBlocks: + def test_simple_if(self): + ast = _build( + 'REMAP {\n if true {\n inbound.req.X = "y";\n }\n}' + ) + ib = ast.body[0].body[0] + assert isinstance(ib, IfBlock) + assert len(ib.body) == 1 + assert ib.elif_branches == () + assert ib.else_body == () + + def test_if_else(self): + src = 'REMAP {\n if true {\n inbound.req.X = "a";\n } else {\n inbound.req.X = "b";\n }\n}' + ast = _build(src) + ib = ast.body[0].body[0] + assert len(ib.else_body) == 1 + + def test_if_elif_else(self): + src = ('SEND_RESPONSE {\n if inbound.url.path == "foo" {\n' + ' inbound.resp.X = "f";\n } elif inbound.url.path == "bar" {\n' + ' inbound.resp.X = "b";\n } else {\n' + ' inbound.resp.X = "other";\n }\n}') + ast = _build(src) + ib = ast.body[0].body[0] + assert isinstance(ib, IfBlock) + assert len(ib.elif_branches) == 1 + assert isinstance(ib.elif_branches[0], ElifBranch) + assert len(ib.elif_branches[0].body) == 1 + assert len(ib.else_body) == 1 + + def test_multiple_elif(self): + src = ('SEND_RESPONSE {\n if inbound.url.path == "a" {\n set-debug();\n' + ' } elif inbound.url.path == "b" {\n set-debug();\n' + ' } elif inbound.url.path == "c" {\n set-debug();\n' + ' } else {\n set-debug();\n }\n}') + ast = _build(src) + ib = ast.body[0].body[0] + assert len(ib.elif_branches) == 2 + + def test_nested_if(self): + src = ('REMAP {\n if inbound.req.X == "a" {\n' + ' if inbound.req.Y == "b" {\n set-debug();\n }\n }\n}') + ast = _build(src) + outer = ast.body[0].body[0] + assert isinstance(outer, IfBlock) + inner = outer.body[0] + assert isinstance(inner, IfBlock) + + def test_mixed_body(self): + src = ('REMAP {\n inbound.req.X = "before";\n' + ' if true {\n set-debug();\n }\n' + ' inbound.req.Y = "after";\n}') + ast = _build(src) + body = ast.body[0].body + assert len(body) == 3 + assert isinstance(body[0], Assignment) + assert isinstance(body[1], IfBlock) + assert isinstance(body[2], Assignment) + + +class TestLineNumbers: + SRC = ( + "use test::helper\n" # line 1 + "VARS {\n" # line 2 + " flag: bool;\n" # line 3 + "}\n" # line 4 + "procedure local::stamp($tag) {\n" # line 5 + " inbound.req.X-Stamp = $tag;\n" # line 6 + "}\n" # line 7 + "REMAP {\n" # line 8 + ' inbound.req.X-Foo = "val";\n' # line 9 + " set-debug();\n" # line 10 + " skip-remap;\n" # line 11 + ' if inbound.req.X-A == "a" {\n' # line 12 + " break;\n" # line 13 + ' } elif inbound.req.X-B == "b" {\n' # line 14 + ' inbound.req.X = "elif";\n' # line 15 + " } else {\n" # line 16 + ' inbound.req.X = "else";\n' # line 17 + " }\n" # line 18 + ' if inbound.req.X-C == "c" && inbound.req.X-D == "d" {\n' # line 19 + " set-debug();\n" # line 20 + " }\n" # line 21 + " if !inbound.resp.All-Cache {\n" # line 22 + " set-debug();\n" # line 23 + " }\n" # line 24 + " if true {\n" # line 25 + " set-debug();\n" # line 26 + " }\n" # line 27 + " if inbound.resp.All-Cache {\n" # line 28 + " set-debug();\n" # line 29 + " }\n" # line 30 + "}\n" # line 31 + ) + + def setup_method(self): + self.ast = _build(self.SRC) + + def test_use_directive(self): + u = self.ast.body[0] + assert isinstance(u, UseDirective) + assert u.line == 1 + + def test_var_section(self): + vs = self.ast.body[1] + assert isinstance(vs, VarSection) + assert vs.line == 2 + + def test_var_decl(self): + vd = self.ast.body[1].declarations[0] + assert isinstance(vd, VarDecl) + assert vd.line == 3 + + def test_procedure_decl(self): + pd = self.ast.body[2] + assert isinstance(pd, ProcedureDecl) + assert pd.line == 5 + + def test_proc_param(self): + pp = self.ast.body[2].params[0] + assert isinstance(pp, ProcParam) + assert pp.line == 5 + + def test_procedure_body_assignment(self): + a = self.ast.body[2].body[0] + assert isinstance(a, Assignment) + assert a.line == 6 + + def test_section(self): + s = self.ast.body[3] + assert isinstance(s, Section) + assert s.line == 8 + + def test_assignment(self): + a = self.ast.body[3].body[0] + assert isinstance(a, Assignment) + assert a.line == 9 + + def test_function_call(self): + fc = self.ast.body[3].body[1] + assert isinstance(fc, FunctionCall) + assert fc.line == 10 + + def test_standalone_operator(self): + fc = self.ast.body[3].body[2] + assert isinstance(fc, FunctionCall) + assert fc.line == 11 + + def test_if_block(self): + ib = self.ast.body[3].body[3] + assert isinstance(ib, IfBlock) + assert ib.line == 12 + + def test_comparison_in_condition(self): + cond = self.ast.body[3].body[3].condition + assert isinstance(cond, Comparison) + assert cond.line == 12 + + def test_break(self): + brk = self.ast.body[3].body[3].body[0] + assert isinstance(brk, Break) + assert brk.line == 13 + + def test_elif_branch(self): + eb = self.ast.body[3].body[3].elif_branches[0] + assert isinstance(eb, ElifBranch) + assert eb.line == 14 + + def test_elif_condition(self): + cond = self.ast.body[3].body[3].elif_branches[0].condition + assert isinstance(cond, Comparison) + assert cond.line == 14 + + def test_logical_op(self): + cond = self.ast.body[3].body[4].condition + assert isinstance(cond, LogicalOp) + assert cond.line == 19 + + def test_not_op(self): + cond = self.ast.body[3].body[5].condition + assert isinstance(cond, NotOp) + assert cond.line == 22 + + def test_bool_literal(self): + cond = self.ast.body[3].body[6].condition + assert isinstance(cond, BoolLiteral) + assert cond.line == 25 + + def test_ident_condition(self): + cond = self.ast.body[3].body[7].condition + assert isinstance(cond, IdentCondition) + assert cond.line == 28 + + +class TestRealConfigs: + def test_nested_ifs_from_test_data(self): + """Validates AST for tests/data/conds/nested-ifs.input.txt pattern.""" + src = '''VARS { + bool_0: bool; + bool_1: bool; + bool_2: bool; +} + +REMAP { + if inbound.req.X-Foo == "bar" { + inbound.req.X-Hello = "there"; + if inbound.req.X-Fie == "fie" { + inbound.req.X-first = "1"; + if bool_0 || (bool_1 && bool_2) { + inbound.req.X-Parsed = "more"; + } else { + inbound.req.X-Parsed = "yes"; + } + } elif inbound.req.X-Fum == "bar" { + inbound.req.X-Parsed = "no"; + } else { + inbound.req.X-More = "yes"; + } + } elif inbound.req.X-Foo == "foo" with NOCASE,PRE { + inbound.req.X-Nocase = "foo"; + } else { + inbound.req.X-Something = "no-bar"; + } +}''' + ast = _build(src) + sections = [i for i in ast.body if isinstance(i, Section)] + assert len(sections) == 1 + s = sections[0] + assert s.type == "REMAP" + + # Top-level if block + outer = s.body[0] + assert isinstance(outer, IfBlock) + + # Body: assignment + nested if + assert isinstance(outer.body[0], Assignment) + assert isinstance(outer.body[1], IfBlock) + middle = outer.body[1] + + # Middle if has elif and else + assert len(middle.elif_branches) == 1 + assert len(middle.else_body) == 1 + + # Deepest nested if (3 levels) + inner = middle.body[1] + assert isinstance(inner, IfBlock) + assert isinstance(inner.condition, LogicalOp) + assert inner.condition.operator == "||" + + # Outer elif has modifiers + assert len(outer.elif_branches) == 1 + elif_cond = outer.elif_branches[0].condition + assert isinstance(elif_cond, Comparison) + assert elif_cond.modifiers == ("NOCASE", "PRE") + + def test_http_cntl_booleans(self): + """Validates value coercion for boolean-like assignments.""" + src = '''SEND_RESPONSE { + http.cntl.TXN_DEBUG = true; + http.cntl.LOGGING = FALSE; +}''' + ast = _build(src) + body = ast.body[0].body + assert body[0].value is True + assert body[1].value is False + + def test_ip_range_condition(self): + """Validates IP range handling from tests/data/conds/ip.input.txt.""" + src = '''SEND_REQUEST { + if inbound.ip in {192.168.0.0/16, 10.0.0.0/8} { + set-debug(); + } +}''' + ast = _build(src) + cond = ast.body[0].body[0].condition + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert len(cond.right) == 2 + + def test_set_membership_with_modifier(self): + """From tests/data/conds/in-sets.input.txt.""" + src = '''REMAP { + if inbound.url.path in ["php", "php3", "php4"] with EXT { + inbound.req.X-Is-PHP = "yes"; + } +}''' + ast = _build(src) + cond = ast.body[0].body[0].condition + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert cond.right == ("php", "php3", "php4") + assert cond.modifiers == ("EXT",) + + def test_debug_pattern_for_lint_rules(self): + """Validates the exact pattern the no-debug lint rule will match.""" + src = '''REMAP { + set-debug(); + http.cntl.TXN_DEBUG = true; + inbound.req.X-Foo = "test"; +}''' + ast = _build(src) + body = ast.body[0].body + + # set-debug() function call + assert isinstance(body[0], FunctionCall) + assert body[0].name == "set-debug" + + # TXN_DEBUG assignment with True + assert isinstance(body[1], Assignment) + assert body[1].target == Target.from_dotted("http.cntl.TXN_DEBUG") + assert body[1].value is True + + # Regular assignment (not flagged) + assert isinstance(body[2], Assignment) + assert body[2].target.namespace == "inbound.req" From 3898823c8d79c74193400806e2573983f0fbd597 Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Wed, 29 Apr 2026 15:57:40 -0600 Subject: [PATCH 04/13] Add tagged Value type to AST for codegen support Introduce ValueKind enum and Value dataclass to preserve semantic distinction between string literals, identifiers, param refs, IPs, and regexes in the AST. Without this, the codegen visitor cannot re-emit values correctly since _extract_value was collapsing all string-like values into bare Python str. --- tools/hrw4u/src/ast_nodes.py | 25 ++++++++++--- tools/hrw4u/src/ast_visitor.py | 20 +++++----- tools/hrw4u/tests/test_ast_visitor.py | 53 ++++++++++++++------------- 3 files changed, 59 insertions(+), 39 deletions(-) diff --git a/tools/hrw4u/src/ast_nodes.py b/tools/hrw4u/src/ast_nodes.py index aee65c02a71..8667290cc15 100644 --- a/tools/hrw4u/src/ast_nodes.py +++ b/tools/hrw4u/src/ast_nodes.py @@ -18,9 +18,24 @@ from __future__ import annotations from dataclasses import dataclass +from enum import Enum from typing import Union +class ValueKind(Enum): + STRING = "string" + IDENT = "ident" + PARAM_REF = "param_ref" + IP = "ip" + REGEX = "regex" + + +@dataclass(frozen=True, kw_only=True) +class Value: + raw: str + kind: ValueKind + + @dataclass(frozen=True, kw_only=True) class Node: line: int @@ -46,13 +61,13 @@ def from_dotted(name: str) -> Target: class Assignment(Node): target: Target operator: str # "=" or "+=" - value: str | int | bool | tuple + value: Value | int | bool | tuple @dataclass(frozen=True, kw_only=True) class FunctionCall(Node): name: str - args: tuple[str | int | bool, ...] + args: tuple[Value | int | bool, ...] @dataclass(frozen=True, kw_only=True) @@ -62,9 +77,9 @@ class Break(Node): @dataclass(frozen=True, kw_only=True) class Comparison(Node): - left: str | FunctionCall + left: Value | FunctionCall operator: str # "==", "!=", ">", "<", "~", "!~", "in", "!in" - right: str | int | bool | tuple + right: Value | int | bool | tuple modifiers: tuple[str, ...] @@ -113,7 +128,7 @@ class Section(Node): @dataclass(frozen=True, kw_only=True) class ProcParam(Node): name: str - default: str | int | bool | None + default: Value | int | bool | None @dataclass(frozen=True, kw_only=True) diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py index 8b5c63aab77..a847d929648 100644 --- a/tools/hrw4u/src/ast_visitor.py +++ b/tools/hrw4u/src/ast_visitor.py @@ -37,6 +37,8 @@ VarSection, UseDirective, ProcedureDecl, + Value, + ValueKind, ) @@ -145,20 +147,20 @@ def _extract_value(self, ctx): if ctx.number is not None: return int(ctx.number.text) if ctx.str_ is not None: - return ctx.str_.text[1:-1] + return Value(raw=ctx.str_.text[1:-1], kind=ValueKind.STRING) if ctx.TRUE(): return True if ctx.FALSE(): return False if ctx.ident is not None: - return ctx.ident.text + return Value(raw=ctx.ident.text, kind=ValueKind.IDENT) if ctx.ip(): - return ctx.ip().getText() + return Value(raw=ctx.ip().getText(), kind=ValueKind.IP) if ctx.iprange(): - return tuple(ip.getText() for ip in ctx.iprange().ip()) + return tuple(Value(raw=ip.getText(), kind=ValueKind.IP) for ip in ctx.iprange().ip()) if ctx.paramRef(): - return ctx.paramRef().getText() - return ctx.getText() + return Value(raw=ctx.paramRef().IDENT().getText(), kind=ValueKind.PARAM_REF) + return Value(raw=ctx.getText(), kind=ValueKind.IDENT) def _visit_conditional(self, ctx): if_stmt = ctx.ifStatement() @@ -236,7 +238,7 @@ def _visit_comparison(self, ctx): line = ctx.start.line comp = ctx.comparable() if comp.ident is not None: - left = comp.ident.text + left = Value(raw=comp.ident.text, kind=ValueKind.IDENT) else: left = self._visit_function_call(comp.functionCall()) @@ -271,7 +273,7 @@ def _detect_comparison_operator(self, ctx): def _extract_comparison_rhs(self, ctx, operator): if operator in ("~", "!~"): - return ctx.regex().getText()[1:-1] + return Value(raw=ctx.regex().getText()[1:-1], kind=ValueKind.REGEX) if operator in ("in", "!in"): if ctx.set_(): return tuple( @@ -279,7 +281,7 @@ def _extract_comparison_rhs(self, ctx, operator): ) if ctx.iprange(): return tuple( - ip.getText() for ip in ctx.iprange().ip() + Value(raw=ip.getText(), kind=ValueKind.IP) for ip in ctx.iprange().ip() ) if ctx.value(): return self._extract_value(ctx.value()) diff --git a/tools/hrw4u/tests/test_ast_visitor.py b/tools/hrw4u/tests/test_ast_visitor.py index 77fedc874ec..6d3cf574b05 100644 --- a/tools/hrw4u/tests/test_ast_visitor.py +++ b/tools/hrw4u/tests/test_ast_visitor.py @@ -19,6 +19,7 @@ Target, Assignment, FunctionCall, Break, Section, HRW4UAST, Comparison, IfBlock, ElifBranch, BoolLiteral, NotOp, LogicalOp, IdentCondition, VarSection, VarDecl, UseDirective, ProcedureDecl, ProcParam, + Value, ValueKind, ) from utils import parse_input_text from hrw4u.ast_visitor import ASTVisitor @@ -36,7 +37,7 @@ def test_simple_assignment(self): assert isinstance(a, Assignment) assert a.target == Target.from_dotted("inbound.req.X-Foo") assert a.operator == "=" - assert a.value == "test" + assert a.value == Value(raw="test", kind=ValueKind.STRING) def test_bool_value(self): ast = _build('SEND_RESPONSE {\n http.cntl.TXN_DEBUG = true;\n}') @@ -58,14 +59,14 @@ def test_ip_value(self): ast = _build('REMAP {\n inbound.req.X-IP = 10.0.0.1;\n}') a = ast.body[0].body[0] assert isinstance(a, Assignment) - assert a.value == "10.0.0.1" + assert a.value == Value(raw="10.0.0.1", kind=ValueKind.IP) def test_param_ref_value(self): src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = $tag;\n}\nREMAP {\n set-debug();\n}' ast = _build(src) a = ast.body[0].body[0] assert isinstance(a, Assignment) - assert a.value == "$tag" + assert a.value == Value(raw="tag", kind=ValueKind.PARAM_REF) class TestFunctionCalls: @@ -80,7 +81,7 @@ def test_with_args(self): ast = _build('REMAP {\n set-header("X-Foo", "bar");\n}') fc = ast.body[0].body[0] assert fc.name == "set-header" - assert fc.args == ("X-Foo", "bar") + assert fc.args == (Value(raw="X-Foo", kind=ValueKind.STRING), Value(raw="bar", kind=ValueKind.STRING)) def test_standalone_operator(self): ast = _build('REMAP {\n skip-remap;\n}') @@ -205,9 +206,9 @@ def test_equality_comparison(self): 'REMAP {\n if inbound.req.X-Foo == "bar" {\n set-debug();\n }\n}' ) assert isinstance(cond, Comparison) - assert cond.left == "inbound.req.X-Foo" + assert cond.left == Value(raw="inbound.req.X-Foo", kind=ValueKind.IDENT) assert cond.operator == "==" - assert cond.right == "bar" + assert cond.right == Value(raw="bar", kind=ValueKind.STRING) assert cond.modifiers == () def test_regex_comparison(self): @@ -216,7 +217,8 @@ def test_regex_comparison(self): ) assert isinstance(cond, Comparison) assert cond.operator == "~" - assert isinstance(cond.right, str) + assert isinstance(cond.right, Value) + assert cond.right.kind == ValueKind.REGEX def test_in_set(self): cond = self._first_condition( @@ -224,7 +226,7 @@ def test_in_set(self): ) assert isinstance(cond, Comparison) assert cond.operator == "in" - assert cond.right == ("a", "b") + assert cond.right == (Value(raw="a", kind=ValueKind.STRING), Value(raw="b", kind=ValueKind.STRING)) def test_not_in_set(self): cond = self._first_condition( @@ -239,7 +241,7 @@ def test_in_iprange(self): ) assert isinstance(cond, Comparison) assert cond.operator == "in" - assert cond.right == ("10.0.0.0/8",) + assert cond.right == (Value(raw="10.0.0.0/8", kind=ValueKind.IP),) def test_modifiers(self): cond = self._first_condition( @@ -300,7 +302,7 @@ def test_function_call_in_condition(self): ) assert isinstance(cond, FunctionCall) assert cond.name == "access" - assert cond.args == ("/tmp/bar",) + assert cond.args == (Value(raw="/tmp/bar", kind=ValueKind.STRING),) def test_not_tilde_comparison(self): cond = self._first_condition( @@ -308,7 +310,8 @@ def test_not_tilde_comparison(self): ) assert isinstance(cond, Comparison) assert cond.operator == "!~" - assert isinstance(cond.right, str) + assert isinstance(cond.right, Value) + assert cond.right.kind == ValueKind.REGEX def test_greater_than_comparison(self): cond = self._first_condition( @@ -332,7 +335,7 @@ def test_neq_comparison(self): ) assert isinstance(cond, Comparison) assert cond.operator == "!=" - assert cond.right == "bar" + assert cond.right == Value(raw="bar", kind=ValueKind.STRING) def test_parenthesized_condition(self): cond = self._first_condition( @@ -340,7 +343,7 @@ def test_parenthesized_condition(self): ) assert isinstance(cond, Comparison) assert cond.operator == "==" - assert cond.right == "bar" + assert cond.right == Value(raw="bar", kind=ValueKind.STRING) def test_and_binds_tighter_than_or(self): # a || b && c should parse as a || (b && c) @@ -352,11 +355,11 @@ def test_and_binds_tighter_than_or(self): assert isinstance(cond, LogicalOp) assert cond.operator == "||" assert isinstance(cond.left, Comparison) - assert cond.left.left == "inbound.req.X-A" + assert cond.left.left == Value(raw="inbound.req.X-A", kind=ValueKind.IDENT) assert isinstance(cond.right, LogicalOp) assert cond.right.operator == "&&" - assert cond.right.left.left == "inbound.req.X-B" - assert cond.right.right.left == "inbound.req.X-C" + assert cond.right.left.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) + assert cond.right.right.left == Value(raw="inbound.req.X-C", kind=ValueKind.IDENT) def test_not_with_and(self): # !ident && comparison should parse as (!ident) && comparison @@ -371,7 +374,7 @@ def test_not_with_and(self): assert isinstance(cond.left.operand, IdentCondition) assert cond.left.operand.name == "inbound.resp.All-Cache" assert isinstance(cond.right, Comparison) - assert cond.right.left == "inbound.req.X-B" + assert cond.right.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) def test_not_comparison_with_or(self): # !(a == "x") || b == "y" should parse as (!(a == "x")) || (b == "y") @@ -384,10 +387,10 @@ def test_not_comparison_with_or(self): assert cond.operator == "||" assert isinstance(cond.left, NotOp) assert isinstance(cond.left.operand, Comparison) - assert cond.left.operand.left == "inbound.req.X-A" - assert cond.left.operand.right == "x" + assert cond.left.operand.left == Value(raw="inbound.req.X-A", kind=ValueKind.IDENT) + assert cond.left.operand.right == Value(raw="x", kind=ValueKind.STRING) assert isinstance(cond.right, Comparison) - assert cond.right.left == "inbound.req.X-B" + assert cond.right.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) def test_double_negation(self): cond = self._first_condition( @@ -417,10 +420,10 @@ def test_parens_override_precedence(self): assert cond.operator == "&&" assert isinstance(cond.left, LogicalOp) assert cond.left.operator == "||" - assert cond.left.left.left == "inbound.req.X-A" - assert cond.left.right.left == "inbound.req.X-B" + assert cond.left.left.left == Value(raw="inbound.req.X-A", kind=ValueKind.IDENT) + assert cond.left.right.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) assert isinstance(cond.right, Comparison) - assert cond.right.left == "inbound.req.X-C" + assert cond.right.left == Value(raw="inbound.req.X-C", kind=ValueKind.IDENT) def test_nested_parens_with_not(self): # !(a == "x" || b == "y") && c == "z" @@ -435,7 +438,7 @@ def test_nested_parens_with_not(self): assert isinstance(cond.left.operand, LogicalOp) assert cond.left.operand.operator == "||" assert isinstance(cond.right, Comparison) - assert cond.right.left == "inbound.req.X-C" + assert cond.right.left == Value(raw="inbound.req.X-C", kind=ValueKind.IDENT) class TestIfBlocks: @@ -728,7 +731,7 @@ def test_set_membership_with_modifier(self): cond = ast.body[0].body[0].condition assert isinstance(cond, Comparison) assert cond.operator == "in" - assert cond.right == ("php", "php3", "php4") + assert cond.right == (Value(raw="php", kind=ValueKind.STRING), Value(raw="php3", kind=ValueKind.STRING), Value(raw="php4", kind=ValueKind.STRING)) assert cond.modifiers == ("EXT",) def test_debug_pattern_for_lint_rules(self): From 8b9cf3569a7a9175837c0ca0d17e00008e84e3b8 Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Wed, 29 Apr 2026 16:24:27 -0600 Subject: [PATCH 05/13] Handle commentLine explicitly in visitProgram Skip comments intentionally and raise on unrecognized programItem alternatives to catch visitor/grammar drift. --- tools/hrw4u/src/ast_visitor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py index a847d929648..ec8fcedf1aa 100644 --- a/tools/hrw4u/src/ast_visitor.py +++ b/tools/hrw4u/src/ast_visitor.py @@ -59,6 +59,10 @@ def visitProgram(self, ctx): items.append(self._visit_procedure_decl(item.procedureDecl())) elif item.section() is not None: items.append(self._visit_section(item.section())) + elif item.commentLine() is not None: + pass + else: + raise ValueError(f"Unhandled programItem alternative at line {item.start.line}") return HRW4UAST(body=tuple(items)) def _visit_use_directive(self, ctx): From 2685f2629ab4a05ce4679639ad6dc3321b2e63dd Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Mon, 4 May 2026 14:20:56 -0600 Subject: [PATCH 06/13] Replace tagged Value type with concrete value types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Split Value/ValueKind into LiteralStringValue, IdentValue, IPValue, ParamRef, and RegexValue so the type system encodes what kind of value each node carries. This makes the AST more precise — e.g. iprange is now tuple[IPValue, ...] instead of tuple[Value, ...], and RegexValue only appears in Comparison.right where the grammar allows it. --- tools/hrw4u/src/ast_nodes.py | 41 +++++++++++++------- tools/hrw4u/src/ast_visitor.py | 25 +++++++------ tools/hrw4u/tests/test_ast_visitor.py | 54 +++++++++++++-------------- 3 files changed, 67 insertions(+), 53 deletions(-) diff --git a/tools/hrw4u/src/ast_nodes.py b/tools/hrw4u/src/ast_nodes.py index 8667290cc15..846265f001a 100644 --- a/tools/hrw4u/src/ast_nodes.py +++ b/tools/hrw4u/src/ast_nodes.py @@ -18,22 +18,35 @@ from __future__ import annotations from dataclasses import dataclass -from enum import Enum from typing import Union -class ValueKind(Enum): - STRING = "string" - IDENT = "ident" - PARAM_REF = "param_ref" - IP = "ip" - REGEX = "regex" +@dataclass(frozen=True, kw_only=True) +class LiteralStringValue: + raw: str + + +@dataclass(frozen=True, kw_only=True) +class IdentValue: + raw: str + + +@dataclass(frozen=True, kw_only=True) +class IPValue: + raw: str + + +@dataclass(frozen=True, kw_only=True) +class ParamRef: + raw: str @dataclass(frozen=True, kw_only=True) -class Value: +class RegexValue: raw: str - kind: ValueKind + + +ValueExpr = Union[LiteralStringValue, IdentValue, IPValue, ParamRef, int, bool, tuple[IPValue, ...]] @dataclass(frozen=True, kw_only=True) @@ -61,13 +74,13 @@ def from_dotted(name: str) -> Target: class Assignment(Node): target: Target operator: str # "=" or "+=" - value: Value | int | bool | tuple + value: ValueExpr @dataclass(frozen=True, kw_only=True) class FunctionCall(Node): name: str - args: tuple[Value | int | bool, ...] + args: tuple[ValueExpr, ...] @dataclass(frozen=True, kw_only=True) @@ -77,9 +90,9 @@ class Break(Node): @dataclass(frozen=True, kw_only=True) class Comparison(Node): - left: Value | FunctionCall + left: IdentValue | FunctionCall operator: str # "==", "!=", ">", "<", "~", "!~", "in", "!in" - right: Value | int | bool | tuple + right: ValueExpr | RegexValue | tuple[ValueExpr, ...] modifiers: tuple[str, ...] @@ -128,7 +141,7 @@ class Section(Node): @dataclass(frozen=True, kw_only=True) class ProcParam(Node): name: str - default: Value | int | bool | None + default: ValueExpr | None @dataclass(frozen=True, kw_only=True) diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py index ec8fcedf1aa..3e8619001e4 100644 --- a/tools/hrw4u/src/ast_visitor.py +++ b/tools/hrw4u/src/ast_visitor.py @@ -37,8 +37,11 @@ VarSection, UseDirective, ProcedureDecl, - Value, - ValueKind, + LiteralStringValue, + IdentValue, + IPValue, + ParamRef, + RegexValue, ) @@ -151,20 +154,20 @@ def _extract_value(self, ctx): if ctx.number is not None: return int(ctx.number.text) if ctx.str_ is not None: - return Value(raw=ctx.str_.text[1:-1], kind=ValueKind.STRING) + return LiteralStringValue(raw=ctx.str_.text[1:-1]) if ctx.TRUE(): return True if ctx.FALSE(): return False if ctx.ident is not None: - return Value(raw=ctx.ident.text, kind=ValueKind.IDENT) + return IdentValue(raw=ctx.ident.text) if ctx.ip(): - return Value(raw=ctx.ip().getText(), kind=ValueKind.IP) + return IPValue(raw=ctx.ip().getText()) if ctx.iprange(): - return tuple(Value(raw=ip.getText(), kind=ValueKind.IP) for ip in ctx.iprange().ip()) + return tuple(IPValue(raw=ip.getText()) for ip in ctx.iprange().ip()) if ctx.paramRef(): - return Value(raw=ctx.paramRef().IDENT().getText(), kind=ValueKind.PARAM_REF) - return Value(raw=ctx.getText(), kind=ValueKind.IDENT) + return ParamRef(raw=ctx.paramRef().IDENT().getText()) + return IdentValue(raw=ctx.getText()) def _visit_conditional(self, ctx): if_stmt = ctx.ifStatement() @@ -242,7 +245,7 @@ def _visit_comparison(self, ctx): line = ctx.start.line comp = ctx.comparable() if comp.ident is not None: - left = Value(raw=comp.ident.text, kind=ValueKind.IDENT) + left = IdentValue(raw=comp.ident.text) else: left = self._visit_function_call(comp.functionCall()) @@ -277,7 +280,7 @@ def _detect_comparison_operator(self, ctx): def _extract_comparison_rhs(self, ctx, operator): if operator in ("~", "!~"): - return Value(raw=ctx.regex().getText()[1:-1], kind=ValueKind.REGEX) + return RegexValue(raw=ctx.regex().getText()[1:-1]) if operator in ("in", "!in"): if ctx.set_(): return tuple( @@ -285,7 +288,7 @@ def _extract_comparison_rhs(self, ctx, operator): ) if ctx.iprange(): return tuple( - Value(raw=ip.getText(), kind=ValueKind.IP) for ip in ctx.iprange().ip() + IPValue(raw=ip.getText()) for ip in ctx.iprange().ip() ) if ctx.value(): return self._extract_value(ctx.value()) diff --git a/tools/hrw4u/tests/test_ast_visitor.py b/tools/hrw4u/tests/test_ast_visitor.py index 6d3cf574b05..408801b4f21 100644 --- a/tools/hrw4u/tests/test_ast_visitor.py +++ b/tools/hrw4u/tests/test_ast_visitor.py @@ -19,7 +19,7 @@ Target, Assignment, FunctionCall, Break, Section, HRW4UAST, Comparison, IfBlock, ElifBranch, BoolLiteral, NotOp, LogicalOp, IdentCondition, VarSection, VarDecl, UseDirective, ProcedureDecl, ProcParam, - Value, ValueKind, + LiteralStringValue, IdentValue, IPValue, ParamRef, RegexValue, ) from utils import parse_input_text from hrw4u.ast_visitor import ASTVisitor @@ -37,7 +37,7 @@ def test_simple_assignment(self): assert isinstance(a, Assignment) assert a.target == Target.from_dotted("inbound.req.X-Foo") assert a.operator == "=" - assert a.value == Value(raw="test", kind=ValueKind.STRING) + assert a.value == LiteralStringValue(raw="test") def test_bool_value(self): ast = _build('SEND_RESPONSE {\n http.cntl.TXN_DEBUG = true;\n}') @@ -59,14 +59,14 @@ def test_ip_value(self): ast = _build('REMAP {\n inbound.req.X-IP = 10.0.0.1;\n}') a = ast.body[0].body[0] assert isinstance(a, Assignment) - assert a.value == Value(raw="10.0.0.1", kind=ValueKind.IP) + assert a.value == IPValue(raw="10.0.0.1") def test_param_ref_value(self): src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = $tag;\n}\nREMAP {\n set-debug();\n}' ast = _build(src) a = ast.body[0].body[0] assert isinstance(a, Assignment) - assert a.value == Value(raw="tag", kind=ValueKind.PARAM_REF) + assert a.value == ParamRef(raw="tag") class TestFunctionCalls: @@ -81,7 +81,7 @@ def test_with_args(self): ast = _build('REMAP {\n set-header("X-Foo", "bar");\n}') fc = ast.body[0].body[0] assert fc.name == "set-header" - assert fc.args == (Value(raw="X-Foo", kind=ValueKind.STRING), Value(raw="bar", kind=ValueKind.STRING)) + assert fc.args == (LiteralStringValue(raw="X-Foo"), LiteralStringValue(raw="bar")) def test_standalone_operator(self): ast = _build('REMAP {\n skip-remap;\n}') @@ -206,9 +206,9 @@ def test_equality_comparison(self): 'REMAP {\n if inbound.req.X-Foo == "bar" {\n set-debug();\n }\n}' ) assert isinstance(cond, Comparison) - assert cond.left == Value(raw="inbound.req.X-Foo", kind=ValueKind.IDENT) + assert cond.left == IdentValue(raw="inbound.req.X-Foo") assert cond.operator == "==" - assert cond.right == Value(raw="bar", kind=ValueKind.STRING) + assert cond.right == LiteralStringValue(raw="bar") assert cond.modifiers == () def test_regex_comparison(self): @@ -217,8 +217,7 @@ def test_regex_comparison(self): ) assert isinstance(cond, Comparison) assert cond.operator == "~" - assert isinstance(cond.right, Value) - assert cond.right.kind == ValueKind.REGEX + assert isinstance(cond.right, RegexValue) def test_in_set(self): cond = self._first_condition( @@ -226,7 +225,7 @@ def test_in_set(self): ) assert isinstance(cond, Comparison) assert cond.operator == "in" - assert cond.right == (Value(raw="a", kind=ValueKind.STRING), Value(raw="b", kind=ValueKind.STRING)) + assert cond.right == (LiteralStringValue(raw="a"), LiteralStringValue(raw="b")) def test_not_in_set(self): cond = self._first_condition( @@ -241,7 +240,7 @@ def test_in_iprange(self): ) assert isinstance(cond, Comparison) assert cond.operator == "in" - assert cond.right == (Value(raw="10.0.0.0/8", kind=ValueKind.IP),) + assert cond.right == (IPValue(raw="10.0.0.0/8"),) def test_modifiers(self): cond = self._first_condition( @@ -302,7 +301,7 @@ def test_function_call_in_condition(self): ) assert isinstance(cond, FunctionCall) assert cond.name == "access" - assert cond.args == (Value(raw="/tmp/bar", kind=ValueKind.STRING),) + assert cond.args == (LiteralStringValue(raw="/tmp/bar"),) def test_not_tilde_comparison(self): cond = self._first_condition( @@ -310,8 +309,7 @@ def test_not_tilde_comparison(self): ) assert isinstance(cond, Comparison) assert cond.operator == "!~" - assert isinstance(cond.right, Value) - assert cond.right.kind == ValueKind.REGEX + assert isinstance(cond.right, RegexValue) def test_greater_than_comparison(self): cond = self._first_condition( @@ -335,7 +333,7 @@ def test_neq_comparison(self): ) assert isinstance(cond, Comparison) assert cond.operator == "!=" - assert cond.right == Value(raw="bar", kind=ValueKind.STRING) + assert cond.right == LiteralStringValue(raw="bar") def test_parenthesized_condition(self): cond = self._first_condition( @@ -343,7 +341,7 @@ def test_parenthesized_condition(self): ) assert isinstance(cond, Comparison) assert cond.operator == "==" - assert cond.right == Value(raw="bar", kind=ValueKind.STRING) + assert cond.right == LiteralStringValue(raw="bar") def test_and_binds_tighter_than_or(self): # a || b && c should parse as a || (b && c) @@ -355,11 +353,11 @@ def test_and_binds_tighter_than_or(self): assert isinstance(cond, LogicalOp) assert cond.operator == "||" assert isinstance(cond.left, Comparison) - assert cond.left.left == Value(raw="inbound.req.X-A", kind=ValueKind.IDENT) + assert cond.left.left == IdentValue(raw="inbound.req.X-A") assert isinstance(cond.right, LogicalOp) assert cond.right.operator == "&&" - assert cond.right.left.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) - assert cond.right.right.left == Value(raw="inbound.req.X-C", kind=ValueKind.IDENT) + assert cond.right.left.left == IdentValue(raw="inbound.req.X-B") + assert cond.right.right.left == IdentValue(raw="inbound.req.X-C") def test_not_with_and(self): # !ident && comparison should parse as (!ident) && comparison @@ -374,7 +372,7 @@ def test_not_with_and(self): assert isinstance(cond.left.operand, IdentCondition) assert cond.left.operand.name == "inbound.resp.All-Cache" assert isinstance(cond.right, Comparison) - assert cond.right.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) + assert cond.right.left == IdentValue(raw="inbound.req.X-B") def test_not_comparison_with_or(self): # !(a == "x") || b == "y" should parse as (!(a == "x")) || (b == "y") @@ -387,10 +385,10 @@ def test_not_comparison_with_or(self): assert cond.operator == "||" assert isinstance(cond.left, NotOp) assert isinstance(cond.left.operand, Comparison) - assert cond.left.operand.left == Value(raw="inbound.req.X-A", kind=ValueKind.IDENT) - assert cond.left.operand.right == Value(raw="x", kind=ValueKind.STRING) + assert cond.left.operand.left == IdentValue(raw="inbound.req.X-A") + assert cond.left.operand.right == LiteralStringValue(raw="x") assert isinstance(cond.right, Comparison) - assert cond.right.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) + assert cond.right.left == IdentValue(raw="inbound.req.X-B") def test_double_negation(self): cond = self._first_condition( @@ -420,10 +418,10 @@ def test_parens_override_precedence(self): assert cond.operator == "&&" assert isinstance(cond.left, LogicalOp) assert cond.left.operator == "||" - assert cond.left.left.left == Value(raw="inbound.req.X-A", kind=ValueKind.IDENT) - assert cond.left.right.left == Value(raw="inbound.req.X-B", kind=ValueKind.IDENT) + assert cond.left.left.left == IdentValue(raw="inbound.req.X-A") + assert cond.left.right.left == IdentValue(raw="inbound.req.X-B") assert isinstance(cond.right, Comparison) - assert cond.right.left == Value(raw="inbound.req.X-C", kind=ValueKind.IDENT) + assert cond.right.left == IdentValue(raw="inbound.req.X-C") def test_nested_parens_with_not(self): # !(a == "x" || b == "y") && c == "z" @@ -438,7 +436,7 @@ def test_nested_parens_with_not(self): assert isinstance(cond.left.operand, LogicalOp) assert cond.left.operand.operator == "||" assert isinstance(cond.right, Comparison) - assert cond.right.left == Value(raw="inbound.req.X-C", kind=ValueKind.IDENT) + assert cond.right.left == IdentValue(raw="inbound.req.X-C") class TestIfBlocks: @@ -731,7 +729,7 @@ def test_set_membership_with_modifier(self): cond = ast.body[0].body[0].condition assert isinstance(cond, Comparison) assert cond.operator == "in" - assert cond.right == (Value(raw="php", kind=ValueKind.STRING), Value(raw="php3", kind=ValueKind.STRING), Value(raw="php4", kind=ValueKind.STRING)) + assert cond.right == (LiteralStringValue(raw="php"), LiteralStringValue(raw="php3"), LiteralStringValue(raw="php4")) assert cond.modifiers == ("EXT",) def test_debug_pattern_for_lint_rules(self): From 8084e00c63f5d9d99f053f88b7bd8e5997bca346 Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Mon, 4 May 2026 14:39:22 -0600 Subject: [PATCH 07/13] Raise on unhandled grammar alternatives in AST visitor Add explicit commentLine handling and ValueError raises in _visit_var_section, _visit_body, and _extract_value so new grammar alternatives fail fast instead of being silently dropped or misclassified. --- tools/hrw4u/src/ast_visitor.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py index 3e8619001e4..b567836f564 100644 --- a/tools/hrw4u/src/ast_visitor.py +++ b/tools/hrw4u/src/ast_visitor.py @@ -103,6 +103,10 @@ def _visit_var_section(self, ctx, scope): for var_item in ctx.variables().variablesItem(): if var_item.variableDecl() is not None: decls.append(self._visit_var_decl(var_item.variableDecl())) + elif var_item.commentLine() is not None: + pass + else: + raise ValueError(f"Unhandled variablesItem alternative at line {var_item.start.line}") return VarSection(scope=scope, declarations=tuple(decls), line=ctx.start.line) def _visit_var_decl(self, ctx): @@ -121,6 +125,10 @@ def _visit_body(self, items): result.append(self._visit_statement(item.statement())) elif item.conditional() is not None: result.append(self._visit_conditional(item.conditional())) + elif item.commentLine() is not None: + pass + else: + raise ValueError(f"Unhandled body item alternative at line {item.start.line}") return result def _visit_statement(self, ctx): @@ -167,7 +175,7 @@ def _extract_value(self, ctx): return tuple(IPValue(raw=ip.getText()) for ip in ctx.iprange().ip()) if ctx.paramRef(): return ParamRef(raw=ctx.paramRef().IDENT().getText()) - return IdentValue(raw=ctx.getText()) + raise ValueError(f"Unhandled value alternative at line {ctx.start.line}") def _visit_conditional(self, ctx): if_stmt = ctx.ifStatement() From 6a4d45b0bc6c68ce004eb2034521fa6234cb1da2 Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Mon, 4 May 2026 15:15:56 -0600 Subject: [PATCH 08/13] Format code --- tools/hrw4u/src/ast_nodes.py | 6 +- tools/hrw4u/src/ast_visitor.py | 37 +++-- tools/hrw4u/tests/test_ast_nodes.py | 1 + tools/hrw4u/tests/test_ast_visitor.py | 222 ++++++++++++-------------- 4 files changed, 126 insertions(+), 140 deletions(-) diff --git a/tools/hrw4u/src/ast_nodes.py b/tools/hrw4u/src/ast_nodes.py index 846265f001a..09fd5f1a964 100644 --- a/tools/hrw4u/src/ast_nodes.py +++ b/tools/hrw4u/src/ast_nodes.py @@ -73,7 +73,7 @@ def from_dotted(name: str) -> Target: @dataclass(frozen=True, kw_only=True) class Assignment(Node): target: Target - operator: str # "=" or "+=" + operator: str # "=" or "+=" value: ValueExpr @@ -91,14 +91,14 @@ class Break(Node): @dataclass(frozen=True, kw_only=True) class Comparison(Node): left: IdentValue | FunctionCall - operator: str # "==", "!=", ">", "<", "~", "!~", "in", "!in" + operator: str # "==", "!=", ">", "<", "~", "!~", "in", "!in" right: ValueExpr | RegexValue | tuple[ValueExpr, ...] modifiers: tuple[str, ...] @dataclass(frozen=True, kw_only=True) class LogicalOp(Node): - operator: str # "&&" or "||" + operator: str # "&&" or "||" left: ConditionExpr right: ConditionExpr diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py index b567836f564..784b92f3c08 100644 --- a/tools/hrw4u/src/ast_visitor.py +++ b/tools/hrw4u/src/ast_visitor.py @@ -78,9 +78,7 @@ def _visit_procedure_decl(self, ctx): name = ctx.QUALIFIED_IDENT().getText() params = () if ctx.paramList(): - params = tuple( - self._visit_proc_param(p) for p in ctx.paramList().param() - ) + params = tuple(self._visit_proc_param(p) for p in ctx.paramList().param()) body = tuple(self._visit_body(ctx.block().blockItem())) return ProcedureDecl(name=name, params=params, body=body, line=ctx.start.line) @@ -153,9 +151,7 @@ def _visit_function_call(self, ctx): name = ctx.funcName.text args = () if ctx.argumentList(): - args = tuple( - self._extract_value(v) for v in ctx.argumentList().value() - ) + args = tuple(self._extract_value(v) for v in ctx.argumentList().value()) return FunctionCall(name=name, args=args, line=ctx.start.line) def _extract_value(self, ctx): @@ -216,7 +212,10 @@ def _visit_expression(self, ctx): left = self._visit_expression(ctx.expression()) right = self._visit_term(ctx.term()) return LogicalOp( - operator="||", left=left, right=right, line=ctx.start.line, + operator="||", + left=left, + right=right, + line=ctx.start.line, ) return self._visit_term(ctx.term()) @@ -225,7 +224,10 @@ def _visit_term(self, ctx): left = self._visit_term(ctx.term()) right = self._visit_factor(ctx.factor()) return LogicalOp( - operator="&&", left=left, right=right, line=ctx.start.line, + operator="&&", + left=left, + right=right, + line=ctx.start.line, ) return self._visit_factor(ctx.factor()) @@ -262,8 +264,11 @@ def _visit_comparison(self, ctx): modifiers = self._extract_modifiers(ctx) return Comparison( - left=left, operator=operator, right=right, - modifiers=modifiers, line=line, + left=left, + operator=operator, + right=right, + modifiers=modifiers, + line=line, ) def _detect_comparison_operator(self, ctx): @@ -291,20 +296,14 @@ def _extract_comparison_rhs(self, ctx, operator): return RegexValue(raw=ctx.regex().getText()[1:-1]) if operator in ("in", "!in"): if ctx.set_(): - return tuple( - self._extract_value(v) for v in ctx.set_().value() - ) + return tuple(self._extract_value(v) for v in ctx.set_().value()) if ctx.iprange(): - return tuple( - IPValue(raw=ip.getText()) for ip in ctx.iprange().ip() - ) + return tuple(IPValue(raw=ip.getText()) for ip in ctx.iprange().ip()) if ctx.value(): return self._extract_value(ctx.value()) raise ValueError(f"Unhandled comparison RHS at line {ctx.start.line}") def _extract_modifiers(self, ctx): if ctx.modifier(): - return tuple( - tok.text for tok in ctx.modifier().modifierList().mods - ) + return tuple(tok.text for tok in ctx.modifier().modifierList().mods) return () diff --git a/tools/hrw4u/tests/test_ast_nodes.py b/tools/hrw4u/tests/test_ast_nodes.py index 1c2320c4e53..d76d4a89b26 100644 --- a/tools/hrw4u/tests/test_ast_nodes.py +++ b/tools/hrw4u/tests/test_ast_nodes.py @@ -19,6 +19,7 @@ class TestTarget: + def test_dotted_path(self): t = Target.from_dotted("inbound.req.X-Foo") assert t.namespace == "inbound.req" diff --git a/tools/hrw4u/tests/test_ast_visitor.py b/tools/hrw4u/tests/test_ast_visitor.py index 408801b4f21..0bc7d648b20 100644 --- a/tools/hrw4u/tests/test_ast_visitor.py +++ b/tools/hrw4u/tests/test_ast_visitor.py @@ -16,10 +16,29 @@ # limitations under the License. from hrw4u.ast_nodes import ( - Target, Assignment, FunctionCall, Break, Section, HRW4UAST, - Comparison, IfBlock, ElifBranch, BoolLiteral, NotOp, LogicalOp, IdentCondition, - VarSection, VarDecl, UseDirective, ProcedureDecl, ProcParam, - LiteralStringValue, IdentValue, IPValue, ParamRef, RegexValue, + Target, + Assignment, + FunctionCall, + Break, + Section, + HRW4UAST, + Comparison, + IfBlock, + ElifBranch, + BoolLiteral, + NotOp, + LogicalOp, + IdentCondition, + VarSection, + VarDecl, + UseDirective, + ProcedureDecl, + ProcParam, + LiteralStringValue, + IdentValue, + IPValue, + ParamRef, + RegexValue, ) from utils import parse_input_text from hrw4u.ast_visitor import ASTVisitor @@ -31,6 +50,7 @@ def _build(source: str) -> HRW4UAST: class TestAssignments: + def test_simple_assignment(self): ast = _build('REMAP {\n inbound.req.X-Foo = "test";\n}') a = ast.body[0].body[0] @@ -70,6 +90,7 @@ def test_param_ref_value(self): class TestFunctionCalls: + def test_no_args(self): ast = _build('REMAP {\n set-debug();\n}') fc = ast.body[0].body[0] @@ -97,6 +118,7 @@ def test_break(self): class TestSections: + def test_section_type(self): ast = _build('REMAP {\n set-debug();\n}') s = ast.body[0] @@ -129,6 +151,7 @@ def test_item_ordering(self): class TestVarSections: + def test_txn_scope(self): src = 'VARS {\n flag: bool;\n}\nREMAP {\n set-debug();\n}' ast = _build(src) @@ -167,6 +190,7 @@ def test_multiple_declarations(self): class TestProcedures: + def test_basic_decl(self): src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = "$tag";\n}\nREMAP {\n set-debug();\n}' ast = _build(src) @@ -197,14 +221,13 @@ def test_body(self): class TestConditionExpressions: + def _first_condition(self, source: str): ast = _build(source) return ast.body[0].body[0].condition def test_equality_comparison(self): - cond = self._first_condition( - 'REMAP {\n if inbound.req.X-Foo == "bar" {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.req.X-Foo == "bar" {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.left == IdentValue(raw="inbound.req.X-Foo") assert cond.operator == "==" @@ -212,77 +235,58 @@ def test_equality_comparison(self): assert cond.modifiers == () def test_regex_comparison(self): - cond = self._first_condition( - 'REMAP {\n if inbound.url.path ~ /\\.php$/ {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.url.path ~ /\\.php$/ {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.operator == "~" assert isinstance(cond.right, RegexValue) def test_in_set(self): - cond = self._first_condition( - 'REMAP {\n if inbound.url.path in ["a", "b"] {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.url.path in ["a", "b"] {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.operator == "in" assert cond.right == (LiteralStringValue(raw="a"), LiteralStringValue(raw="b")) def test_not_in_set(self): - cond = self._first_condition( - 'REMAP {\n if inbound.url.path !in ["a"] {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.url.path !in ["a"] {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.operator == "!in" def test_in_iprange(self): - cond = self._first_condition( - 'REMAP {\n if inbound.ip in {10.0.0.0/8} {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.ip in {10.0.0.0/8} {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.operator == "in" assert cond.right == (IPValue(raw="10.0.0.0/8"),) def test_modifiers(self): - cond = self._first_condition( - 'REMAP {\n if inbound.req.X-Foo == "bar" with NOCASE {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.req.X-Foo == "bar" with NOCASE {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.modifiers == ("NOCASE",) def test_function_call_comparable(self): - cond = self._first_condition( - 'REMAP {\n if url(true) ~ /pat/ {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if url(true) ~ /pat/ {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert isinstance(cond.left, FunctionCall) assert cond.left.name == "url" assert cond.left.args == (True,) def test_bool_literal_true(self): - cond = self._first_condition( - 'REMAP {\n if true {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if true {\n set-debug();\n }\n}') assert isinstance(cond, BoolLiteral) assert cond.value is True def test_ident_condition(self): - cond = self._first_condition( - 'REMAP {\n if inbound.resp.All-Cache {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.resp.All-Cache {\n set-debug();\n }\n}') assert isinstance(cond, IdentCondition) assert cond.name == "inbound.resp.All-Cache" def test_not_condition(self): - cond = self._first_condition( - 'REMAP {\n if !inbound.resp.All-Cache {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if !inbound.resp.All-Cache {\n set-debug();\n }\n}') assert isinstance(cond, NotOp) assert isinstance(cond.operand, IdentCondition) def test_and_condition(self): cond = self._first_condition( - 'REMAP {\n if inbound.req.X-A == "a" && inbound.req.X-B == "b" {\n set-debug();\n }\n}' - ) + 'REMAP {\n if inbound.req.X-A == "a" && inbound.req.X-B == "b" {\n set-debug();\n }\n}') assert isinstance(cond, LogicalOp) assert cond.operator == "&&" assert isinstance(cond.left, Comparison) @@ -290,55 +294,42 @@ def test_and_condition(self): def test_or_condition(self): cond = self._first_condition( - 'REMAP {\n if inbound.req.X-A == "a" || inbound.req.X-B == "b" {\n set-debug();\n }\n}' - ) + 'REMAP {\n if inbound.req.X-A == "a" || inbound.req.X-B == "b" {\n set-debug();\n }\n}') assert isinstance(cond, LogicalOp) assert cond.operator == "||" def test_function_call_in_condition(self): - cond = self._first_condition( - 'REMAP {\n if access("/tmp/bar") {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if access("/tmp/bar") {\n set-debug();\n }\n}') assert isinstance(cond, FunctionCall) assert cond.name == "access" assert cond.args == (LiteralStringValue(raw="/tmp/bar"),) def test_not_tilde_comparison(self): - cond = self._first_condition( - 'REMAP {\n if inbound.url.path !~ /\\.jpg$/ {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.url.path !~ /\\.jpg$/ {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.operator == "!~" assert isinstance(cond.right, RegexValue) def test_greater_than_comparison(self): - cond = self._first_condition( - 'REMAP {\n if inbound.req.Content-Length > 1000 {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.req.Content-Length > 1000 {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.operator == ">" assert cond.right == 1000 def test_less_than_comparison(self): - cond = self._first_condition( - 'REMAP {\n if inbound.req.Content-Length < 500 {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.req.Content-Length < 500 {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.operator == "<" assert cond.right == 500 def test_neq_comparison(self): - cond = self._first_condition( - 'REMAP {\n if inbound.req.X-Foo != "bar" {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if inbound.req.X-Foo != "bar" {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.operator == "!=" assert cond.right == LiteralStringValue(raw="bar") def test_parenthesized_condition(self): - cond = self._first_condition( - 'REMAP {\n if (inbound.req.X-Foo == "bar") {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if (inbound.req.X-Foo == "bar") {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) assert cond.operator == "==" assert cond.right == LiteralStringValue(raw="bar") @@ -348,8 +339,7 @@ def test_and_binds_tighter_than_or(self): cond = self._first_condition( 'REMAP {\n' ' if inbound.req.X-A == "a" || inbound.req.X-B == "b" && inbound.req.X-C == "c" {\n' - ' set-debug();\n }\n}' - ) + ' set-debug();\n }\n}') assert isinstance(cond, LogicalOp) assert cond.operator == "||" assert isinstance(cond.left, Comparison) @@ -364,8 +354,7 @@ def test_not_with_and(self): cond = self._first_condition( 'REMAP {\n' ' if !inbound.resp.All-Cache && inbound.req.X-B == "b" {\n' - ' set-debug();\n }\n}' - ) + ' set-debug();\n }\n}') assert isinstance(cond, LogicalOp) assert cond.operator == "&&" assert isinstance(cond.left, NotOp) @@ -379,8 +368,7 @@ def test_not_comparison_with_or(self): cond = self._first_condition( 'REMAP {\n' ' if !(inbound.req.X-A == "x") || inbound.req.X-B == "y" {\n' - ' set-debug();\n }\n}' - ) + ' set-debug();\n }\n}') assert isinstance(cond, LogicalOp) assert cond.operator == "||" assert isinstance(cond.left, NotOp) @@ -391,18 +379,14 @@ def test_not_comparison_with_or(self): assert cond.right.left == IdentValue(raw="inbound.req.X-B") def test_double_negation(self): - cond = self._first_condition( - 'REMAP {\n if !!inbound.resp.All-Cache {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if !!inbound.resp.All-Cache {\n set-debug();\n }\n}') assert isinstance(cond, NotOp) assert isinstance(cond.operand, NotOp) assert isinstance(cond.operand.operand, IdentCondition) assert cond.operand.operand.name == "inbound.resp.All-Cache" def test_not_bool_literal(self): - cond = self._first_condition( - 'REMAP {\n if !false {\n set-debug();\n }\n}' - ) + cond = self._first_condition('REMAP {\n if !false {\n set-debug();\n }\n}') assert isinstance(cond, NotOp) assert isinstance(cond.operand, BoolLiteral) assert cond.operand.value is False @@ -412,8 +396,7 @@ def test_parens_override_precedence(self): cond = self._first_condition( 'REMAP {\n' ' if (inbound.req.X-A == "a" || inbound.req.X-B == "b") && inbound.req.X-C == "c" {\n' - ' set-debug();\n }\n}' - ) + ' set-debug();\n }\n}') assert isinstance(cond, LogicalOp) assert cond.operator == "&&" assert isinstance(cond.left, LogicalOp) @@ -428,8 +411,7 @@ def test_nested_parens_with_not(self): cond = self._first_condition( 'REMAP {\n' ' if !(inbound.req.X-A == "x" || inbound.req.X-B == "y") && inbound.req.X-C == "z" {\n' - ' set-debug();\n }\n}' - ) + ' set-debug();\n }\n}') assert isinstance(cond, LogicalOp) assert cond.operator == "&&" assert isinstance(cond.left, NotOp) @@ -440,10 +422,9 @@ def test_nested_parens_with_not(self): class TestIfBlocks: + def test_simple_if(self): - ast = _build( - 'REMAP {\n if true {\n inbound.req.X = "y";\n }\n}' - ) + ast = _build('REMAP {\n if true {\n inbound.req.X = "y";\n }\n}') ib = ast.body[0].body[0] assert isinstance(ib, IfBlock) assert len(ib.body) == 1 @@ -457,10 +438,11 @@ def test_if_else(self): assert len(ib.else_body) == 1 def test_if_elif_else(self): - src = ('SEND_RESPONSE {\n if inbound.url.path == "foo" {\n' - ' inbound.resp.X = "f";\n } elif inbound.url.path == "bar" {\n' - ' inbound.resp.X = "b";\n } else {\n' - ' inbound.resp.X = "other";\n }\n}') + src = ( + 'SEND_RESPONSE {\n if inbound.url.path == "foo" {\n' + ' inbound.resp.X = "f";\n } elif inbound.url.path == "bar" {\n' + ' inbound.resp.X = "b";\n } else {\n' + ' inbound.resp.X = "other";\n }\n}') ast = _build(src) ib = ast.body[0].body[0] assert isinstance(ib, IfBlock) @@ -470,17 +452,19 @@ def test_if_elif_else(self): assert len(ib.else_body) == 1 def test_multiple_elif(self): - src = ('SEND_RESPONSE {\n if inbound.url.path == "a" {\n set-debug();\n' - ' } elif inbound.url.path == "b" {\n set-debug();\n' - ' } elif inbound.url.path == "c" {\n set-debug();\n' - ' } else {\n set-debug();\n }\n}') + src = ( + 'SEND_RESPONSE {\n if inbound.url.path == "a" {\n set-debug();\n' + ' } elif inbound.url.path == "b" {\n set-debug();\n' + ' } elif inbound.url.path == "c" {\n set-debug();\n' + ' } else {\n set-debug();\n }\n}') ast = _build(src) ib = ast.body[0].body[0] assert len(ib.elif_branches) == 2 def test_nested_if(self): - src = ('REMAP {\n if inbound.req.X == "a" {\n' - ' if inbound.req.Y == "b" {\n set-debug();\n }\n }\n}') + src = ( + 'REMAP {\n if inbound.req.X == "a" {\n' + ' if inbound.req.Y == "b" {\n set-debug();\n }\n }\n}') ast = _build(src) outer = ast.body[0].body[0] assert isinstance(outer, IfBlock) @@ -488,9 +472,10 @@ def test_nested_if(self): assert isinstance(inner, IfBlock) def test_mixed_body(self): - src = ('REMAP {\n inbound.req.X = "before";\n' - ' if true {\n set-debug();\n }\n' - ' inbound.req.Y = "after";\n}') + src = ( + 'REMAP {\n inbound.req.X = "before";\n' + ' if true {\n set-debug();\n }\n' + ' inbound.req.Y = "after";\n}') ast = _build(src) body = ast.body[0].body assert len(body) == 3 @@ -501,37 +486,37 @@ def test_mixed_body(self): class TestLineNumbers: SRC = ( - "use test::helper\n" # line 1 - "VARS {\n" # line 2 - " flag: bool;\n" # line 3 - "}\n" # line 4 - "procedure local::stamp($tag) {\n" # line 5 - " inbound.req.X-Stamp = $tag;\n" # line 6 - "}\n" # line 7 - "REMAP {\n" # line 8 - ' inbound.req.X-Foo = "val";\n' # line 9 - " set-debug();\n" # line 10 - " skip-remap;\n" # line 11 - ' if inbound.req.X-A == "a" {\n' # line 12 - " break;\n" # line 13 - ' } elif inbound.req.X-B == "b" {\n' # line 14 - ' inbound.req.X = "elif";\n' # line 15 - " } else {\n" # line 16 - ' inbound.req.X = "else";\n' # line 17 - " }\n" # line 18 + "use test::helper\n" # line 1 + "VARS {\n" # line 2 + " flag: bool;\n" # line 3 + "}\n" # line 4 + "procedure local::stamp($tag) {\n" # line 5 + " inbound.req.X-Stamp = $tag;\n" # line 6 + "}\n" # line 7 + "REMAP {\n" # line 8 + ' inbound.req.X-Foo = "val";\n' # line 9 + " set-debug();\n" # line 10 + " skip-remap;\n" # line 11 + ' if inbound.req.X-A == "a" {\n' # line 12 + " break;\n" # line 13 + ' } elif inbound.req.X-B == "b" {\n' # line 14 + ' inbound.req.X = "elif";\n' # line 15 + " } else {\n" # line 16 + ' inbound.req.X = "else";\n' # line 17 + " }\n" # line 18 ' if inbound.req.X-C == "c" && inbound.req.X-D == "d" {\n' # line 19 - " set-debug();\n" # line 20 - " }\n" # line 21 - " if !inbound.resp.All-Cache {\n" # line 22 - " set-debug();\n" # line 23 - " }\n" # line 24 - " if true {\n" # line 25 - " set-debug();\n" # line 26 - " }\n" # line 27 - " if inbound.resp.All-Cache {\n" # line 28 - " set-debug();\n" # line 29 - " }\n" # line 30 - "}\n" # line 31 + " set-debug();\n" # line 20 + " }\n" # line 21 + " if !inbound.resp.All-Cache {\n" # line 22 + " set-debug();\n" # line 23 + " }\n" # line 24 + " if true {\n" # line 25 + " set-debug();\n" # line 26 + " }\n" # line 27 + " if inbound.resp.All-Cache {\n" # line 28 + " set-debug();\n" # line 29 + " }\n" # line 30 + "}\n" # line 31 ) def setup_method(self): @@ -634,6 +619,7 @@ def test_ident_condition(self): class TestRealConfigs: + def test_nested_ifs_from_test_data(self): """Validates AST for tests/data/conds/nested-ifs.input.txt pattern.""" src = '''VARS { From 8ad960a24dcee7eee19cac52f10f4c4ccd0dc3c0 Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Mon, 4 May 2026 15:58:08 -0600 Subject: [PATCH 09/13] Add tests for comment skipping and modifier casing Cover comment handling in section bodies, blocks, and var sections. Add test confirming modifiers preserve source casing rather than normalizing. --- tools/hrw4u/tests/test_ast_visitor.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tools/hrw4u/tests/test_ast_visitor.py b/tools/hrw4u/tests/test_ast_visitor.py index 0bc7d648b20..d9da440e75c 100644 --- a/tools/hrw4u/tests/test_ast_visitor.py +++ b/tools/hrw4u/tests/test_ast_visitor.py @@ -119,6 +119,16 @@ def test_break(self): class TestSections: + def test_comments_in_section_body_skipped(self): + src = 'REMAP {\n # a comment\n set-debug();\n # another comment\n}' + ast = _build(src) + assert len(ast.body[0].body) == 1 + + def test_comments_in_block_skipped(self): + src = 'REMAP {\n if true {\n # comment\n set-debug();\n }\n}' + ast = _build(src) + assert len(ast.body[0].body[0].body) == 1 + def test_section_type(self): ast = _build('REMAP {\n set-debug();\n}') s = ast.body[0] @@ -152,6 +162,13 @@ def test_item_ordering(self): class TestVarSections: + def test_comments_in_var_section_skipped(self): + src = 'VARS {\n # comment\n x: bool;\n # another\n y: int;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert len(vs.declarations) == 2 + def test_txn_scope(self): src = 'VARS {\n flag: bool;\n}\nREMAP {\n set-debug();\n}' ast = _build(src) @@ -262,6 +279,11 @@ def test_modifiers(self): assert isinstance(cond, Comparison) assert cond.modifiers == ("NOCASE",) + def test_modifiers_preserve_source_casing(self): + cond = self._first_condition('REMAP {\n if inbound.req.X-Foo == "bar" with nocase,Pre {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.modifiers == ("nocase", "Pre") + def test_function_call_comparable(self): cond = self._first_condition('REMAP {\n if url(true) ~ /pat/ {\n set-debug();\n }\n}') assert isinstance(cond, Comparison) From 59b1e5db2236cad356a1e61e63d00af35292f5d0 Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Tue, 5 May 2026 13:54:26 -0600 Subject: [PATCH 10/13] Add __all__ to ast_nodes.py for safe wildcard imports Defines the public API surface explicitly, enabling shorter `from hrw4u.ast_nodes import *` in downstream modules. --- tools/hrw4u/src/ast_nodes.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tools/hrw4u/src/ast_nodes.py b/tools/hrw4u/src/ast_nodes.py index 09fd5f1a964..acf5bacccb3 100644 --- a/tools/hrw4u/src/ast_nodes.py +++ b/tools/hrw4u/src/ast_nodes.py @@ -20,6 +20,37 @@ from dataclasses import dataclass from typing import Union +__all__ = [ + "LiteralStringValue", + "IdentValue", + "IPValue", + "ParamRef", + "RegexValue", + "ValueExpr", + "Node", + "Target", + "Assignment", + "FunctionCall", + "Break", + "Comparison", + "LogicalOp", + "NotOp", + "BoolLiteral", + "IdentCondition", + "ElifBranch", + "IfBlock", + "Section", + "ProcParam", + "VarDecl", + "VarSection", + "UseDirective", + "ProcedureDecl", + "HRW4UAST", + "ConditionExpr", + "BodyNode", + "TopLevelNode", +] + @dataclass(frozen=True, kw_only=True) class LiteralStringValue: From 8f5a1bb8d00fe59d1abe1ab7e9112c682ead9368 Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Tue, 5 May 2026 14:02:02 -0600 Subject: [PATCH 11/13] Use wildcard imports from ast_nodes Now that __all__ is defined, switch ast_visitor.py and test_ast_visitor.py to `from hrw4u.ast_nodes import *`. --- tools/hrw4u/src/ast_visitor.py | 26 +------------------------- tools/hrw4u/tests/test_ast_visitor.py | 26 +------------------------- 2 files changed, 2 insertions(+), 50 deletions(-) diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py index 784b92f3c08..0aebbba28ad 100644 --- a/tools/hrw4u/src/ast_visitor.py +++ b/tools/hrw4u/src/ast_visitor.py @@ -18,31 +18,7 @@ from __future__ import annotations from hrw4u.hrw4uVisitor import hrw4uVisitor -from hrw4u.ast_nodes import ( - HRW4UAST, - Section, - Assignment, - FunctionCall, - Break, - Target, - IfBlock, - ElifBranch, - BoolLiteral, - Comparison, - LogicalOp, - NotOp, - IdentCondition, - ProcParam, - VarDecl, - VarSection, - UseDirective, - ProcedureDecl, - LiteralStringValue, - IdentValue, - IPValue, - ParamRef, - RegexValue, -) +from hrw4u.ast_nodes import * class ASTVisitor(hrw4uVisitor): diff --git a/tools/hrw4u/tests/test_ast_visitor.py b/tools/hrw4u/tests/test_ast_visitor.py index d9da440e75c..ec919d1f060 100644 --- a/tools/hrw4u/tests/test_ast_visitor.py +++ b/tools/hrw4u/tests/test_ast_visitor.py @@ -15,31 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hrw4u.ast_nodes import ( - Target, - Assignment, - FunctionCall, - Break, - Section, - HRW4UAST, - Comparison, - IfBlock, - ElifBranch, - BoolLiteral, - NotOp, - LogicalOp, - IdentCondition, - VarSection, - VarDecl, - UseDirective, - ProcedureDecl, - ProcParam, - LiteralStringValue, - IdentValue, - IPValue, - ParamRef, - RegexValue, -) +from hrw4u.ast_nodes import * from utils import parse_input_text from hrw4u.ast_visitor import ASTVisitor From 639f477c40b4f07bcb3b7c9689c996bdd2f10d4e Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Tue, 5 May 2026 14:04:16 -0600 Subject: [PATCH 12/13] Add return type annotations to ASTVisitor methods Matches the convention used in visitor.py where all visit/helper methods declare their return types. --- tools/hrw4u/src/ast_visitor.py | 40 +++++++++++++++++----------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py index 0aebbba28ad..b0be2247dbd 100644 --- a/tools/hrw4u/src/ast_visitor.py +++ b/tools/hrw4u/src/ast_visitor.py @@ -29,7 +29,7 @@ class ASTVisitor(hrw4uVisitor): # method has an explicit return type and full control over how # child results are assembled into parent AST nodes. - def visitProgram(self, ctx): + def visitProgram(self, ctx) -> HRW4UAST: items = [] for item in ctx.programItem(): if item.useDirective() is not None: @@ -44,13 +44,13 @@ def visitProgram(self, ctx): raise ValueError(f"Unhandled programItem alternative at line {item.start.line}") return HRW4UAST(body=tuple(items)) - def _visit_use_directive(self, ctx): + def _visit_use_directive(self, ctx) -> UseDirective: return UseDirective( spec=ctx.QUALIFIED_IDENT().getText(), line=ctx.start.line, ) - def _visit_procedure_decl(self, ctx): + def _visit_procedure_decl(self, ctx) -> ProcedureDecl: name = ctx.QUALIFIED_IDENT().getText() params = () if ctx.paramList(): @@ -58,12 +58,12 @@ def _visit_procedure_decl(self, ctx): body = tuple(self._visit_body(ctx.block().blockItem())) return ProcedureDecl(name=name, params=params, body=body, line=ctx.start.line) - def _visit_proc_param(self, ctx): + def _visit_proc_param(self, ctx) -> ProcParam: name = ctx.IDENT().getText() default = self._extract_value(ctx.value()) if ctx.value() else None return ProcParam(name=name, default=default, line=ctx.start.line) - def _visit_section(self, ctx): + def _visit_section(self, ctx) -> VarSection | Section: if ctx.varSection() is not None: return self._visit_var_section(ctx.varSection(), "txn") if ctx.sessionVarSection() is not None: @@ -72,7 +72,7 @@ def _visit_section(self, ctx): body = self._visit_body(ctx.sectionBody()) return Section(type=name, body=tuple(body), line=ctx.start.line) - def _visit_var_section(self, ctx, scope): + def _visit_var_section(self, ctx, scope) -> VarSection: decls = [] for var_item in ctx.variables().variablesItem(): if var_item.variableDecl() is not None: @@ -83,7 +83,7 @@ def _visit_var_section(self, ctx, scope): raise ValueError(f"Unhandled variablesItem alternative at line {var_item.start.line}") return VarSection(scope=scope, declarations=tuple(decls), line=ctx.start.line) - def _visit_var_decl(self, ctx): + def _visit_var_decl(self, ctx) -> VarDecl: return VarDecl( name=ctx.name.text, type_name=ctx.typeName.text, @@ -91,7 +91,7 @@ def _visit_var_decl(self, ctx): line=ctx.start.line, ) - def _visit_body(self, items): + def _visit_body(self, items) -> list[BodyNode]: """Shared helper for sectionBody and blockItem lists.""" result = [] for item in items: @@ -105,7 +105,7 @@ def _visit_body(self, items): raise ValueError(f"Unhandled body item alternative at line {item.start.line}") return result - def _visit_statement(self, ctx): + def _visit_statement(self, ctx) -> BodyNode: line = ctx.start.line if ctx.BREAK(): return Break(line=line) @@ -123,14 +123,14 @@ def _visit_statement(self, ctx): return FunctionCall(name=ctx.op.text, args=(), line=line) raise ValueError(f"Unhandled statement alternative at line {line}") - def _visit_function_call(self, ctx): + def _visit_function_call(self, ctx) -> FunctionCall: name = ctx.funcName.text args = () if ctx.argumentList(): args = tuple(self._extract_value(v) for v in ctx.argumentList().value()) return FunctionCall(name=name, args=args, line=ctx.start.line) - def _extract_value(self, ctx): + def _extract_value(self, ctx) -> ValueExpr: if ctx.number is not None: return int(ctx.number.text) if ctx.str_ is not None: @@ -149,7 +149,7 @@ def _extract_value(self, ctx): return ParamRef(raw=ctx.paramRef().IDENT().getText()) raise ValueError(f"Unhandled value alternative at line {ctx.start.line}") - def _visit_conditional(self, ctx): + def _visit_conditional(self, ctx) -> IfBlock: if_stmt = ctx.ifStatement() condition = self._visit_condition(if_stmt.condition()) block = if_stmt.block() @@ -180,10 +180,10 @@ def _visit_conditional(self, ctx): line=ctx.start.line, ) - def _visit_condition(self, ctx): + def _visit_condition(self, ctx) -> ConditionExpr: return self._visit_expression(ctx.expression()) - def _visit_expression(self, ctx): + def _visit_expression(self, ctx) -> ConditionExpr: if ctx.OR(): left = self._visit_expression(ctx.expression()) right = self._visit_term(ctx.term()) @@ -195,7 +195,7 @@ def _visit_expression(self, ctx): ) return self._visit_term(ctx.term()) - def _visit_term(self, ctx): + def _visit_term(self, ctx) -> ConditionExpr: if ctx.AND(): left = self._visit_term(ctx.term()) right = self._visit_factor(ctx.factor()) @@ -207,7 +207,7 @@ def _visit_term(self, ctx): ) return self._visit_factor(ctx.factor()) - def _visit_factor(self, ctx): + def _visit_factor(self, ctx) -> ConditionExpr: if ctx.getChildCount() == 2 and ctx.getChild(0).getText() == "!": return NotOp( operand=self._visit_factor(ctx.factor()), @@ -227,7 +227,7 @@ def _visit_factor(self, ctx): return BoolLiteral(value=False, line=ctx.start.line) raise ValueError(f"Unhandled factor alternative at line {ctx.start.line}") - def _visit_comparison(self, ctx): + def _visit_comparison(self, ctx) -> Comparison: line = ctx.start.line comp = ctx.comparable() if comp.ident is not None: @@ -247,7 +247,7 @@ def _visit_comparison(self, ctx): line=line, ) - def _detect_comparison_operator(self, ctx): + def _detect_comparison_operator(self, ctx) -> str: if ctx.EQUALS(): return "==" if ctx.NEQ(): @@ -267,7 +267,7 @@ def _detect_comparison_operator(self, ctx): return "in" raise ValueError(f"Unhandled comparison operator at line {ctx.start.line}") - def _extract_comparison_rhs(self, ctx, operator): + def _extract_comparison_rhs(self, ctx, operator) -> ValueExpr | RegexValue | tuple[ValueExpr, ...]: if operator in ("~", "!~"): return RegexValue(raw=ctx.regex().getText()[1:-1]) if operator in ("in", "!in"): @@ -279,7 +279,7 @@ def _extract_comparison_rhs(self, ctx, operator): return self._extract_value(ctx.value()) raise ValueError(f"Unhandled comparison RHS at line {ctx.start.line}") - def _extract_modifiers(self, ctx): + def _extract_modifiers(self, ctx) -> tuple[str, ...]: if ctx.modifier(): return tuple(tok.text for tok in ctx.modifier().modifierList().mods) return () From 04837adef0a8267b760469efe9f16179e118d7b1 Mon Sep 17 00:00:00 2001 From: Juan Posadas Date: Tue, 5 May 2026 14:04:41 -0600 Subject: [PATCH 13/13] Collapse short constructor calls to single lines Fits within the project's 132-char column limit per .style.yapf. --- tools/hrw4u/src/ast_visitor.py | 52 ++++++---------------------------- 1 file changed, 8 insertions(+), 44 deletions(-) diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py index b0be2247dbd..4a66ec0a710 100644 --- a/tools/hrw4u/src/ast_visitor.py +++ b/tools/hrw4u/src/ast_visitor.py @@ -45,10 +45,7 @@ def visitProgram(self, ctx) -> HRW4UAST: return HRW4UAST(body=tuple(items)) def _visit_use_directive(self, ctx) -> UseDirective: - return UseDirective( - spec=ctx.QUALIFIED_IDENT().getText(), - line=ctx.start.line, - ) + return UseDirective(spec=ctx.QUALIFIED_IDENT().getText(), line=ctx.start.line) def _visit_procedure_decl(self, ctx) -> ProcedureDecl: name = ctx.QUALIFIED_IDENT().getText() @@ -85,11 +82,7 @@ def _visit_var_section(self, ctx, scope) -> VarSection: def _visit_var_decl(self, ctx) -> VarDecl: return VarDecl( - name=ctx.name.text, - type_name=ctx.typeName.text, - slot=int(ctx.slot.text) if ctx.slot else None, - line=ctx.start.line, - ) + name=ctx.name.text, type_name=ctx.typeName.text, slot=int(ctx.slot.text) if ctx.slot else None, line=ctx.start.line) def _visit_body(self, items) -> list[BodyNode]: """Shared helper for sectionBody and blockItem lists.""" @@ -160,11 +153,7 @@ def _visit_conditional(self, ctx) -> IfBlock: elif_cond = self._visit_condition(elif_ctx.condition()) elif_block = elif_ctx.block() elif_body = tuple(self._visit_body(elif_block.blockItem())) if elif_block else () - elif_branches.append(ElifBranch( - condition=elif_cond, - body=elif_body, - line=elif_ctx.start.line, - )) + elif_branches.append(ElifBranch(condition=elif_cond, body=elif_body, line=elif_ctx.start.line)) else_body = () if ctx.elseClause(): @@ -172,13 +161,7 @@ def _visit_conditional(self, ctx) -> IfBlock: if else_block: else_body = tuple(self._visit_body(else_block.blockItem())) - return IfBlock( - condition=condition, - body=body, - elif_branches=tuple(elif_branches), - else_body=else_body, - line=ctx.start.line, - ) + return IfBlock(condition=condition, body=body, elif_branches=tuple(elif_branches), else_body=else_body, line=ctx.start.line) def _visit_condition(self, ctx) -> ConditionExpr: return self._visit_expression(ctx.expression()) @@ -187,32 +170,19 @@ def _visit_expression(self, ctx) -> ConditionExpr: if ctx.OR(): left = self._visit_expression(ctx.expression()) right = self._visit_term(ctx.term()) - return LogicalOp( - operator="||", - left=left, - right=right, - line=ctx.start.line, - ) + return LogicalOp(operator="||", left=left, right=right, line=ctx.start.line) return self._visit_term(ctx.term()) def _visit_term(self, ctx) -> ConditionExpr: if ctx.AND(): left = self._visit_term(ctx.term()) right = self._visit_factor(ctx.factor()) - return LogicalOp( - operator="&&", - left=left, - right=right, - line=ctx.start.line, - ) + return LogicalOp(operator="&&", left=left, right=right, line=ctx.start.line) return self._visit_factor(ctx.factor()) def _visit_factor(self, ctx) -> ConditionExpr: if ctx.getChildCount() == 2 and ctx.getChild(0).getText() == "!": - return NotOp( - operand=self._visit_factor(ctx.factor()), - line=ctx.start.line, - ) + return NotOp(operand=self._visit_factor(ctx.factor()), line=ctx.start.line) if ctx.LPAREN(): return self._visit_expression(ctx.expression()) if ctx.functionCall(): @@ -239,13 +209,7 @@ def _visit_comparison(self, ctx) -> Comparison: right = self._extract_comparison_rhs(ctx, operator) modifiers = self._extract_modifiers(ctx) - return Comparison( - left=left, - operator=operator, - right=right, - modifiers=modifiers, - line=line, - ) + return Comparison(left=left, operator=operator, right=right, modifiers=modifiers, line=line) def _detect_comparison_operator(self, ctx) -> str: if ctx.EQUALS():