diff --git a/include/tvm/meta_schedule/apply_history_best.h b/include/tvm/meta_schedule/apply_history_best.h index 08c259ea1812..44a34b3ee496 100644 --- a/include/tvm/meta_schedule/apply_history_best.h +++ b/include/tvm/meta_schedule/apply_history_best.h @@ -40,8 +40,8 @@ namespace meta_schedule { class ApplyHistoryBestNode : public runtime::Object { public: /*! \brief A callback function that filters TE compute */ - using FTEFilterFunc = - runtime::TypedPackedFunc(const Array&)>; + using FTEFilterFunc = runtime::TypedPackedFunc( + const Array&, const Array&)>; /*! \brief A callback function that takes a tuning record and does something with it */ using FTakeTuningRecord = runtime::TypedPackedFunc; using FDirectDispatch = runtime::TypedPackedFunc(const IRModule&)>; diff --git a/include/tvm/meta_schedule/extracted_task.h b/include/tvm/meta_schedule/extracted_task.h index bed1428f8303..bce40e6b95f0 100644 --- a/include/tvm/meta_schedule/extracted_task.h +++ b/include/tvm/meta_schedule/extracted_task.h @@ -79,16 +79,22 @@ class ExtractedTask : public runtime::ObjectRef { /*! * \brief The default TE task filter * \param args The input/output arguments of the TE compute graph + * \param constants Raw data for constant tensors in args. If the size of this array is N, the last + * N tensors in args will be treated as constant tensors. * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc */ -Optional DefaultTaskFilter(const Array& args); +Optional DefaultTaskFilter(const Array& args, + const Array& constants); /*! * \brief The default TE task filter, with `te.extern` allowed * \param args The input/output arguments of the TE compute graph + * \param constants Raw data for constant tensors in args. If the size of this array is N, the last + * N tensors in args will be treated as constant tensors. * \return NullOpt if the task is filtered out, otherwise the task in PrimFunc */ -Optional DefaultTaskFilterAllowExtern(const Array& args); +Optional DefaultTaskFilterAllowExtern(const Array& args, + const Array& constants); } // namespace meta_schedule } // namespace tvm diff --git a/python/tvm/meta_schedule/apply_history_best.py b/python/tvm/meta_schedule/apply_history_best.py index 43a6ffe37620..a7b9b20bf244 100644 --- a/python/tvm/meta_schedule/apply_history_best.py +++ b/python/tvm/meta_schedule/apply_history_best.py @@ -40,7 +40,7 @@ class ApplyHistoryBest(Object): ---------- database : Database The database to be queried from - te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None + te_filter_func : Union[str, None, Callable[[List[Tensor], List[NDArray]], PrimFunc]] = None The filtering function for TE computation If it's a string, it's the name of the filtering function. Built in functions are - "meta_schedule.DefaultTaskFilter" diff --git a/python/tvm/meta_schedule/relay_integration.py b/python/tvm/meta_schedule/relay_integration.py index bd12ac350a61..d3b3ea796532 100644 --- a/python/tvm/meta_schedule/relay_integration.py +++ b/python/tvm/meta_schedule/relay_integration.py @@ -56,7 +56,7 @@ def extract_task_from_relay( The pass config of the compiler disabled_pass : Optional[List[str]] The list of disabled passes of the compiler - te_filter_func : Callable[[List[tvm.te.Tensor]], bool] + te_filter_func : Callable[[List[tvm.te.Tensor], List[NDArray]], bool] The filter function to filter out the extracted tasks If it's a string, it's the name of the filtering function. Built in functions are - "meta_schedule.DefaultTaskFilter" diff --git a/python/tvm/meta_schedule/testing/utils.py b/python/tvm/meta_schedule/testing/utils.py index 0d011b726473..dda492008ffe 100644 --- a/python/tvm/meta_schedule/testing/utils.py +++ b/python/tvm/meta_schedule/testing/utils.py @@ -16,7 +16,7 @@ # under the License. """Testing utility functions in meta schedule""" from typing import Callable, Dict, Optional, Union -from tvm.ir import IRModule +from tvm.ir import IRModule, transform from tvm.relay import Function as RelayFunc from tvm.runtime import NDArray from tvm.target import Target @@ -45,7 +45,7 @@ def apply_fixed_schedules( schedule_fn : Callable[[ExtractedTask, Schedule], bool] A callable that is applied for each extracted task and the corresponding default schedule. Returns True if the given schedule should be committed to the database, False otherwise. - te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None + te_filter_func : Union[str, None, Callable[[List[Tensor], List[NDArray]], PrimFunc]] = None The filtering function for TE computation If it's a string, it's the name of the filtering function. Built in functions are - "meta_schedule.DefaultTaskFilter" @@ -59,11 +59,12 @@ def apply_fixed_schedules( The database containing dummy tuning records for manually scheduled traces. """ target = Target(target) if isinstance(target, str) else target + config = {"relay.backend.use_meta_schedule": True} + for k, v in transform.PassContext.current().config.items(): + config[k] = v + extracted_tasks = ms.extract_task_from_relay( - relay_mod, - target, - params, - te_filter_func=te_filter_func, + relay_mod, target, params, te_filter_func=te_filter_func, pass_config=config ) database = ms.database.MemoryDatabase() for task in extracted_tasks: diff --git a/src/meta_schedule/extracted_task.cc b/src/meta_schedule/extracted_task.cc index 358f56efab2e..3406f82eb1f0 100644 --- a/src/meta_schedule/extracted_task.cc +++ b/src/meta_schedule/extracted_task.cc @@ -38,7 +38,9 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target, data_ = n; } -Optional DefaultTaskFilterImpl(const Array& args, bool allow_extern_op) { +Optional DefaultTaskFilterImpl(const Array& args, + const Array& constants, + bool allow_extern_op) { using namespace ::tvm::te; std::vector stack; std::unordered_set visited; @@ -72,7 +74,7 @@ Optional DefaultTaskFilterImpl(const Array& args, boo return NullOpt; } } - PrimFunc func = te::CreatePrimFunc(args); + PrimFunc func = te::CreatePrimFuncWithConstants(args, constants); bool dynamic_loop_extent = false; PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void { if (const auto* loop = obj.as()) { @@ -87,12 +89,14 @@ Optional DefaultTaskFilterImpl(const Array& args, boo return func; } -Optional DefaultTaskFilter(const Array& args) { - return DefaultTaskFilterImpl(args, false); +Optional DefaultTaskFilter(const Array& args, + const Array& constants) { + return DefaultTaskFilterImpl(args, constants, false); } -Optional DefaultTaskFilterAllowExtern(const Array& args) { - return DefaultTaskFilterImpl(args, true); +Optional DefaultTaskFilterAllowExtern(const Array& args, + const Array& constants) { + return DefaultTaskFilterImpl(args, constants, true); } TVM_REGISTER_NODE_TYPE(ExtractedTaskNode); @@ -101,8 +105,15 @@ TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask") int weight) -> ExtractedTask { return ExtractedTask(task_name, mod, target, dispatched, weight); }); -TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilter").set_body_typed(DefaultTaskFilter); + +TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilter") + .set_body_typed([](const Array& args, const Array& constants) { + return DefaultTaskFilter(args, constants); + }); + TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilterAllowExtern") - .set_body_typed(DefaultTaskFilterAllowExtern); + .set_body_typed([](const Array& args, const Array& constants) { + return DefaultTaskFilterAllowExtern(args, constants); + }); } // namespace meta_schedule } // namespace tvm diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 6708922444b6..7649b6101919 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -429,9 +429,14 @@ void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) { tot_dim *= arr->shape[i]; } T* data_ptr = reinterpret_cast(arr->data); + constexpr int NUM_PRINT = 20; os << "["; for (int i = 0; i < tot_dim; i++) { os << (i != 0 ? ", " : "") << data_ptr[i]; + if (i == NUM_PRINT) { + os << "..."; + break; + } } os << "]"; } @@ -1121,6 +1126,20 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) { NDArrayToTIR(data, ss); } else if (alloc->dtype.bits() == 32) { NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 64) { + NDArrayToTIR(data, ss); + } else { + LOG(FATAL) << "DataType not supported"; + } + } else if (alloc->dtype.is_uint()) { + if (alloc->dtype.bits() == 8) { + // NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 16) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 32) { + NDArrayToTIR(data, ss); + } else if (alloc->dtype.bits() == 64) { + NDArrayToTIR(data, ss); } else { LOG(FATAL) << "DataType not supported"; } diff --git a/src/relay/backend/task_extraction.cc b/src/relay/backend/task_extraction.cc index c577e8e356d6..4f83b6eeed60 100644 --- a/src/relay/backend/task_extraction.cc +++ b/src/relay/backend/task_extraction.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -33,7 +34,7 @@ namespace backend { Array ExtractTask( IRModule mod, Target target, Map params, - runtime::TypedPackedFunc(const Array&)> filter_func) { + meta_schedule::ApplyHistoryBestNode::FTEFilterFunc filter_func) { using meta_schedule::ExtractedTask; if (filter_func == nullptr) { filter_func = tvm::meta_schedule::DefaultTaskFilter; @@ -42,6 +43,7 @@ Array ExtractTask( // is_vm=true for backward compatibility Array pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true); pass_seqs.push_back(transform::FuseOps()); + mod = transform::Sequential(pass_seqs)(std::move(mod)); std::vector tasks; @@ -58,11 +60,9 @@ Array ExtractTask( it->second->weight += 1; return; } - Array inputs_outputs{nullptr}; - std::string fused_name; - std::tie(inputs_outputs, fused_name) = + auto [inputs_outputs, constants, fused_name] = tec::LowerTECompute(relay_func, target, /*return_inputs=*/true); - if (Optional prim_func = filter_func(inputs_outputs)) { + if (Optional prim_func = filter_func(inputs_outputs, constants)) { GlobalVar prim_fn_var(fused_name); IRModule relay_mod({{prim_fn_var, relay_func}}); IRModule tir_mod({{prim_fn_var, prim_func.value()}}); diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc index da52d94b4e46..92cc6f8cfa46 100644 --- a/src/relay/backend/te_compiler_cache.cc +++ b/src/relay/backend/te_compiler_cache.cc @@ -50,7 +50,6 @@ #include "../../te/operation/create_primfunc.h" #include "../op/memory/memory.h" #include "../transforms/meta_schedule_layout_rewrite.h" -#include "../transforms/pass_utils.h" #include "utils.h" namespace tvm { @@ -362,8 +361,14 @@ class ScheduleBuilder : public ExprVisitor { } if (meta_schedule_ctx_) { Array te_args = Concat(fn_inputs, tensor_outs); + Array constants; + for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) { + te_args.push_back(te_tensor); + constants.push_back(const_node->data); + } + if (Optional tir_func = - meta_schedule_ctx_.value()->te_filter_func(te_args)) { + meta_schedule_ctx_.value()->te_filter_func(te_args, constants)) { IRModule relay_mod({{prim_fn_var, relay_func}}); IRModule tir_mod({{prim_fn_var, tir_func.value()}}); if (Optional opt_scheduled_mod = meta_schedule_ctx_.value()->Query( @@ -785,8 +790,8 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, return MakeShapeFunc().Create(prim_func, target, global_var_supply); } -std::pair, std::string> LowerTECompute(const Function& source_func, Target target, - bool return_inputs) { +std::tuple, Array, std::string> LowerTECompute( + const Function& source_func, Target target, bool return_inputs) { LowerToTECompute lower_te_compute(target); Array outputs = lower_te_compute.Lower(source_func); // Following ScheduleBuilder, remove placeholder ops from outputs. @@ -796,11 +801,18 @@ std::pair, std::string> LowerTECompute(const Function& source_ tensor_outs.push_back(tensor); } } + + tvm::Array constants; + for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) { + tensor_outs.push_back(te_tensor); + constants.push_back(const_node->data); + } + if (return_inputs) { - return std::make_pair(Concat(lower_te_compute.fn_inputs_, tensor_outs), - lower_te_compute.candidate_name_); + return std::make_tuple(Concat(lower_te_compute.fn_inputs_, tensor_outs), constants, + lower_te_compute.candidate_name_); } - return std::make_pair(tensor_outs, lower_te_compute.candidate_name_); + return std::make_tuple(tensor_outs, constants, lower_te_compute.candidate_name_); } TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) { diff --git a/src/relay/backend/te_compiler_cache.h b/src/relay/backend/te_compiler_cache.h index 894a5f5be5f6..95c5bc974181 100644 --- a/src/relay/backend/te_compiler_cache.h +++ b/src/relay/backend/te_compiler_cache.h @@ -37,6 +37,7 @@ #include #include +#include #include #include @@ -215,10 +216,10 @@ Array GetShape(const Array& shape); * \param source_func The primitive function to be lowered. * \param target The target we want to create schedule for. * \param return_inputs If true, prepend input tensors to the output array of tensors. - * \return Pair of schedule and fused function name. + * \return Tuple of the lowered TE compute, constant raw data, and fused function name. */ -std::pair, std::string> LowerTECompute(const Function& source_func, Target target, - bool return_inputs = true); +std::tuple, Array, std::string> LowerTECompute( + const Function& source_func, Target target, bool return_inputs = true); /*! * \brief Create schedule for target. diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 55df71a8053e..4c1358f42519 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -17,6 +17,8 @@ * under the License. */ +#include "create_primfunc.h" + #include #include #include @@ -24,9 +26,13 @@ #include #include +#include +#include #include +#include #include "../../tir/ir/functor_common.h" +#include "../../tir/transforms/ir_utils.h" #include "../schedule/graph.h" namespace tvm { @@ -492,6 +498,12 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { return GenerateAndCompletePrimFunc(arg_list, root_stmts, &info); } +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, + const Array& constants) { + PrimFunc func = CreatePrimFunc(arg_list); + return tir::BindParams(func, constants); +} + TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc); } // namespace tir diff --git a/src/te/operation/create_primfunc.h b/src/te/operation/create_primfunc.h index c3cddd83f57a..b68d30a2fb82 100644 --- a/src/te/operation/create_primfunc.h +++ b/src/te/operation/create_primfunc.h @@ -30,6 +30,14 @@ namespace tir { /*! \brief Use Tensor Expression to create a schedulable TensorIR func. */ PrimFunc CreatePrimFunc(const Array& arg_list); +/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the + * constants array is N, the last N tensors in arg_list will be treated as constant tensors. + * Constant tensors will not be part of the parameters of the created PrimFunc, instead constants + * will be embedded in the body as AllocateConstNode. + */ +PrimFunc CreatePrimFuncWithConstants(const Array& arg_list, + const Array& constants); + } // namespace tir } // namespace tvm diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc index 44a58f8792c4..576476ae30aa 100644 --- a/src/tir/analysis/estimate_flops.cc +++ b/src/tir/analysis/estimate_flops.cc @@ -169,6 +169,7 @@ class FlopEstimator : private ExprFunctor, TResult VisitExpr_(const IntImmNode* op) override { return TResult(); } TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } + TResult VisitStmt_(const AllocateConstNode* op) override { return VisitStmt(op->body); } TResult VisitStmt_(const SeqStmtNode* seq) override { TResult result; diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index a739373ab329..1c21d770db30 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -220,9 +220,24 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } } if (const auto* block = sref->StmtAs()) { - if (const auto* seq = block->body.as()) { + auto body = block->body; + // Peel off AllocateConst nodes at the beginning of the block body. + std::vector allocs; + while (const auto* alloc = body.as()) { + allocs.push_back(alloc); + body = alloc->body; + } + if (const auto* seq = body.as()) { ObjectPtr n = make_object(*block); - n->body = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + auto new_seq = RemoveFromSeqStmt(GetRef(seq), GetRef(last_stmt)); + // Re-attach AllocateConst nodes + auto new_body = new_seq; + for (int i = 0; i < static_cast(allocs.size()); ++i) { + auto alloc = allocs[allocs.size() - 1 - i]; + new_body = AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data, + new_body, alloc->annotations, alloc->span); + } + n->body = new_body; *src_stmt = GetRef(block); *tgt_stmt = Stmt(std::move(n)); return; diff --git a/src/tir/transforms/bind_params.cc b/src/tir/transforms/bind_params.cc index a5bc519e9a0e..0b71b2e8fa34 100644 --- a/src/tir/transforms/bind_params.cc +++ b/src/tir/transforms/bind_params.cc @@ -84,44 +84,57 @@ class ParamsCollector : public StmtExprVisitor { Map constant_map_; }; -namespace transform { +PrimFunc BindParams(PrimFunc f, const Array& constants) { + Map constant_map; -Pass BindParams(const Array& constants) { - auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { - Map constant_map; - - // Remove constants from the primfunc signature - size_t num_constants = constants.size(); - size_t start = f->params.size() - num_constants; - Array params; - for (unsigned i = 0; i < start; i++) { - params.push_back(f->params[i]); - } + // Remove constants from the primfunc signature + size_t num_constants = constants.size(); + size_t start = f->params.size() - num_constants; + Array params; + for (unsigned i = 0; i < start; i++) { + params.push_back(f->params[i]); + } - auto* n = f.CopyOnWrite(); - for (unsigned i = start; i < f->params.size(); i++) { - tir::Var p = n->params[i]; - tir::Var b = n->buffer_map[p]->data; - n->buffer_map.erase(p); - constant_map.Set(b, constants[i - start]); + auto* n = f.CopyOnWrite(); + for (unsigned i = start; i < f->params.size(); i++) { + tir::Var p = n->params[i]; + tir::Var b = n->buffer_map[p]->data; + n->buffer_map.erase(p); + constant_map.Set(b, constants[i - start]); + } + n->params = params; + auto constant_list = ParamsCollector(constant_map).CollectParams(n->body); + + // Allocate constants within the primfunc + for (auto i : constant_list) { + auto var = GetRef(i); + int ndim = constant_map[var]->ndim; + Array extents; + + for (int i = 0; i < ndim; i++) { + int shape = constant_map[var]->shape[i]; + extents.push_back(make_const(DataType::Int(32), shape)); } - n->params = params; - auto constant_list = ParamsCollector(constant_map).CollectParams(n->body); - - // Allocate constants within the primfunc - for (auto i : constant_list) { - auto var = GetRef(i); - int ndim = constant_map[var]->ndim; - Array extents; - - for (int i = 0; i < ndim; i++) { - int shape = constant_map[var]->shape[i]; - extents.push_back(make_const(DataType::Int(32), shape)); - } - DataType dtype = DataType(constant_map[var]->dtype); + DataType dtype = DataType(constant_map[var]->dtype); + + if (n->body->IsInstance()) { + auto* block_realize = n->body.as(); + auto block = block_realize->block; + block.CopyOnWrite()->body = + tir::AllocateConst(var, dtype, extents, constant_map[var], block->body); + n->body = BlockRealize(block_realize->iter_values, block_realize->predicate, block); + } else { n->body = tir::AllocateConst(var, dtype, extents, constant_map[var], n->body); } - return f; + } + return f; +} + +namespace transform { + +Pass BindParams(const Array& constants) { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return BindParams(f, constants); }; return CreatePrimFuncPass(pass_func, 0, "tir.BindParams", {}); } diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 227935bf72dd..ab1489d5ad4a 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -709,10 +709,7 @@ class PipelineRewriter : public StmtExprMutator { arith::Analyzer* ana_normalized) const { std::vector new_blocks = blocks; std::vector commit_group_indices(new_blocks.size(), -1); - for (const auto& kv : async_states_local) { - const int stage_id = kv.first; - const AsyncStateLocal& state = kv.second; - + for (const auto& [stage_id, state] : async_states_local) { if (!state.commit_groups.empty()) { for (size_t i = 0; i < state.commit_groups.size(); ++i) { for (size_t j = 0; j < state.commit_groups[i].size(); ++j) { diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index d89ee3619699..a54eebe4ed05 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -311,6 +311,16 @@ std::unordered_map GetTensorCoreFragmentInfo(const // attr::async_wait_queue_scope annotation. std::pair GetAsyncWaitAttributes(const AttrStmtNode* op); +/*! + * \brief Bind a subset of parameter tensors to constants, replacing them by AllocateConst nodes. + * \param f The function to bind constants to. + * \param constants Raw constant data. If the size of this array is N, the last N parameter tensors + * will be removed from the signature and instead AllocateConst nodes will be introduced in the + * function body. + * \return The updated function. + */ +PrimFunc BindParams(PrimFunc f, const Array& constants); + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc index 81dfceb40d32..db59824bf1ce 100644 --- a/src/tir/transforms/plan_update_buffer_allocation_location.cc +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -31,10 +31,31 @@ namespace tvm { namespace tir { +class CollectUnmanagedAllocations : public StmtExprVisitor { + public: + void VisitStmt_(const AllocateNode* op) final { + unmanaged_allocations.insert(op->buffer_var.get()); + StmtExprVisitor::VisitStmt_(op); + } + + void VisitStmt_(const AllocateConstNode* op) final { + unmanaged_allocations.insert(op->buffer_var.get()); + StmtExprVisitor::VisitStmt_(op); + } + + /*! \brief Buffers that are allocated outside of the BlockNode, and should not be moved by + * BufferAllocationLocator. */ + std::unordered_set unmanaged_allocations; +}; + class BufferAllocationLocator : public StmtExprMutator { public: explicit BufferAllocationLocator(const PrimFunc& func) { Map> buffer_lca = DetectBufferAccessLCA(func); + CollectUnmanagedAllocations collector; + collector(func->body); + unmanaged_allocations_ = collector.unmanaged_allocations; + std::unordered_set arg_buffers; for (const auto& kv : func->buffer_map) { const Buffer& buffer = kv.second; @@ -48,7 +69,10 @@ class BufferAllocationLocator : public StmtExprMutator { if (arg_buffers.count(buffer.get())) { continue; } - alloc_buffers_[stmt].push_back(buffer); + if (!unmanaged_allocations_.count(buffer->data.get())) { + alloc_buffers_[stmt].push_back(buffer); + } + buffer_data_to_buffer_.Set(buffer->data, buffer); } } @@ -119,16 +143,6 @@ class BufferAllocationLocator : public StmtExprMutator { return Stmt(n); } - Stmt VisitStmt_(const AllocateNode* op) final { - unmanaged_allocations_.insert(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - - Stmt VisitStmt_(const AllocateConstNode* op) final { - unmanaged_allocations_.insert(op->buffer_var.get()); - return StmtExprMutator::VisitStmt_(op); - } - Stmt VisitStmt_(const BufferRealizeNode* op) final { ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR."; throw; diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index 80c2fbaeb416..8e299dc935d5 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -19,16 +19,19 @@ import json import os import re -import sys +from io import StringIO +from contextlib import redirect_stderr import numpy as np -import pytest import tvm import tvm.relay import tvm.testing +from tvm import meta_schedule as ms +from tvm import relay from tvm.relay.backend import Executor, Runtime from tvm.contrib import utils +from tvm.meta_schedule.testing.utils import apply_fixed_schedules INPUT_SHAPE = (1, 3, 16, 16) @@ -382,5 +385,62 @@ def _run_unlinked(lib): np.testing.assert_allclose(unlinked_output, linked_output) +def test_tir_link_params(): + def get_dense(data_shape, weight_shape): + data = relay.var("data", shape=data_shape, dtype="float32") + weight = relay.var("weight", shape=weight_shape, dtype="float32") + dense = relay.nn.dense(data, weight) + return relay.Function([data, weight], dense) + + def get_ref_dense(data_np, weight_np): + return np.dot(data_np, np.transpose(weight_np)) + + def schedule_dense(sch): + dense = sch.get_block("T_matmul_NT") + _y, _x, _k = sch.get_loops(dense) + + M, N, K = 128, 128, 128 + data_shape = (M, K) + weight_shape = (N, K) + relay_mod = tvm.IRModule.from_expr(get_dense(data_shape, weight_shape)) + relay_mod = relay.transform.InferType()(relay_mod) + data_np = np.random.randn(*data_shape).astype("float32") + weight_np = np.random.randn(*weight_shape).astype("float32") + target = "llvm" + params = {"weight": weight_np} + + def schedule_fn(task, sch): + if "nn_dense" in task.task_name: + schedule_dense(sch) + return True + return False + + link_params = True + + with tvm.transform.PassContext(config={"relay.FuseOps.link_params": link_params}): + database = apply_fixed_schedules(relay_mod, target, params, schedule_fn) + + with StringIO() as stderr_buf, redirect_stderr(stderr_buf): + with ms.ApplyHistoryBest(database): + with tvm.transform.PassContext( + opt_level=3, + config={"relay.backend.use_meta_schedule": True}, + ): + executor = Executor("graph", {"link-params": link_params}) + lib = relay.build(relay_mod, target=target, executor=executor) + + # Workload look up should succeed. This does not work when the test is invoked from pytest. + assert not "Cannot find workload" in stderr_buf.getvalue() + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + runtime.set_input(**params) + runtime.set_input("data", data_np) + runtime.run() + out = runtime.get_output(0).numpy() + ref = get_ref_dense(data_np, weight_np) + tvm.testing.assert_allclose(out, ref, atol=1e-4, rtol=1e-4) + + if __name__ == "__main__": tvm.testing.main()