diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 4907a0bf2bd4..33a46cc6e6af 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -268,6 +268,18 @@ def FuseOps(fuse_opt_level=-1): return _ffi_api.FuseOps(fuse_opt_level) +def DefuseOps(): + """The inverse operation of FuseOps. It transforms a fused program returned by FuseOps into the + program before FuseOps. (i.e., x == DefuseOps(FuseOps(x))) + + Returns + ------- + ret : tvm.transform.Pass + The registered pass for operator defusion. + """ + return _ffi_api.DefuseOps() + + def CombineParallelConv2D(min_num_branches=3): """Combine multiple conv2d operators into one. diff --git a/src/relay/transforms/defuse_ops.cc b/src/relay/transforms/defuse_ops.cc new file mode 100644 index 000000000000..6abf4c31d359 --- /dev/null +++ b/src/relay/transforms/defuse_ops.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file src/relay/transforms/defuse_ops.cc + * \brief This is an inverse operation of fusion pass. It transforms a fused + * program returned by relay::transform::FuseOps into the program before FuseOps. + * (i.e., x == DefuseOps(FuseOps(x))) + */ + +#include +#include +#include + +#include +#include + +#include "pattern_utils.h" + +namespace tvm { +namespace relay { + +class DefuseOpsMutator : public ExprMutator { + public: + class FuncBodyMutator : public ExprMutator { + public: + explicit FuncBodyMutator(const Array& args) : ExprMutator() { args_ = args; } + + Expr VisitExpr_(const VarNode* n) { + const std::string& name = n->name_hint(); + ICHECK(!name.empty() && (name[0] == 'p')); + std::string id_str = name.substr(1); + int id = std::stoi(id_str); + ICHECK(id >= 0 && size_t(id) < args_.size()); + return args_[id]; + } + + private: + Array args_; + }; + + Expr VisitExpr_(const CallNode* n) { + auto new_n = ExprMutator::VisitExpr_(n); + + if (const auto* call = new_n.as()) { + if (const auto* func = call->op.as()) { + if (func->body->IsInstance()) { + return FuncBodyMutator(call->args).Mutate(func->body); + } + } + } + return new_n; + } +}; + +Expr DefuseOps(const Expr& expr) { return DefuseOpsMutator().Mutate(expr); } + +namespace transform { + +Pass DefuseOps() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { return Downcast(DefuseOps(f)); }; + return CreateFunctionPass(pass_func, 3, "DefuseOps", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.DefuseOps").set_body_typed(DefuseOps); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_defuse_ops.py b/tests/python/relay/test_pass_defuse_ops.py new file mode 100644 index 000000000000..2312b2d9ec47 --- /dev/null +++ b/tests/python/relay/test_pass_defuse_ops.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.testing import run_opt_pass + + +def test_defuse_simple(): + """Simple testcase.""" + + def before(): + x = relay.var("x", shape=(10, 20)) + y = relay.add(x, relay.const(1, "float32")) + z = relay.exp(y) + w = relay.squeeze(z) + return relay.Function([x], w) + + x = before() + x = run_opt_pass(x, transform.InferType()) + fused = run_opt_pass(x, transform.FuseOps()) + defused = run_opt_pass(fused, transform.DefuseOps()) + + assert tvm.ir.structural_equal(x, defused) + + +def test_inception_like(): + def conv(data): + y = relay.nn.conv2d(data, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=16) + return relay.nn.relu(data=y) + + def inception_like(data): + c0 = conv(data) + c1 = conv(data) + return relay.concatenate((c0, c1), axis=1) + + def before(dshape): + x = relay.var("x", shape=dshape) + in1 = inception_like(x) + in2 = inception_like(in1) + return relay.Function(relay.analysis.free_vars(in2), in2) + + dshape = (1, 16, 64, 64) + x = before(dshape) + x = run_opt_pass(x, transform.InferType()) + fused = run_opt_pass(x, transform.FuseOps()) + defused = run_opt_pass(fused, transform.DefuseOps()) + + assert tvm.ir.structural_equal(x, defused) + + +if __name__ == "__main__": + test_defuse_simple() + test_inception_like() diff --git a/tests/python/relay/test_pass_fuse_ops.py b/tests/python/relay/test_pass_fuse_ops.py index ff282df7c832..a3146de55d5a 100644 --- a/tests/python/relay/test_pass_fuse_ops.py +++ b/tests/python/relay/test_pass_fuse_ops.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te from tvm import relay from tvm.relay import transform from tvm.relay.testing import run_opt_pass @@ -44,7 +43,6 @@ def expected(): return relay.Function([x], y) z = before() - zz = run_opt_pass(z, transform.FuseOps(fuse_opt_level=2)) zz = run_opt_pass(z, transform.FuseOps()) after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(zz, after)