From fb416efefcb69a99f131f1342cc37adea9b7c3d5 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 25 Jul 2023 03:46:35 -0400 Subject: [PATCH] [Codegen][Metal] Support metal warp-level primitive This PR introduces the warp-level shuffle primitives used in Metal Shading Language, and uses them in the implementation of allreduce lowering. The introduced primitives are: * `simd_shuffle`, * `simd_shuffle_up`, * `simd_shuffle_down`. See section 6.9.2 of https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf for details. The correctness are validated by `test_allreduce_cuda` with the backend changed to Metal. Given we do not have Metal CI tests, the correctness is checked only locally. Given the Metal shuffle primitives do not support (or need) masking, the pass LowerThreadAllreduce is updated to support such backend which does not have masks. One unit test for metal is added to ensure that no mask is used. --- src/target/source/intrin_rule_metal.cc | 53 +++++++++ src/tir/transforms/lower_thread_allreduce.cc | 35 ++++-- ...t_tir_transform_lower_thread_all_reduce.py | 103 ++++++++++++++++++ 3 files changed, 180 insertions(+), 11 deletions(-) diff --git a/src/target/source/intrin_rule_metal.cc b/src/target/source/intrin_rule_metal.cc index dd924b925596..cc83eb1462c6 100644 --- a/src/target/source/intrin_rule_metal.cc +++ b/src/target/source/intrin_rule_metal.cc @@ -30,6 +30,28 @@ namespace codegen { namespace intrin { using tir::FLowerIntrinsic; +struct MetalWarpIntrinsic { + const Op operator()(DataType t, const Op& orig_op) const { + if (orig_op.same_as(builtin::tvm_warp_shuffle())) { + return Op::Get("tir.metal.simd_shuffle"); + } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { + return Op::Get("tir.metal.simd_shuffle_up"); + } else { + ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); + return Op::Get("tir.metal.simd_shuffle_down"); + } + } +}; + +template +static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { + const CallNode* call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + Array metal_args{{call->args[1], call->args[2]}}; + return Call(call->dtype, T()(call->dtype, Downcast(call->op)), metal_args); +} + TVM_REGISTER_OP("tir.floor") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); @@ -95,6 +117,37 @@ TVM_REGISTER_OP("tir.cosh") TVM_REGISTER_OP("tir.erf").set_attr("metal.FLowerIntrinsic", DispatchFastErf); +TVM_REGISTER_OP("tir.tvm_warp_shuffle") + .set_attr("metal.FLowerIntrinsic", DispatchMetalShuffle); + +TVM_REGISTER_OP("tir.tvm_warp_shuffle_up") + .set_attr("metal.FLowerIntrinsic", DispatchMetalShuffle); + +TVM_REGISTER_OP("tir.tvm_warp_shuffle_down") + .set_attr("metal.FLowerIntrinsic", DispatchMetalShuffle); + +// Register low-level builtin ops. +TVM_REGISTER_OP("tir.metal.simd_shuffle") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("lane", "Expr", "The source thread id.") + .set_attr("TGlobalSymbol", "simd_shuffle") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tir.metal.simd_shuffle_up") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be added.") + .set_attr("TGlobalSymbol", "simd_shuffle_up") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + +TVM_REGISTER_OP("tir.metal.simd_shuffle_down") + .set_num_inputs(2) + .add_argument("var", "Expr", "The variable to sync.") + .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") + .set_attr("TGlobalSymbol", "simd_shuffle_down") + .set_attr("TCallEffectKind", Integer(CallEffectKind::kOpaque)); + } // namespace intrin } // namespace codegen } // namespace tvm diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 438dccff0bdb..fba62a0c18ac 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -476,12 +476,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // The mask for this reducer, as this reducer may sit inside // a divergent control flow. Here it uses a variable to cache the current // active channels. - Buffer mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local"); - { - seq->emplace_back(BufferStore(mask_buffer, mask, zero_indices)); + Optional mask_buffer; + if (need_warp_shuffle_mask_) { + mask_buffer = decl_buffer(shape, mask->dtype, "mask", "local"); + seq->emplace_back(BufferStore(mask_buffer.value(), mask, zero_indices)); // Push the buffer description. Later this will have an // allocation built for it. - local_bufs.push_back(mask_buffer); + local_bufs.push_back(mask_buffer.value()); } // Emit reductions within a warp. @@ -698,9 +699,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } // Emit warp shuffle calls. - PrimExpr WarpShuffle(const Op& op, Buffer mask_buffer, PrimExpr val, PrimExpr delta_or_lane) { + PrimExpr WarpShuffle(const Op& op, Optional mask_buffer, PrimExpr val, + PrimExpr delta_or_lane) { Array indices = {0}; - PrimExpr mask = BufferLoad(mask_buffer, indices); + PrimExpr mask; + if (mask_buffer.defined()) { + mask = BufferLoad(mask_buffer.value(), indices); + } else { + mask = IntImm(DataType::Int(32), 0); + } PrimExpr width = IntImm(DataType::Int(32), warp_size_); Array args{mask, val, delta_or_lane, width, width}; return Call(val.dtype(), op, args); @@ -709,11 +716,15 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // Check if we can use warp level reduction. // // Note: The ROCm backend will only have warp reductions for now. - // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda). + // Also, the warp/wavefront size differs (64 on rocm, 32 on cuda and metal). bool IsWarpReduction(const std::vector& types, int group_extent, int reduce_extent, - int contiguous_reduce_extent) const { - // Only cuda target supports warp reductions. - if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm")) return false; + int contiguous_reduce_extent) { + if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") && + (target_->kind->name != "metal")) { + return false; + } + + need_warp_shuffle_mask_ = target_->kind->name != "metal"; // rocm only supports 32 bit operands for shuffling at the moment if ((target_->kind->name == "rocm") && @@ -745,7 +756,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // whether reduce_extent and group_extent are valid for warp reduction. if (target_->kind->name == "rocm") { return reduce_extent == warp_size_; - } else { // target_->kind->name == "cuda" + } else { if (reduce_extent == 1) { return false; // no need to warp reduce } else { @@ -769,6 +780,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { int warp_size_{1}; // The maximum number of threads of the device. "-1" denotes unknown. int max_num_threads_{-1}; + // A boolean indicating if the target supports warp-level masking. + bool need_warp_shuffle_mask_; // surrounding scope of thread extent. std::vector thread_extents_; diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py index 9d53b1f9dfb5..f797d35d47ca 100644 --- a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py +++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py @@ -702,5 +702,108 @@ def expected(A: T.Buffer((2, 70), "float32"), B: T.Buffer((2,), "float32")): B_1[threadIdx_y] = red_result_1[threadIdx_y] +class TestMetalNoMask(BaseCompare): + @T.prim_func + def before(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")): + T.func_attr( + { + "target": T.target( + { + "kind": "metal", + "max_threads_per_block": 1024, + "thread_warp_size": 32, + "host": "llvm", + } + ), + } + ) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + cross_thread_B = T.allocate([1], "float32", "local") + threadIdx_z = T.launch_thread("threadIdx.z", 1) + threadIdx_y = T.launch_thread("threadIdx.y", 2) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + cross_thread_B_1 = T.Buffer((1,), data=cross_thread_B, scope="local") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + A_1 = T.Buffer((256,), data=A.data) + T.tvm_thread_allreduce( + T.uint32(1), + A_1[threadIdx_y * 128 + threadIdx_x], + T.bool(True), + cross_thread_B_1[0], + threadIdx_x, + ) + if threadIdx_x == 0: + B_1 = T.Buffer((2,), data=B.data) + B_1[threadIdx_y] = cross_thread_B_1[0] + + @T.prim_func + def expected(A: T.Buffer((1, 1, 2, 128), "float32"), B: T.Buffer((1, 1, 2), "float32")): + T.func_attr( + { + "target": T.target( + { + "kind": "metal", + "max_threads_per_block": 1024, + "thread_warp_size": 32, + "host": "llvm", + } + ), + } + ) + blockIdx_x = T.launch_thread("blockIdx.x", 1) + red_result = T.allocate([2], "float32", "shared") + T.attr(red_result, "volatile_scope", 1) + threadIdx_z = T.launch_thread("threadIdx.z", 1) + threadIdx_y = T.launch_thread("threadIdx.y", 2) + threadIdx_x = T.launch_thread("threadIdx.x", 128) + red_result_1 = T.Buffer((2,), data=red_result, scope="shared") + with T.attr( + T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]), + "reduce_scope", + T.reinterpret("handle", T.uint64(0)), + ): + red_buf0 = T.allocate([1], "float32", "local") + t0 = T.allocate([1], "float32", "local") + red_buf0_1 = T.allocate([1], "float32", "local") + t0_1 = T.allocate([1], "float32", "local") + red_buf_staging = T.allocate([8], "float32", "shared") + red_buf0_2 = T.Buffer((1,), data=red_buf0_1, scope="local") + A_1 = T.Buffer((256,), data=A.data) + red_buf0_2[0] = A_1[threadIdx_y * 128 + threadIdx_x] + t0_2 = T.Buffer((1,), data=t0_1, scope="local") + t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 16, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 8, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 4, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 2, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + t0_2[0] = T.tvm_warp_shuffle_down(0, red_buf0_2[0], 1, 32, 32) + red_buf0_2[0] = red_buf0_2[0] + t0_2[0] + red_buf_staging_1 = T.Buffer((8,), data=red_buf_staging, scope="shared") + if threadIdx_x % 32 == 0: + red_buf_staging_1[threadIdx_y * 4 + threadIdx_x // 32] = red_buf0_2[0] + T.tvm_storage_sync("shared") + red_buf0_3 = T.Buffer((1,), data=red_buf0, scope="local") + if threadIdx_x < 4: + red_buf0_3[0] = red_buf_staging_1[threadIdx_y * 4 + threadIdx_x] + t0_3 = T.Buffer((1,), data=t0, scope="local") + t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 2, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + t0_3[0] = T.tvm_warp_shuffle_down(0, red_buf0_3[0], 1, 32, 32) + red_buf0_3[0] = red_buf0_3[0] + t0_3[0] + if threadIdx_x == 0: + red_result_1[threadIdx_y] = red_buf0_3[0] + T.tvm_storage_sync("shared") + if threadIdx_x == 0: + B_1 = T.Buffer((2,), data=B.data) + B_1[threadIdx_y] = red_result_1[threadIdx_y] + + if __name__ == "__main__": tvm.testing.main()