From 31903241ed80d6166d1c3cffbf80c9605c9c12c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ser=C3=B3dio?= Date: Thu, 14 Nov 2024 03:04:31 -0300 Subject: [PATCH 1/3] increase precision when printing 64 bit constants (FloatImmNode visit) in CUDA codegen --- src/target/source/codegen_cuda.cc | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index bd2804830172..eae25fe048ab 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1381,26 +1381,38 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) } // Type code is kFloat switch (op->dtype.bits()) { - case 64: - case 32: { + case 64: { std::ostringstream temp; if (std::isinf(op->value)) { if (op->value < 0) { temp << "-"; } - temp << ((op->dtype.bits() == 32) ? "CUDART_INF_F" : "CUDART_INF"); + temp << "CUDART_INF"; p->need_math_constants_h_ = true; } else if (std::isnan(op->value)) { - temp << ((op->dtype.bits() == 32) ? "CUDART_NAN_F" : "CUDART_NAN"); + temp << "CUDART_NAN"; p->need_math_constants_h_ = true; } else { - temp << std::scientific << op->value; - if (op->dtype.bits() == 32) temp << 'f'; + temp << std::fixed << std::setprecision(15) << op->value; } p->MarkConst(temp.str()); os << temp.str(); break; } + case 32: { + std::ostringstream temp; + if (std::isinf(op->value)) { + if (op->value < 0) { + temp << "-"; + } + temp << "CUDART_INF_F"; + p->need_math_constants_h_ = true; + } else if (std::isnan(op->value)) { + temp << "CUDART_NAN_F"; + p->need_math_constants_h_ = true; + } else { + temp << std::scientific << op->value << 'f'; + } case 16: { os << "__float2half_rn" << '('; FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); From ec4b856516d0c64f4f9622af68e3b51387d38d03 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ser=C3=B3dio?= Date: Tue, 26 Nov 2024 12:57:19 -0300 Subject: [PATCH 2/3] add missing bracket from switch case --- src/target/source/codegen_cuda.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index eae25fe048ab..4dc5d8d406aa 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1413,6 +1413,7 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) } else { temp << std::scientific << op->value << 'f'; } + } case 16: { os << "__float2half_rn" << '('; FloatImm const_f32 = FloatImm(DataType::Float(32), op->value); From 901c37ea327e36d923dcf2db1b7707c9e8e64bcc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Ser=C3=B3dio?= Date: Tue, 26 Nov 2024 16:54:26 -0300 Subject: [PATCH 3/3] fix switch case 32 --- src/target/source/codegen_cuda.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 4dc5d8d406aa..826b3d94e0d9 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -1413,6 +1413,9 @@ inline void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p) } else { temp << std::scientific << op->value << 'f'; } + p->MarkConst(temp.str()); + os << temp.str(); + break; } case 16: { os << "__float2half_rn" << '(';