diff --git a/tools/hrw4u/Makefile b/tools/hrw4u/Makefile index 7b929356af3..ba25de02d98 100644 --- a/tools/hrw4u/Makefile +++ b/tools/hrw4u/Makefile @@ -53,7 +53,9 @@ SRC_FILES_HRW4U=src/visitor.py \ src/suggestions.py \ src/procedures.py \ src/sandbox.py \ - src/kg_visitor.py + src/kg_visitor.py \ + src/ast_nodes.py \ + src/ast_visitor.py ALL_HRW4U_FILES=$(SHARED_FILES) $(UTILS_FILES) $(SRC_FILES_HRW4U) diff --git a/tools/hrw4u/src/ast_nodes.py b/tools/hrw4u/src/ast_nodes.py new file mode 100644 index 00000000000..acf5bacccb3 --- /dev/null +++ b/tools/hrw4u/src/ast_nodes.py @@ -0,0 +1,211 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Union + +__all__ = [ + "LiteralStringValue", + "IdentValue", + "IPValue", + "ParamRef", + "RegexValue", + "ValueExpr", + "Node", + "Target", + "Assignment", + "FunctionCall", + "Break", + "Comparison", + "LogicalOp", + "NotOp", + "BoolLiteral", + "IdentCondition", + "ElifBranch", + "IfBlock", + "Section", + "ProcParam", + "VarDecl", + "VarSection", + "UseDirective", + "ProcedureDecl", + "HRW4UAST", + "ConditionExpr", + "BodyNode", + "TopLevelNode", +] + + +@dataclass(frozen=True, kw_only=True) +class LiteralStringValue: + raw: str + + +@dataclass(frozen=True, kw_only=True) +class IdentValue: + raw: str + + +@dataclass(frozen=True, kw_only=True) +class IPValue: + raw: str + + +@dataclass(frozen=True, kw_only=True) +class ParamRef: + raw: str + + +@dataclass(frozen=True, kw_only=True) +class RegexValue: + raw: str + + +ValueExpr = Union[LiteralStringValue, IdentValue, IPValue, ParamRef, int, bool, tuple[IPValue, ...]] + + +@dataclass(frozen=True, kw_only=True) +class Node: + line: int + + +@dataclass(frozen=True) +class Target: + namespace: str | None + field: str + + @staticmethod + def from_dotted(name: str) -> Target: + # TODO: the grammar lexes dotted paths as a single IDENT token; + # ideally the grammar would split namespace/field so this + # heuristic isn't needed. + dot = name.rfind(".") + if dot == -1: + return Target(namespace=None, field=name) + return Target(namespace=name[:dot], field=name[dot + 1:]) + + +@dataclass(frozen=True, kw_only=True) +class Assignment(Node): + target: Target + operator: str # "=" or "+=" + value: ValueExpr + + +@dataclass(frozen=True, kw_only=True) +class FunctionCall(Node): + name: str + args: tuple[ValueExpr, ...] + + +@dataclass(frozen=True, kw_only=True) +class Break(Node): + pass + + +@dataclass(frozen=True, kw_only=True) +class Comparison(Node): + left: IdentValue | FunctionCall + operator: str # "==", "!=", ">", "<", "~", "!~", "in", "!in" + right: ValueExpr | RegexValue | tuple[ValueExpr, ...] + modifiers: tuple[str, ...] + + +@dataclass(frozen=True, kw_only=True) +class LogicalOp(Node): + operator: str # "&&" or "||" + left: ConditionExpr + right: ConditionExpr + + +@dataclass(frozen=True, kw_only=True) +class NotOp(Node): + operand: ConditionExpr + + +@dataclass(frozen=True, kw_only=True) +class BoolLiteral(Node): + value: bool + + +@dataclass(frozen=True, kw_only=True) +class IdentCondition(Node): + name: str + + +@dataclass(frozen=True, kw_only=True) +class ElifBranch(Node): + condition: ConditionExpr + body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class IfBlock(Node): + condition: ConditionExpr + body: tuple[BodyNode, ...] + elif_branches: tuple[ElifBranch, ...] + else_body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class Section(Node): + type: str + body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class ProcParam(Node): + name: str + default: ValueExpr | None + + +@dataclass(frozen=True, kw_only=True) +class VarDecl(Node): + name: str + type_name: str + slot: int | None + + +@dataclass(frozen=True, kw_only=True) +class VarSection(Node): + scope: str + declarations: tuple[VarDecl, ...] + + +@dataclass(frozen=True, kw_only=True) +class UseDirective(Node): + spec: str + + +@dataclass(frozen=True, kw_only=True) +class ProcedureDecl(Node): + name: str + params: tuple[ProcParam, ...] + body: tuple[BodyNode, ...] + + +@dataclass(frozen=True, kw_only=True) +class HRW4UAST: + body: tuple[TopLevelNode, ...] + + +# Type aliases: must follow all class definitions (evaluated at runtime). +ConditionExpr = Union[Comparison, LogicalOp, NotOp, BoolLiteral, IdentCondition, FunctionCall] +BodyNode = Union[Assignment, FunctionCall, IfBlock, Break] +TopLevelNode = Union[UseDirective, VarSection, ProcedureDecl, Section] diff --git a/tools/hrw4u/src/ast_visitor.py b/tools/hrw4u/src/ast_visitor.py new file mode 100644 index 00000000000..4a66ec0a710 --- /dev/null +++ b/tools/hrw4u/src/ast_visitor.py @@ -0,0 +1,249 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from hrw4u.hrw4uVisitor import hrw4uVisitor +from hrw4u.ast_nodes import * + + +class ASTVisitor(hrw4uVisitor): + """ANTLR visitor that walks an HRW4U parse tree and produces an AST for HRW4U.""" + + # Only visitProgram is overridden from the ANTLR visitor interface; + # all other traversal uses private _visit_* helpers so that each + # method has an explicit return type and full control over how + # child results are assembled into parent AST nodes. + + def visitProgram(self, ctx) -> HRW4UAST: + items = [] + for item in ctx.programItem(): + if item.useDirective() is not None: + items.append(self._visit_use_directive(item.useDirective())) + elif item.procedureDecl() is not None: + items.append(self._visit_procedure_decl(item.procedureDecl())) + elif item.section() is not None: + items.append(self._visit_section(item.section())) + elif item.commentLine() is not None: + pass + else: + raise ValueError(f"Unhandled programItem alternative at line {item.start.line}") + return HRW4UAST(body=tuple(items)) + + def _visit_use_directive(self, ctx) -> UseDirective: + return UseDirective(spec=ctx.QUALIFIED_IDENT().getText(), line=ctx.start.line) + + def _visit_procedure_decl(self, ctx) -> ProcedureDecl: + name = ctx.QUALIFIED_IDENT().getText() + params = () + if ctx.paramList(): + params = tuple(self._visit_proc_param(p) for p in ctx.paramList().param()) + body = tuple(self._visit_body(ctx.block().blockItem())) + return ProcedureDecl(name=name, params=params, body=body, line=ctx.start.line) + + def _visit_proc_param(self, ctx) -> ProcParam: + name = ctx.IDENT().getText() + default = self._extract_value(ctx.value()) if ctx.value() else None + return ProcParam(name=name, default=default, line=ctx.start.line) + + def _visit_section(self, ctx) -> VarSection | Section: + if ctx.varSection() is not None: + return self._visit_var_section(ctx.varSection(), "txn") + if ctx.sessionVarSection() is not None: + return self._visit_var_section(ctx.sessionVarSection(), "session") + name = ctx.name.text + body = self._visit_body(ctx.sectionBody()) + return Section(type=name, body=tuple(body), line=ctx.start.line) + + def _visit_var_section(self, ctx, scope) -> VarSection: + decls = [] + for var_item in ctx.variables().variablesItem(): + if var_item.variableDecl() is not None: + decls.append(self._visit_var_decl(var_item.variableDecl())) + elif var_item.commentLine() is not None: + pass + else: + raise ValueError(f"Unhandled variablesItem alternative at line {var_item.start.line}") + return VarSection(scope=scope, declarations=tuple(decls), line=ctx.start.line) + + def _visit_var_decl(self, ctx) -> VarDecl: + return VarDecl( + name=ctx.name.text, type_name=ctx.typeName.text, slot=int(ctx.slot.text) if ctx.slot else None, line=ctx.start.line) + + def _visit_body(self, items) -> list[BodyNode]: + """Shared helper for sectionBody and blockItem lists.""" + result = [] + for item in items: + if item.statement() is not None: + result.append(self._visit_statement(item.statement())) + elif item.conditional() is not None: + result.append(self._visit_conditional(item.conditional())) + elif item.commentLine() is not None: + pass + else: + raise ValueError(f"Unhandled body item alternative at line {item.start.line}") + return result + + def _visit_statement(self, ctx) -> BodyNode: + line = ctx.start.line + if ctx.BREAK(): + return Break(line=line) + if ctx.functionCall(): + return self._visit_function_call(ctx.functionCall()) + if ctx.EQUAL(): + target = Target.from_dotted(ctx.lhs.text) + value = self._extract_value(ctx.value()) + return Assignment(target=target, operator="=", value=value, line=line) + if ctx.PLUSEQUAL(): + target = Target.from_dotted(ctx.lhs.text) + value = self._extract_value(ctx.value()) + return Assignment(target=target, operator="+=", value=value, line=line) + if ctx.op: + return FunctionCall(name=ctx.op.text, args=(), line=line) + raise ValueError(f"Unhandled statement alternative at line {line}") + + def _visit_function_call(self, ctx) -> FunctionCall: + name = ctx.funcName.text + args = () + if ctx.argumentList(): + args = tuple(self._extract_value(v) for v in ctx.argumentList().value()) + return FunctionCall(name=name, args=args, line=ctx.start.line) + + def _extract_value(self, ctx) -> ValueExpr: + if ctx.number is not None: + return int(ctx.number.text) + if ctx.str_ is not None: + return LiteralStringValue(raw=ctx.str_.text[1:-1]) + if ctx.TRUE(): + return True + if ctx.FALSE(): + return False + if ctx.ident is not None: + return IdentValue(raw=ctx.ident.text) + if ctx.ip(): + return IPValue(raw=ctx.ip().getText()) + if ctx.iprange(): + return tuple(IPValue(raw=ip.getText()) for ip in ctx.iprange().ip()) + if ctx.paramRef(): + return ParamRef(raw=ctx.paramRef().IDENT().getText()) + raise ValueError(f"Unhandled value alternative at line {ctx.start.line}") + + def _visit_conditional(self, ctx) -> IfBlock: + if_stmt = ctx.ifStatement() + condition = self._visit_condition(if_stmt.condition()) + block = if_stmt.block() + body = tuple(self._visit_body(block.blockItem())) if block else () + + elif_branches = [] + for elif_ctx in ctx.elifClause(): + elif_cond = self._visit_condition(elif_ctx.condition()) + elif_block = elif_ctx.block() + elif_body = tuple(self._visit_body(elif_block.blockItem())) if elif_block else () + elif_branches.append(ElifBranch(condition=elif_cond, body=elif_body, line=elif_ctx.start.line)) + + else_body = () + if ctx.elseClause(): + else_block = ctx.elseClause().block() + if else_block: + else_body = tuple(self._visit_body(else_block.blockItem())) + + return IfBlock(condition=condition, body=body, elif_branches=tuple(elif_branches), else_body=else_body, line=ctx.start.line) + + def _visit_condition(self, ctx) -> ConditionExpr: + return self._visit_expression(ctx.expression()) + + def _visit_expression(self, ctx) -> ConditionExpr: + if ctx.OR(): + left = self._visit_expression(ctx.expression()) + right = self._visit_term(ctx.term()) + return LogicalOp(operator="||", left=left, right=right, line=ctx.start.line) + return self._visit_term(ctx.term()) + + def _visit_term(self, ctx) -> ConditionExpr: + if ctx.AND(): + left = self._visit_term(ctx.term()) + right = self._visit_factor(ctx.factor()) + return LogicalOp(operator="&&", left=left, right=right, line=ctx.start.line) + return self._visit_factor(ctx.factor()) + + def _visit_factor(self, ctx) -> ConditionExpr: + if ctx.getChildCount() == 2 and ctx.getChild(0).getText() == "!": + return NotOp(operand=self._visit_factor(ctx.factor()), line=ctx.start.line) + if ctx.LPAREN(): + return self._visit_expression(ctx.expression()) + if ctx.functionCall(): + return self._visit_function_call(ctx.functionCall()) + if ctx.comparison(): + return self._visit_comparison(ctx.comparison()) + if ctx.ident is not None: + return IdentCondition(name=ctx.ident.text, line=ctx.start.line) + if ctx.TRUE(): + return BoolLiteral(value=True, line=ctx.start.line) + if ctx.FALSE(): + return BoolLiteral(value=False, line=ctx.start.line) + raise ValueError(f"Unhandled factor alternative at line {ctx.start.line}") + + def _visit_comparison(self, ctx) -> Comparison: + line = ctx.start.line + comp = ctx.comparable() + if comp.ident is not None: + left = IdentValue(raw=comp.ident.text) + else: + left = self._visit_function_call(comp.functionCall()) + + operator = self._detect_comparison_operator(ctx) + right = self._extract_comparison_rhs(ctx, operator) + modifiers = self._extract_modifiers(ctx) + + return Comparison(left=left, operator=operator, right=right, modifiers=modifiers, line=line) + + def _detect_comparison_operator(self, ctx) -> str: + if ctx.EQUALS(): + return "==" + if ctx.NEQ(): + return "!=" + if ctx.GT(): + return ">" + if ctx.LT(): + return "<" + if ctx.TILDE(): + return "~" + if ctx.NOT_TILDE(): + return "!~" + if ctx.IN(): + for child in ctx.children: + if hasattr(child, "getText") and child.getText() == "!": + return "!in" + return "in" + raise ValueError(f"Unhandled comparison operator at line {ctx.start.line}") + + def _extract_comparison_rhs(self, ctx, operator) -> ValueExpr | RegexValue | tuple[ValueExpr, ...]: + if operator in ("~", "!~"): + return RegexValue(raw=ctx.regex().getText()[1:-1]) + if operator in ("in", "!in"): + if ctx.set_(): + return tuple(self._extract_value(v) for v in ctx.set_().value()) + if ctx.iprange(): + return tuple(IPValue(raw=ip.getText()) for ip in ctx.iprange().ip()) + if ctx.value(): + return self._extract_value(ctx.value()) + raise ValueError(f"Unhandled comparison RHS at line {ctx.start.line}") + + def _extract_modifiers(self, ctx) -> tuple[str, ...]: + if ctx.modifier(): + return tuple(tok.text for tok in ctx.modifier().modifierList().mods) + return () diff --git a/tools/hrw4u/tests/test_ast_nodes.py b/tools/hrw4u/tests/test_ast_nodes.py new file mode 100644 index 00000000000..d76d4a89b26 --- /dev/null +++ b/tools/hrw4u/tests/test_ast_nodes.py @@ -0,0 +1,41 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hrw4u.ast_nodes import Target + + +class TestTarget: + + def test_dotted_path(self): + t = Target.from_dotted("inbound.req.X-Foo") + assert t.namespace == "inbound.req" + assert t.field == "X-Foo" + + def test_two_segments(self): + t = Target.from_dotted("inbound.ip") + assert t.namespace == "inbound" + assert t.field == "ip" + + def test_no_dots(self): + t = Target.from_dotted("bool_0") + assert t.namespace is None + assert t.field == "bool_0" + + def test_deep_namespace(self): + t = Target.from_dotted("http.cntl.TXN_DEBUG") + assert t.namespace == "http.cntl" + assert t.field == "TXN_DEBUG" diff --git a/tools/hrw4u/tests/test_ast_visitor.py b/tools/hrw4u/tests/test_ast_visitor.py new file mode 100644 index 00000000000..ec919d1f060 --- /dev/null +++ b/tools/hrw4u/tests/test_ast_visitor.py @@ -0,0 +1,740 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hrw4u.ast_nodes import * +from utils import parse_input_text +from hrw4u.ast_visitor import ASTVisitor + + +def _build(source: str) -> HRW4UAST: + _, tree = parse_input_text(source) + return ASTVisitor().visit(tree) + + +class TestAssignments: + + def test_simple_assignment(self): + ast = _build('REMAP {\n inbound.req.X-Foo = "test";\n}') + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.target == Target.from_dotted("inbound.req.X-Foo") + assert a.operator == "=" + assert a.value == LiteralStringValue(raw="test") + + def test_bool_value(self): + ast = _build('SEND_RESPONSE {\n http.cntl.TXN_DEBUG = true;\n}') + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.value is True + + def test_int_value(self): + ast = _build('REMAP {\n http.cntl.INTERCEPT_RETRY = 1;\n}') + a = ast.body[0].body[0] + assert a.value == 1 + + def test_plus_equals(self): + ast = _build('REMAP {\n inbound.req.X-Foo += "extra";\n}') + a = ast.body[0].body[0] + assert a.operator == "+=" + + def test_ip_value(self): + ast = _build('REMAP {\n inbound.req.X-IP = 10.0.0.1;\n}') + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.value == IPValue(raw="10.0.0.1") + + def test_param_ref_value(self): + src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = $tag;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + a = ast.body[0].body[0] + assert isinstance(a, Assignment) + assert a.value == ParamRef(raw="tag") + + +class TestFunctionCalls: + + def test_no_args(self): + ast = _build('REMAP {\n set-debug();\n}') + fc = ast.body[0].body[0] + assert isinstance(fc, FunctionCall) + assert fc.name == "set-debug" + assert fc.args == () + + def test_with_args(self): + ast = _build('REMAP {\n set-header("X-Foo", "bar");\n}') + fc = ast.body[0].body[0] + assert fc.name == "set-header" + assert fc.args == (LiteralStringValue(raw="X-Foo"), LiteralStringValue(raw="bar")) + + def test_standalone_operator(self): + ast = _build('REMAP {\n skip-remap;\n}') + fc = ast.body[0].body[0] + assert isinstance(fc, FunctionCall) + assert fc.name == "skip-remap" + assert fc.args == () + + def test_break(self): + ast = _build('REMAP {\n if true {\n break;\n }\n}') + body = ast.body[0].body[0].body + assert isinstance(body[0], Break) + + +class TestSections: + + def test_comments_in_section_body_skipped(self): + src = 'REMAP {\n # a comment\n set-debug();\n # another comment\n}' + ast = _build(src) + assert len(ast.body[0].body) == 1 + + def test_comments_in_block_skipped(self): + src = 'REMAP {\n if true {\n # comment\n set-debug();\n }\n}' + ast = _build(src) + assert len(ast.body[0].body[0].body) == 1 + + def test_section_type(self): + ast = _build('REMAP {\n set-debug();\n}') + s = ast.body[0] + assert isinstance(s, Section) + assert s.type == "REMAP" + + def test_multiple_sections(self): + src = 'REMAP {\n set-debug();\n}\nSEND_RESPONSE {\n set-debug();\n}' + ast = _build(src) + sections = [i for i in ast.body if isinstance(i, Section)] + assert len(sections) == 2 + assert sections[0].type == "REMAP" + assert sections[1].type == "SEND_RESPONSE" + + def test_use_directive(self): + src = 'use test::add-debug-header\nREMAP {\n test::add-debug-header("tag");\n}' + ast = _build(src) + assert len(ast.body) == 2 + u = ast.body[0] + assert isinstance(u, UseDirective) + assert u.spec == "test::add-debug-header" + + def test_item_ordering(self): + src = 'VARS {\n x: bool;\n}\nREMAP {\n set-debug();\n}\nSEND_RESPONSE {\n set-debug();\n}' + ast = _build(src) + assert len(ast.body) == 3 + assert isinstance(ast.body[0], VarSection) + assert isinstance(ast.body[1], Section) + assert isinstance(ast.body[2], Section) + + +class TestVarSections: + + def test_comments_in_var_section_skipped(self): + src = 'VARS {\n # comment\n x: bool;\n # another\n y: int;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert len(vs.declarations) == 2 + + def test_txn_scope(self): + src = 'VARS {\n flag: bool;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert vs.scope == "txn" + assert len(vs.declarations) == 1 + assert vs.declarations[0].name == "flag" + assert vs.declarations[0].type_name == "bool" + assert vs.declarations[0].slot is None + + def test_session_scope(self): + src = 'SESSION_VARS {\n counter: int;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert vs.scope == "session" + assert vs.declarations[0].name == "counter" + + def test_slot(self): + src = 'VARS {\n x: int @3;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert vs.declarations[0].slot == 3 + + def test_multiple_declarations(self): + src = 'VARS {\n a: bool;\n b: int;\n c: string;\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + vs = ast.body[0] + assert isinstance(vs, VarSection) + assert len(vs.declarations) == 3 + assert vs.declarations[0].name == "a" + assert vs.declarations[1].name == "b" + assert vs.declarations[2].name == "c" + + +class TestProcedures: + + def test_basic_decl(self): + src = 'procedure local::stamp($tag) {\n inbound.req.X-Stamp = "$tag";\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + pd = ast.body[0] + assert isinstance(pd, ProcedureDecl) + assert pd.name == "local::stamp" + assert len(pd.params) == 1 + assert pd.params[0].name == "tag" + assert pd.params[0].default is None + + def test_default_param(self): + src = 'procedure local::cache($ttl=300) {\n set-debug();\n}\nREMAP {\n set-debug();\n}' + ast = _build(src) + pd = ast.body[0] + assert isinstance(pd, ProcedureDecl) + assert pd.params[0].name == "ttl" + assert pd.params[0].default == 300 + + def test_body(self): + src = ('procedure local::multi() {\n inbound.req.X = "a";\n' + ' set-debug();\n}\nREMAP {\n set-debug();\n}') + ast = _build(src) + pd = ast.body[0] + assert isinstance(pd, ProcedureDecl) + assert len(pd.body) == 2 + assert isinstance(pd.body[0], Assignment) + assert isinstance(pd.body[1], FunctionCall) + + +class TestConditionExpressions: + + def _first_condition(self, source: str): + ast = _build(source) + return ast.body[0].body[0].condition + + def test_equality_comparison(self): + cond = self._first_condition('REMAP {\n if inbound.req.X-Foo == "bar" {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.left == IdentValue(raw="inbound.req.X-Foo") + assert cond.operator == "==" + assert cond.right == LiteralStringValue(raw="bar") + assert cond.modifiers == () + + def test_regex_comparison(self): + cond = self._first_condition('REMAP {\n if inbound.url.path ~ /\\.php$/ {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.operator == "~" + assert isinstance(cond.right, RegexValue) + + def test_in_set(self): + cond = self._first_condition('REMAP {\n if inbound.url.path in ["a", "b"] {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert cond.right == (LiteralStringValue(raw="a"), LiteralStringValue(raw="b")) + + def test_not_in_set(self): + cond = self._first_condition('REMAP {\n if inbound.url.path !in ["a"] {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.operator == "!in" + + def test_in_iprange(self): + cond = self._first_condition('REMAP {\n if inbound.ip in {10.0.0.0/8} {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert cond.right == (IPValue(raw="10.0.0.0/8"),) + + def test_modifiers(self): + cond = self._first_condition('REMAP {\n if inbound.req.X-Foo == "bar" with NOCASE {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.modifiers == ("NOCASE",) + + def test_modifiers_preserve_source_casing(self): + cond = self._first_condition('REMAP {\n if inbound.req.X-Foo == "bar" with nocase,Pre {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.modifiers == ("nocase", "Pre") + + def test_function_call_comparable(self): + cond = self._first_condition('REMAP {\n if url(true) ~ /pat/ {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert isinstance(cond.left, FunctionCall) + assert cond.left.name == "url" + assert cond.left.args == (True,) + + def test_bool_literal_true(self): + cond = self._first_condition('REMAP {\n if true {\n set-debug();\n }\n}') + assert isinstance(cond, BoolLiteral) + assert cond.value is True + + def test_ident_condition(self): + cond = self._first_condition('REMAP {\n if inbound.resp.All-Cache {\n set-debug();\n }\n}') + assert isinstance(cond, IdentCondition) + assert cond.name == "inbound.resp.All-Cache" + + def test_not_condition(self): + cond = self._first_condition('REMAP {\n if !inbound.resp.All-Cache {\n set-debug();\n }\n}') + assert isinstance(cond, NotOp) + assert isinstance(cond.operand, IdentCondition) + + def test_and_condition(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-A == "a" && inbound.req.X-B == "b" {\n set-debug();\n }\n}') + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, Comparison) + assert isinstance(cond.right, Comparison) + + def test_or_condition(self): + cond = self._first_condition( + 'REMAP {\n if inbound.req.X-A == "a" || inbound.req.X-B == "b" {\n set-debug();\n }\n}') + assert isinstance(cond, LogicalOp) + assert cond.operator == "||" + + def test_function_call_in_condition(self): + cond = self._first_condition('REMAP {\n if access("/tmp/bar") {\n set-debug();\n }\n}') + assert isinstance(cond, FunctionCall) + assert cond.name == "access" + assert cond.args == (LiteralStringValue(raw="/tmp/bar"),) + + def test_not_tilde_comparison(self): + cond = self._first_condition('REMAP {\n if inbound.url.path !~ /\\.jpg$/ {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.operator == "!~" + assert isinstance(cond.right, RegexValue) + + def test_greater_than_comparison(self): + cond = self._first_condition('REMAP {\n if inbound.req.Content-Length > 1000 {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.operator == ">" + assert cond.right == 1000 + + def test_less_than_comparison(self): + cond = self._first_condition('REMAP {\n if inbound.req.Content-Length < 500 {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.operator == "<" + assert cond.right == 500 + + def test_neq_comparison(self): + cond = self._first_condition('REMAP {\n if inbound.req.X-Foo != "bar" {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.operator == "!=" + assert cond.right == LiteralStringValue(raw="bar") + + def test_parenthesized_condition(self): + cond = self._first_condition('REMAP {\n if (inbound.req.X-Foo == "bar") {\n set-debug();\n }\n}') + assert isinstance(cond, Comparison) + assert cond.operator == "==" + assert cond.right == LiteralStringValue(raw="bar") + + def test_and_binds_tighter_than_or(self): + # a || b && c should parse as a || (b && c) + cond = self._first_condition( + 'REMAP {\n' + ' if inbound.req.X-A == "a" || inbound.req.X-B == "b" && inbound.req.X-C == "c" {\n' + ' set-debug();\n }\n}') + assert isinstance(cond, LogicalOp) + assert cond.operator == "||" + assert isinstance(cond.left, Comparison) + assert cond.left.left == IdentValue(raw="inbound.req.X-A") + assert isinstance(cond.right, LogicalOp) + assert cond.right.operator == "&&" + assert cond.right.left.left == IdentValue(raw="inbound.req.X-B") + assert cond.right.right.left == IdentValue(raw="inbound.req.X-C") + + def test_not_with_and(self): + # !ident && comparison should parse as (!ident) && comparison + cond = self._first_condition( + 'REMAP {\n' + ' if !inbound.resp.All-Cache && inbound.req.X-B == "b" {\n' + ' set-debug();\n }\n}') + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, NotOp) + assert isinstance(cond.left.operand, IdentCondition) + assert cond.left.operand.name == "inbound.resp.All-Cache" + assert isinstance(cond.right, Comparison) + assert cond.right.left == IdentValue(raw="inbound.req.X-B") + + def test_not_comparison_with_or(self): + # !(a == "x") || b == "y" should parse as (!(a == "x")) || (b == "y") + cond = self._first_condition( + 'REMAP {\n' + ' if !(inbound.req.X-A == "x") || inbound.req.X-B == "y" {\n' + ' set-debug();\n }\n}') + assert isinstance(cond, LogicalOp) + assert cond.operator == "||" + assert isinstance(cond.left, NotOp) + assert isinstance(cond.left.operand, Comparison) + assert cond.left.operand.left == IdentValue(raw="inbound.req.X-A") + assert cond.left.operand.right == LiteralStringValue(raw="x") + assert isinstance(cond.right, Comparison) + assert cond.right.left == IdentValue(raw="inbound.req.X-B") + + def test_double_negation(self): + cond = self._first_condition('REMAP {\n if !!inbound.resp.All-Cache {\n set-debug();\n }\n}') + assert isinstance(cond, NotOp) + assert isinstance(cond.operand, NotOp) + assert isinstance(cond.operand.operand, IdentCondition) + assert cond.operand.operand.name == "inbound.resp.All-Cache" + + def test_not_bool_literal(self): + cond = self._first_condition('REMAP {\n if !false {\n set-debug();\n }\n}') + assert isinstance(cond, NotOp) + assert isinstance(cond.operand, BoolLiteral) + assert cond.operand.value is False + + def test_parens_override_precedence(self): + # (a || b) && c — parens force || to bind first + cond = self._first_condition( + 'REMAP {\n' + ' if (inbound.req.X-A == "a" || inbound.req.X-B == "b") && inbound.req.X-C == "c" {\n' + ' set-debug();\n }\n}') + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, LogicalOp) + assert cond.left.operator == "||" + assert cond.left.left.left == IdentValue(raw="inbound.req.X-A") + assert cond.left.right.left == IdentValue(raw="inbound.req.X-B") + assert isinstance(cond.right, Comparison) + assert cond.right.left == IdentValue(raw="inbound.req.X-C") + + def test_nested_parens_with_not(self): + # !(a == "x" || b == "y") && c == "z" + cond = self._first_condition( + 'REMAP {\n' + ' if !(inbound.req.X-A == "x" || inbound.req.X-B == "y") && inbound.req.X-C == "z" {\n' + ' set-debug();\n }\n}') + assert isinstance(cond, LogicalOp) + assert cond.operator == "&&" + assert isinstance(cond.left, NotOp) + assert isinstance(cond.left.operand, LogicalOp) + assert cond.left.operand.operator == "||" + assert isinstance(cond.right, Comparison) + assert cond.right.left == IdentValue(raw="inbound.req.X-C") + + +class TestIfBlocks: + + def test_simple_if(self): + ast = _build('REMAP {\n if true {\n inbound.req.X = "y";\n }\n}') + ib = ast.body[0].body[0] + assert isinstance(ib, IfBlock) + assert len(ib.body) == 1 + assert ib.elif_branches == () + assert ib.else_body == () + + def test_if_else(self): + src = 'REMAP {\n if true {\n inbound.req.X = "a";\n } else {\n inbound.req.X = "b";\n }\n}' + ast = _build(src) + ib = ast.body[0].body[0] + assert len(ib.else_body) == 1 + + def test_if_elif_else(self): + src = ( + 'SEND_RESPONSE {\n if inbound.url.path == "foo" {\n' + ' inbound.resp.X = "f";\n } elif inbound.url.path == "bar" {\n' + ' inbound.resp.X = "b";\n } else {\n' + ' inbound.resp.X = "other";\n }\n}') + ast = _build(src) + ib = ast.body[0].body[0] + assert isinstance(ib, IfBlock) + assert len(ib.elif_branches) == 1 + assert isinstance(ib.elif_branches[0], ElifBranch) + assert len(ib.elif_branches[0].body) == 1 + assert len(ib.else_body) == 1 + + def test_multiple_elif(self): + src = ( + 'SEND_RESPONSE {\n if inbound.url.path == "a" {\n set-debug();\n' + ' } elif inbound.url.path == "b" {\n set-debug();\n' + ' } elif inbound.url.path == "c" {\n set-debug();\n' + ' } else {\n set-debug();\n }\n}') + ast = _build(src) + ib = ast.body[0].body[0] + assert len(ib.elif_branches) == 2 + + def test_nested_if(self): + src = ( + 'REMAP {\n if inbound.req.X == "a" {\n' + ' if inbound.req.Y == "b" {\n set-debug();\n }\n }\n}') + ast = _build(src) + outer = ast.body[0].body[0] + assert isinstance(outer, IfBlock) + inner = outer.body[0] + assert isinstance(inner, IfBlock) + + def test_mixed_body(self): + src = ( + 'REMAP {\n inbound.req.X = "before";\n' + ' if true {\n set-debug();\n }\n' + ' inbound.req.Y = "after";\n}') + ast = _build(src) + body = ast.body[0].body + assert len(body) == 3 + assert isinstance(body[0], Assignment) + assert isinstance(body[1], IfBlock) + assert isinstance(body[2], Assignment) + + +class TestLineNumbers: + SRC = ( + "use test::helper\n" # line 1 + "VARS {\n" # line 2 + " flag: bool;\n" # line 3 + "}\n" # line 4 + "procedure local::stamp($tag) {\n" # line 5 + " inbound.req.X-Stamp = $tag;\n" # line 6 + "}\n" # line 7 + "REMAP {\n" # line 8 + ' inbound.req.X-Foo = "val";\n' # line 9 + " set-debug();\n" # line 10 + " skip-remap;\n" # line 11 + ' if inbound.req.X-A == "a" {\n' # line 12 + " break;\n" # line 13 + ' } elif inbound.req.X-B == "b" {\n' # line 14 + ' inbound.req.X = "elif";\n' # line 15 + " } else {\n" # line 16 + ' inbound.req.X = "else";\n' # line 17 + " }\n" # line 18 + ' if inbound.req.X-C == "c" && inbound.req.X-D == "d" {\n' # line 19 + " set-debug();\n" # line 20 + " }\n" # line 21 + " if !inbound.resp.All-Cache {\n" # line 22 + " set-debug();\n" # line 23 + " }\n" # line 24 + " if true {\n" # line 25 + " set-debug();\n" # line 26 + " }\n" # line 27 + " if inbound.resp.All-Cache {\n" # line 28 + " set-debug();\n" # line 29 + " }\n" # line 30 + "}\n" # line 31 + ) + + def setup_method(self): + self.ast = _build(self.SRC) + + def test_use_directive(self): + u = self.ast.body[0] + assert isinstance(u, UseDirective) + assert u.line == 1 + + def test_var_section(self): + vs = self.ast.body[1] + assert isinstance(vs, VarSection) + assert vs.line == 2 + + def test_var_decl(self): + vd = self.ast.body[1].declarations[0] + assert isinstance(vd, VarDecl) + assert vd.line == 3 + + def test_procedure_decl(self): + pd = self.ast.body[2] + assert isinstance(pd, ProcedureDecl) + assert pd.line == 5 + + def test_proc_param(self): + pp = self.ast.body[2].params[0] + assert isinstance(pp, ProcParam) + assert pp.line == 5 + + def test_procedure_body_assignment(self): + a = self.ast.body[2].body[0] + assert isinstance(a, Assignment) + assert a.line == 6 + + def test_section(self): + s = self.ast.body[3] + assert isinstance(s, Section) + assert s.line == 8 + + def test_assignment(self): + a = self.ast.body[3].body[0] + assert isinstance(a, Assignment) + assert a.line == 9 + + def test_function_call(self): + fc = self.ast.body[3].body[1] + assert isinstance(fc, FunctionCall) + assert fc.line == 10 + + def test_standalone_operator(self): + fc = self.ast.body[3].body[2] + assert isinstance(fc, FunctionCall) + assert fc.line == 11 + + def test_if_block(self): + ib = self.ast.body[3].body[3] + assert isinstance(ib, IfBlock) + assert ib.line == 12 + + def test_comparison_in_condition(self): + cond = self.ast.body[3].body[3].condition + assert isinstance(cond, Comparison) + assert cond.line == 12 + + def test_break(self): + brk = self.ast.body[3].body[3].body[0] + assert isinstance(brk, Break) + assert brk.line == 13 + + def test_elif_branch(self): + eb = self.ast.body[3].body[3].elif_branches[0] + assert isinstance(eb, ElifBranch) + assert eb.line == 14 + + def test_elif_condition(self): + cond = self.ast.body[3].body[3].elif_branches[0].condition + assert isinstance(cond, Comparison) + assert cond.line == 14 + + def test_logical_op(self): + cond = self.ast.body[3].body[4].condition + assert isinstance(cond, LogicalOp) + assert cond.line == 19 + + def test_not_op(self): + cond = self.ast.body[3].body[5].condition + assert isinstance(cond, NotOp) + assert cond.line == 22 + + def test_bool_literal(self): + cond = self.ast.body[3].body[6].condition + assert isinstance(cond, BoolLiteral) + assert cond.line == 25 + + def test_ident_condition(self): + cond = self.ast.body[3].body[7].condition + assert isinstance(cond, IdentCondition) + assert cond.line == 28 + + +class TestRealConfigs: + + def test_nested_ifs_from_test_data(self): + """Validates AST for tests/data/conds/nested-ifs.input.txt pattern.""" + src = '''VARS { + bool_0: bool; + bool_1: bool; + bool_2: bool; +} + +REMAP { + if inbound.req.X-Foo == "bar" { + inbound.req.X-Hello = "there"; + if inbound.req.X-Fie == "fie" { + inbound.req.X-first = "1"; + if bool_0 || (bool_1 && bool_2) { + inbound.req.X-Parsed = "more"; + } else { + inbound.req.X-Parsed = "yes"; + } + } elif inbound.req.X-Fum == "bar" { + inbound.req.X-Parsed = "no"; + } else { + inbound.req.X-More = "yes"; + } + } elif inbound.req.X-Foo == "foo" with NOCASE,PRE { + inbound.req.X-Nocase = "foo"; + } else { + inbound.req.X-Something = "no-bar"; + } +}''' + ast = _build(src) + sections = [i for i in ast.body if isinstance(i, Section)] + assert len(sections) == 1 + s = sections[0] + assert s.type == "REMAP" + + # Top-level if block + outer = s.body[0] + assert isinstance(outer, IfBlock) + + # Body: assignment + nested if + assert isinstance(outer.body[0], Assignment) + assert isinstance(outer.body[1], IfBlock) + middle = outer.body[1] + + # Middle if has elif and else + assert len(middle.elif_branches) == 1 + assert len(middle.else_body) == 1 + + # Deepest nested if (3 levels) + inner = middle.body[1] + assert isinstance(inner, IfBlock) + assert isinstance(inner.condition, LogicalOp) + assert inner.condition.operator == "||" + + # Outer elif has modifiers + assert len(outer.elif_branches) == 1 + elif_cond = outer.elif_branches[0].condition + assert isinstance(elif_cond, Comparison) + assert elif_cond.modifiers == ("NOCASE", "PRE") + + def test_http_cntl_booleans(self): + """Validates value coercion for boolean-like assignments.""" + src = '''SEND_RESPONSE { + http.cntl.TXN_DEBUG = true; + http.cntl.LOGGING = FALSE; +}''' + ast = _build(src) + body = ast.body[0].body + assert body[0].value is True + assert body[1].value is False + + def test_ip_range_condition(self): + """Validates IP range handling from tests/data/conds/ip.input.txt.""" + src = '''SEND_REQUEST { + if inbound.ip in {192.168.0.0/16, 10.0.0.0/8} { + set-debug(); + } +}''' + ast = _build(src) + cond = ast.body[0].body[0].condition + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert len(cond.right) == 2 + + def test_set_membership_with_modifier(self): + """From tests/data/conds/in-sets.input.txt.""" + src = '''REMAP { + if inbound.url.path in ["php", "php3", "php4"] with EXT { + inbound.req.X-Is-PHP = "yes"; + } +}''' + ast = _build(src) + cond = ast.body[0].body[0].condition + assert isinstance(cond, Comparison) + assert cond.operator == "in" + assert cond.right == (LiteralStringValue(raw="php"), LiteralStringValue(raw="php3"), LiteralStringValue(raw="php4")) + assert cond.modifiers == ("EXT",) + + def test_debug_pattern_for_lint_rules(self): + """Validates the exact pattern the no-debug lint rule will match.""" + src = '''REMAP { + set-debug(); + http.cntl.TXN_DEBUG = true; + inbound.req.X-Foo = "test"; +}''' + ast = _build(src) + body = ast.body[0].body + + # set-debug() function call + assert isinstance(body[0], FunctionCall) + assert body[0].name == "set-debug" + + # TXN_DEBUG assignment with True + assert isinstance(body[1], Assignment) + assert body[1].target == Target.from_dotted("http.cntl.TXN_DEBUG") + assert body[1].value is True + + # Regular assignment (not flagged) + assert isinstance(body[2], Assignment) + assert body[2].target.namespace == "inbound.req"