diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 6a6712a4ce26..4c0107e43802 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -349,6 +349,47 @@ void CodeGenWebGPU::VisitExpr_(const BufferLoadNode* op, std::ostream& os) { // } } +void CodeGenWebGPU::VisitExpr_(const ShuffleNode* op, std::ostream& os) { // NOLINT(*) + // First concatenate op->vectors, then return those indexed by op->indices. + // 1. concat_vec = concat(op->vectors) + // 2. Return (concat_vec[op->indices[0]], concat_vec[op->indices[1]], ...) + + // 1. Print expr first in case each expr have their own nested expressions. + // e.g. if op->vectors is [f32 a, vec3 b, vec2 c] + // then concat_vec is ["a", "b[0]", "b[1]", "b[2]", "c[0]", "c[1]"] + std::vector concat_vec; + for (PrimExpr vec : op->vectors) { + std::string vec_value = this->PrintExpr(vec); + if (vec.dtype().lanes() == 1) { + concat_vec.push_back(vec_value); + } else { + // Print out each element of vec + for (int i = 0; i < vec.dtype().lanes(); ++i) { + std::ostringstream vec_elem_strm; + vec_elem_strm << vec_value << "[" << i << "]"; + concat_vec.push_back(vec_elem_strm.str()); + } + } + } + + // 2. Print out shuffle + if (op->indices.size() == 1) { + // If only accessing one element (ExtractElement), directly print the value + // e.g. if op->indices is [1], then print "b[0]" + os << concat_vec[Downcast(op->indices[0])->value]; + } else { + // Otherwise, print the shuffle as a vector constructor + // e.g. if op->indices is [0, 3, 5], then print "vec3(a, b[2], c[1])" + PrintType(op->dtype, os); + os << '('; + for (size_t i = 0; i < op->indices.size(); ++i) { + if (i != 0) os << ", "; + os << concat_vec[Downcast(op->indices[i])->value]; + } + os << ')'; + } +} + void CodeGenWebGPU::VisitStmt_(const LetStmtNode* op) { // use ssa form. if (print_ssa_form_) { diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index 6ae942a3ad49..53247866252c 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -67,6 +67,7 @@ class CodeGenWebGPU final : public CodeGenC { void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) // stmt printing void VisitStmt_(const LetStmtNode* op) final;