From 6c15b261d90a1fd39b4bba1e494c97a298c0fce4 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 12 Mar 2025 23:53:02 -0400 Subject: [PATCH] [Codegen] Support codegen for vectorized tir.ShuffleNode This PR introduces the support for vectorized tir.ShuffleNode, which is useful for extracting bits and converting to float4, since float4 is sub-byte. Prior to this PR, ShuffleNode is not supported in vectorization. This PR allows vectorizing ShuffleNode subject to special patterns, and still throws error for ShuffleNodes that don't meet the pattern requirements. --- src/target/source/codegen_c.cc | 36 +++++++-- src/target/source/literal/cuda_half_t.h | 44 ++++++----- src/tir/ir/expr_functor.cc | 5 +- src/tir/transforms/vectorize_loop.cc | 68 +++++++++++++++- .../codegen/test_target_codegen_cuda_fp4.py | 79 +++++++++++++++++++ 5 files changed, 204 insertions(+), 28 deletions(-) diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index 575f52e2257a..a67cb80b917b 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -943,22 +943,43 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT( // NOTE: important to print expr first // in case each expr have their own nested expressions // print each elements - for (const PrimExpr& vec : op->vectors) { - std::string vec_value = this->PrintExpr(vec); - if (vec.dtype().lanes() == 1) { + if (op->vectors.size() > 1) { + for (const PrimExpr& vec : op->vectors) { + std::string vec_value = this->PrintExpr(vec); + if (vec.dtype().lanes() == 1) { + concat_vec.push_back(vec_value); + } else { + // print out each element + for (int i = 0; i < vec.dtype().lanes(); ++i) { + // access i-th element of each vector + std::ostringstream vec_elem_strm; + vec_elem_strm << vec_value << "[" << i << "]"; + concat_vec.push_back(vec_elem_strm.str()); + } + } + } + } else { + // Extract elements from a single vector-type value. + std::string vec_value = "(" + this->PrintExpr(op->vectors[0]) + ")"; + if (op->vectors[0].dtype().lanes() == 1) { concat_vec.push_back(vec_value); } else { // print out each element - for (int i = 0; i < vec.dtype().lanes(); ++i) { + for (int i = 0; i < op->vectors[0].dtype().lanes(); ++i) { // access i-th element of each vector std::ostringstream vec_elem_strm; - vec_elem_strm << vec_value << "[" << i << "]"; + PrintVecElemLoad(vec_value, op->vectors[0].dtype(), i, vec_elem_strm); concat_vec.push_back(vec_elem_strm.str()); } } } if (op->indices.size() == 1) { // This is an extract element + CHECK(op->indices[0]->IsInstance()) + << "The ShuffleNode indices are expected to be constants at codegen time. However, " + << "a non-constant index is " << op->indices[0] + << ". Please avoid using ShuffleNode or eliminate the ShuffleNode with loop unroll or " + << "vectorize."; int64_t idx = Downcast(op->indices[0])->value; ICHECK_LT(idx, concat_vec.size()); os << concat_vec[idx]; @@ -969,6 +990,11 @@ void CodeGenC::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT( os << '('; for (size_t i = 0; i < op->indices.size(); ++i) { if (i != 0) os << ", "; + CHECK(op->indices[i]->IsInstance()) + << "The ShuffleNode indices are expected to be constants at codegen time. However, " + << "a non-constant index is " << op->indices[i] + << ". Please avoid using ShuffleNode or eliminate the ShuffleNode with loop unroll or " + << "vectorize."; os << concat_vec[Downcast(op->indices[i])->value]; } os << ')'; diff --git a/src/target/source/literal/cuda_half_t.h b/src/target/source/literal/cuda_half_t.h index b095f5b8cf20..039d89b93feb 100644 --- a/src/target/source/literal/cuda_half_t.h +++ b/src/target/source/literal/cuda_half_t.h @@ -454,26 +454,6 @@ struct __align__(8) half4_bfloat164 { (static_cast<__uint32_t>(lo_part.__x) | (static_cast<__uint32_t>(hi_part.__x) << 16)); return result; } - __device__ __nv_fp8x2_e5m2 make_fp8x2_e5m2(__nv_fp8_storage_t x, __nv_fp8_storage_t y) { - __nv_fp8x2_e5m2 result; - result.__x = (x) | (y << 8); - return result; - } - __device__ __nv_fp8x4_e5m2 make_fp8x4_e5m2(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) { - __nv_fp8x4_e5m2 result; - result.__x = (a) | (b << 8) | (c << 16) | (d << 24); - return result; - } - __device__ __nv_fp8x2_e4m3 make_fp8x2_e4m3(__nv_fp8_storage_t x, __nv_fp8_storage_t y) { - __nv_fp8x2_e4m3 result; - result.__x = (x) | (y << 8); - return result; - } - __device__ __nv_fp8x4_e4m3 make_fp8x4_e4m3(__nv_fp8_storage_t a, __nv_fp8_storage_t b, __nv_fp8_storage_t c, __nv_fp8_storage_t d) { - __nv_fp8x4_e4m3 result; - result.__x = (a) | (b << 8) | (c << 16) | (d << 24); - return result; - } )"; } if (enable_fp4) { @@ -542,6 +522,30 @@ __host__ __device__ nv_bfloat162 cast_to_nv_bfloat162(const __nv_fp8x2_e4m3& fp8 )"; } } + if (enable_fp8) { + stream << R"( +__device__ __nv_fp8x2_e5m2 make___nv_fp8x2_e5m2(__nv_fp8_e5m2 x, __nv_fp8_e5m2 y) { + __nv_fp8x2_e5m2 result; + result.__x = (x.__x) | (y.__x << 8); + return result; +} +__device__ __nv_fp8x4_e5m2 make___nv_fp8x4_e5m2(__nv_fp8_e5m2 a, __nv_fp8_e5m2 b, __nv_fp8_e5m2 c, __nv_fp8_e5m2 d) { + __nv_fp8x4_e5m2 result; + result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24); + return result; +} +__device__ __nv_fp8x2_e4m3 make___nv_fp8x2_e4m3(__nv_fp8_e4m3 x, __nv_fp8_e4m3 y) { + __nv_fp8x2_e4m3 result; + result.__x = (x.__x) | (y.__x << 8); + return result; +} +__device__ __nv_fp8x4_e4m3 make___nv_fp8x4_e4m3(__nv_fp8_e4m3 a, __nv_fp8_e4m3 b, __nv_fp8_e4m3 c, __nv_fp8_e4m3 d) { + __nv_fp8x4_e4m3 result; + result.__x = (a.__x) | (b.__x << 8) | (c.__x << 16) | (d.__x << 24); + return result; +} +)"; + } if (enable_fp4) { stream << R"( __device__ __nv_fp4x2_e2m1 make___nv_fp4x2_e2m1(__nv_fp4_e2m1 x, __nv_fp4_e2m1 y) { diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc index 34b46583d5ad..3c117b58a7a3 100644 --- a/src/tir/ir/expr_functor.cc +++ b/src/tir/ir/expr_functor.cc @@ -279,10 +279,11 @@ PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) { PrimExpr ExprMutator::VisitExpr_(const ShuffleNode* op) { auto fexpr = [this](const PrimExpr& e) { return this->VisitExpr(e); }; auto vectors = op->vectors.Map(fexpr); - if (vectors.same_as(op->vectors)) { + auto indices = op->indices.Map(fexpr); + if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) { return GetRef(op); } else { - return Shuffle(vectors, op->indices); + return Shuffle(vectors, indices); } } diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index ec290e48d457..58ce6d61742a 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -503,7 +503,11 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype.with_scalable_vscale_factor(lanes), op->op, {value}); } else { - return Call(op->dtype.with_lanes(lanes), op->op, {value}); + int new_lanes = (op->dtype != DataType::NVFloat4E2M1FN() && + op->args[0].dtype() != DataType::NVFloat4E2M1FN()) + ? (value.dtype().bits() * value.dtype().lanes()) / op->dtype.bits() + : value.dtype().lanes(); + return Call(op->dtype.with_lanes(new_lanes), op->op, {value}); } } } @@ -624,6 +628,68 @@ class Vectorizer : public StmtMutator, public ExprFunctorvectors.size() == 1 && op->indices.size() == 1) + << "Cannot vectorize ShuffleNode with multiple vectors or indices: the vector size is " + << op->vectors.size() << " and the index size is " << op->indices.size(); + int lane_vectors = 0; + int lane_indices = 0; + Array vectors = MutateArray(op->vectors, &lane_vectors); + Array indices = MutateArray(op->indices, &lane_indices); + if (vectors.same_as(op->vectors) && indices.same_as(op->indices)) { + return GetRef(op); + } + + int new_vec_length = Downcast(var_lanes_)->value / op->vectors[0].dtype().lanes(); + PrimExpr updated_index = indices[0]; + // Check that the indices satisfy the specific patterns. + auto f_check_index = [this, op](const PrimExpr& index) { + // Allowing Ramp(0, 1, var_lanes_) + if (const auto* ramp = index.as()) { + if (ramp->base->IsInstance() && Downcast(ramp->base)->value == 0 && + ramp->stride->IsInstance() && Downcast(ramp->stride)->value == 1 && + ramp->lanes->IsInstance() && + Downcast(ramp->lanes)->value == Downcast(var_lanes_)->value) { + return true; + } + } + // Allowing FloorMod(Ramp(0, 1, var_lanes_), Broadcast(op->vectors[0]->lanes, var_lanes_)) + if (const auto* floordiv = index.as()) { + if (const auto* ramp = floordiv->a.as()) { + if (const auto* broadcast = floordiv->b.as()) { + if (ramp->base->IsInstance() && Downcast(ramp->base)->value == 0 && + ramp->stride->IsInstance() && + Downcast(ramp->stride)->value == 1 && + ramp->lanes->IsInstance() && + Downcast(ramp->lanes)->value == Downcast(var_lanes_)->value && + broadcast->value->IsInstance() && + Downcast(broadcast->value)->value == op->vectors[0]->dtype.lanes() && + broadcast->lanes->IsInstance() && + Downcast(broadcast->lanes)->value == Downcast(var_lanes_)->value) { + return true; + } + } + } + } + + return false; + }; + CHECK(f_check_index(updated_index)); + + if (new_vec_length == 1) { + return tir::Substitute(op->vectors[0], {{var_, tvm::IntImm(var_->dtype, 0)}}); + } else { + PrimExpr prev_ramp = ramp_; + PrimExpr prev_var_lanes = var_lanes_; + ramp_ = Ramp(IntImm(var_->dtype, 0), IntImm(var_->dtype, 2), new_vec_length); + var_lanes_ = tvm::IntImm(var_lanes_.dtype(), new_vec_length); + lane_vectors = 0; + vectors = MutateArray(op->vectors, &lane_vectors); + ramp_ = prev_ramp; + var_lanes_ = prev_var_lanes; + return vectors[0]; + } + } // BufferStore Stmt VisitStmt_(const BufferStoreNode* op) final { auto store = GetRef(op); diff --git a/tests/python/codegen/test_target_codegen_cuda_fp4.py b/tests/python/codegen/test_target_codegen_cuda_fp4.py index 0a170026c96b..14820ec34f09 100644 --- a/tests/python/codegen/test_target_codegen_cuda_fp4.py +++ b/tests/python/codegen/test_target_codegen_cuda_fp4.py @@ -211,5 +211,84 @@ def reinterpret( ) +@tvm.testing.requires_cuda_compute_version(10) +def test_e2m1_dequantize(): + n = 128 + + dev = tvm.device("cuda", 0) + target = tvm.target.Target.from_device(dev) + num_elem_per_storage = 32 // 4 + + def get_reinterpret_mod(func_type, vector_length): + @T.prim_func + def shuffle_reinterpret( + A: T.Buffer((n // num_elem_per_storage,), "uint32"), + B: T.Buffer((n,), "float16"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(n): + with T.block("C"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.Shuffle( + [ + T.reinterpret( + "float4_e2m1fnx2", + T.bitwise_and( + T.shift_right( + A[v_i // num_elem_per_storage], + ((v_i % num_elem_per_storage) // 2 * 4 * 2).astype( + "uint32" + ), + ), + T.uint32((1 << (4 * 2)) - 1), + ).astype("uint8"), + ).astype("float16x2") + ], + indices=[v_i % 2], + ) + + @T.prim_func + def scalar_reinterpret( + A: T.Buffer((n // num_elem_per_storage,), "uint32"), + B: T.Buffer((n,), "float16"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(n): + with T.block("C"): + v_i = T.axis.spatial(n, i) + T.reads(A[v_i]) + T.writes(B[v_i]) + B[v_i] = T.reinterpret( + "float4_e2m1fn", + T.bitwise_and( + T.shift_right( + A[v_i // num_elem_per_storage], + (v_i % num_elem_per_storage * 4).astype("uint32"), + ), + T.uint32((1 << 4) - 1), + ).astype("uint8"), + ).astype("float16") + + func = shuffle_reinterpret if func_type == "shuffle" else scalar_reinterpret + sch = tvm.tir.Schedule(func) + block = sch.get_block("C") + b = sch.get_loops(block) + bx, tx, vec = sch.split(b[0], factors=[None, 32, vector_length]) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + return sch.mod + + # We only test the whether the code can be compiled. + for func_type, vector_length in product(["shuffle", "scalar"], [1, 2, 4]): + if func_type == "shuffle" and vector_length == 1: + # Vectorize is necessary for shuffle. + continue + mod = get_reinterpret_mod(func_type, vector_length) + tvm.compile(mod, target=target) + + if __name__ == "__main__": tvm.testing.main()