diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index e7f733864cc5..069ec7f3ea41 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -30,8 +30,42 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // opt_realize.defined() ? opt_realize.value().get() : nullptr; const ObjectPathNode* realize_p = opt_realize_p.defined() ? opt_realize_p.get() : nullptr; // Step 1. Handle block var and block bindings - int n_vars = block->iter_vars.size(); - for (int i = 0; i < n_vars; ++i) { + // Step 1.1. Obtain all loop var defined along path + std::unordered_map loop_vars; + for (Frame f : d->frames) { + if (const auto* tir_f = f.as()) { + if (const auto* for_loop = tir_f->tir.as()) { + for (const tir::ForNode* l = for_loop; l != nullptr; l = l->body.as()) { + loop_vars.insert(std::make_pair(l->loop_var.get(), GetRef(l))); + } + } + } + } + + std::vector remap_vars_indices; + auto add_remapped_iter_var = [&](int i) -> bool { + if (realize) { + tir::ExprDeepEqual expr_equal; + tir::IterVar iter_var = block->iter_vars[i]; + PrimExpr value = realize->iter_values[i]; + if (iter_var->iter_type == tir::IterVarType::kDataPar || + iter_var->iter_type == tir::IterVarType::kCommReduce) { + if (const auto* var = value.as()) { + if (loop_vars.count(var)) { + tir::For for_loop = loop_vars.at(var); + if (expr_equal(for_loop->min, iter_var->dom->min) && + expr_equal(for_loop->extent, iter_var->dom->extent)) { + remap_vars_indices.push_back(i); + return true; + } + } + } + } + } + return false; + }; + + auto print_single_iter_var = [&](int i) { tir::IterVar iter_var = block->iter_vars[i]; ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i); ExprDoc rhs = TIR("axis"); @@ -66,7 +100,49 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // rhs = rhs->Call({dom}); } (*frame)->stmts.push_back(AssignDoc(DefineVar(iter_var->var, *frame, d), rhs, NullOpt)); + }; + + auto print_remapped_iter_var = [&]() { + if (remap_vars_indices.size()) { + int m = remap_vars_indices.size(); + if (!m) { + return; + } + if (m == 1) { + print_single_iter_var(remap_vars_indices[0]); + remap_vars_indices.clear(); + return; + } + Array lhs; + Array loop_var_doc; + lhs.reserve(m); + loop_var_doc.reserve(m); + std::string binding_type = ""; + for (int i : remap_vars_indices) { + tir::IterVar iter_var = block->iter_vars[i]; + ObjectPath iter_var_p = block_p->Attr("iter_var")->ArrayIndex(i); + lhs.push_back(DefineVar(iter_var->var, *frame, d)); + loop_var_doc.push_back(d->AsDoc(realize->iter_values[i], + realize_p->Attr("iter_values")->ArrayIndex(i))); + binding_type += iter_var->iter_type == tir::IterVarType::kDataPar ? "S" : "R"; + } + ExprDoc rhs = TIR("axis")->Attr("remap"); + rhs = rhs->Call({LiteralDoc::Str(binding_type), ListDoc(loop_var_doc)}); + (*frame)->stmts.push_back(AssignDoc(TupleDoc(lhs), rhs, NullOpt)); + remap_vars_indices.clear(); + } + }; + + // Step 1.2. Construct all block var bindings + int n_vars = block->iter_vars.size(); + for (int i = 0; i < n_vars; ++i) { + if (!add_remapped_iter_var(i)) { + print_remapped_iter_var(); + print_single_iter_var(i); + } } + print_remapped_iter_var(); + // Step 2. Handle block predicate if (realize) { ICHECK(realize->predicate.defined() && realize->predicate->dtype.is_bool()); diff --git a/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py b/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py new file mode 100644 index 000000000000..1bccb8188c9d --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_syntax_sugar.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm.testing +from tvm.script.parser import tir as T +from tvm.script import script + + +def _test(obj, expected: str): + assert script(obj).strip() == expected.strip() + + +def test_remap(): + @T.prim_func + def block_with_remap_implicitly(): + for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): + with T.block("update"): + v0 = T.axis.spatial(128, i0 + 1) + v1 = T.axis.spatial(128, i1) + v2 = T.axis.reduce(128, i2) + v3 = T.axis.spatial(128, i3 - 1) + v4 = T.axis.reduce(128, i4) + v5 = T.axis.spatial(128, i5) + pass + + @T.prim_func + def block_with_remap_explicitly(): + for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): + with T.block("update"): + v0 = T.axis.spatial(128, i0 + 1) + v1, v2 = T.axis.remap("SR", [i1, i2]) + v3 = T.axis.spatial(128, i3 - 1) + v4, v5 = T.axis.remap("RS", [i4, i5]) + pass + + expected_output = """@T.prim_func +def main(): + with T.block("root"): + T.reads() + T.writes() + for i0, i1, i2, i3, i4, i5 in T.grid(128, 128, 128, 128, 128, 128): + with T.block("update"): + v0 = T.axis.spatial(128, i0 + 1) + v1, v2 = T.axis.remap("SR", [i1, i2]) + v3 = T.axis.spatial(128, i3 - 1) + v4, v5 = T.axis.remap("RS", [i4, i5]) + T.reads() + T.writes() + T.evaluate(0)""" + _test(block_with_remap_implicitly, expected_output) + _test(block_with_remap_explicitly, expected_output) + + +if __name__ == "__main__": + tvm.testing.main()