diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index b1fd0171910a..ab9aec077542 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -765,6 +765,8 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { builder_->StartLabel(merge_label); } +void CodeGenSPIRV::VisitStmt_(const DeclBufferNode* op) { VisitStmt(op->body); } + void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK(!is_zero(op->condition)); ICHECK(!op->dtype.is_handle()); diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 3a0336120a8f..1e7b53558508 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -107,6 +107,7 @@ class CodeGenSPIRV : public ExprFunctor, void VisitStmt_(const ForNode* op) override; void VisitStmt_(const WhileNode* op) override; void VisitStmt_(const IfThenElseNode* op) override; + void VisitStmt_(const DeclBufferNode* op) override; void VisitStmt_(const AllocateNode* op) override; void VisitStmt_(const AttrStmtNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 7057ff840637..a8d1719ff200 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -28,7 +28,7 @@ import tvm.testing from tvm import relay, te from tvm.topi.math import cast -from tvm.script import tir as T +from tvm.script import tir as T, ir as I from tvm.tir import TensorIntrin, IntImm, Cast, Schedule from tvm.tir.tensor_intrin.cuda import ( WMMA_LOAD_16x16x16_F16_A_INTRIN, @@ -728,5 +728,22 @@ def tensorize_load(block, dim): tvm.testing.assert_allclose(C.numpy(), ref, rtol=1e-2, atol=1e-2) +@tvm.testing.requires_vulkan(support_required="compile-only") +def test_codegen_decl_buffer(): + """The codegen should accept DeclBuffer nodes in its input""" + + @I.ir_module + class mod: + @T.prim_func + def kernel(): + T.func_attr({"calling_conv": 2, "global_symbol": "kernel", "tir.noalias": True}) + A_data = T.allocate([256], dtype="float32", scope="local") + A_buf = T.decl_buffer([256], dtype="float32", scope="local", data=A_data) + + target = tvm.target.Target("vulkan") + vulkan_codegen = tvm.get_global_func("target.build.vulkan") + vulkan_codegen(mod, target) + + if __name__ == "__main__": tvm.testing.main()