Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 23 additions & 18 deletions src/relay/transforms/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,6 @@ namespace alter_op_layout {
class AlterTransformMemorizerNode : public TransformMemorizerNode {
public:
static constexpr const char* _type_key = "relay.alter_op_layout.AlterTransformMemorizerNode";
};

/*!
* \brief Container that provides the transformation function for alter layout..
*/
class AlterTransformMemorizer : public TransformMemorizer {
public:
AlterTransformMemorizer() {}
explicit AlterTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}

AlterTransformMemorizerNode* operator->() {
return static_cast<AlterTransformMemorizerNode*>(get_mutable());
}

/*!
* \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by
Expand Down Expand Up @@ -102,7 +89,23 @@ class AlterTransformMemorizer : public TransformMemorizer {
return GetRef<Call>(new_call);
}

using TransformMemorizer::CallWithNewLayouts;
Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
}
};

/*!
* \brief Container that provides the transformation function for alter layout..
*/
class AlterTransformMemorizer : public TransformMemorizer {
public:
AlterTransformMemorizer() = default;
explicit AlterTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}

AlterTransformMemorizerNode* operator->() {
return static_cast<AlterTransformMemorizerNode*>(get_mutable());
}

using ContainerType = AlterTransformMemorizerNode;
};

Expand All @@ -113,10 +116,12 @@ class AlterTransformMemorizer : public TransformMemorizer {
*/
Expr AlterOpLayout(const Expr& expr) {
// TODO(@icemelon9): need to rerun type inference after applying an alter op.
AlterTransformMemorizer alterMemorizer(make_object<AlterTransformMemorizerNode>());
auto fcontext = [&](const Call& call) -> ObjectRef { return alterMemorizer; };

return ForwardRewrite(expr, LayoutRewriter<AlterTransformMemorizer>, fcontext);
AlterTransformMemorizer alter_memorizer(make_object<AlterTransformMemorizerNode>());
std::function<ObjectRef(const Call&)> fcontext = [=](const Call& call) -> ObjectRef {
return alter_memorizer;
};
FForwardRewrite rewrite_func = LayoutRewriter<AlterTransformMemorizer>;
return ForwardRewrite(expr, rewrite_func, fcontext);
}

} // namespace alter_op_layout
Expand Down
39 changes: 21 additions & 18 deletions src/relay/transforms/convert_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,6 @@ class ConvertTransformMemorizerNode : public TransformMemorizerNode {
explicit ConvertTransformMemorizerNode(Map<String, Array<String>> desired_layouts)
: desired_layouts_(std::move(desired_layouts)) {}

/*! \brief A mapping of op_name to array of desired layouts for each input. */
Map<String, Array<String>> desired_layouts_;
};

/*!
* \brief Container that provides the transformation function for convert layout.
*/
class ConvertTransformMemorizer : public TransformMemorizer {
public:
ConvertTransformMemorizer() {}
explicit ConvertTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}

ConvertTransformMemorizerNode* operator->() {
return static_cast<ConvertTransformMemorizerNode*>(get_mutable());
}

/*!
* \brief Defines the call transformation for ConvertLayout pass. The new layouts should be the
* desired layout as specified by the user.
Expand All @@ -89,7 +73,7 @@ class ConvertTransformMemorizer : public TransformMemorizer {
Expr new_e;
bool modified = false;
if (fconvert_layout.count(op)) {
auto desired_layouts = operator->()->desired_layouts_;
auto desired_layouts = desired_layouts_;
if (desired_layouts.find(op->name) != desired_layouts.end()) {
tvm::Array<tvm::te::Tensor> tinfos;
for (auto& expr : ref_call->args) {
Expand Down Expand Up @@ -124,7 +108,26 @@ class ConvertTransformMemorizer : public TransformMemorizer {
return Call(new_call->op, new_call->args, new_call->attrs, new_call->type_args, ref_call->span);
}

using TransformMemorizer::CallWithNewLayouts;
Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override {
return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
}

/*! \brief A mapping of op_name to array of desired layouts for each input. */
Map<String, Array<String>> desired_layouts_;
};

/*!
* \brief Container that provides the transformation function for convert layout.
*/
class ConvertTransformMemorizer : public TransformMemorizer {
public:
ConvertTransformMemorizer() = default;
explicit ConvertTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {}

ConvertTransformMemorizerNode* operator->() {
return static_cast<ConvertTransformMemorizerNode*>(get_mutable());
}

using ContainerType = ConvertTransformMemorizerNode;
};

Expand Down
36 changes: 18 additions & 18 deletions src/relay/transforms/transform_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,21 @@ class TransformMemorizerNode : public Object {
}
};

/*!
* \brief Defines the call transformation for derived passes. The new layouts are defined by
* used for different targets using a packed func.
* \param ref_call The original call.
* \param new_attrs Updated attributes consistent with new layouts.
* \param new_args The traversed/recursed args to the call.
* \return The new Call after calling the packed func.
*/
virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs,
const std::vector<Expr>& new_args) = 0;

virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) {
return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
}

/*! \brief The memorizer map. */
std::unordered_map<TransformKey, Expr, key_hash> memo;

Expand All @@ -69,11 +84,9 @@ class TransformMemorizerNode : public Object {
*/
class TransformMemorizer : public ObjectRef {
public:
TransformMemorizer() {}
TransformMemorizer() = default;
explicit TransformMemorizer(ObjectPtr<Object> n) : ObjectRef(n) {}

virtual ~TransformMemorizer() {}

TransformMemorizerNode* operator->() {
return static_cast<TransformMemorizerNode*>(get_mutable());
}
Expand Down Expand Up @@ -146,19 +159,6 @@ class TransformMemorizer : public ObjectRef {
return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name());
}

/*!
* \brief Defines the call transformation for derived passes. The new layouts are defined by
* used for different targets using a packed func.
* \param ref_call The original call.
* \param new_attrs Updated attributes consistent with new layouts.
* \param new_args The traversed/recursed args to the call.
* \return The new Call after calling the packed func.
*/
virtual Call CallWithNewLayouts(const Call& ref_call, Attrs new_attrs,
const std::vector<Expr>& new_args) = 0;
virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) {
return CallWithNewLayouts(ref_call, ref_call->attrs, new_args);
}
using ContainerType = TransformMemorizerNode;
};

Expand Down Expand Up @@ -312,7 +312,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
if (ref_call->op.as<OpNode>()) {
Op op = Downcast<Op>(ref_call->op);
if (falter_layout.count(op) && !finfer_layout.count(op)) {
return memorizer.CallWithNewLayouts(ref_call, normal_new_args);
return memorizer->CallWithNewLayouts(ref_call, normal_new_args);
}
}
}
Expand Down Expand Up @@ -349,7 +349,7 @@ Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const Obj
}

// new_op = alter(op)
Call new_call = memorizer.CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args);
Call new_call = memorizer->CallWithNewLayouts(ref_call, infer_out->new_attrs, normal_new_args);

// new_in2, new_out = op.infer(new_in)
if (new_call->op->IsInstance<OpNode>()) {
Expand Down
30 changes: 0 additions & 30 deletions tutorials/dev/use_pass_infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,6 @@ def example():
return relay.Function([x, weight], z2)


###############################################################################
# Let us register layout alteration for a conv2d op so that we can apply the
# layout alteration pass on the example. How alter layout pass works is out
# the scope of this tutorial.


@relay.op.register_alter_op_layout("nn.conv2d", level=101)
def alter_conv2d(attrs, inputs, tinfos, out_type):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs["data_layout"] = "NCHW16c"
return relay.nn.conv2d(data, weight, **new_attrs)


###############################################################################
# Optimize the Program
# --------------------
Expand Down Expand Up @@ -188,21 +174,6 @@ def alter_conv2d(attrs, inputs, tinfos, out_type):
mod3 = seq(mod)
print(mod3)

###############################################################################
# The passes applied so far are target independent. The pass infra also
# provides a means to make pass target-aware. For example, the layout
# alteration pass falls in such category.

with tvm.transform.PassContext(opt_level=3):
mod4 = seq(mod)
print(mod4)

seq1 = tvm.transform.Sequential([relay.transform.AlterOpLayout()])
with tvm.transform.PassContext(opt_level=3):
with tvm.target.Target("llvm"):
mod5 = seq1(mod)
print(mod5)

##############################################################################
# Implement a Pass Using Python Decorator
# ------------------------------------------
Expand Down Expand Up @@ -257,7 +228,6 @@ def visit_constant(self, c):
tvm.transform.PrintIR(),
relay.transform.EliminateCommonSubexpr(),
relay.transform.FuseOps(),
relay.transform.AlterOpLayout(),
]
)

Expand Down