diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index b76b05470d5d..83079a9f0756 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -433,6 +433,27 @@ void CodeGenWebGPU::VisitExpr_(const SelectNode* op, std::ostream& os) { // NOL << PrintExpr(op->condition) << ")"; } +void CodeGenWebGPU::VisitExpr_(const LetNode* op, std::ostream& os) { // NOLINT(*) + // use ssa form. + if (print_ssa_form_) { + std::string value = PrintExpr(op->value); + ICHECK(!var_idmap_.count(op->var.get())); + var_idmap_[op->var.get()] = value; + } else { + PrintIndent(); + std::string value = PrintExpr(op->value); + this->stream << "let " << AllocVarID(op->var.get()) << " : "; + PrintType(op->var.dtype(), this->stream); + this->stream << " = " << value << ";\n"; + } + os << PrintExpr(op->body); + // Pop the defined var from var_idmap when exiting its scope. + // We do this because it is hard to completely avoid a same LetNode appearing + // at different places. + bool removed = var_idmap_.erase(op->var.get()); + ICHECK(removed); +} + void CodeGenWebGPU::VisitExpr_(const IntImmNode* op, std::ostream& os) { // NOLINT(*) if (op->dtype.bits() == 32) { std::ostringstream temp; diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index a100396b25a2..09f99fb88600 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -63,7 +63,8 @@ class CodeGenWebGPU final : public CodeGenC { void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const BufferLoadNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const CastNode* op, std::ostream& os) final; // NOLINT(*) - void VisitExpr_(const SelectNode* op, std::ostream& os) override; // NOLINT(*) + void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) + void VisitExpr_(const LetNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const IntImmNode* op, std::ostream& os) final; // NOLINT(*)