Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
/common/common.* @ggerganov
/common/console.* @ggerganov
/common/http.* @angt
/common/jinja/ @ngxson @CISC @aldehir
/common/llguidance.* @ggerganov
/common/log.* @ggerganov
/common/peg-parser.* @aldehir
Expand Down
83 changes: 52 additions & 31 deletions common/jinja/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,19 +775,30 @@ const func_builtins & value_array_t::get_builtins() const {
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("join() first argument must be an array");
}
value val_delim = args.get_kwarg_or_pos("d", 1);
value val_attribute = args.get_kwarg_or_pos("attribute", 2);
if (!val_attribute->is_undefined()) {
throw not_implemented_exception("array attribute join not implemented");
}
value val_delim = args.get_kwarg_or_pos("d", 1);
value attribute = args.get_kwarg_or_pos("attribute", 2);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both map, sort, join can have attribute, probably we can find a way to abstract out the attribute handling?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but didn't immediately see how since they are applied differently.

const auto & arr = args.get_pos(0)->as_array();
std::string delim = is_val<value_string>(val_delim) ? val_delim->as_string().str() : "";
const bool attr_is_int = is_val<value_int>(attribute);
if (!attribute->is_undefined() && !is_val<value_string>(attribute) && !attr_is_int) {
throw raised_exception("join() attribute must be string or integer");
}
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string delim = val_delim->is_undefined() ? "" : val_delim->as_string().str();
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
std::string result;
for (size_t i = 0; i < arr.size(); ++i) {
if (!is_val<value_string>(arr[i]) && !is_val<value_int>(arr[i]) && !is_val<value_float>(arr[i])) {
value val_arr = arr[i];
if (!attribute->is_undefined()) {
if (attr_is_int && is_val<value_array>(val_arr)) {
val_arr = val_arr->at(attr_int);
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(val_arr)) {
val_arr = val_arr->at(attr_name);
}
}
if (!is_val<value_string>(val_arr) && !is_val<value_int>(val_arr) && !is_val<value_float>(val_arr)) {
throw raised_exception("join() can only join arrays of strings or numerics");
}
result += arr[i]->as_string().str();
result += val_arr->as_string().str();
if (i < arr.size() - 1) {
result += delim;
}
Expand All @@ -802,26 +813,30 @@ const func_builtins & value_array_t::get_builtins() const {
}},
{"tojson", tojson},
{"map", [](const func_args & args) -> value {
args.ensure_count(2, 3);
args.ensure_count(2);
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("map: first argument must be an array");
}
value attribute = args.get_kwarg_or_pos("attribute", 1);
if (is_val<value_int>(attribute)) {
throw not_implemented_exception("map: integer attribute not implemented");
if (!is_val<value_kwarg>(args.get_args().at(1))) {
throw not_implemented_exception("map: filter-mapping not implemented");
}
if (!is_val<value_string>(attribute)) {
value attribute = args.get_kwarg_or_pos("attribute", 1);
const bool attr_is_int = is_val<value_int>(attribute);
if (!is_val<value_string>(attribute) && !attr_is_int) {
throw raised_exception("map: attribute must be string or integer");
}
std::string attr_name = attribute->as_string().str();
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string attr_name = attribute->as_string().str();
value default_val = args.get_kwarg("default", mk_val<value_undefined>());
auto out = mk_val<value_array>();
auto arr = args.get_pos(0)->as_array();
for (const auto & item : arr) {
if (!is_val<value_object>(item)) {
throw raised_exception("map: item is not an object");
value attr_val;
if (attr_is_int) {
attr_val = is_val<value_array>(item) ? item->at(attr_int, default_val) : default_val;
} else {
attr_val = is_val<value_object>(item) ? item->at(attr_name, default_val) : default_val;
}
value attr_val = item->at(attr_name, default_val);
out->push_back(attr_val);
}
return out;
Expand All @@ -847,29 +862,35 @@ const func_builtins & value_array_t::get_builtins() const {
return arr_editable->pop_at(index);
}},
{"sort", [](const func_args & args) -> value {
args.ensure_count(1, 3);
args.ensure_count(1, 4);
if (!is_val<value_array>(args.get_pos(0))) {
throw raised_exception("sort: first argument must be an array");
}
bool reverse = args.get_kwarg("reverse", mk_val<value_undefined>())->as_bool();
value attribute = args.get_kwarg("attribute", mk_val<value_undefined>());
std::string attr = attribute->is_undefined() ? "" : attribute->as_string().str();
value val_reverse = args.get_kwarg_or_pos("reverse", 1);
value val_case = args.get_kwarg_or_pos("case_sensitive", 2);
value attribute = args.get_kwarg_or_pos("attribute", 3);
// FIXME: sorting is currently always case sensitive
//const bool case_sensitive = val_case->as_bool(); // undefined == false
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe throw not_implemented_exception if case_sensitive is set?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the other way around, case_sensitive=False is default, but we are always case sensitive ATM.

const bool reverse = val_reverse->as_bool(); // undefined == false
const bool attr_is_int = is_val<value_int>(attribute);
const int64_t attr_int = attr_is_int ? attribute->as_int() : 0;
const std::string attr_name = attribute->is_undefined() ? "" : attribute->as_string().str();
std::vector<value> arr = cast_val<value_array>(args.get_pos(0))->as_array(); // copy
std::sort(arr.begin(), arr.end(),[&](const value & a, const value & b) {
value val_a = a;
value val_b = b;
if (!attribute->is_undefined()) {
if (!is_val<value_object>(a) || !is_val<value_object>(b)) {
throw raised_exception("sort: items are not objects");
if (attr_is_int && is_val<value_array>(a) && is_val<value_array>(b)) {
val_a = a->at(attr_int);
val_b = b->at(attr_int);
} else if (!attr_is_int && !attr_name.empty() && is_val<value_object>(a) && is_val<value_object>(b)) {
val_a = a->at(attr_name);
val_b = b->at(attr_name);
} else {
throw raised_exception("sort: unsupported object attribute comparison");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe also log the type of a and b in the message, so that it's easier to debug in the future

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't want to add std::format here since it's not used anywhere else.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you can just use string concat +, we already doing that in multiple places. Not the best, but it makes debugging much easier.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I spaced out a little, thinking I needed to format values or something. :)

I'll fix it in the next PR which will touch this code anyway.

}
val_a = attr.empty() ? a : a->at(attr);
val_b = attr.empty() ? b : b->at(attr);
}
if (reverse) {
return value_compare(val_a, val_b, value_compare_op::gt);
} else {
return !value_compare(val_a, val_b, value_compare_op::gt);
}
return value_compare(val_a, val_b, reverse ? value_compare_op::gt : value_compare_op::lt);
});
return mk_val<value_array>(arr);
}},
Expand Down Expand Up @@ -958,7 +979,7 @@ const func_builtins & value_object_t::get_builtins() const {
value val_case = args.get_kwarg_or_pos("case_sensitive", 1);
value val_by = args.get_kwarg_or_pos("by", 2);
value val_reverse = args.get_kwarg_or_pos("reverse", 3);
// FIXME: sorting is case sensitive
// FIXME: sorting is currently always case sensitive
//const bool case_sensitive = val_case->as_bool(); // undefined == false
const bool reverse = val_reverse->as_bool(); // undefined == false
if (!val_by->is_undefined()) {
Expand Down
16 changes: 14 additions & 2 deletions common/jinja/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,20 @@ struct value_t {
}
return val_obj.unordered.at(key);
}
virtual value & at(size_t index) {
if (index >= val_arr.size()) {
virtual value & at(int64_t index, value & default_val) {
if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
return default_val;
}
return val_arr[index];
}
virtual value & at(int64_t index) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure why I placed the at(...) inside value_t in the first place, probably better to place them inside value_array_t and value_object_t respectively - let's have a look in a follow-up PR

if (index < 0) {
index += val_arr.size();
}
if (index < 0 || static_cast<size_t>(index) >= val_arr.size()) {
throw std::runtime_error("Index " + std::to_string(index) + " out of bounds for array of size " + std::to_string(val_arr.size()));
}
return val_arr[index];
Expand Down
82 changes: 77 additions & 5 deletions tests/test-jinja.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,32 @@ static void test_filters(testing & t) {
"123"
);

test_template(t, "sort reverse",
"{% for i in items|sort(true) %}{{ i }}{% endfor %}",
{{"items", json::array({3, 1, 2})}},
"321"
);

test_template(t, "sort with attribute",
"{{ items|sort(attribute='name')|join(attribute='age') }}",
{{"items", json::array({
json({{"name", "c"}, {"age", 3}}),
json({{"name", "a"}, {"age", 1}}),
json({{"name", "b"}, {"age", 2}}),
})}},
"123"
);

test_template(t, "sort with numeric attribute",
"{{ items|sort(attribute=0)|join(attribute=1) }}",
{{"items", json::array({
json::array({3, "z"}),
json::array({1, "x"}),
json::array({2, "y"}),
})}},
"xyz"
);

test_template(t, "join",
"{{ items|join(', ') }}",
{{"items", json::array({"a", "b", "c"})}},
Expand Down Expand Up @@ -934,7 +960,17 @@ static void test_array_methods(testing & t) {
);

test_template(t, "array|join attribute",
"{{ arr|join(attribute=0) }}",
"{{ arr|join(attribute='age') }}",
{{"arr", json::array({
json({{"name", "a"}, {"age", 1}}),
json({{"name", "b"}, {"age", 2}}),
json({{"name", "c"}, {"age", 3}}),
})}},
"123"
);

test_template(t, "array|join numeric attribute",
"{{ arr|join(attribute=-1) }}",
{{"arr", json::array({json::array({1}), json::array({2}), json::array({3})})}},
"123"
);
Expand All @@ -957,8 +993,8 @@ static void test_array_methods(testing & t) {
"a,b,c,d"
);

test_template(t, "array.map() with attribute",
"{% for v in arr.map('age') %}{{ v }} {% endfor %}",
test_template(t, "array|map with attribute",
"{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}",
{{"arr", json::array({
json({{"name", "a"}, {"age", 1}}),
json({{"name", "b"}, {"age", 2}}),
Expand All @@ -967,8 +1003,28 @@ static void test_array_methods(testing & t) {
"1 2 3 "
);

test_template(t, "array.map() with numeric attribute",
"{% for v in arr.map(0) %}{{ v }} {% endfor %}",
test_template(t, "array|map with attribute default",
"{% for v in arr|map(attribute='age', default=3) %}{{ v }} {% endfor %}",
{{"arr", json::array({
json({{"name", "a"}, {"age", 1}}),
json({{"name", "b"}, {"age", 2}}),
json({{"name", "c"}}),
})}},
"1 2 3 "
);

test_template(t, "array|map without attribute default",
"{% for v in arr|map(attribute='age') %}{{ v }} {% endfor %}",
{{"arr", json::array({
json({{"name", "a"}, {"age", 1}}),
json({{"name", "b"}, {"age", 2}}),
json({{"name", "c"}}),
})}},
"1 2 "
);

test_template(t, "array|map with numeric attribute",
"{% for v in arr|map(attribute=0) %}{{ v }} {% endfor %}",
{{"arr", json::array({
json::array({10, "x"}),
json::array({20, "y"}),
Expand All @@ -977,6 +1033,22 @@ static void test_array_methods(testing & t) {
"10 20 30 "
);

test_template(t, "array|map with negative attribute",
"{% for v in arr|map(attribute=-1) %}{{ v }} {% endfor %}",
{{"arr", json::array({
json::array({10, "x"}),
json::array({20, "y"}),
json::array({30, "z"}),
})}},
"x y z "
);

test_template(t, "array|map with filter",
"{{ arr|map('int')|sum }}",
{{"arr", json::array({"1", "2", "3"})}},
"6"
);

// not used by any chat templates
// test_template(t, "array.insert()",
// "{% set _ = arr.insert(1, 'x') %}{{ arr|join(',') }}",
Expand Down