Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 112 additions & 4 deletions tests/test-jinja.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cstdlib>

#include <nlohmann/json.hpp>
#include <sheredom/subprocess.h>

#include "jinja/runtime.h"
#include "jinja/parser.h"
Expand Down Expand Up @@ -31,12 +32,24 @@ static void test_array_methods(testing & t);
static void test_object_methods(testing & t);
static void test_fuzzing(testing & t);

static bool g_python_mode = false;

int main(int argc, char *argv[]) {
testing t(std::cout);
t.verbose = true;

if (argc >= 2) {
t.set_filter(argv[1]);
// usage: test-jinja [-py] [filter_regex]
// -py : enable python mode (use python jinja2 for rendering expected output)
// only use this for cross-checking, not for correctness
// note: the implementation of this flag is basic, only intented to be used by maintainers

for (int i = 1; i < argc; i++) {
std::string arg = argv[i];
if (arg == "-py") {
g_python_mode = true;
} else {
t.set_filter(arg);
}
}

t.test("whitespace control", test_whitespace_control);
Expand All @@ -53,7 +66,9 @@ int main(int argc, char *argv[]) {
t.test("string methods", test_string_methods);
t.test("array methods", test_array_methods);
t.test("object methods", test_object_methods);
t.test("fuzzing", test_fuzzing);
if (!g_python_mode) {
t.test("fuzzing", test_fuzzing);
}

return t.summary();
}
Expand Down Expand Up @@ -1215,7 +1230,7 @@ static void test_object_methods(testing & t) {
);
}

static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
t.test(name, [&tmpl, &vars, &expect](testing & t) {
jinja::lexer lexer;
auto lexer_res = lexer.tokenize(tmpl);
Expand Down Expand Up @@ -1248,6 +1263,99 @@ static void test_template(testing & t, const std::string & name, const std::stri
});
}

// keep this in-sync with https://github.com/huggingface/transformers/blob/main/src/transformers/utils/chat_template_utils.py
// note: we use SandboxedEnvironment instead of ImmutableSandboxedEnvironment to allow usage of in-place array methods like append() and pop()
static std::string py_script = R"(
import jinja2
import jinja2.ext as jinja2_ext
import json
import sys
from datetime import datetime
from jinja2.sandbox import SandboxedEnvironment

tmpl = json.loads(sys.argv[1])
vars_json = json.loads(sys.argv[2])

env = SandboxedEnvironment(
trim_blocks=True,
lstrip_blocks=True,
extensions=[jinja2_ext.loopcontrols],
)

def raise_exception(message):
raise jinja2.exceptions.TemplateError(message)

env.filters["tojson"] = lambda x, ensure_ascii=False, indent=None, separators=None, sort_keys=False: json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
env.globals["strftime_now"] = lambda format: datetime.now().strftime(format)
env.globals["raise_exception"] = raise_exception

template = env.from_string(tmpl)
result = template.render(**vars_json)
print(result, end='')
)";

static void test_template_py(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
t.test(name, [&tmpl, &vars, &expect](testing & t) {
// Prepare arguments
std::string tmpl_json = json(tmpl).dump();
std::string vars_json = vars.dump();

#ifdef _WIN32
const char * python_executable = "python.exe";
#else
const char * python_executable = "python3";
#endif

const char * command_line[] = {python_executable, "-c", py_script.c_str(), tmpl_json.c_str(), vars_json.c_str(), NULL};

struct subprocess_s subprocess;
int options = subprocess_option_combined_stdout_stderr
| subprocess_option_no_window
| subprocess_option_inherit_environment
| subprocess_option_search_user_path;
int result = subprocess_create(command_line, options, &subprocess);

if (result != 0) {
t.log("Failed to create subprocess, error code: " + std::to_string(result));
t.assert_true("subprocess creation", false);
return;
}

// Read output
std::string output;
char buffer[1024];
FILE * p_stdout = subprocess_stdout(&subprocess);
while (fgets(buffer, sizeof(buffer), p_stdout)) {
output += buffer;
}

int process_return;
subprocess_join(&subprocess, &process_return);
subprocess_destroy(&subprocess);

if (process_return != 0) {
t.log("Python script failed with exit code: " + std::to_string(process_return));
t.log("Output: " + output);
t.assert_true("python execution", false);
return;
}

if (!t.assert_true("Template render mismatch", expect == output)) {
t.log("Template: " + json(tmpl).dump());
t.log("Expected: " + json(expect).dump());
t.log("Python : " + json(output).dump());
}
});
}

static void test_template(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) {
if (g_python_mode) {
test_template_py(t, name, tmpl, vars, expect);
} else {
test_template_cpp(t, name, tmpl, vars, expect);
}
}

//
// fuzz tests to ensure no crashes occur on malformed inputs
//
Expand Down
Loading