Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/target/spirv/codegen_spirv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,14 @@ runtime::VulkanShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::

builder_->CommitKernelFunction(func_ptr, name);

ICHECK_LE(shared_memory_bytes_used_, spirv_support_.max_shared_memory_per_block)
<< "Vulkan shader " << name << " uses " << shared_memory_bytes_used_
<< " bytes of shared memory, "
<< "but target supports only " << spirv_support_.max_shared_memory_per_block << " bytes. "
<< "If the device supports this allocation, "
<< "please add -max_shared_memory_per_block=NBYTES to the target, "
<< "or query all device parameters by adding -from_device=0.";

shader.data = builder_->Finalize();
return shader;
}
Expand All @@ -121,6 +129,7 @@ void CodeGenSPIRV::InitFuncState() {
analyzer_.reset(new arith::Analyzer());
builder_.reset(new spirv::IRBuilder(spirv_support_));
builder_->InitHeader();
shared_memory_bytes_used_ = 0;
}

spirv::Value CodeGenSPIRV::GetThreadIndex(const IterVar& iv, const PrimExpr& extent) {
Expand Down Expand Up @@ -642,6 +651,9 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) {
// Shared memory
buf =
builder_->Allocate(etype, static_cast<uint32_t>(constant_size), spv::StorageClassWorkgroup);

size_t num_bytes = op->dtype.bytes() * op->dtype.lanes() * static_cast<uint32_t>(constant_size);
shared_memory_bytes_used_ += num_bytes;
} else {
LOG(FATAL) << "Can only allocate shared or local memory inside kernel";
}
Expand Down
4 changes: 4 additions & 0 deletions src/target/spirv/codegen_spirv.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ class CodeGenSPIRV : public ExprFunctor<spirv::Value(const PrimExpr&)>,

// binding of let variables. Enables duplicate var defs that map to same value
std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_;

// Running total of the number of bytes of shared memory used.
// Checked against the max_shared_memory_per_group
size_t shared_memory_bytes_used_{0};
};

} // namespace codegen
Expand Down
3 changes: 3 additions & 0 deletions src/target/spirv/spirv_support.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ SPIRVSupport::SPIRVSupport(tvm::Target target) {
if (target->GetAttr<Integer>("max_storage_buffer_range")) {
max_storage_buffer_range = target->GetAttr<Integer>("max_storage_buffer_range").value();
}
if (target->GetAttr<Integer>("max_shared_memory_per_block")) {
max_shared_memory_per_block = target->GetAttr<Integer>("max_shared_memory_per_block").value();
}
if (target->GetAttr<Integer>("max_per_stage_descriptor_storage_buffer")) {
max_per_stage_descriptor_storage_buffers =
target->GetAttr<Integer>("max_per_stage_descriptor_storage_buffer").value();
Expand Down
16 changes: 16 additions & 0 deletions src/target/spirv/spirv_support.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,22 @@ struct SPIRVSupport {
*/
uint32_t max_storage_buffer_range{1 << 27};

/*!
* \brief The maximum amount of shared memory usable by a shader
*
* Vulkan extension: N/A
* Vulkan struct: VkPhysicalDeviceLimits
* Device Property: maxComputeSharedMemorySize
* SPV Extension name: N/A
* SPV Capability: N/A
*
* The maximum amount of shared memory (Workgroup scope) that may be
* allocated by a shader. Default value is from Vulkan spec,
* "Required Limits" table. Implementations may have a larger
* limit.
*/
uint32_t max_shared_memory_per_block{16384};

/*!
* \brief The maximum number of storage buffers accessible by a single shader.
*
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_target_codegen_vulkan.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,5 +527,35 @@ def test_ramp_broadcast_index(self, target, dev, mod, ref_data):
tvm.testing.assert_allclose(b.numpy(), b_np)


@tvm.testing.parametrize_targets("vulkan -max_shared_memory_per_block=16384")
def test_shared_mem_alloc(target, dev):
alloc_nbytes = 16384 * 2

def do_compute(ins, outs):
ib = tvm.tir.ir_builder.create()
out = ib.buffer_ptr(outs[0])

ib.scope_attr(te.thread_axis("blockIdx.x"), "thread_extent", 0)

array = ib.allocate("int32", (alloc_nbytes,), name="array", scope="shared")
array[0] = 0
out[0] = array[0]

return ib.get()

Out = te.extern(
shape=(1,),
inputs=[],
fcompute=do_compute,
dtype="int32",
)
s = te.create_schedule(Out.op)

# Codegen should raise error when allocating more memory than the
# target supports.
with pytest.raises(tvm.TVMError):
tvm.build(s, [Out], target)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))