diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 273fc48c3e30..3a9de4e077dc 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -222,7 +222,14 @@ Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { DataType t = value_types[i].type; uint32_t nbits = t.bits() * t.lanes(); ICHECK_EQ(nbits % 8, 0); - offset += nbits / 8; + uint32_t bytes = (nbits / 8); + if (t.bits() == 32) { + // In our Vulkan runtime, each push constant always occupies 64 bit. + offset += bytes * 2; + } else { + ICHECK_EQ(t.bits(), 64); + offset += bytes; + } } // Decorate push constants as UBO this->Decorate(spv::OpDecorate, struct_type, spv::DecorationBlock); diff --git a/tests/python/unittest/test_target_codegen_spirv.py b/tests/python/unittest/test_target_codegen_spirv.py index 2cbf0bea9257..68be5c480358 100644 --- a/tests/python/unittest/test_target_codegen_spirv.py +++ b/tests/python/unittest/test_target_codegen_spirv.py @@ -17,6 +17,7 @@ import tvm import tvm.testing from tvm import te +from tvm import relay from tvm.topi.math import cast import numpy as np @@ -71,5 +72,38 @@ def do_copy(A, B, n): tvm.testing.assert_allclose(b.asnumpy(), ref) +def test_pushconstants(): + if not tvm.testing.device_enabled("vulkan"): + return + + def check_mod(mod, x_np, res_np): + target = "vulkan" + ctx = tvm.context(target, 0) + ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) + res = ex.evaluate()(x_np).asnumpy() + tvm.testing.assert_allclose(res, res_np, atol=1e-5) + + # Three 32 bit pushconstants: any_dim, stride, stride + dtype = "float32" + x = relay.var("x", shape=(relay.Any(),), dtype=dtype) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], relay.sqrt(x)) + x_np = np.random.uniform(size=(10,)).astype(dtype) + res_np = np.sqrt(x_np) + + check_mod(mod, x_np, res_np) + + # One 64 bit and one 32 bit constants + dtype = "int32" + x = relay.var("x", shape=(relay.Any(),), dtype=dtype) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], relay.argsort(x)) + x_np = np.random.randint(0, high=10, size=(10,)).astype(dtype) + res_np = np.argsort(x_np) + + check_mod(mod, x_np, res_np) + + if __name__ == "__main__": test_bool_load() + test_pushconstants()