From 5629c272605a7b0b54812a67b0fa25895556ebc1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 14 Aug 2022 04:42:02 +0900 Subject: [PATCH 01/10] [TIR] Support AllocConstantNode in CreatePrimFunc --- .../tvm/meta_schedule/apply_history_best.h | 4 +- include/tvm/meta_schedule/extracted_task.h | 6 +- src/meta_schedule/extracted_task.cc | 27 +++++-- src/printer/tvmscript_printer.cc | 11 +++ src/relay/backend/task_extraction.cc | 16 +++- src/relay/backend/te_compiler_cache.cc | 26 ++++-- src/relay/backend/te_compiler_cache.h | 4 +- src/te/operation/create_primfunc.cc | 9 ++- src/te/operation/create_primfunc.h | 2 + src/tir/schedule/transform.cc | 6 +- src/tir/transforms/bind_params.cc | 79 +++++++++++-------- .../transforms/inject_software_pipeline.cc | 5 +- src/tir/transforms/ir_utils.h | 1 + .../plan_update_buffer_allocation_location.cc | 2 +- 14 files changed, 134 insertions(+), 64 deletions(-) diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h index 08c259ea1812..44a34b3ee496 100644 --- a/include/tvm/meta_schedule/apply_history_best.h +++ b/include/tvm/meta_schedule/apply_history_best.h @@ -40,8 +40,8 @@ namespace meta_schedule { class ApplyHistoryBestNode : public runtime::Object { public: /*! \brief A callback function that filters TE compute */ - using FTEFilterFunc = - runtime::TypedPackedFunc(const Array&)>; + using FTEFilterFunc = runtime::TypedPackedFunc( + const Array&, const Array&)>; /*! \brief A callback function that takes a tuning record and does something with it */ using FTakeTuningRecord = runtime::TypedPackedFunc; using FDirectDispatch = runtime::TypedPackedFunc(const IRModule&)>; diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index bed1428f8303..5260e5a52129 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -81,14 +81,16 @@ class ExtractedTask : public runtime::ObjectRef { * \param args The input/output arguments of the TE compute graph * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc */ -Optional DefaultTaskFilter(const Array& args); +Optional DefaultTaskFilter(const Array& args, + const Array& constants); /*! * \brief The default TE task filter, with `te.extern` allowed * \param args The input/output arguments of the TE compute graph * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc */ -Optional DefaultTaskFilterAllowExtern(const Array& args); +Optional DefaultTaskFilterAllowExtern(const Array& args, + const Array& constants); } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index 358f56efab2e..3406f82eb1f0 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -38,7 +38,9 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, data_ = n; } -Optional DefaultTaskFilterImpl(const Array& args, bool allow_extern_op) { +Optional DefaultTaskFilterImpl(const Array& args, + const Array& constants, + bool allow_extern_op) { using namespace ::tvm::te; std::vector stack; std::unordered_set visited; @@ -72,7 +74,7 @@ Optional DefaultTaskFilterImpl(const Array& args, boo return NullOpt; } } - PrimFunc func = te::CreatePrimFunc(args); + PrimFunc func = te::CreatePrimFuncWithConstants(args, constants); bool dynamic_loop_extent = false; PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void { if (const auto* loop = obj.as()) { @@ -87,12 +89,14 @@ Optional DefaultTaskFilterImpl(const Array& args, boo return func; } -Optional DefaultTaskFilter(const Array& args) { - return DefaultTaskFilterImpl(args, false); +Optional DefaultTaskFilter(const Array& args, + const Array& constants) { + return DefaultTaskFilterImpl(args, constants, false); } -Optional DefaultTaskFilterAllowExtern(const Array& args) { - return DefaultTaskFilterImpl(args, true); +Optional DefaultTaskFilterAllowExtern(const Array& args, + const Array& constants) { + return DefaultTaskFilterImpl(args, constants, true); } TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); @@ -101,8 +105,15 @@ TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") int weight) -> ExtractedTask { return ExtractedTask(task_name, mod, target, dispatched, weight); }); -TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilter").set_body_typed(DefaultTaskFilter); + +TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilter") + .set_body_typed([](const Array& args, const Array& constants) { + return DefaultTaskFilter(args, constants); + }); + TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilterAllowExtern") - .set_body_typed(DefaultTaskFilterAllowExtern); + .set_body_typed([](const Array& args, const Array& constants) { + return DefaultTaskFilterAllowExtern(args, constants); + }); } // namespace meta_schedule } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 6708922444b6..596655bfaf7b 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -423,6 +423,7 @@ class TVMScriptPrinter : public StmtFunctor, */ template void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) { + return; int ndim = arr->ndim; int tot_dim = 1; for (int i = 0; i < ndim; i++) { @@ -1124,6 +1125,16 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { } else { LOG(FATAL) << "DataType not supported"; } + } else if (alloc->dtype.is_uint()) { + if (alloc->dtype.bits() == 8) { + // NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 16) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 32) { + NDArrayToTIR(data, ss); + } else { + LOG(FATAL) << "DataType not supported"; + } } else if (alloc->dtype.is_float()) { if (alloc->dtype.bits() == 16) { NDArrayToTIR(data, ss); diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index c577e8e356d6..3b2a0212473b 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -33,7 +34,7 @@ namespace backend { Array ExtractTask( IRModule mod, Target target, Map params, - runtime::TypedPackedFunc(const Array&)> filter_func) { + meta_schedule::ApplyHistoryBestNode::FTEFilterFunc filter_func) { using meta_schedule::ExtractedTask; if (filter_func == nullptr) { filter_func = tvm::meta_schedule::DefaultTaskFilter; @@ -42,6 +43,14 @@ Array ExtractTask( // is_vm=true for backward compatibility Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); pass_seqs.push_back(transform::FuseOps()); + + if (auto link_params_obj = target->attrs.Get("link-params")) { + ICHECK(link_params_obj.as()); + int link_params = link_params_obj.as()->value; + auto ctx = transform::PassContext::Current(); + ctx->config.Set("relay.FuseOps.link_params", IntImm(DataType::Int(32), link_params)); + } + mod = transform::Sequential(pass_seqs)(std::move(mod)); std::vector tasks; @@ -59,10 +68,11 @@ Array ExtractTask( return; } Array inputs_outputs{nullptr}; + Array constants; std::string fused_name; - std::tie(inputs_outputs, fused_name) = + std::tie(inputs_outputs, constants, fused_name) = tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); - if (Optional prim_func = filter_func(inputs_outputs)) { + if (Optional prim_func = filter_func(inputs_outputs, constants)) { GlobalVar prim_fn_var(fused_name); IRModule relay_mod({{prim_fn_var, relay_func}}); IRModule tir_mod({{prim_fn_var, prim_func.value()}}); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index da52d94b4e46..8c2ce92e0061 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -50,7 +50,6 @@ #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/meta_schedule_layout_rewrite.h" -#include "../transforms/pass_utils.h" #include "utils.h" namespace tvm { @@ -362,8 +361,14 @@ class ScheduleBuilder : public ExprVisitor { } if (meta_schedule_ctx_) { Array te_args = Concat(fn_inputs, tensor_outs); + Array constants; + for (auto kv : lower_te_compute.constant_tensors_) { + te_args.push_back(kv.second); + constants.push_back(kv.first->data); + } + if (Optional tir_func = - meta_schedule_ctx_.value()->te_filter_func(te_args)) { + meta_schedule_ctx_.value()->te_filter_func(te_args, constants)) { IRModule relay_mod({{prim_fn_var, relay_func}}); IRModule tir_mod({{prim_fn_var, tir_func.value()}}); if (Optional opt_scheduled_mod = meta_schedule_ctx_.value()->Query( @@ -785,8 +790,8 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, return MakeShapeFunc().Create(prim_func, target, global_var_supply); } -std::pair, std::string> LowerTECompute(const Function& source_func, Target target, - bool return_inputs) { +std::tuple, Array, std::string> LowerTECompute( + const Function& source_func, Target target, bool return_inputs) { LowerToTECompute lower_te_compute(target); Array outputs = lower_te_compute.Lower(source_func); // Following ScheduleBuilder, remove placeholder ops from outputs. @@ -796,11 +801,18 @@ std::pair, std::string> LowerTECompute(const Function& source_ tensor_outs.push_back(tensor); } } + + tvm::Array constants; + for (auto kv : lower_te_compute.constant_tensors_) { + tensor_outs.push_back(kv.second); + constants.push_back(kv.first->data); + } + if (return_inputs) { - return std::make_pair(Concat(lower_te_compute.fn_inputs_, tensor_outs), - lower_te_compute.candidate_name_); + return std::make_tuple(Concat(lower_te_compute.fn_inputs_, tensor_outs), constants, + lower_te_compute.candidate_name_); } - return std::make_pair(tensor_outs, lower_te_compute.candidate_name_); + return std::make_tuple(tensor_outs, constants, lower_te_compute.candidate_name_); } TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 894a5f5be5f6..075ecae735d8 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -217,8 +217,8 @@ Array GetShape(const Array& shape); * \param return_inputs If true, prepend input tensors to the output array of tensors. * \return Pair of schedule and fused function name. */ -std::pair, std::string> LowerTECompute(const Function& source_func, Target target, - bool return_inputs = true); +std::tuple, Array, std::string> LowerTECompute( + const Function& source_func, Target target, bool return_inputs = true); /*! * \brief Create schedule for target. diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 55df71a8053e..a866769af373 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -27,6 +27,7 @@ #include #include "../../tir/ir/functor_common.h" +#include "../../tir/transforms/ir_utils.h" #include "../schedule/graph.h" namespace tvm { @@ -427,7 +428,7 @@ void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Arraynum_outputs(), 1); const te::Tensor& tensor = op.output(0); // Check op is in op list - ICHECK(info->IsArg(tensor)); + ICHECK(info->IsArg(tensor)) << tensor; // Declare a buffer for any argument tensors without a pre-existing // buffer declaration recorded in the tensor2buffer binds map if (info->tensor2buffers.count(tensor) == 0) { @@ -492,6 +493,12 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { return GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); } +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, + const Array& constants) { + PrimFunc func = CreatePrimFunc(arg_list); + return tir::BindParams(func, constants); +} + TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); } // namespace tir diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index c3cddd83f57a..2cfdde56ea2f 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -30,6 +30,8 @@ namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ PrimFunc CreatePrimFunc(const Array& arg_list); +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index a739373ab329..0bc0419775ea 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -220,7 +220,11 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } } if (const auto* block = sref->StmtAs()) { - if (const auto* seq = block->body.as()) { + auto body = block->body; + while (const auto* alloc = body.as()) { + body = alloc->body; + } + if (const auto* seq = body.as()) { ObjectPtr n = make_object(*block); n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); *src_stmt = GetRef(block); diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc index a5bc519e9a0e..0b71b2e8fa34 100644 --- a/src/tir/transforms/bind_params.cc +++ b/src/tir/transforms/bind_params.cc @@ -84,44 +84,57 @@ class ParamsCollector : public StmtExprVisitor { Map constant_map_; }; -namespace transform { +PrimFunc BindParams(PrimFunc f, const Array& constants) { + Map constant_map; -Pass BindParams(const Array& constants) { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - Map constant_map; - - // Remove constants from the primfunc signature - size_t num_constants = constants.size(); - size_t start = f->params.size() - num_constants; - Array params; - for (unsigned i = 0; i < start; i++) { - params.push_back(f->params[i]); - } + // Remove constants from the primfunc signature + size_t num_constants = constants.size(); + size_t start = f->params.size() - num_constants; + Array params; + for (unsigned i = 0; i < start; i++) { + params.push_back(f->params[i]); + } - auto* n = f.CopyOnWrite(); - for (unsigned i = start; i < f->params.size(); i++) { - tir::Var p = n->params[i]; - tir::Var b = n->buffer_map[p]->data; - n->buffer_map.erase(p); - constant_map.Set(b, constants[i - start]); + auto* n = f.CopyOnWrite(); + for (unsigned i = start; i < f->params.size(); i++) { + tir::Var p = n->params[i]; + tir::Var b = n->buffer_map[p]->data; + n->buffer_map.erase(p); + constant_map.Set(b, constants[i - start]); + } + n->params = params; + auto constant_list = ParamsCollector(constant_map).CollectParams(n->body); + + // Allocate constants within the primfunc + for (auto i : constant_list) { + auto var = GetRef(i); + int ndim = constant_map[var]->ndim; + Array extents; + + for (int i = 0; i < ndim; i++) { + int shape = constant_map[var]->shape[i]; + extents.push_back(make_const(DataType::Int(32), shape)); } - n->params = params; - auto constant_list = ParamsCollector(constant_map).CollectParams(n->body); - - // Allocate constants within the primfunc - for (auto i : constant_list) { - auto var = GetRef(i); - int ndim = constant_map[var]->ndim; - Array extents; - - for (int i = 0; i < ndim; i++) { - int shape = constant_map[var]->shape[i]; - extents.push_back(make_const(DataType::Int(32), shape)); - } - DataType dtype = DataType(constant_map[var]->dtype); + DataType dtype = DataType(constant_map[var]->dtype); + + if (n->body->IsInstance()) { + auto* block_realize = n->body.as(); + auto block = block_realize->block; + block.CopyOnWrite()->body = + tir::AllocateConst(var, dtype, extents, constant_map[var], block->body); + n->body = BlockRealize(block_realize->iter_values, block_realize->predicate, block); + } else { n->body = tir::AllocateConst(var, dtype, extents, constant_map[var], n->body); } - return f; + } + return f; +} + +namespace transform { + +Pass BindParams(const Array& constants) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return BindParams(f, constants); }; return CreatePrimFuncPass(pass_func, 0, "tir.BindParams", {}); } diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 227935bf72dd..40a03ef5b75e 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -709,10 +709,7 @@ class PipelineRewriter : public StmtExprMutator { arith::Analyzer* ana_normalized) const { std::vector new_blocks = blocks; std::vector commit_group_indices(new_blocks.size(), -1); - for (const auto& kv : async_states_local) { - const int stage_id = kv.first; - const AsyncStateLocal& state = kv.second; - + for (const auto& [stage_id, state]: async_states_local) { if (!state.commit_groups.empty()) { for (size_t i = 0; i < state.commit_groups.size(); ++i) { for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index d89ee3619699..6dc09a170b06 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -311,6 +311,7 @@ std::unordered_map GetTensorCoreFragmentInfo(const // attr::async_wait_queue_scope annotation. std::pair GetAsyncWaitAttributes(const AttrStmtNode* op); +PrimFunc BindParams(PrimFunc f, const Array& constants); } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 81dfceb40d32..3224aa781619 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -112,7 +112,7 @@ class BufferAllocationLocator : public StmtExprMutator { } ObjectPtr n = CopyOnWrite(op); - n->alloc_buffers = std::move(alloc_buffers); + // n->alloc_buffers = std::move(alloc_buffers); // Erase buffer allocated inside the block from access region. n->reads = RemoveRedundantBufferRegion(n->reads); n->writes = RemoveRedundantBufferRegion(n->writes); From 8d2f700b6de5e6a510006251acf373875238281b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 14 Aug 2022 05:37:03 +0900 Subject: [PATCH 02/10] Handle AllocConstantNode in LeafBlockRemovalPlan --- src/tir/schedule/transform.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 0bc0419775ea..36a6fa092235 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -221,12 +221,21 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } if (const auto* block = sref->StmtAs()) { auto body = block->body; + std::vector allocs; while (const auto* alloc = body.as()) { + allocs.push_back(alloc); body = alloc->body; } if (const auto* seq = body.as()) { ObjectPtr n = make_object(*block); - n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + auto new_seq = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + auto new_body = new_seq; + for (int i = 0; i < allocs.size(); ++i) { + auto alloc = allocs[allocs.size() - 1 - i]; + new_body = AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, + new_body, alloc->annotations, alloc->span); + } + n->body = new_body; *src_stmt = GetRef(block); *tgt_stmt = Stmt(std::move(n)); return; From b8bdb5fb2a24c9d2e53799dfea92e37f87a7825e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Aug 2022 15:49:27 +0900 Subject: [PATCH 03/10] Properly handle AllocConstNode in BufferAllocationLocator --- .../plan_update_buffer_allocation_location.cc | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 3224aa781619..db59824bf1ce 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -31,10 +31,31 @@ namespace tvm { namespace tir { +class CollectUnmanagedAllocations : public StmtExprVisitor { + public: + void VisitStmt_(const AllocateNode* op) final { + unmanaged_allocations.insert(op->buffer_var.get()); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AllocateConstNode* op) final { + unmanaged_allocations.insert(op->buffer_var.get()); + StmtExprVisitor::VisitStmt_(op); + } + + /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved by + * BufferAllocationLocator. */ + std::unordered_set unmanaged_allocations; +}; + class BufferAllocationLocator : public StmtExprMutator { public: explicit BufferAllocationLocator(const PrimFunc& func) { Map> buffer_lca = DetectBufferAccessLCA(func); + CollectUnmanagedAllocations collector; + collector(func->body); + unmanaged_allocations_ = collector.unmanaged_allocations; + std::unordered_set arg_buffers; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; @@ -48,7 +69,10 @@ class BufferAllocationLocator : public StmtExprMutator { if (arg_buffers.count(buffer.get())) { continue; } - alloc_buffers_[stmt].push_back(buffer); + if (!unmanaged_allocations_.count(buffer->data.get())) { + alloc_buffers_[stmt].push_back(buffer); + } + buffer_data_to_buffer_.Set(buffer->data, buffer); } } @@ -112,23 +136,13 @@ class BufferAllocationLocator : public StmtExprMutator { } ObjectPtr n = CopyOnWrite(op); - // n->alloc_buffers = std::move(alloc_buffers); + n->alloc_buffers = std::move(alloc_buffers); // Erase buffer allocated inside the block from access region. n->reads = RemoveRedundantBufferRegion(n->reads); n->writes = RemoveRedundantBufferRegion(n->writes); return Stmt(n); } - Stmt VisitStmt_(const AllocateNode* op) final { - unmanaged_allocations_.insert(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const AllocateConstNode* op) final { - unmanaged_allocations_.insert(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - Stmt VisitStmt_(const BufferRealizeNode* op) final { ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR."; throw; From b2b96c882b2e17832c99954807363962c836b093 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 15 Aug 2022 04:16:00 -0700 Subject: [PATCH 04/10] handle AllocateConst in EstimateFlops --- src/tir/analysis/estimate_flops.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index 44a58f8792c4..576476ae30aa 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -169,6 +169,7 @@ class FlopEstimator : private ExprFunctor, TResult VisitExpr_(const IntImmNode* op) override { return TResult(); } TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } + TResult VisitStmt_(const AllocateConstNode* op) override { return VisitStmt(op->body); } TResult VisitStmt_(const SeqStmtNode* seq) override { TResult result; From 1c6893ad0b86e1c6f8c3f8db5c78725085256263 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 18 Aug 2022 18:34:30 +0900 Subject: [PATCH 05/10] remove NDArray printing --- src/printer/tvmscript_printer.cc | 57 +------------------ src/te/operation/create_primfunc.cc | 2 +- src/te/operation/create_primfunc.h | 3 +- .../transforms/inject_software_pipeline.cc | 2 +- 4 files changed, 5 insertions(+), 59 deletions(-) diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 596655bfaf7b..303ad7032cc9 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -416,27 +416,6 @@ class TVMScriptPrinter : public StmtFunctor, } }; -/*! - * \brief special method to print NDArray in TIR - * \param arr the NDArray to be printed - * \param os the output stream where the NDArray will be printed to - */ -template -void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) { - return; - int ndim = arr->ndim; - int tot_dim = 1; - for (int i = 0; i < ndim; i++) { - tot_dim *= arr->shape[i]; - } - T* data_ptr = reinterpret_cast(arr->data); - os << "["; - for (int i = 0; i < tot_dim; i++) { - os << (i != 0 ? ", " : "") << data_ptr[i]; - } - os << "]"; -} - Doc TVMScriptPrinter::GetUniqueName(std::string prefix) { std::replace(prefix.begin(), prefix.end(), '.', '_'); std::string unique_prefix = prefix; @@ -1113,41 +1092,7 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { std::stringstream ss; ICHECK(alloc->data) << "Should be presented"; - const auto& data = alloc->data.value(); - - if (alloc->dtype.is_int()) { - if (alloc->dtype.bits() == 8) { - NDArrayToTIR(data, ss); - } else if (alloc->dtype.bits() == 16) { - NDArrayToTIR(data, ss); - } else if (alloc->dtype.bits() == 32) { - NDArrayToTIR(data, ss); - } else { - LOG(FATAL) << "DataType not supported"; - } - } else if (alloc->dtype.is_uint()) { - if (alloc->dtype.bits() == 8) { - // NDArrayToTIR(data, ss); - } else if (alloc->dtype.bits() == 16) { - NDArrayToTIR(data, ss); - } else if (alloc->dtype.bits() == 32) { - NDArrayToTIR(data, ss); - } else { - LOG(FATAL) << "DataType not supported"; - } - } else if (alloc->dtype.is_float()) { - if (alloc->dtype.bits() == 16) { - NDArrayToTIR(data, ss); - } else if (alloc->dtype.bits() == 32) { - NDArrayToTIR(data, ss); - } else if (alloc->dtype.bits() == 64) { - NDArrayToTIR(data, ss); - } else { - LOG(FATAL) << "DataType not supported"; - } - } else { - LOG(FATAL) << "DataType not supported"; - } + ss << "..."; auto ndarray_str = ss.str(); auto usage = FindAllocateUsage(alloc, &buffer_var_usage_); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index a866769af373..0d1a8318e823 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -428,7 +428,7 @@ void RewriteStageToBlock(const te::Operation& op, CreateFuncInfo* info, Arraynum_outputs(), 1); const te::Tensor& tensor = op.output(0); // Check op is in op list - ICHECK(info->IsArg(tensor)) << tensor; + ICHECK(info->IsArg(tensor)); // Declare a buffer for any argument tensors without a pre-existing // buffer declaration recorded in the tensor2buffer binds map if (info->tensor2buffers.count(tensor) == 0) { diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index 2cfdde56ea2f..62ca2627ecb8 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -30,7 +30,8 @@ namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ PrimFunc CreatePrimFunc(const Array& arg_list); -PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants); +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, + const Array& constants); } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 40a03ef5b75e..ab1489d5ad4a 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -709,7 +709,7 @@ class PipelineRewriter : public StmtExprMutator { arith::Analyzer* ana_normalized) const { std::vector new_blocks = blocks; std::vector commit_group_indices(new_blocks.size(), -1); - for (const auto& [stage_id, state]: async_states_local) { + for (const auto& [stage_id, state] : async_states_local) { if (!state.commit_groups.empty()) { for (size_t i = 0; i < state.commit_groups.size(); ++i) { for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { From 508567d212e3121444366c5a61175842dfe5d0e4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 18 Aug 2022 19:58:32 +0900 Subject: [PATCH 06/10] doc update --- include/tvm/meta_schedule/extracted_task.h | 4 ++++ python/tvm/meta_schedule/apply_history_best.py | 2 +- python/tvm/meta_schedule/relay_integration.py | 2 +- python/tvm/meta_schedule/testing/utils.py | 2 +- src/relay/backend/task_extraction.cc | 5 +---- src/relay/backend/te_compiler_cache.cc | 12 ++++++------ src/relay/backend/te_compiler_cache.h | 2 +- src/te/operation/create_primfunc.cc | 2 ++ src/te/operation/create_primfunc.h | 5 +++++ src/tir/schedule/transform.cc | 4 +++- src/tir/transforms/ir_utils.h | 9 +++++++++ 11 files changed, 34 insertions(+), 15 deletions(-) diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index 5260e5a52129..bce40e6b95f0 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -79,6 +79,8 @@ class ExtractedTask : public runtime::ObjectRef { /*! * \brief The default TE task filter * \param args The input/output arguments of the TE compute graph + * \param constants Raw data for constant tensors in args. If the size of this array is N, the last + * N tensors in args will be treated as constant tensors. * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc */ Optional DefaultTaskFilter(const Array& args, @@ -87,6 +89,8 @@ Optional DefaultTaskFilter(const Array DefaultTaskFilterAllowExtern(const Array& args, diff --git a/python/tvm/meta_schedule/apply_history_best.py b/python/tvm/meta_schedule/apply_history_best.py index 43a6ffe37620..a7b9b20bf244 100644 --- a/python/tvm/meta_schedule/apply_history_best.py +++ b/python/tvm/meta_schedule/apply_history_best.py @@ -40,7 +40,7 @@ class ApplyHistoryBest(Object): ---------- database : Database The database to be queried from - te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None + te_filter_func : Union[str, None, Callable[[List[Tensor], List[NDArray]], PrimFunc]] = None The filtering function for TE computation If it's a string, it's the name of the filtering function. Built in functions are - "meta_schedule.DefaultTaskFilter" diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index bd12ac350a61..d3b3ea796532 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -56,7 +56,7 @@ def extract_task_from_relay( The pass config of the compiler disabled_pass : Optional[List[str]] The list of disabled passes of the compiler - te_filter_func : Callable[[List[tvm.te.Tensor]], bool] + te_filter_func : Callable[[List[tvm.te.Tensor], List[NDArray]], bool] The filter function to filter out the extracted tasks If it's a string, it's the name of the filtering function. Built in functions are - "meta_schedule.DefaultTaskFilter" diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index 0d011b726473..8fd3211f09c5 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -45,7 +45,7 @@ def apply_fixed_schedules( schedule_fn : Callable[[ExtractedTask, Schedule], bool] A callable that is applied for each extracted task and the corresponding default schedule. Returns True if the given schedule should be committed to the database, False otherwise. - te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None + te_filter_func : Union[str, None, Callable[[List[Tensor], List[NDArray]], PrimFunc]] = None The filtering function for TE computation If it's a string, it's the name of the filtering function. Built in functions are - "meta_schedule.DefaultTaskFilter" diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index 3b2a0212473b..e4b58c3a896f 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -67,10 +67,7 @@ Array ExtractTask( it->second->weight += 1; return; } - Array inputs_outputs{nullptr}; - Array constants; - std::string fused_name; - std::tie(inputs_outputs, constants, fused_name) = + auto [inputs_outputs, constants, fused_name] = tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); if (Optional prim_func = filter_func(inputs_outputs, constants)) { GlobalVar prim_fn_var(fused_name); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index 8c2ce92e0061..92cc6f8cfa46 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -362,9 +362,9 @@ class ScheduleBuilder : public ExprVisitor { if (meta_schedule_ctx_) { Array te_args = Concat(fn_inputs, tensor_outs); Array constants; - for (auto kv : lower_te_compute.constant_tensors_) { - te_args.push_back(kv.second); - constants.push_back(kv.first->data); + for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) { + te_args.push_back(te_tensor); + constants.push_back(const_node->data); } if (Optional tir_func = @@ -803,9 +803,9 @@ std::tuple, Array, std::string> LowerTECompu } tvm::Array constants; - for (auto kv : lower_te_compute.constant_tensors_) { - tensor_outs.push_back(kv.second); - constants.push_back(kv.first->data); + for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) { + tensor_outs.push_back(te_tensor); + constants.push_back(const_node->data); } if (return_inputs) { diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 075ecae735d8..57e813054c4d 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -215,7 +215,7 @@ Array GetShape(const Array& shape); * \param source_func The primitive function to be lowered. * \param target The target we want to create schedule for. * \param return_inputs If true, prepend input tensors to the output array of tensors. - * \return Pair of schedule and fused function name. + * \return Tuple of the lowered TE compute, constant raw data, and fused function name. */ std::tuple, Array, std::string> LowerTECompute( const Function& source_func, Target target, bool return_inputs = true); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 0d1a8318e823..a89a28564c0f 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -17,6 +17,8 @@ * under the License. */ +#include "create_primfunc.h" + #include #include #include diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index 62ca2627ecb8..b68d30a2fb82 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -30,6 +30,11 @@ namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ PrimFunc CreatePrimFunc(const Array& arg_list); +/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the + * constants array is N, the last N tensors in arg_list will be treated as constant tensors. + * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants + * will be embedded in the body as AllocateConstNode. + */ PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, const Array& constants); diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index 36a6fa092235..1c21d770db30 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -221,6 +221,7 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } if (const auto* block = sref->StmtAs()) { auto body = block->body; + // Peel off AllocateConst nodes at the beginning of the block body. std::vector allocs; while (const auto* alloc = body.as()) { allocs.push_back(alloc); @@ -229,8 +230,9 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ if (const auto* seq = body.as()) { ObjectPtr n = make_object(*block); auto new_seq = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + // Re-attach AllocateConst nodes auto new_body = new_seq; - for (int i = 0; i < allocs.size(); ++i) { + for (int i = 0; i < static_cast(allocs.size()); ++i) { auto alloc = allocs[allocs.size() - 1 - i]; new_body = AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, new_body, alloc->annotations, alloc->span); diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 6dc09a170b06..a54eebe4ed05 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -311,7 +311,16 @@ std::unordered_map GetTensorCoreFragmentInfo(const // attr::async_wait_queue_scope annotation. std::pair GetAsyncWaitAttributes(const AttrStmtNode* op); +/*! + * \brief Bind a subset of parameter tensors to constants, replacing them by AllocateConst nodes. + * \param f The function to bind constants to. + * \param constants Raw constant data. If the size of this array is N, the last N parameter tensors + * will be removed from the signature and instead AllocateConst nodes will be introduced in the + * function body. + * \return The updated function. + */ PrimFunc BindParams(PrimFunc f, const Array& constants); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ From 1af70e033eb2e5e54b8515963d2bd381f6f45a8f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 18 Aug 2022 20:36:13 +0900 Subject: [PATCH 07/10] add test --- tests/python/unittest/test_link_params.py | 60 ++++++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index 80c2fbaeb416..beb8cf04bed9 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -19,16 +19,19 @@ import json import os import re -import sys +from io import StringIO +from contextlib import redirect_stderr import numpy as np -import pytest import tvm import tvm.relay import tvm.testing +from tvm import meta_schedule as ms +from tvm import relay from tvm.relay.backend import Executor, Runtime from tvm.contrib import utils +from tvm.meta_schedule.testing.utils import apply_fixed_schedules INPUT_SHAPE = (1, 3, 16, 16) @@ -382,5 +385,58 @@ def _run_unlinked(lib): np.testing.assert_allclose(unlinked_output, linked_output) +def test_tir_link_params(): + def get_dense(data_shape, weight_shape): + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + dense = relay.nn.dense(data, weight) + return relay.Function([data, weight], dense) + + def get_ref_dense(data_np, weight_np): + return np.dot(data_np, np.transpose(weight_np)) + + def schedule_dense(sch): + dense = sch.get_block("T_matmul_NT") + _y, _x, _k = sch.get_loops(dense) + + M, N, K = 128, 128, 128 + data_shape = (M, K) + weight_shape = (N, K) + relay_mod = tvm.IRModule.from_expr(get_dense(data_shape, weight_shape)) + relay_mod = relay.transform.InferType()(relay_mod) + data_np = np.random.randn(*data_shape).astype("float32") + weight_np = np.random.randn(*weight_shape).astype("float32") + target = "llvm --link-params=1" + params = {"weight": weight_np} + + def schedule_fn(task, sch): + if "nn_dense" in task.task_name: + schedule_dense(sch) + return True + return False + + database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) + + with StringIO() as stderr_buf, redirect_stderr(stderr_buf): + with ms.ApplyHistoryBest(database): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + lib = relay.build(relay_mod, target=target) + + # Workload look up should succeed + assert not "Cannot find workload" in stderr_buf.getvalue() + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + runtime.set_input(**params) + runtime.set_input("data", data_np) + runtime.run() + out = runtime.get_output(0).numpy() + ref = get_ref_dense(data_np, weight_np) + tvm.testing.assert_allclose(out, ref, atol=1e-4, rtol=1e-4) + + if __name__ == "__main__": tvm.testing.main() From 7d39cae708fe5bfa02f2fb82e5f2b97f8690062a Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 18 Aug 2022 21:32:39 +0900 Subject: [PATCH 08/10] cpplint --- src/relay/backend/te_compiler_cache.h | 1 + src/te/operation/create_primfunc.cc | 3 +++ 2 files changed, 4 insertions(+) diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 57e813054c4d..95c5bc974181 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -37,6 +37,7 @@ #include #include +#include #include #include diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index a89a28564c0f..4c1358f42519 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -26,7 +26,10 @@ #include #include +#include +#include #include +#include #include "../../tir/ir/functor_common.h" #include "../../tir/transforms/ir_utils.h" From 487aed93bda8383a4a51a5fd45f62e36ba419e78 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Aug 2022 10:28:25 +0900 Subject: [PATCH 09/10] Removed dependency on link-params attribute from target --- python/tvm/meta_schedule/testing/utils.py | 7 ++++++- src/relay/backend/task_extraction.cc | 7 ------- tests/python/unittest/test_link_params.py | 12 ++++++++---- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index 8fd3211f09c5..7698d50991ed 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -16,7 +16,7 @@ # under the License. """Testing utility functions in meta schedule""" from typing import Callable, Dict, Optional, Union -from tvm.ir import IRModule +from tvm.ir import IRModule, transform from tvm.relay import Function as RelayFunc from tvm.runtime import NDArray from tvm.target import Target @@ -59,11 +59,16 @@ def apply_fixed_schedules( The database containing dummy tuning records for manually scheduled traces. """ target = Target(target) if isinstance(target, str) else target + config = {"relay.backend.use_meta_schedule": True} + for k, v in transform.PassContext.current().config.items(): + config[k] = v + extracted_tasks = ms.extract_task_from_relay( relay_mod, target, params, te_filter_func=te_filter_func, + pass_config=config ) database = ms.database.MemoryDatabase() for task in extracted_tasks: diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index e4b58c3a896f..4f83b6eeed60 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -44,13 +44,6 @@ Array ExtractTask( Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); pass_seqs.push_back(transform::FuseOps()); - if (auto link_params_obj = target->attrs.Get("link-params")) { - ICHECK(link_params_obj.as()); - int link_params = link_params_obj.as()->value; - auto ctx = transform::PassContext::Current(); - ctx->config.Set("relay.FuseOps.link_params", IntImm(DataType::Int(32), link_params)); - } - mod = transform::Sequential(pass_seqs)(std::move(mod)); std::vector tasks; diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index beb8cf04bed9..394e95f38f71 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -406,7 +406,7 @@ def schedule_dense(sch): relay_mod = relay.transform.InferType()(relay_mod) data_np = np.random.randn(*data_shape).astype("float32") weight_np = np.random.randn(*weight_shape).astype("float32") - target = "llvm --link-params=1" + target = "llvm" params = {"weight": weight_np} def schedule_fn(task, sch): @@ -415,7 +415,10 @@ def schedule_fn(task, sch): return True return False - database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) + link_params = True + + with tvm.transform.PassContext(config={"relay.FuseOps.link_params": link_params}): + database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) with StringIO() as stderr_buf, redirect_stderr(stderr_buf): with ms.ApplyHistoryBest(database): @@ -423,9 +426,10 @@ def schedule_fn(task, sch): opt_level=3, config={"relay.backend.use_meta_schedule": True}, ): - lib = relay.build(relay_mod, target=target) + executor = Executor("graph", {"link-params":link_params}) + lib = relay.build(relay_mod, target=target, executor=executor) - # Workload look up should succeed + # Workload look up should succeed. This does not work when the test is invoked from pytest. assert not "Cannot find workload" in stderr_buf.getvalue() dev = tvm.device(target, 0) From 4af2e1f7fec1d7df485cf156906cce4e07c65555 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 19 Aug 2022 10:38:56 +0900 Subject: [PATCH 10/10] Restored NDArray printing to unbreak test --- python/tvm/meta_schedule/testing/utils.py | 6 +-- src/printer/tvmscript_printer.cc | 65 ++++++++++++++++++++++- tests/python/unittest/test_link_params.py | 2 +- 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index 7698d50991ed..dda492008ffe 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -64,11 +64,7 @@ def apply_fixed_schedules( config[k] = v extracted_tasks = ms.extract_task_from_relay( - relay_mod, - target, - params, - te_filter_func=te_filter_func, - pass_config=config + relay_mod, target, params, te_filter_func=te_filter_func, pass_config=config ) database = ms.database.MemoryDatabase() for task in extracted_tasks: diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 303ad7032cc9..7649b6101919 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -416,6 +416,31 @@ class TVMScriptPrinter : public StmtFunctor, } }; +/*! + * \brief special method to print NDArray in TIR + * \param arr the NDArray to be printed + * \param os the output stream where the NDArray will be printed to + */ +template +void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) { + int ndim = arr->ndim; + int tot_dim = 1; + for (int i = 0; i < ndim; i++) { + tot_dim *= arr->shape[i]; + } + T* data_ptr = reinterpret_cast(arr->data); + constexpr int NUM_PRINT = 20; + os << "["; + for (int i = 0; i < tot_dim; i++) { + os << (i != 0 ? ", " : "") << data_ptr[i]; + if (i == NUM_PRINT) { + os << "..."; + break; + } + } + os << "]"; +} + Doc TVMScriptPrinter::GetUniqueName(std::string prefix) { std::replace(prefix.begin(), prefix.end(), '.', '_'); std::string unique_prefix = prefix; @@ -1092,7 +1117,45 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateNode* op) { Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { std::stringstream ss; ICHECK(alloc->data) << "Should be presented"; - ss << "..."; + const auto& data = alloc->data.value(); + + if (alloc->dtype.is_int()) { + if (alloc->dtype.bits() == 8) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 16) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 32) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 64) { + NDArrayToTIR(data, ss); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else if (alloc->dtype.is_uint()) { + if (alloc->dtype.bits() == 8) { + // NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 16) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 32) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 64) { + NDArrayToTIR(data, ss); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else if (alloc->dtype.is_float()) { + if (alloc->dtype.bits() == 16) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 32) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 64) { + NDArrayToTIR(data, ss); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else { + LOG(FATAL) << "DataType not supported"; + } auto ndarray_str = ss.str(); auto usage = FindAllocateUsage(alloc, &buffer_var_usage_); diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index 394e95f38f71..8e299dc935d5 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -426,7 +426,7 @@ def schedule_fn(task, sch): opt_level=3, config={"relay.backend.use_meta_schedule": True}, ): - executor = Executor("graph", {"link-params":link_params}) + executor = Executor("graph", {"link-params": link_params}) lib = relay.build(relay_mod, target=target, executor=executor) # Workload look up should succeed. This does not work when the test is invoked from pytest.