Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions src/IRPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,20 @@ using std::vector;
ostream &operator<<(ostream &out, const Type &type) {
switch (type.code()) {
case Type::Int:
out << "int";
out << "i";
break;
case Type::UInt:
out << "uint";
out << "u";
break;
case Type::Float:
out << "float";
out << "f";
break;
case Type::Handle:
// ensure that 'const' (etc) qualifiers are emitted when appropriate
out << "(" << type_to_c_type(type, false) << ")";
break;
case Type::BFloat:
out << "bfloat";
out << "bf";
break;
}
if (!type.is_handle()) {
Expand Down Expand Up @@ -262,8 +262,8 @@ void IRPrinter::test() {
ostringstream source;
source << allocate;
std::string correct_source =
"allocate buf[float32 * 1023] in Stack\n"
"let y = 17\n"
"allocate buf[f32 * 1023] in Stack\n"
"let y : i32 = 17\n"
"assert(y >= 3, halide_error_param_too_small_i64(\"y\", y, 3))\n"
"produce buf {\n"
" parallel (x, -2, y + 2) {\n"
Expand Down Expand Up @@ -724,12 +724,16 @@ void IRPrinter::visit(const IntImm *op) {
if (op->type == Int(32)) {
stream << imm_int(op->value);
} else {
stream << typep(op->type) << imm_int(op->value);
stream << ansi_imm_int << op->value << "_i" << op->type.bits() << ansi_reset;
}
}

void IRPrinter::visit(const UIntImm *op) {
stream << typep(op->type) << imm_int(op->value);
if (op->type.bits() == 1) {
stream << ansi_imm_int << (op->value ? "true" : "false") << ansi_reset;
} else {
stream << ansi_imm_int << op->value << "_u" << op->type.bits() << ansi_reset;
}
}

void IRPrinter::visit(const FloatImm *op) {
Expand Down Expand Up @@ -796,10 +800,20 @@ void IRPrinter::visit(const StringImm *op) {
}

void IRPrinter::visit(const Cast *op) {
stream << type(op->type);
#if 0
// More explicit style of denoting a cast, which we did not yet agree upon.
// Leaving it in as commted out code, because it might be useful at some point.
stream << kw("cast<") << type(op->type) << kw(">");
openf();
print_no_parens(op->value);
closef();
#else
std::stringstream ss;
ss << op->type;
openf(ss.str().c_str());
print_no_parens(op->value);
closef();
#endif
}

void IRPrinter::visit(const Reinterpret *op) {
Expand Down Expand Up @@ -1084,7 +1098,7 @@ void IRPrinter::visit(const Let *op) {
if (!implicit_parens) {
stream << paren("(");
}
stream << paren("let ") << var(op->name) << paren(" = ");
stream << paren("let ") << var(op->name) << paren(" : ") << paren(op->value.type()) << paren(" = ");
print(op->value);
stream << paren(" in ");
if (!is_summary) {
Expand All @@ -1098,7 +1112,7 @@ void IRPrinter::visit(const Let *op) {

void IRPrinter::visit(const LetStmt *op) {
ScopedBinding<> bind(known_type, op->name);
stream << get_indent() << kw("let ") << var(op->name) << kw(" = ");
stream << get_indent() << ansi_kw << "let " << ansi_reset << op->name << ansi_kw << " : " << op->value.type() << " = " << ansi_reset;
{
ScopedValue<int> reset_paren_depth(paren_depth, 0);
print_no_parens(op->value);
Expand Down
20 changes: 15 additions & 5 deletions src/StmtToHTML.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1429,19 +1429,25 @@ class HTMLCodePrinter : public IRVisitor {
void visit(const Cast *op) override {
std::string type_str = type_to_string(op->type);
print_opening_tag("span", "Cast");
print_opening_tag("span", "matched");
print_opening_tag("span", "matched keyword");
#if 0
print_text("cast&lt;");
print_type(op->type);
print_text("(");
print_text("&gt;");
#else
print_text(type_str);
#endif
print_closing_tag("span");
print_html_element("span", "matched", "(", type_str);
print(op->value);
print_html_element("span", "matched", ")");
print_html_element("span", "matched", ")", type_str);
print_closing_tag("span");
}

void visit(const Reinterpret *op) override {
std::string type_str = type_to_string(op->type);
print_opening_tag("span", "Reinterpret");
print_opening_tag("span", "matched Symbol", type_str);
print_opening_tag("span", "matched keyword", type_str);
print_text("reinterpret&lt;");
print_type(op->type);
print_text("&gt;");
Expand Down Expand Up @@ -1574,7 +1580,9 @@ class HTMLCodePrinter : public IRVisitor {
print_opening_tag("span", "matched");
print_text("(");
print_html_element("span", "keyword", "let ");
print_variable(op->name, op->type);
print_variable(op->name, op->value.type());
print_html_element("span", "Operator Assign", " : ");
print_type(op->type);
print_html_element("span", "Operator Assign", " = ");
print_closing_tag("span");
print(op->value);
Expand All @@ -1593,6 +1601,8 @@ class HTMLCodePrinter : public IRVisitor {
print_opening_tag("span", "matched");
print_html_element("span", "keyword", "let ");
print_variable(op->name, op->value.type());
print_html_element("span", "Operator Assign", " : ");
print_type(op->value.type());
print_html_element("span", "Operator Assign", " = ");
print_closing_tag("span"); // matched
print(op->value);
Expand Down
20 changes: 10 additions & 10 deletions test/correctness/callable_errors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ void test_bad_untyped_calls() {
expect_failure(c(Buffer<uint8_t, AnyDims>(), 2, 1.0f, result1), "Buffer argument p_img is nullptr");
expect_failure(c(Buffer<void, 2>(), 2, 1.0f, result1), "Buffer argument p_img is nullptr");
expect_failure(c(Buffer<void, AnyDims>(), 2, 1.0f, result1), "Buffer argument p_img is nullptr");
expect_failure(c(42, 2, 1.0f, result1), "Argument 1 of 4 ('p_img') was expected to be a buffer of type 'uint8' and dimension 2");
expect_failure(c(in1, 2.25, 1.0f, result1), "Argument 2 of 4 ('p_int') was expected to be a scalar of type 'int32' and dimension 0");
expect_failure(c(in1, 2, 1, result1), "Argument 3 of 4 ('p_float') was expected to be a scalar of type 'float32' and dimension 0");
expect_failure(c(42, 2, 1.0f, result1), "Argument 1 of 4 ('p_img') was expected to be a buffer of type 'u8' and dimension 2");
expect_failure(c(in1, 2.25, 1.0f, result1), "Argument 2 of 4 ('p_int') was expected to be a scalar of type 'i32' and dimension 0");
expect_failure(c(in1, 2, 1, result1), "Argument 3 of 4 ('p_float') was expected to be a scalar of type 'f32' and dimension 0");
expect_failure(c(in1, 2, 1.0f, (const halide_buffer_t *)nullptr), "Buffer argument fn1 is nullptr");
expect_failure(c(in1, 2, 1.0f, (halide_buffer_t *)nullptr), "Buffer argument fn1 is nullptr");
expect_failure(c(in1, 2, 1.0f, Buffer<const uint8_t, 2>()), "Buffer argument fn1 is nullptr");
Expand Down Expand Up @@ -112,9 +112,9 @@ void test_bad_untyped_calls() {
expect_failure(c(&context, Buffer<uint8_t, AnyDims>(), 2, 1.0f, result1), "Buffer argument p_img is nullptr");
expect_failure(c(&context, Buffer<void, 2>(), 2, 1.0f, result1), "Buffer argument p_img is nullptr");
expect_failure(c(&context, Buffer<void, AnyDims>(), 2, 1.0f, result1), "Buffer argument p_img is nullptr");
expect_failure(c(&context, 42, 2, 1.0f, result1), "Argument 1 of 4 ('p_img') was expected to be a buffer of type 'uint8' and dimension 2");
expect_failure(c(&context, in1, 2.25, 1.0f, result1), "Argument 2 of 4 ('p_int') was expected to be a scalar of type 'int32' and dimension 0");
expect_failure(c(&context, in1, 2, 1, result1), "Argument 3 of 4 ('p_float') was expected to be a scalar of type 'float32' and dimension 0");
expect_failure(c(&context, 42, 2, 1.0f, result1), "Argument 1 of 4 ('p_img') was expected to be a buffer of type 'u8' and dimension 2");
expect_failure(c(&context, in1, 2.25, 1.0f, result1), "Argument 2 of 4 ('p_int') was expected to be a scalar of type 'i32' and dimension 0");
expect_failure(c(&context, in1, 2, 1, result1), "Argument 3 of 4 ('p_float') was expected to be a scalar of type 'f32' and dimension 0");
expect_failure(c(&context, in1, 2, 1.0f, (const halide_buffer_t *)nullptr), "Buffer argument fn2 is nullptr");
expect_failure(c(&context, in1, 2, 1.0f, (halide_buffer_t *)nullptr), "Buffer argument fn2 is nullptr");
expect_failure(c(&context, in1, 2, 1.0f, Buffer<const uint8_t, 2>()), "Buffer argument fn2 is nullptr");
Expand Down Expand Up @@ -162,16 +162,16 @@ void test_bad_typed_calls() {

// Calls to make_std_function fail
c.make_std_function<bool, int32_t, float, Buffer<uint8_t, 2>>();
expect_failure(-1, "Argument 1 of 4 ('p_img') was expected to be a buffer of type 'uint8' and dimension 2");
expect_failure(-1, "Argument 1 of 4 ('p_img') was expected to be a buffer of type 'u8' and dimension 2");

c.make_std_function<Buffer<uint8_t, 2>, bool, float, Buffer<uint8_t, 2>>();
expect_failure(-1, "Argument 2 of 4 ('p_int') was expected to be a scalar of type 'int32' and dimension 0");
expect_failure(-1, "Argument 2 of 4 ('p_int') was expected to be a scalar of type 'i32' and dimension 0");

c.make_std_function<Buffer<uint8_t, 2>, int32_t, bool, Buffer<uint8_t, 2>>();
expect_failure(-1, "Argument 3 of 4 ('p_float') was expected to be a scalar of type 'float32' and dimension 0");
expect_failure(-1, "Argument 3 of 4 ('p_float') was expected to be a scalar of type 'f32' and dimension 0");

c.make_std_function<Buffer<uint8_t, 2>, int32_t, float, bool>();
expect_failure(-1, "Argument 4 of 4 ('fn3') was expected to be a buffer of type 'uint8' and dimension 2");
expect_failure(-1, "Argument 4 of 4 ('fn3') was expected to be a buffer of type 'u8' and dimension 2");
}

// Test custom error handler in the JITUserContext
Expand Down
Loading