diff --git a/README.md b/README.md index aa3579c..a33bc36 100644 --- a/README.md +++ b/README.md @@ -104,10 +104,11 @@ Minja supports the following subset of the [Jinja2/3 template syntax](https://ji - Full expression syntax - Statements `{{% … %}}`, variable sections `{{ … }}`, and comments `{# … #}` with pre/post space elision `{%- … -%}` / `{{- … -}}` / `{#- … -#}` - `if` / `elif` / `else` / `endif` -- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring +- `for` (`recursive`) (`if`) / `else` / `endfor` w/ `loop.*` (including `loop.cycle`) and destructuring) - `break`, `continue` (aka [loop controls extensions](https://github.com/google/minja/pull/39)) - `set` w/ namespaces & destructuring - `macro` / `endmacro` +- `call` / `endcall` - for calling macro (w/ macro arguments and `caller()` syntax) and passing a macro to another macro (w/o passing arguments back to the call block) - `filter` / `endfilter` - Extensible filters collection: `count`, `dictsort`, `equalto`, `e` / `escape`, `items`, `join`, `joiner`, `namespace`, `raise_exception`, `range`, `reject` / `rejectattr` / `select` / `selectattr`, `tojson`, `trim` diff --git a/include/minja/minja.hpp b/include/minja/minja.hpp index f04073c..5ed0556 100644 --- a/include/minja/minja.hpp +++ b/include/minja/minja.hpp @@ -706,7 +706,7 @@ enum SpaceHandling { Keep, Strip, StripSpaces, StripNewline }; class TemplateToken { public: - enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue }; + enum class Type { Text, Expression, If, Else, Elif, EndIf, For, EndFor, Generation, EndGeneration, Set, EndSet, Comment, Macro, EndMacro, Filter, EndFilter, Break, Continue, Call, EndCall }; static std::string typeToString(Type t) { switch (t) { @@ -729,6 +729,8 @@ class TemplateToken { case Type::EndGeneration: return "endgeneration"; case Type::Break: return "break"; case Type::Continue: return "continue"; + case Type::Call: return "call"; + case Type::EndCall: return "endcall"; } return "Unknown"; } @@ -846,6 +848,17 @@ struct LoopControlTemplateToken : public TemplateToken { LoopControlTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, LoopControlType control_type) : TemplateToken(Type::Break, loc, pre, post), control_type(control_type) {} }; +struct CallTemplateToken : public TemplateToken { + std::shared_ptr expr; + CallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post, std::shared_ptr && e) + : TemplateToken(Type::Call, loc, pre, post), expr(std::move(e)) {} +}; + +struct EndCallTemplateToken : public TemplateToken { + EndCallTemplateToken(const Location & loc, SpaceHandling pre, SpaceHandling post) + : TemplateToken(Type::EndCall, loc, pre, post) {} +}; + class TemplateNode { Location location_; protected: @@ -1050,31 +1063,36 @@ class MacroNode : public TemplateNode { void do_render(std::ostringstream &, const std::shared_ptr & macro_context) const override { if (!name) throw std::runtime_error("MacroNode.name is null"); if (!body) throw std::runtime_error("MacroNode.body is null"); - auto callable = Value::callable([&](const std::shared_ptr & context, ArgumentsValue & args) { - auto call_context = macro_context; + auto callable = Value::callable([this, macro_context](const std::shared_ptr & call_context, ArgumentsValue & args) { + auto execution_context = Context::make(Value::object(), macro_context); + + if (call_context->contains("caller")) { + execution_context->set("caller", call_context->get("caller")); + } + std::vector param_set(params.size(), false); for (size_t i = 0, n = args.args.size(); i < n; i++) { auto & arg = args.args[i]; if (i >= params.size()) throw std::runtime_error("Too many positional arguments for macro " + name->get_name()); param_set[i] = true; auto & param_name = params[i].first; - call_context->set(param_name, arg); + execution_context->set(param_name, arg); } for (auto & [arg_name, value] : args.kwargs) { auto it = named_param_positions.find(arg_name); if (it == named_param_positions.end()) throw std::runtime_error("Unknown parameter name for macro " + name->get_name() + ": " + arg_name); - call_context->set(arg_name, value); + execution_context->set(arg_name, value); param_set[it->second] = true; } // Set default values for parameters that were not passed for (size_t i = 0, n = params.size(); i < n; i++) { if (!param_set[i] && params[i].second != nullptr) { - auto val = params[i].second->evaluate(context); - call_context->set(params[i].first, val); + auto val = params[i].second->evaluate(call_context); + execution_context->set(params[i].first, val); } } - return body->render(call_context); + return body->render(execution_context); }); macro_context->set(name->get_name(), callable); } @@ -1611,6 +1629,40 @@ class CallExpr : public Expression { } }; +class CallNode : public TemplateNode { + std::shared_ptr expr; + std::shared_ptr body; + +public: + CallNode(const Location & loc, std::shared_ptr && e, std::shared_ptr && b) + : TemplateNode(loc), expr(std::move(e)), body(std::move(b)) {} + + void do_render(std::ostringstream & out, const std::shared_ptr & context) const override { + if (!expr) throw std::runtime_error("CallNode.expr is null"); + if (!body) throw std::runtime_error("CallNode.body is null"); + + auto caller = Value::callable([this, context](const std::shared_ptr &, ArgumentsValue &) -> Value { + return Value(body->render(context)); + }); + + context->set("caller", caller); + + auto call_expr = dynamic_cast(expr.get()); + if (!call_expr) { + throw std::runtime_error("Invalid call block syntax - expected function call"); + } + + Value function = call_expr->object->evaluate(context); + if (!function.is_callable()) { + throw std::runtime_error("Call target must be callable: " + function.dump()); + } + ArgumentsValue args = call_expr->args.evaluate(context); + + Value result = function.call(context, args); + out << result.to_str(); + } +}; + class FilterExpr : public Expression { std::vector> parts; public: @@ -2320,7 +2372,7 @@ class Parser { static std::regex comment_tok(R"(\{#([-~]?)([\s\S]*?)([-~]?)#\})"); static std::regex expr_open_regex(R"(\{\{([-~])?)"); static std::regex block_open_regex(R"(^\{%([-~])?\s*)"); - static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue)\b)"); + static std::regex block_keyword_tok(R"((if|else|elif|endif|for|endfor|generation|endgeneration|set|endset|block|endblock|macro|endmacro|filter|endfilter|break|continue|call|endcall)\b)"); static std::regex non_text_open_regex(R"(\{\{|\{%|\{#)"); static std::regex expr_close_regex(R"(\s*([-~])?\}\})"); static std::regex block_close_regex(R"(\s*([-~])?%\})"); @@ -2443,6 +2495,15 @@ class Parser { } else if (keyword == "endmacro") { auto post_space = parseBlockClose(); tokens.push_back(std::make_unique(location, pre_space, post_space)); + } else if (keyword == "call") { + auto expr = parseExpression(); + if (!expr) throw std::runtime_error("Expected expression in call block"); + + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space, std::move(expr))); + } else if (keyword == "endcall") { + auto post_space = parseBlockClose(); + tokens.push_back(std::make_unique(location, pre_space, post_space)); } else if (keyword == "filter") { auto filter = parseExpression(); if (!filter) throw std::runtime_error("Expected expression in filter block"); @@ -2575,6 +2636,12 @@ class Parser { throw unterminated(**start); } children.emplace_back(std::make_shared(token->location, std::move(macro_token->name), std::move(macro_token->params), std::move(body))); + } else if (auto call_token = dynamic_cast(token.get())) { + auto body = parseTemplate(begin, it, end); + if (it == end || (*(it++))->type != TemplateToken::Type::EndCall) { + throw unterminated(**start); + } + children.emplace_back(std::make_shared(token->location, std::move(call_token->expr), std::move(body))); } else if (auto filter_token = dynamic_cast(token.get())) { auto body = parseTemplate(begin, it, end); if (it == end || (*(it++))->type != TemplateToken::Type::EndFilter) { @@ -2588,6 +2655,7 @@ class Parser { } else if (dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) + || dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) || dynamic_cast(token.get()) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 0388a74..db82c2d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -226,6 +226,7 @@ set(MODEL_IDS OnlyCheeini/greesychat-turbo onnx-community/DeepSeek-R1-Distill-Qwen-1.5B-ONNX open-thoughts/OpenThinker-7B + openbmb/MiniCPM3-4B openchat/openchat-3.5-0106 Orenguteng/Llama-3.1-8B-Lexi-Uncensored-V2 OrionStarAI/Orion-14B-Chat @@ -261,7 +262,6 @@ set(MODEL_IDS prithivMLmods/Qwen2.5-7B-DeepSeek-R1-1M prithivMLmods/QwQ-Math-IO-500M prithivMLmods/Triangulum-v2-10B - qingy2024/Falcon3-2x10B-MoE-Instruct Qwen/QVQ-72B-Preview Qwen/Qwen1.5-7B-Chat Qwen/Qwen2-7B-Instruct diff --git a/tests/test-syntax.cpp b/tests/test-syntax.cpp index b4bf638..36bdaa3 100644 --- a/tests/test-syntax.cpp +++ b/tests/test-syntax.cpp @@ -429,6 +429,54 @@ TEST(SyntaxTest, SimpleCases) { {%- endmacro -%} {{- foo() }} {{ foo() -}})", {}, {})); + EXPECT_EQ( + "x,x", + render(R"( + {%- macro test() -%}{{ caller() }},{{ caller() }}{%- endmacro -%} + {%- call test() -%}x{%- endcall -%} + )", {}, {})); + + EXPECT_EQ( + "Outer[Inner(X)]", + render(R"( + {%- macro outer() -%}Outer[{{ caller() }}]{%- endmacro -%} + {%- macro inner() -%}Inner({{ caller() }}){%- endmacro -%} + {%- call outer() -%}{%- call inner() -%}X{%- endcall -%}{%- endcall -%} + )", {}, {})); + + EXPECT_EQ( + "
  • A
  • B
", + render(R"( + {%- macro test(prefix, suffix) -%}{{ prefix }}{{ caller() }}{{ suffix }}{%- endmacro -%} + {%- set items = ["a", "b"] -%} + {%- call test("
    ", "
") -%} + {%- for item in items -%} +
  • {{ item | upper }}
  • + {%- endfor -%} + {%- endcall -%} + )", {}, {})); + + EXPECT_EQ( + "\\n\\nclass A:\\n b: 1\\n c: 2\\n", + render(R"( + {%- macro recursive(obj) -%} + {%- set ns = namespace(content = caller()) -%} + {%- for key, value in obj.items() %} + {%- if value is mapping %} + {%- call recursive(value) -%} + {{ '\\n\\nclass ' + key.title() + ':\\n' }} + {%- endcall -%} + {%- else -%} + {%- set ns.content = ns.content + ' ' + key + ': ' + value + '\\n' -%} + {%- endif -%} + {%- endfor -%} + {{ ns.content }} + {%- endmacro -%} + + {%- call recursive({"a": {"b": "1", "c": "2"}}) -%} + {%- endcall -%} + )", {}, {})); + if (!getenv("USE_JINJA2")) { EXPECT_EQ( "Foo", @@ -576,6 +624,8 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_THAT([]() { render("{% elif 1 %}", {}, {}); }, ThrowsWithSubstr("Unexpected elif")); EXPECT_THAT([]() { render("{% endfor %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfor")); EXPECT_THAT([]() { render("{% endfilter %}", {}, {}); }, ThrowsWithSubstr("Unexpected endfilter")); + EXPECT_THAT([]() { render("{% endmacro %}", {}, {}); }, ThrowsWithSubstr("Unexpected endmacro")); + EXPECT_THAT([]() { render("{% endcall %}", {}, {}); }, ThrowsWithSubstr("Unexpected endcall")); EXPECT_THAT([]() { render("{% if 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); EXPECT_THAT([]() { render("{% for x in 1 %}", {}, {}); }, ThrowsWithSubstr("Unterminated for")); @@ -584,6 +634,12 @@ TEST(SyntaxTest, SimpleCases) { EXPECT_THAT([]() { render("{% if 1 %}{% else %}{% elif 1 %}{% endif %}", {}, {}); }, ThrowsWithSubstr("Unterminated if")); EXPECT_THAT([]() { render("{% filter trim %}", {}, {}); }, ThrowsWithSubstr("Unterminated filter")); EXPECT_THAT([]() { render("{# ", {}, {}); }, ThrowsWithSubstr("Missing end of comment tag")); + EXPECT_THAT([]() { render("{% macro test() %}", {}, {}); }, ThrowsWithSubstr("Unterminated macro")); + EXPECT_THAT([]() { render("{% call test %}", {}, {}); }, ThrowsWithSubstr("Unterminated call")); + + EXPECT_THAT([]() { + render("{%- macro test() -%}content{%- endmacro -%}{%- call test -%}caller_content{%- endcall -%}", {}, {}); + }, ThrowsWithSubstr("Invalid call block syntax - expected function call")); } EXPECT_EQ(