From b9bcc619029bd38014c31b1f25ecab2893d71143 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Mon, 13 Dec 2021 15:47:17 +0800 Subject: [PATCH 1/8] add ceildiv and shapediv --- include/tvm/tir/op.h | 27 +++++++++++++++++++++++++++ src/tir/op/op.cc | 11 +++++++++++ 2 files changed, 38 insertions(+) diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h index 9cf7d0a3cd1f..565267540447 100644 --- a/include/tvm/tir/op.h +++ b/include/tvm/tir/op.h @@ -495,6 +495,22 @@ TVM_DLL PrimExpr truncmod(PrimExpr a, PrimExpr b, Span span = Span()); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span = Span()); +/*! + * \brief compute ceil(a / b) where a and b are non-negative. + * + * Use this function for shape split calculation. + * + * This function might take advantage of the fact + * that a and b are non-negative. + * + * \param a left operand + * \param b right operand + * \param span The location of this operation in the source. + * \return The result expression. + * \note this function does eager constant folding for + * shape types(int32, int64) when possible. + */ +TVM_DLL PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span = Span()); /*! * \brief compute the remainder floor(a / b) where a and b are non-negative. * @@ -521,6 +537,17 @@ TVM_DLL PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span = Span()); * index types(int32, int64) when possible. */ TVM_DLL PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span = Span()); +/*! + * \brief compute ceil(a / b) + * + * \param a left operand + * \param b right operand + * \param span The location of this operation in the source. + * \return The result expression. + * \note this function does eager constant folding for + * index types(int32, int64) when possible. + */ +TVM_DLL PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span = Span()); /*! * \brief compute the remainder of floordiv * diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index d08bef2ab91a..eadf608fa530 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -369,6 +369,8 @@ PrimExpr operator%(PrimExpr a, PrimExpr b) { return truncmod(a, b); } // TODO(tqchen): switch to floordiv PrimExpr indexdiv(PrimExpr a, PrimExpr b, Span span) { return floordiv(a, b, span); } +PrimExpr shapediv(PrimExpr a, PrimExpr b, Span span) { return ceildiv(a, b, span); } + PrimExpr indexmod(PrimExpr a, PrimExpr b, Span span) { return floormod(a, b, span); } PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) { @@ -380,6 +382,15 @@ PrimExpr floordiv(PrimExpr a, PrimExpr b, Span span) { return tir::FloorDiv(a, b, span); } +PrimExpr ceildiv(PrimExpr a, PrimExpr b, Span span) { + ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; + ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; + BinaryOpMatchTypes(a, b, span); + PrimExpr ret = arith::TryConstFold(a + b - 1, b); + if (ret.defined()) return ret; + return tir::FloorDiv(a + b - 1, b, span); +} + PrimExpr floormod(PrimExpr a, PrimExpr b, Span span) { ICHECK(a.dtype().is_int() || a.dtype().is_uint()) << a; ICHECK(b.dtype().is_int() || b.dtype().is_uint()) << b; From de43055dd9ab095e6f376d09525607063bc1cf9b Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Mon, 13 Dec 2021 16:15:11 +0800 Subject: [PATCH 2/8] add boundary checking in layout_transform --- include/tvm/topi/transform.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 1ad9d7da72ba..675b055a589f 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -1608,7 +1608,11 @@ inline Tensor layout_transform(const Tensor& src, const std::string& src_layout, [&](const Array& dst_indices) { Array dst_indices_expr(dst_indices.begin(), dst_indices.end()); Array src_indices = layout_converter.BackwardIndex(dst_indices_expr); - return src(src_indices); + PrimExpr in_range = PrimExpr(1) > PrimExpr(0); // init with dtype=bool and value=true + for (size_t i = 0; i < src.ndim(); ++i) { + in_range = in_range && (src_indices[i] < src->shape[i]); + } + return if_then_else(in_range, src(src_indices), tvm::cast(src->dtype, PrimExpr(0))); }, name, tag); } From 488f51e60755df72bc546fe3a88992c157230dcf Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Mon, 13 Dec 2021 16:20:08 +0800 Subject: [PATCH 3/8] support multi-blocking and shape padding --- include/tvm/tir/data_layout.h | 14 ++++-- src/tir/ir/data_layout.cc | 82 ++++++++++++++++++++++++++--------- 2 files changed, 71 insertions(+), 25 deletions(-) diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index 73da05c549e2..81c3e98e663d 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -295,9 +295,13 @@ class BijectiveLayoutNode : public Object { /*! \brief Describes how source axes can be mapped to the destination axes, * e.g., [i0 / 16, i1, i0 % 16] can describe NC -> NC16n */ - Array forward_rule; + Array index_forward_rule; /*! \brief Describes how destination axes can be mapped to the source axes */ - Array backward_rule; + Array index_backward_rule; + /*! \brief Describes how source shapes can be mapped to the destination shapes */ + Array shape_forward_rule; + /*! \brief Describes how destination shapes can be mapped to the source shapes */ + Array shape_backward_rule; /*! \brief The source layout */ Layout src_layout; @@ -307,8 +311,10 @@ class BijectiveLayoutNode : public Object { void VisitAttrs(AttrVisitor* v) { v->Visit("src_layout", &src_layout); v->Visit("dst_layout", &dst_layout); - v->Visit("forward_rule", &forward_rule); - v->Visit("backward_rule", &backward_rule); + v->Visit("index_forward_rule", &index_forward_rule); + v->Visit("index_backward_rule", &index_backward_rule); + v->Visit("shape_forward_rule", &shape_forward_rule); + v->Visit("shape_backward_rule", &shape_backward_rule); } static constexpr const char* _type_key = "tir.BijectiveLayout"; diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 65be0900f9e3..cc54d0c2208b 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -131,7 +131,8 @@ Layout::Layout(const std::string& name) { // NOLINT(*) ICHECK_EQ(axis_str.size(), 1); char axis = axis_str[0]; ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z')); - ICHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis; + // skip this check to support multi-blocking layout + // ICHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis; exist_axis[axis] = true; } for (const IterVar& v : node->axes) { @@ -182,15 +183,20 @@ Layout Layout::Split(const LayoutAxis& axis, size_t target_pos, int32_t factor) int32_t Layout::FactorOf(const LayoutAxis& axis) const { if (!defined()) return -1; const LayoutAxis& sub = axis.ToSubordinate(); - if (!this->defined()) return -1; + + int32_t factor = 1; + bool has_sub = false; for (const IterVar& itvar : operator->()->axes) { if (sub == LayoutAxis::Get(itvar)) { - const auto* factor = itvar->dom->extent.as(); - ICHECK(factor); - return factor->value; + has_sub = true; + int32_t val = itvar->dom->extent.as()->value; + ICHECK(val); + factor *= val; } } - return -1; + factor = has_sub ? factor : -1; + + return factor; } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) @@ -199,8 +205,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "Layout(" << l->name << ")"; }); -inline bool GetStoreRule(Array* rule, const Layout& src_layout, - const Layout& dst_layout) { +inline bool GetStoreRule(Array* index_rule, Array* shape_rule, + const Layout& src_layout, const Layout& dst_layout) { if (!src_layout.defined() || src_layout.name().empty() || !dst_layout.defined() || dst_layout.name().empty()) { return false; @@ -208,7 +214,7 @@ inline bool GetStoreRule(Array* rule, const Layout& src_layout, for (size_t i = 0; i < dst_layout.ndim(); ++i) { const auto& store_axis = dst_layout[i]; const IterVar& store_axis_impl = dst_layout->axes[i]; - PrimExpr store(0); + PrimExpr index_store(0); for (size_t j = 0; j < src_layout.ndim(); ++j) { const auto& orig_axis = src_layout[j]; @@ -220,28 +226,54 @@ inline bool GetStoreRule(Array* rule, const Layout& src_layout, if (factor > 0) { orig_var = orig_var * factor; } - store = store + orig_var; + index_store = index_store + orig_var; } else { - store = store + orig_axis_impl->var; + PrimExpr factor(1); + for (size_t k = j + 1; k < src_layout.ndim(); ++k) { + if (LayoutAxis::Get(orig_axis_impl) == LayoutAxis::Get(src_layout->axes[k])) { + factor = factor * src_layout->axes[k]->dom->extent; + } + } + index_store = index_store + orig_axis_impl->var * factor; } } } - if (tir::is_zero(store)) { + if (tir::is_zero(index_store)) { // Not convertible return false; } + PrimExpr shape_store = index_store; if (store_axis.IsPrimal()) { const int32_t factor = dst_layout.FactorOf(store_axis); if (factor > 0) { - store = indexdiv(store, PrimExpr(factor)); + shape_store = shapediv(index_store, PrimExpr(factor)); + index_store = indexdiv(index_store, PrimExpr(factor)); } } else { - store = indexmod(store, store_axis_impl->dom->extent); + PrimExpr stride(1); + PrimExpr factor(1); + for (size_t j = i; j < dst_layout.ndim(); ++j) { + if (LayoutAxis::Get(store_axis_impl) == LayoutAxis::Get(dst_layout->axes[j])) { + stride = stride * dst_layout->axes[j]->dom->extent; + if (j > i) { + factor = factor * dst_layout->axes[j]->dom->extent; + } + } + } + shape_store = indexdiv(indexmod(index_store, stride), factor); + index_store = indexdiv(indexmod(index_store, stride), factor); } - rule->push_back(store); + index_rule->push_back(index_store); + shape_rule->push_back(shape_store); } + + VLOG(1) << "index rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ "; + for (const auto& r : *index_rule) { VLOG(1) << r << ", "; }; VLOG(1) << "]" << std::endl; + VLOG(1) << "shape rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ "; + for (const auto& r : *shape_rule) { VLOG(1) << r << ", "; }; VLOG(1) << "]" << std::endl; + return true; } @@ -265,7 +297,7 @@ Array BijectiveLayout::ForwardIndex(const Array& src_index) const BijectiveLayoutNode* self = operator->(); ICHECK_EQ(src_index.size(), self->src_layout->axes.size()) << "Input mismatch with layout " << self->src_layout; - return TransformIndex(src_index, self->src_layout->axes, self->forward_rule); + return TransformIndex(src_index, self->src_layout->axes, self->index_forward_rule); } Array BijectiveLayout::BackwardIndex(const Array& dst_index) const { @@ -273,7 +305,7 @@ Array BijectiveLayout::BackwardIndex(const Array& dst_index) const BijectiveLayoutNode* self = operator->(); ICHECK_EQ(dst_index.size(), self->dst_layout->axes.size()) << "Output mismatch with layout " << self->dst_layout; - return TransformIndex(dst_index, self->dst_layout->axes, self->backward_rule); + return TransformIndex(dst_index, self->dst_layout->axes, self->index_backward_rule); } inline Array TransformShape(const Array& src_shape, @@ -331,19 +363,27 @@ inline Array TransformShape(const Array& src_shape, } } } + + VLOG(1) << "shape rule for " << Layout(src_axis).name() << "-->" << Layout(target_axis).name() << ": [ "; + for (const auto& r : transform_rule) { VLOG(1) << r << ", "; }; VLOG(1) << "]" << std::endl; + + VLOG(1) << "shape transform: [ "; + for (const auto& s : src_shape) { VLOG(1) << s << ", "; }; VLOG(1) << "] --> [ "; + for (const auto& r : result) { VLOG(1) << r << ", "; }; VLOG(1) << "]" << std::endl; + return result; } Array BijectiveLayout::ForwardShape(const Array& shape) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->forward_rule); + return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->shape_forward_rule); } Array BijectiveLayout::BackwardShape(const Array& shape) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->backward_rule); + return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->shape_backward_rule); } BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { @@ -354,8 +394,8 @@ BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { // To be consistent with previous behavior, a nullptr layout is created // when argument is invalid. - if (GetStoreRule(&n->forward_rule, n->src_layout, n->dst_layout)) { - ICHECK(GetStoreRule(&n->backward_rule, n->dst_layout, n->src_layout)); + if (GetStoreRule(&n->index_forward_rule, &n->shape_forward_rule, n->src_layout, n->dst_layout)) { + ICHECK(GetStoreRule(&n->index_backward_rule, &n->shape_backward_rule, n->dst_layout, n->src_layout)); data_ = std::move(n); } } From 5eba789bc23d1f3794137f4457b31ef6ca987015 Mon Sep 17 00:00:00 2001 From: Youlei Yang Date: Mon, 20 Dec 2021 10:46:35 +0800 Subject: [PATCH 4/8] refine the log for shape transform --- src/tir/ir/data_layout.cc | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index cc54d0c2208b..3d17877195bd 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -269,10 +269,13 @@ inline bool GetStoreRule(Array* index_rule, Array* shape_rul shape_rule->push_back(shape_store); } - VLOG(1) << "index rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ "; - for (const auto& r : *index_rule) { VLOG(1) << r << ", "; }; VLOG(1) << "]" << std::endl; - VLOG(1) << "shape rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ "; - for (const auto& r : *shape_rule) { VLOG(1) << r << ", "; }; VLOG(1) << "]" << std::endl; + std::stringstream ss; + ss << "index rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ "; + for (const auto& r : *index_rule) { ss << r << ", "; }; ss << "]" << std::endl; + + ss << "shape rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ "; + for (const auto& r : *shape_rule) { ss << r << ", "; }; ss << "]" << std::endl; + VLOG(1) << std::endl << ss.str(); return true; } @@ -364,12 +367,14 @@ inline Array TransformShape(const Array& src_shape, } } - VLOG(1) << "shape rule for " << Layout(src_axis).name() << "-->" << Layout(target_axis).name() << ": [ "; - for (const auto& r : transform_rule) { VLOG(1) << r << ", "; }; VLOG(1) << "]" << std::endl; + std::stringstream ss; + ss << "shape rule for " << Layout(src_axis).name() << "-->" << Layout(target_axis).name() << ": [ "; + for (const auto& r : transform_rule) { ss << r << ", "; }; ss << "]" << std::endl; - VLOG(1) << "shape transform: [ "; - for (const auto& s : src_shape) { VLOG(1) << s << ", "; }; VLOG(1) << "] --> [ "; - for (const auto& r : result) { VLOG(1) << r << ", "; }; VLOG(1) << "]" << std::endl; + ss << "shape transform: [ "; + for (const auto& s : src_shape) { ss << s << ", "; }; ss << "] --> [ "; + for (const auto& r : result) { ss << r << ", "; }; ss << "]" << std::endl; + VLOG(1) << std::endl << ss.str(); return result; } From 3c7895baac9053a28450e4727aae935c16428c70 Mon Sep 17 00:00:00 2001 From: yangulei Date: Thu, 13 Jan 2022 15:28:48 +0800 Subject: [PATCH 5/8] add test for multi-blocking layout transform --- .../python/relay/test_pass_alter_op_layout.py | 44 +++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index ea7fe0bd7871..1971600ae66b 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -162,6 +162,50 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_alter_layout_multi(): + """Test alternating the layout of a conv2d. + The layout of broadcast operators and the weight should be changed accordingly. + """ + + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight") + y = relay.nn.conv2d(x, weight, channels=128, kernel_size=(3, 3), padding=(1, 1)) + y = relay.Function(analysis.free_vars(y), y) + return y + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs["data_layout"] = "NCHW16c" + new_attrs["kernel_layout"] = "OHWI16i64o2i" + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var("weight", shape=(128, 64, 3, 3)) + + y = relay.layout_transform(x, "NCHW", "NCHW16c") + w = relay.layout_transform(weight, "OIHW", "OHWI16i64o2i") + y = relay.nn.conv2d( + y, + w, + channels=128, + kernel_size=(3, 3), + padding=(1, 1), + kernel_layout="OHWI16i64o2i", + data_layout="NCHW16c", + ) + y = relay.layout_transform(y, "NCHW16c", "NCHW") + y = relay.Function(analysis.free_vars(y), y) + return y + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = before() + a = run_opt_pass(a, [transform.CanonicalizeOps(), transform.AlterOpLayout()]) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) def test_alter_layout_lrn(): """Test alternating the layout of a conv2d. From 9c274e9cbc8add203fc0b41ebe8b5c7b2cf52e8e Mon Sep 17 00:00:00 2001 From: yangulei Date: Thu, 20 Jan 2022 16:14:47 +0800 Subject: [PATCH 6/8] delete unwanted comments --- src/tir/ir/data_layout.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index 3d17877195bd..ae9cf5c987de 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -131,8 +131,6 @@ Layout::Layout(const std::string& name) { // NOLINT(*) ICHECK_EQ(axis_str.size(), 1); char axis = axis_str[0]; ICHECK((axis >= 'a' && axis <= 'z') || (axis >= 'A' && axis <= 'Z')); - // skip this check to support multi-blocking layout - // ICHECK(!exist_axis[axis]) << "Invalid layout " << name << ": duplicate axis " << axis; exist_axis[axis] = true; } for (const IterVar& v : node->axes) { From 1c7ef996d9d9c3fe72d717eb847d6dc25c44366d Mon Sep 17 00:00:00 2001 From: yangulei Date: Mon, 24 Jan 2022 11:02:34 +0800 Subject: [PATCH 7/8] remove workaround --- python/tvm/topi/x86/conv2d_alter_op.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/python/tvm/topi/x86/conv2d_alter_op.py b/python/tvm/topi/x86/conv2d_alter_op.py index 3f2df655a615..5d2698d12ca7 100644 --- a/python/tvm/topi/x86/conv2d_alter_op.py +++ b/python/tvm/topi/x86/conv2d_alter_op.py @@ -164,25 +164,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape) out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape) ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1] - n_elems = 4 - - # convert kernel data layout from 4D to 7D - data_expr, kernel_expr = inputs - kernel_IHWO = relay.transpose(kernel_expr, axes=(1, 2, 3, 0)) - kernel_IHWOo = relay.reshape(kernel_IHWO, (in_channel, kh, kw, out_channel // oc_bn, oc_bn)) - kernel_OHWoI = relay.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0)) - kernel_OHWoIi = relay.reshape( - kernel_OHWoI, (out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn) - ) - kernel_OHWoIie = relay.reshape( - kernel_OHWoIi, - (out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn // n_elems, n_elems), - ) - kernel_OIHWioe = relay.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5, 3, 6)) # update new attrs + n_elems = 4 new_attrs["channels"] = out_channel new_attrs["data_layout"] = "NCHW%dc" % ic_bn + new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn // n_elems, oc_bn, n_elems) new_attrs["out_layout"] = "NCHW%dc" % oc_bn # Store altered operator's config. @@ -208,7 +195,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): ) dispatch_ctx.update(target, new_workload, cfg) - return relay.nn.contrib_conv2d_nchwc(data_expr, kernel_OIHWioe, **new_attrs) + return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs) if topi_tmpl == "depthwise_conv2d_NCHWc.x86": if data_layout == "NCHW" and kernel_layout == "OIHW": From aab406bf634146f470025f1af529b8ffb930e858 Mon Sep 17 00:00:00 2001 From: yangulei Date: Thu, 17 Feb 2022 17:02:51 +0800 Subject: [PATCH 8/8] fix lint errors --- src/tir/ir/data_layout.cc | 37 ++++++++++++++----- .../python/relay/test_pass_alter_op_layout.py | 2 + 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/src/tir/ir/data_layout.cc b/src/tir/ir/data_layout.cc index ae9cf5c987de..070cd7077d18 100644 --- a/src/tir/ir/data_layout.cc +++ b/src/tir/ir/data_layout.cc @@ -269,10 +269,16 @@ inline bool GetStoreRule(Array* index_rule, Array* shape_rul std::stringstream ss; ss << "index rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ "; - for (const auto& r : *index_rule) { ss << r << ", "; }; ss << "]" << std::endl; + for (const auto& r : *index_rule) { + ss << r << ", "; + } + ss << "]" << std::endl; ss << "shape rule for " << src_layout.name() << "-->" << dst_layout.name() << ": [ "; - for (const auto& r : *shape_rule) { ss << r << ", "; }; ss << "]" << std::endl; + for (const auto& r : *shape_rule) { + ss << r << ", "; + } + ss << "]" << std::endl; VLOG(1) << std::endl << ss.str(); return true; @@ -366,12 +372,22 @@ inline Array TransformShape(const Array& src_shape, } std::stringstream ss; - ss << "shape rule for " << Layout(src_axis).name() << "-->" << Layout(target_axis).name() << ": [ "; - for (const auto& r : transform_rule) { ss << r << ", "; }; ss << "]" << std::endl; + ss << "shape rule for " << Layout(src_axis).name() << "-->" << Layout(target_axis).name() + << ": [ "; + for (const auto& r : transform_rule) { + ss << r << ", "; + } + ss << "]" << std::endl; ss << "shape transform: [ "; - for (const auto& s : src_shape) { ss << s << ", "; }; ss << "] --> [ "; - for (const auto& r : result) { ss << r << ", "; }; ss << "]" << std::endl; + for (const auto& s : src_shape) { + ss << s << ", "; + } + ss << "] --> [ "; + for (const auto& r : result) { + ss << r << ", "; + } + ss << "]" << std::endl; VLOG(1) << std::endl << ss.str(); return result; @@ -380,13 +396,15 @@ inline Array TransformShape(const Array& src_shape, Array BijectiveLayout::ForwardShape(const Array& shape) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, self->shape_forward_rule); + return TransformShape(shape, self->src_layout->axes, self->dst_layout->axes, + self->shape_forward_rule); } Array BijectiveLayout::BackwardShape(const Array& shape) const { ICHECK(defined()) << "Cannot operate on an undefined bijective layout."; const BijectiveLayoutNode* self = operator->(); - return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, self->shape_backward_rule); + return TransformShape(shape, self->dst_layout->axes, self->src_layout->axes, + self->shape_backward_rule); } BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { @@ -398,7 +416,8 @@ BijectiveLayout::BijectiveLayout(Layout src_layout, Layout dst_layout) { // To be consistent with previous behavior, a nullptr layout is created // when argument is invalid. if (GetStoreRule(&n->index_forward_rule, &n->shape_forward_rule, n->src_layout, n->dst_layout)) { - ICHECK(GetStoreRule(&n->index_backward_rule, &n->shape_backward_rule, n->dst_layout, n->src_layout)); + ICHECK(GetStoreRule(&n->index_backward_rule, &n->shape_backward_rule, n->dst_layout, + n->src_layout)); data_ = std::move(n); } } diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 1971600ae66b..4df3d1943871 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -162,6 +162,7 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + def test_alter_layout_multi(): """Test alternating the layout of a conv2d. The layout of broadcast operators and the weight should be changed accordingly. @@ -207,6 +208,7 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + def test_alter_layout_lrn(): """Test alternating the layout of a conv2d. The layout of broadcast operators and the weight should be changed accordingly.