Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/target/spirv/ir_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,14 @@ Value IRBuilder::DeclarePushConstant(const std::vector<SType>& value_types) {
DataType t = value_types[i].type;
uint32_t nbits = t.bits() * t.lanes();
ICHECK_EQ(nbits % 8, 0);
offset += nbits / 8;
uint32_t bytes = (nbits / 8);
if (t.bits() == 32) {
// In our Vulkan runtime, each push constant always occupies 64 bit.
offset += bytes * 2;
} else {
ICHECK_EQ(t.bits(), 64);
offset += bytes;
}
}
// Decorate push constants as UBO
this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock);
Expand Down
34 changes: 34 additions & 0 deletions tests/python/unittest/test_target_codegen_spirv.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import tvm
import tvm.testing
from tvm import te
from tvm import relay
from tvm.topi.math import cast
import numpy as np

Expand Down Expand Up @@ -71,5 +72,38 @@ def do_copy(A, B, n):
tvm.testing.assert_allclose(b.asnumpy(), ref)


def test_pushconstants():
if not tvm.testing.device_enabled("vulkan"):
return

def check_mod(mod, x_np, res_np):
target = "vulkan"
ctx = tvm.context(target, 0)
ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target)
res = ex.evaluate()(x_np).asnumpy()
tvm.testing.assert_allclose(res, res_np, atol=1e-5)

# Three 32 bit pushconstants: any_dim, stride, stride
dtype = "float32"
x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
mod = tvm.IRModule()
mod["main"] = relay.Function([x], relay.sqrt(x))
x_np = np.random.uniform(size=(10,)).astype(dtype)
res_np = np.sqrt(x_np)

check_mod(mod, x_np, res_np)

# One 64 bit and one 32 bit constants
dtype = "int32"
x = relay.var("x", shape=(relay.Any(),), dtype=dtype)
mod = tvm.IRModule()
mod["main"] = relay.Function([x], relay.argsort(x))
x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype)
res_np = np.argsort(x_np)

check_mod(mod, x_np, res_np)


if __name__ == "__main__":
test_bool_load()
test_pushconstants()