diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 6311b435f197..24608ebc93f4 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -45,10 +45,15 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std:: if (auto* ptr = arg->type_annotation.as()) { auto* prim = ptr->element_type.as(); ICHECK(prim); - DataType value_type = prim->dtype; + DataType value_storage_type = prim->dtype; + if (value_storage_type == DataType::UInt(1)) { + // We need a physically addressable buffer type to support boolean tensors. + // The loaded byte is cast to bool inside the LoadNode visitor below. + value_storage_type = DataType::UInt(8); + } spirv::Value arg_value = - builder_->BufferArgument(builder_->GetSType(value_type), 0, num_buffer); - storage_info_[arg.get()].UpdateContentType(value_type); + builder_->BufferArgument(builder_->GetSType(value_storage_type), 0, num_buffer); + storage_info_[arg.get()].UpdateContentType(value_storage_type); var_map_[arg.get()] = arg_value; } else { LOG(FATAL) << "require all handles to be typed"; @@ -369,11 +374,18 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const LoadNode* op) { mask |= spv::MemoryAccessVolatileMask; } if (op->dtype.lanes() == 1) { - ICHECK_EQ(info.content_type, op->dtype) - << "Vulkan only allow one type access to the same buffer"; spirv::Value index = MakeValue(op->index); spirv::Value ptr = builder_->StructArrayAccess(ptr_type, buffer, index); - return builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); + spirv::Value loaded = builder_->MakeValue(spv::OpLoad, content_type, ptr, mask); + if (op->dtype == DataType::UInt(1)) { + // A bool tensor is backed by a byte buffer, we cast to bool here. + auto bool_ty = builder_->GetSType(DataType::UInt(1)); + return builder_->Cast(bool_ty, loaded); + } else { + ICHECK_EQ(info.content_type, op->dtype) + << "Vulkan only allow one type access to the same buffer"; + return loaded; + } } else { if (op->dtype.element_of() == info.content_type) { // because content type is element type, we can only do scalarize load. diff --git a/tests/python/unittest/test_target_codegen_spirv.py b/tests/python/unittest/test_target_codegen_spirv.py new file mode 100644 index 000000000000..2cbf0bea9257 --- /dev/null +++ b/tests/python/unittest/test_target_codegen_spirv.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import te +from tvm.topi.math import cast +import numpy as np + + +def test_bool_load(): + def do_copy(A, B, n): + ib = tvm.tir.ir_builder.create() + A = ib.buffer_ptr(A) + B = ib.buffer_ptr(B) + + tx = te.thread_axis("threadIdx.x") + bx = te.thread_axis("blockIdx.x") + + max_threads = 32 + ib.scope_attr(bx, "thread_extent", tvm.tir.indexdiv(n + max_threads - 1, max_threads)) + ib.scope_attr(tx, "thread_extent", max_threads) + tid = bx * max_threads + tx + + with ib.if_scope(tid < n): + B[tid] = cast(A[tid], "int32") + + return ib.get() + + n = 1024 + A = te.placeholder((n,), name="A", dtype="bool") + B = te.placeholder((n,), name="B", dtype="int32") + + target = "vulkan" + + if not tvm.testing.device_enabled(target): + return + + B = te.extern( + A.shape, + [A], + lambda ins, outs: do_copy(ins[0], outs[0], n), + name="bool_copy_ir", + dtype="int32", + ) + s = te.create_schedule(B.op) + + with tvm.transform.PassContext(opt_level=3): + func = tvm.build(s, [A, B], target) + + ctx = tvm.context(target, 0) + a_np = np.random.uniform(size=n) > 0.5 + b_np = np.zeros((n,), dtype="int32") + a = tvm.nd.array(a_np, ctx) + b = tvm.nd.array(b_np, ctx) + func(a, b) + ref = a_np.astype(np.int32) + tvm.testing.assert_allclose(b.asnumpy(), ref) + + +if __name__ == "__main__": + test_bool_load()