diff --git a/python/tvm/relay/backend/compile_engine.py b/python/tvm/relay/backend/compile_engine.py index 28f2ac6d489b..59ba53bf52ac 100644 --- a/python/tvm/relay/backend/compile_engine.py +++ b/python/tvm/relay/backend/compile_engine.py @@ -255,7 +255,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True) @tvm._ffi.register_func("relay.backend.lower_call") -def lower_call(call, inputs, target): +def lower_call(call, inputs, target, no_trace=False): """Lower the call expression to op implementation and tensor outputs.""" assert isinstance(call.op, tvm.ir.Op) op = call.op @@ -283,7 +283,7 @@ def lower_call(call, inputs, target): env = autotvm.task.TaskExtractEnv.current reenable_tracing = False if env is not None and env.tracing: - if env.wanted_relay_ops is not None and op not in env.wanted_relay_ops: + if (env.wanted_relay_ops is not None and op not in env.wanted_relay_ops) or no_trace: env.tracing = False reenable_tracing = True @@ -410,3 +410,7 @@ def get(): The compile engine. """ return _backend._CompileEngineGlobal() + + +def translate_to_te(prim_func, target): + return _backend._TranslateToTE(prim_func, target) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 1559d7edf35f..983c33ad056d 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -52,6 +52,7 @@ namespace tvm { namespace relay { TVM_REGISTER_NODE_TYPE(LoweredOutputNode); +TVM_REGISTER_NODE_TYPE(TEGraphNode); TVM_REGISTER_NODE_TYPE(CachedFuncNode); TVM_REGISTER_NODE_TYPE(CCacheKeyNode); TVM_REGISTER_NODE_TYPE(CCacheValueNode); @@ -94,26 +95,44 @@ Array GetShape(const Array& shape) { return res; } -// The getter to get schedule from compile engine. -// Get schedule from functor. -class ScheduleGetter : public backend::MemoizedExprTranslator> { +te::Tensor GetScalar(const ConstantNode* op) { + using tir::make_const; + ICHECK(op->is_scalar()); + void* data = op->data->data; + DataType dtype = DataType(op->data->dtype); + auto value = te::compute( + {}, + [&](const Array&) { + if (dtype == DataType::Int(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Int(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(32)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Float(64)) { + return make_const(dtype, static_cast(data)[0]); + } else if (dtype == DataType::Bool()) { + return make_const(dtype, static_cast(data)[0]); + } else { + LOG(FATAL) << "not handled"; + return tvm::PrimExpr(); + } + }, + "compile_engine_const", topi::kBroadcast); + return value; +} + +class TETranslator : public backend::MemoizedExprTranslator> { public: - explicit ScheduleGetter(Target target) - : target_(target), device_copy_op_(Op::Get("device_copy")) { - // Whether to use auto_scheduler schedule. - use_auto_scheduler_ = transform::PassContext::Current() - ->GetConfig("relay.backend.use_auto_scheduler", Bool(false)) - .value(); - } + explicit TETranslator(Target target) : target_(target), device_copy_op_(Op::Get("device_copy")) {} - CachedFunc Create(const Function& prim_func) { - auto cache_node = make_object(); - cache_node->target = target_; + TEGraph Translate(const Function& prim_func) { + auto graph_node = make_object(); for (Var param : prim_func->params) { Array inputs; if (const auto* ttype = param->checked_type().as()) { tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); + graph_node->inputs.push_back(tensor); inputs.push_back(tensor); } else { // flatten tuple of tensor type. @@ -123,14 +142,123 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> // TODO(@icemelon): Allow recursive tuple ICHECK(ttype != nullptr); tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); - cache_node->inputs.push_back(tensor); + graph_node->inputs.push_back(tensor); inputs.push_back(tensor); } } memo_[param] = inputs; } + graph_node->outputs = this->VisitExpr(prim_func->body); + return TEGraph(graph_node); + } + + Array VisitExpr_(const VarNode* op) final { + LOG(FATAL) << "Free variable " << op->name_hint(); + return {}; + } + + Array VisitExpr_(const ConstantNode* op) final { return {GetScalar(op)}; } + + Array VisitExpr_(const CallNode* call_node) final { + static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); + ICHECK(flower_call) << "relay.backend.lower_call is not registered."; + + Array inputs; + int count_tuple = 0; + for (Expr arg : call_node->args) { + if (arg->checked_type().as()) { + ++count_tuple; + } + for (te::Tensor tensor : VisitExpr(arg)) { + inputs.push_back(tensor); + } + } + if (count_tuple) { + ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; + } + + ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; + Op op = Downcast(call_node->op); + + Array outputs; + OpImplementation impl; + // Skip fcompute for device copy operators as it is not registered. + if (op == device_copy_op_) { + const auto* copy_input = inputs[0].operator->(); + outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); + } else { + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_, true); + outputs = lowered_out->outputs; + } + + if (outputs.size() != 1) { + const auto* tuple_type = call_node->checked_type().as(); + ICHECK(tuple_type) << "Expect output to be a tuple type"; + ICHECK_EQ(tuple_type->fields.size(), outputs.size()); + } + return outputs; + } + + Array VisitExpr_(const FunctionNode* op) final { + LOG(FATAL) << "Do not support sub function"; + return Array(); + } + + Array VisitExpr_(const LetNode* op) final { + Array val = VisitExpr(op->value); + ICHECK(!memo_.count(op->var)); + memo_[op->var] = val; + return VisitExpr(op->body); + } + + Array VisitExpr_(const TupleNode* op) final { + Array fields; + for (Expr field : op->fields) { + ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; + Array res = VisitExpr(field); + ICHECK_EQ(res.size(), 1); + fields.push_back(res[0]); + } + return fields; + } + + Array VisitExpr_(const TupleGetItemNode* op) final { + const auto* tuple_type = op->tuple->type_as(); + Array tuple = VisitExpr(op->tuple); + ICHECK_EQ(tuple_type->fields.size(), tuple.size()); + ICHECK_GE(op->index, 0); + ICHECK_LT(static_cast(op->index), tuple.size()); + return {tuple[op->index]}; + } + + private: + tvm::Target target_; + // Cache device copy op for equivalence checking to reduce registry lookup + // overhead for each invocation of call node when retrieving schedules. + const Op& device_copy_op_; +}; + +// The getter to get schedule from compile engine. +// Get schedule from functor. +class ScheduleGetter : public ExprVisitor { + public: + explicit ScheduleGetter(Target target) + : target_(target), device_copy_op_(Op::Get("device_copy")) { + // Whether to use auto_scheduler schedule. + use_auto_scheduler_ = transform::PassContext::Current() + ->GetConfig("relay.backend.use_auto_scheduler", Bool(false)) + .value(); + } + + CachedFunc Create(const Function& prim_func) { + auto translator = TETranslator(target_); + auto te_graph = translator.Translate(prim_func); + auto cache_node = make_object(); + cache_node->target = target_; + cache_node->inputs = te_graph->inputs; + cache_node->outputs = te_graph->outputs; readable_name_stream_ << "fused"; - cache_node->outputs = this->VisitExpr(prim_func->body); + this->VisitExpr(prim_func->body); auto candidate_name = readable_name_stream_.str(); constexpr static size_t kMaxFuncNameLength = 80; if (candidate_name.size() > kMaxFuncNameLength) { @@ -166,7 +294,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> } } - // Use TOPI schdule if user specificed, or the function has no auto_scheduler schedule. + // Use TOPI schedule if user specified, or the function has no auto_scheduler schedule. if (!schedule.defined()) { ICHECK(anchor_implementation_.defined()); schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_); @@ -181,72 +309,50 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> return CachedFunc(cache_node); } - Array VisitExpr_(const VarNode* op) final { - LOG(FATAL) << "Free variable " << op->name_hint(); - return {}; - } - - Array VisitExpr_(const ConstantNode* op) final { - using tir::make_const; - ICHECK(op->is_scalar()); - void* data = op->data->data; - DataType dtype = DataType(op->data->dtype); - auto value = te::compute( - {}, - [&](const Array&) { - if (dtype == DataType::Int(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Int(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(32)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Float(64)) { - return make_const(dtype, static_cast(data)[0]); - } else if (dtype == DataType::Bool()) { - return make_const(dtype, static_cast(data)[0]); - } else { - LOG(FATAL) << "not handled"; - return tvm::PrimExpr(); - } - }, - "compile_engine_const", topi::kBroadcast); + void VisitExpr_(const ConstantNode* op) final { + auto value = GetScalar(op); scalars_.push_back(value->op); - return {value}; } - Array VisitExpr_(const CallNode* call_node) final { + void VisitExpr_(const CallNode* call_node) final { static auto fpattern = Op::GetAttrMap("TOpPattern"); static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call"); ICHECK(flower_call) << "relay.backend.lower_call is not registered."; - Array inputs; int count_tuple = 0; + Array inputs; for (Expr arg : call_node->args) { - if (arg->checked_type().as()) { - ++count_tuple; - } - for (te::Tensor tensor : VisitExpr(arg)) { + VisitExpr(arg); + if (const auto* ttype = arg->checked_type().as()) { + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); inputs.push_back(tensor); + } else { + ICHECK_EQ(count_tuple, 0) << "Only allow function with a single tuple input"; + // flatten tuple of tensor type. + const auto* tuple_type = arg->type_as(); + for (Type field : tuple_type->fields) { + const auto* ttype = field.as(); + // TODO(@icemelon): Allow recursive tuple + ICHECK(ttype != nullptr); + tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype); + inputs.push_back(tensor); + ++count_tuple; + } } } - if (count_tuple) { - ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input"; - } ICHECK(call_node->op.as()) << "Primitive function only allows call into primitive ops"; Op op = Downcast(call_node->op); - Array outputs; - OpImplementation impl; - // Skip fcompute for device copy operators as it is not registered. if (op == device_copy_op_) { - const auto* copy_input = inputs[0].operator->(); - outputs.push_back(te::Tensor(copy_input->shape, copy_input->dtype, te::Operation(), 0)); - } else { - LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_); - outputs = lowered_out->outputs; - impl = lowered_out->implementation; + // Set the name to `__copy`. It will be detected in graph runtime to perform + // data copy across devices. + readable_name_stream_.str(std::string()); + readable_name_stream_ << "__copy"; + return; } + LoweredOutput lowered_out = (*flower_call)(GetRef(call_node), inputs, target_, false); + OpImplementation impl = lowered_out->implementation; int op_pattern = fpattern[op]; if (!use_auto_scheduler_ && op_pattern >= kCommReduce) { @@ -260,52 +366,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> anchor_op_pattern_ = op_pattern; anchor_implementation_ = impl; } - if (outputs.size() != 1) { - const auto* tuple_type = call_node->checked_type().as(); - ICHECK(tuple_type) << "Expect output to be a tuple type"; - ICHECK_EQ(tuple_type->fields.size(), outputs.size()); - } - // Set the name to `__copy`. It will be detected in graph runtime to perform - // data copy across devices. - if (op == device_copy_op_) { - readable_name_stream_.str(std::string()); - readable_name_stream_ << "__copy"; - } else { - readable_name_stream_ << '_' << op->name; - } - return outputs; - } - - Array VisitExpr_(const FunctionNode* op) final { - LOG(FATAL) << "Do not support sub function"; - return Array(); - } - - Array VisitExpr_(const LetNode* op) final { - Array val = VisitExpr(op->value); - ICHECK(!memo_.count(op->var)); - memo_[op->var] = val; - return VisitExpr(op->body); - } - - Array VisitExpr_(const TupleNode* op) final { - Array fields; - for (Expr field : op->fields) { - ICHECK(field->checked_type().as()) << "Only allow Tuple of Tensor"; - Array res = VisitExpr(field); - ICHECK_EQ(res.size(), 1); - fields.push_back(res[0]); - } - return fields; - } - - Array VisitExpr_(const TupleGetItemNode* op) final { - const auto* tuple_type = op->tuple->type_as(); - Array tuple = VisitExpr(op->tuple); - ICHECK_EQ(tuple_type->fields.size(), tuple.size()); - ICHECK_GE(op->index, 0); - ICHECK_LT(static_cast(op->index), tuple.size()); - return {tuple[op->index]}; + readable_name_stream_ << '_' << op->name; } private: @@ -836,6 +897,12 @@ CompileEngine& CompileEngine::Global() { TVM_REGISTER_PASS_CONFIG_OPTION("relay.backend.use_auto_scheduler", Bool); +TVM_REGISTER_GLOBAL("relay.backend._TranslateToTE") + .set_body_typed([](Function prim_func, Target target) { + auto translator = TETranslator(target); + return translator.Translate(prim_func); + }); + TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") .set_body_typed([](tvm::Array outputs, OpImplementation impl) { return LoweredOutput(outputs, impl); diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index 55822917b6b7..cfd2f9e9281a 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -69,6 +69,27 @@ class LoweredOutput : public ObjectRef { TVM_DEFINE_OBJECT_REF_METHODS(LoweredOutput, ObjectRef, LoweredOutputNode); }; +/*! \brief Node container to represent a Tensor Expression graph. */ +struct TEGraphNode : public Object { + /* \brief The inputs to the graph */ + tvm::Array inputs; + /* \brief The outputs to the graph */ + tvm::Array outputs; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("inputs", &inputs); + v->Visit("outputs", &outputs); + } + + static constexpr const char* _type_key = "relay.TEGraph"; + TVM_DECLARE_FINAL_OBJECT_INFO(TEGraphNode, Object); +}; + +class TEGraph : public ObjectRef { + public: + TVM_DEFINE_OBJECT_REF_METHODS(TEGraph, ObjectRef, TEGraphNode); +}; + /*! \brief Node container to represent a cached function. */ struct CachedFuncNode : public Object { /* \brief compiled target */ diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 3212f9079619..cc3dd853acf3 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -58,7 +58,7 @@ TVM_REGISTER_GLOBAL("test.strategy") TVM_REGISTER_GLOBAL("relay.backend.lower_call") .set_body_typed([](const relay::Call& call, const Array& inputs, - const Target& target) { + const Target& target, bool no_trace = false) { static auto fstrategy = Op::GetAttrMap("FTVMStrategy"); Op op = Downcast(call->op); auto out_type = call->checked_type(); diff --git a/tests/python/relay/test_backend_compile_engine.py b/tests/python/relay/test_backend_compile_engine.py index bf53dc5360e3..ba3d25e7d6db 100644 --- a/tests/python/relay/test_backend_compile_engine.py +++ b/tests/python/relay/test_backend_compile_engine.py @@ -259,6 +259,33 @@ def test_compile_nhwc_pack(): relay.build(mod, target="llvm") +def test_lower_to_te(): + data = relay.var("data", shape=(1, 1, 1, 1024), dtype="uint8") + weight = relay.var("weight", shape=(1, 1, 1024, 1001), dtype="int8") + p2 = relay.var("p2", shape=(1, 1, 1, 1), dtype="int32") + conv = relay.nn.conv2d( + data, + weight, + kernel_size=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + out_dtype="int32", + ) + multiply = relay.multiply(relay.const(-22, dtype="int32"), p2) + tile = relay.tile(multiply, reps=(1, 1, 1, 1001)) + subtract = relay.subtract(conv, tile) + + func = subtract + expr = relay.Function(relay.analysis.free_vars(func), func) + mod = tvm.IRModule.from_expr(expr) + mod = relay.transform.InferType()(mod) + lowered = relay.backend.compile_engine.translate_to_te(mod["main"], tvm.target.create("llvm")) + input_shapes = set() + for inp in lowered.inputs: + input_shapes.add(tuple([x.value for x in inp.shape])) + assert input_shapes == {(1, 1, 1, 1), (1, 1, 1024, 1001), (1, 1, 1, 1024)} + + if __name__ == "__main__": test_get_valid_implementations() test_select_implementation() @@ -268,3 +295,4 @@ def test_compile_nhwc_pack(): test_compile_tuple_dup() test_compile_full() test_compile_nhwc_pack() + test_lower_to_te()