From 98798c709efee7fc74c436115a828e14621a80e0 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Thu, 22 Sep 2022 19:06:17 -0700 Subject: [PATCH 1/3] Also remap VarNode to USMP-allocated buffer. * Fixes cases where a USMP buffer is used with tensorization intrinsics. In these cases, the buffer data var is referenced directly in the call_extern args. Before this patch, ConvertPoolAllocationsToOffsets would generate TIR like the following: let dense_let: Pointer(global int32) = @tir.address_of(global_workspace_37_buffer_var[69952], dtype=handle) for (k.outer: int32, 0, 64) { @tir.call_extern("gemm_1x1x1_update_UKVNAEBL", ..., dense, ...) } T_multiply[ax1] = @tir.q_multiply_shift(((dense: Buffer(dense_let, int32, [10], [], align=32)[ax1], ...) This caused CodegenSourceBase to later fail with this error: "src/target/source/codegen_source_base.cc", line 67 Check failed: (it != var_idmap_.end()) is false: Find undefined Variable dense After this patch, "dense" in the call_extern is changed to read "dense_let." --- .../convert_pool_allocations_to_offsets.cc | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 601e34719632..14a0dd18e34e 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -96,6 +96,7 @@ class PoolAllocationToOffsetConverter : public StmtExprMutator { private: PrimExpr VisitExpr_(const CallNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; + PrimExpr VisitExpr_(const VarNode* op) override; PrimExpr VisitExpr_(const BufferLoadNode* op) override; Stmt VisitStmt_(const BufferStoreNode* op) override; @@ -343,8 +344,14 @@ LetStmt PoolAllocationToOffsetConverter::ToLetStmt(const PoolAllocation& pool_al let_var_type = PointerType(Downcast(let_var_type)->element_type); } Var let_var(buffer_var->name_hint + "_let", let_var_type); + if (buffer_var->name_hint == "dense") { + LOG(INFO) << " Set buffer var: " << let_var; + } allocate_var_to_let_var_.Set(buffer_var, let_var); Stmt new_body = VisitStmt(body); + if (buffer_var->name_hint == "dense") { + LOG(INFO) << " Erase buffer var: " << let_var; + } allocate_var_to_let_var_.erase(buffer_var); return LetStmt(let_var, address_of_load, new_body); } @@ -395,6 +402,18 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) { return std::move(load); } +PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const VarNode* op) { + auto it = allocate_var_to_let_var_.find(GetRef(op)); + if (op->name_hint == "dense") { + LOG(INFO) << "Lookup dense var: " << (it != allocate_var_to_let_var_.end()); + } + if (it != allocate_var_to_let_var_.end()) { + return (*it).second; + } + + return StmtExprMutator::VisitExpr_(op); +} + Buffer PoolAllocationToOffsetConverter::GetRemappedBuffer(Buffer original) { { auto it = original_buf_to_let_buf_.find(original); From de55da77ed984a1e03ff1b1d4c7a14e2a7c7e3ec Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 26 Sep 2022 15:12:08 -0700 Subject: [PATCH 2/3] Revert LOG(INFO), add tests. --- .../convert_pool_allocations_to_offsets.cc | 9 -- ...orm_convert_pool_allocations_to_offsets.py | 93 +++++++++++++++++++ 2 files changed, 93 insertions(+), 9 deletions(-) diff --git a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc index 14a0dd18e34e..56aba654b59e 100644 --- a/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc +++ b/src/tir/usmp/transform/convert_pool_allocations_to_offsets.cc @@ -344,14 +344,8 @@ LetStmt PoolAllocationToOffsetConverter::ToLetStmt(const PoolAllocation& pool_al let_var_type = PointerType(Downcast(let_var_type)->element_type); } Var let_var(buffer_var->name_hint + "_let", let_var_type); - if (buffer_var->name_hint == "dense") { - LOG(INFO) << " Set buffer var: " << let_var; - } allocate_var_to_let_var_.Set(buffer_var, let_var); Stmt new_body = VisitStmt(body); - if (buffer_var->name_hint == "dense") { - LOG(INFO) << " Erase buffer var: " << let_var; - } allocate_var_to_let_var_.erase(buffer_var); return LetStmt(let_var, address_of_load, new_body); } @@ -404,9 +398,6 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const BufferLoadNode* op) { PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const VarNode* op) { auto it = allocate_var_to_let_var_.find(GetRef(op)); - if (op->name_hint == "dense") { - LOG(INFO) << "Lookup dense var: " << (it != allocate_var_to_let_var_.end()); - } if (it != allocate_var_to_let_var_.end()) { return (*it).second; } diff --git a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py index fdda400a779f..31cc6e07dec3 100644 --- a/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py +++ b/tests/python/unittest/test_tir_usmp_transform_convert_pool_allocations_to_offsets.py @@ -600,5 +600,98 @@ def test_resnet_subgraph(): tvm.ir.assert_structural_equal(actual_func, ref_func) +@tvm.script.ir_module +class TensorIntrinStructure: + @T.prim_func + def tensor_intrin_primfunc() -> None: + dense_data = T.allocate([10], "int32", "global") + T.evaluate( + T.call_extern( + "intrin_function", + T.tvm_access_ptr( + T.type_annotation(dtype="int32"), dense_data, 0, 1, 2, dtype="handle" + ), + dtype="int32", + ) + ) + + dense = T.buffer_decl([10], "int32", data=dense_data) + dense[0] = T.q_multiply_shift(dense[0], 1608879842, 31, -7, dtype="int32") + + @T.prim_func + def __tvm_main__(input: T.handle, output: T.handle) -> None: + T.evaluate(T.call_extern("tensor_intrin_primfunc", dtype="int32")) + + +@tvm.script.ir_module +class TensorIntrinStructurePlanned: + @T.prim_func + def tensor_intrin_primfunc(global_workspace_1_var: T.Ptr[T.uint8]) -> None: + global_workspace_1_buffer_var = T.match_buffer( + global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 + ) + T.preflattened_buffer( + global_workspace_1_buffer_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 + ) + dense_let = T.buffer_decl([10], "int32") + with T.let(dense_let.data, T.address_of(global_workspace_1_buffer_var[0], dtype="handle")): + T.evaluate( + T.call_extern( + "intrin_function", + T.tvm_access_ptr( + T.type_annotation(dtype="int32"), dense_let.data, 0, 1, 2, dtype="handle" + ), + dtype="int32", + ) + ) + dense_let[0] = T.q_multiply_shift(dense_let[0], 1608879842, 31, -7, dtype="int32") + + @T.prim_func + def __tvm_main__( + input: T.handle, global_workspace_1_var: T.Ptr[T.uint8], output: T.handle + ) -> None: + global_workspace_1_buffer_var = T.match_buffer( + global_workspace_1_var, [40], dtype="uint8", strides=[1], elem_offset=0, align=16 + ) + T.evaluate( + T.call_extern( + "tensor_intrin_primfunc", global_workspace_1_buffer_var.data, dtype="int32" + ) + ) + + +def test_tensor_intrin(): + target = Target("c") + global_workspace_pool = WorkspacePoolInfo( + "global_workspace", + [target], + ) + + tir_mod = TensorIntrinStructure + tir_mod = _assign_targets_to_primfuncs_irmodule(tir_mod, target) + tir_mod = assign_poolinfos_to_allocates_in_irmodule(tir_mod, [global_workspace_pool]) + main_func = tir_mod["__tvm_main__"] + buffer_analysis = tvm.tir.usmp.analysis.extract_buffer_info(main_func, tir_mod) + buffer_info_map = buffer_analysis.buffer_info_stmts + + fcreate_array_bi = tvm.get_global_func("tir.usmp.CreateArrayBufferInfo") + buffer_info_arr = fcreate_array_bi(buffer_info_map) + fusmp_algo_greedy_by_size = tvm.get_global_func("tir.usmp.algo.greedy_by_size") + buffer_pool_allocations = fusmp_algo_greedy_by_size( + buffer_info_arr, buffer_analysis.memory_pressure + ) + fassign_stmt_pool_allocations = tvm.get_global_func("tir.usmp.AssignStmtPoolAllocations") + pool_allocations = fassign_stmt_pool_allocations(buffer_info_map, buffer_pool_allocations) + tir_mod_with_offsets = tvm.tir.usmp.transform.convert_pool_allocations_to_offsets( + pool_allocations, emit_tvmscript_printable=True + )(tir_mod) + + expected = TensorIntrinStructurePlanned + + for gv, ref_func in expected.functions.items(): + actual_func = tir_mod_with_offsets[gv.name_hint] + tvm.ir.assert_structural_equal(actual_func, ref_func) + + if __name__ == "__main__": pytest.main([__file__] + sys.argv[1:]) From 9c538e045063dc7e11d34c7dab455da02cfac503 Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Mon, 26 Sep 2022 15:12:19 -0700 Subject: [PATCH 3/3] Don't crash on 0-arg PrimFunc in ExtractBufferInfo. --- src/tir/usmp/analysis/extract_buffer_info.cc | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/tir/usmp/analysis/extract_buffer_info.cc b/src/tir/usmp/analysis/extract_buffer_info.cc index 74d428f6dddf..268058945750 100644 --- a/src/tir/usmp/analysis/extract_buffer_info.cc +++ b/src/tir/usmp/analysis/extract_buffer_info.cc @@ -429,15 +429,17 @@ void BufferInfoExtractor::VisitExpr_(const VarNode* op) { Array static GetMatchedBuffers(const PrimFunc& func) { Array buffer_vars; - for (unsigned int i = 0; i < func->params.size() - 1; i++) { - Var param = func->params[i]; - buffer_vars.push_back(func->buffer_map[param]->data); - } - Var last_param = func->params.back(); - // Checks whether last var is present in the buffer map - // because it could be the resource handle - if (func->buffer_map.find(last_param) != func->buffer_map.end()) { - buffer_vars.push_back(func->buffer_map[last_param]->data); + if (func->params.size() > 0) { + for (unsigned int i = 0; i < func->params.size() - 1; i++) { + Var param = func->params[i]; + buffer_vars.push_back(func->buffer_map[param]->data); + } + Var last_param = func->params.back(); + // Checks whether last var is present in the buffer map + // because it could be the resource handle + if (func->buffer_map.find(last_param) != func->buffer_map.end()) { + buffer_vars.push_back(func->buffer_map[last_param]->data); + } } return buffer_vars; }