diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 42d0027a326f..66952dae269e 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -110,6 +110,14 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: builder_->CommitKernelFunction(func_ptr, name); + ICHECK_LE(shared_memory_bytes_used_, spirv_support_.max_shared_memory_per_block) + << "Vulkan shader " << name << " uses " << shared_memory_bytes_used_ + << " bytes of shared memory, " + << "but target supports only " << spirv_support_.max_shared_memory_per_block << " bytes. " + << "If the device supports this allocation, " + << "please add -max_shared_memory_per_block=NBYTES to the target, " + << "or query all device parameters by adding -from_device=0."; + shader.data = builder_->Finalize(); return shader; } @@ -121,6 +129,7 @@ void CodeGenSPIRV::InitFuncState() { analyzer_.reset(new arith::Analyzer()); builder_.reset(new spirv::IRBuilder(spirv_support_)); builder_->InitHeader(); + shared_memory_bytes_used_ = 0; } spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) { @@ -642,6 +651,9 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { // Shared memory buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassWorkgroup); + + size_t num_bytes = op->dtype.bytes() * op->dtype.lanes() * static_cast(constant_size); + shared_memory_bytes_used_ += num_bytes; } else { LOG(FATAL) << "Can only allocate shared or local memory inside kernel"; } diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 8b14754f617f..74b62e7613d1 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -214,6 +214,10 @@ class CodeGenSPIRV : public ExprFunctor, // binding of let variables. Enables duplicate var defs that map to same value std::unordered_map let_binding_; + + // Running total of the number of bytes of shared memory used. + // Checked against the max_shared_memory_per_group + size_t shared_memory_bytes_used_{0}; }; } // namespace codegen diff --git a/src/target/spirv/spirv_support.cc b/src/target/spirv/spirv_support.cc index 4a294d56bd9c..0f1207f3e9a8 100644 --- a/src/target/spirv/spirv_support.cc +++ b/src/target/spirv/spirv_support.cc @@ -52,6 +52,9 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) { if (target->GetAttr("max_storage_buffer_range")) { max_storage_buffer_range = target->GetAttr("max_storage_buffer_range").value(); } + if (target->GetAttr("max_shared_memory_per_block")) { + max_shared_memory_per_block = target->GetAttr("max_shared_memory_per_block").value(); + } if (target->GetAttr("max_per_stage_descriptor_storage_buffer")) { max_per_stage_descriptor_storage_buffers = target->GetAttr("max_per_stage_descriptor_storage_buffer").value(); diff --git a/src/target/spirv/spirv_support.h b/src/target/spirv/spirv_support.h index 1497c7c6333a..04d13cca5031 100644 --- a/src/target/spirv/spirv_support.h +++ b/src/target/spirv/spirv_support.h @@ -101,6 +101,22 @@ struct SPIRVSupport { */ uint32_t max_storage_buffer_range{1 << 27}; + /*! + * \brief The maximum amount of shared memory usable by a shader + * + * Vulkan extension: N/A + * Vulkan struct: VkPhysicalDeviceLimits + * Device Property: maxComputeSharedMemorySize + * SPV Extension name: N/A + * SPV Capability: N/A + * + * The maximum amount of shared memory (Workgroup scope) that may be + * allocated by a shader. Default value is from Vulkan spec, + * "Required Limits" table. Implementations may have a larger + * limit. + */ + uint32_t max_shared_memory_per_block{16384}; + /*! * \brief The maximum number of storage buffers accessible by a single shader. * diff --git a/tests/python/unittest/test_target_codegen_vulkan.py b/tests/python/unittest/test_target_codegen_vulkan.py index 85e9cb12d8d2..01f734beb8fd 100644 --- a/tests/python/unittest/test_target_codegen_vulkan.py +++ b/tests/python/unittest/test_target_codegen_vulkan.py @@ -527,5 +527,35 @@ def test_ramp_broadcast_index(self, target, dev, mod, ref_data): tvm.testing.assert_allclose(b.numpy(), b_np) +@tvm.testing.parametrize_targets("vulkan -max_shared_memory_per_block=16384") +def test_shared_mem_alloc(target, dev): + alloc_nbytes = 16384 * 2 + + def do_compute(ins, outs): + ib = tvm.tir.ir_builder.create() + out = ib.buffer_ptr(outs[0]) + + ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0) + + array = ib.allocate("int32", (alloc_nbytes,), name="array", scope="shared") + array[0] = 0 + out[0] = array[0] + + return ib.get() + + Out = te.extern( + shape=(1,), + inputs=[], + fcompute=do_compute, + dtype="int32", + ) + s = te.create_schedule(Out.op) + + # Codegen should raise error when allocating more memory than the + # target supports. + with pytest.raises(tvm.TVMError): + tvm.build(s, [Out], target) + + if __name__ == "__main__": sys.exit(pytest.main(sys.argv))