From 59eef99f0453fdde94906b5da0636097a8a30910 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 15 Aug 2023 18:03:17 -0700 Subject: [PATCH 1/7] Do not include logging macros in installed C headers Signed-off-by: Tim Moon --- transformer_engine/common/common.h | 3 +- .../fused_softmax/scaled_masked_softmax.cu | 4 +- .../scaled_upper_triang_masked_softmax.cu | 4 +- .../common/gemm/cublaslt_gemm.cu | 2 +- .../include/transformer_engine/logging.h | 74 ------------------ transformer_engine/common/util/logging.h | 76 +++++++++++++++++++ transformer_engine/jax/csrc/logging.h | 40 ++++++++++ transformer_engine/jax/csrc/modules.h | 5 +- transformer_engine/jax/csrc/utils.h | 7 +- transformer_engine/paddle/csrc/common.h | 10 ++- transformer_engine/paddle/csrc/logging.h | 36 +++++++++ transformer_engine/pytorch/csrc/common.h | 52 +++++++------ transformer_engine/pytorch/csrc/logging.h | 29 +++++++ 13 files changed, 227 insertions(+), 115 deletions(-) delete mode 100644 transformer_engine/common/include/transformer_engine/logging.h create mode 100644 transformer_engine/common/util/logging.h create mode 100644 transformer_engine/jax/csrc/logging.h create mode 100644 transformer_engine/paddle/csrc/logging.h create mode 100644 transformer_engine/pytorch/csrc/logging.h diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 04e7b15230..0aed8ab79e 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -8,7 +8,6 @@ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_ #include -#include #include #include #include @@ -22,6 +21,8 @@ #include #include "nvtx.h" +#include "./util/logging.h" + namespace transformer_engine { struct SimpleTensor { diff --git a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu index 20c75a2125..c8b1662a40 100644 --- a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu @@ -5,7 +5,6 @@ ************************************************************************/ #include -#include #include #include #include @@ -14,8 +13,9 @@ #include #include #include -#include "../utils.cuh" #include "../common.h" +#include "../utils.cuh" +#include "../util/logging.h" namespace transformer_engine { diff --git a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu index 57d88dac07..3fe46d8d13 100644 --- a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -5,7 +5,6 @@ ************************************************************************/ #include -#include #include #include #include @@ -14,8 +13,9 @@ #include #include #include -#include "../utils.cuh" #include "../common.h" +#include "../utils.cuh" +#include "../util/logging.h" namespace transformer_engine { diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7f8b0b723d..a3e80326ff 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -5,11 +5,11 @@ ************************************************************************/ #include -#include #include #include #include #include "../common.h" +#include "../util/logging.h" namespace { diff --git a/transformer_engine/common/include/transformer_engine/logging.h b/transformer_engine/common/include/transformer_engine/logging.h deleted file mode 100644 index 9ac0bbbde2..0000000000 --- a/transformer_engine/common/include/transformer_engine/logging.h +++ /dev/null @@ -1,74 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_LOGGING_H_ -#define TRANSFORMER_ENGINE_LOGGING_H_ - -#include -#include -#include -#include -#include -#include - -#define NVTE_ERROR(x) \ - do { \ - throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \ - " in function " + __func__ + ": " + x); \ - } while (false) - -#define NVTE_CHECK(x, ...) \ - do { \ - if (!(x)) { \ - NVTE_ERROR(std::string("Assertion failed: " #x ". ") + std::string(__VA_ARGS__)); \ - } \ - } while (false) - -namespace { - -inline void check_cuda_(cudaError_t status) { - if ( status != cudaSuccess ) { - NVTE_ERROR("CUDA Error: " + std::string(cudaGetErrorString(status))); - } -} - -inline void check_cublas_(cublasStatus_t status) { - if ( status != CUBLAS_STATUS_SUCCESS ) { - NVTE_ERROR("CUBLAS Error: " + std::string(cublasGetStatusString(status))); - } -} - -inline void check_cudnn_(cudnnStatus_t status) { - if ( status != CUDNN_STATUS_SUCCESS ) { - std::string message; - message.reserve(1024); - message += "CUDNN Error: "; - message += cudnnGetErrorString(status); - message += (". " - "For more information, enable cuDNN error logging " - "by setting CUDNN_LOGERR_DBG=1 and " - "CUDNN_LOGDEST_DBG=stderr in the environment."); - NVTE_ERROR(message); - } -} - -inline void check_nvrtc_(nvrtcResult status) { - if ( status != NVRTC_SUCCESS ) { - NVTE_ERROR("NVRTC Error: " + std::string(nvrtcGetErrorString(status))); - } -} - -} // namespace - -#define NVTE_CHECK_CUDA(ans) { check_cuda_(ans); } - -#define NVTE_CHECK_CUBLAS(ans) { check_cublas_(ans); } - -#define NVTE_CHECK_CUDNN(ans) { check_cudnn_(ans); } - -#define NVTE_CHECK_NVRTC(ans) { check_nvrtc_(ans); } - -#endif // TRANSFORMER_ENGINE_LOGGING_H_ diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h new file mode 100644 index 0000000000..e95ca80830 --- /dev/null +++ b/transformer_engine/common/util/logging.h @@ -0,0 +1,76 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ + +#include + +#include +#include +#include +#include + +#include "../util/string.h" + +#define NVTE_ERROR(...) \ + do { \ + throw ::std::runtime_error( \ + ::transformer_engine::concat_strings( \ + __FILE__ ":", __LINE__, \ + " in function ", __func__, ": ", \ + ::transformer_engine::concat_strings(__VA_ARGS__))); \ + } while (false) + +#define NVTE_CHECK(expr, ...) \ + do { \ + if (!(expr)) { \ + NVTE_ERROR("Assertion failed: " #expr ". ", \ + ::transformer_engine::concat_strings(__VA_ARGS__)); \ + } \ + } while (false) + +#define NVTE_CHECK_CUDA(expr) \ + do { \ + const cudaError_t status_NVTE_CHECK_CUDA = (expr); \ + if (status_NVTE_CHECK_CUDA != cudaSuccess) { \ + NVTE_ERROR("CUDA Error: ", \ + cudaGetErrorString(status_NVTE_CHECK_CUDA)); \ + } \ + } while (false) + +#define NVTE_CHECK_CUBLAS(expr) \ + do { \ + const cublasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \ + if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \ + NVTE_ERROR("cuBLAS Error: ", \ + cublasGetStatusString(status_NVTE_CHECK_CUBLAS)); \ + } \ + } while (false) + +#define NVTE_CHECK_CUDNN(expr) \ + do { \ + const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \ + if ( status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS ) { \ + NVTE_ERROR("cuDNN Error: ", \ + cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \ + ". " \ + "For more information, enable cuDNN error logging " \ + "by setting CUDNN_LOGERR_DBG=1 and " \ + "CUDNN_LOGDEST_DBG=stderr in the environment."); \ + } \ + } while (false) + +#define NVTE_CHECK_NVRTC(expr) \ + do { \ + const nvrtcResult status_NVTE_CHECK_NVRTC = (expr); \ + if (status_NVTE_CHECK_NVRTC != NVRTC_SUCCESS) { \ + NVTE_ERROR("NVRTC Error: ", \ + nvrtcGetErrorString(status_NVTE_CHECK_NVRTC)); \ + } \ + } while (false) + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/jax/csrc/logging.h b/transformer_engine/jax/csrc/logging.h new file mode 100644 index 0000000000..621e8a4f78 --- /dev/null +++ b/transformer_engine/jax/csrc/logging.h @@ -0,0 +1,40 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_JAX_CSRC_LOGGING_H_ +#define TRANSFORMER_ENGINE_JAX_CSRC_LOGGING_H_ + +#include +#include + +#include + +#define NVTE_ERROR(message) \ + do { \ + throw std::runtime_error(std::string(__FILE__ ":") \ + + std::to_string(__LINE__) \ + + " in function " + __func__ + ": " \ + + message); \ + } while (false) + +#define NVTE_CHECK(expr, ...) \ + do { \ + if (!(expr)) { \ + NVTE_ERROR(std::string("Assertion failed: " #x ". ") \ + + std::string(__VA_ARGS__)); \ + } \ + } while (false) + +#define NVTE_CHECK_CUDA(expr) \ + do { \ + const cudaError_t status_NVTE_CHECK_CUDA = (expr); \ + if (status_NVTE_CHECK_CUDA != cudaSuccess) { \ + NVTE_ERROR(std::string("CUDA Error: ") \ + + cudaGetErrorString(status_NVTE_CHECK_CUDA)); \ + } \ + } while (false) + +#endif // TRANSFORMER_ENGINE_JAX_CSRC_LOGGING_H_ diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index 75b4df574f..1c02690fb6 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -7,19 +7,18 @@ #ifndef TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ #define TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ -#include - #include #include #include #include +#include #include #include #include "transformer_engine/fused_attn.h" -#include "transformer_engine/logging.h" #include "transformer_engine/transformer_engine.h" +#include "logging.h" namespace transformer_engine { namespace jax { diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index 0ecd765b28..b508200b99 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -7,14 +7,15 @@ #ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_ #define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_ -#include - #include #include #include #include + +#include + #include "transformer_engine/fused_attn.h" -#include "transformer_engine/logging.h" +#include "logging.h" namespace transformer_engine { namespace jax { diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index 3f24dac075..455bc42344 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -5,21 +5,23 @@ ************************************************************************/ #pragma once +#include +#include + #include +#include "paddle/extension.h" + #include #include #include #include #include -#include #include #include #include #include -#include -#include -#include "paddle/extension.h" +#include "logging.h" namespace transformer_engine { namespace paddle_ext { diff --git a/transformer_engine/paddle/csrc/logging.h b/transformer_engine/paddle/csrc/logging.h new file mode 100644 index 0000000000..8214f1d90b --- /dev/null +++ b/transformer_engine/paddle/csrc/logging.h @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#pragma once + +#include +#include + +#include + +#define NVTE_ERROR(message) \ + do { \ + throw std::runtime_error(std::string(__FILE__ ":") \ + + std::to_string(__LINE__) \ + + " in function " + __func__ + ": " \ + + message); \ + } while (false) + +#define NVTE_CHECK(expr, ...) \ + do { \ + if (!(expr)) { \ + NVTE_ERROR(std::string("Assertion failed: " #x ". ") \ + + std::string(__VA_ARGS__)); \ + } \ + } while (false) + +#define NVTE_CHECK_CUDA(expr) \ + do { \ + const cudaError_t status_NVTE_CHECK_CUDA = (expr); \ + if (status_NVTE_CHECK_CUDA != cudaSuccess) { \ + NVTE_ERROR(std::string("CUDA Error: ") \ + + cudaGetErrorString(status_NVTE_CHECK_CUDA)); \ + } \ + } while (false) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 7c17f1f34c..6c646a6a10 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -7,38 +7,40 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include + #include -#include -#include -#include #include -#include +#include #include #include -#include -#include +#include +#include +#include +#include #include -#include #include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "logging.h" namespace transformer_engine { diff --git a/transformer_engine/pytorch/csrc/logging.h b/transformer_engine/pytorch/csrc/logging.h new file mode 100644 index 0000000000..9ed7da1eba --- /dev/null +++ b/transformer_engine/pytorch/csrc/logging.h @@ -0,0 +1,29 @@ +/************************************************************************* + * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_LOGGING_H_ +#define TRANSFORMER_ENGINE_PYTORCH_CSRC_LOGGING_H_ + +#include +#include + +#define NVTE_ERROR(message) \ + do { \ + throw std::runtime_error(std::string(__FILE__ ":") \ + + std::to_string(__LINE__) \ + + " in function " + __func__ + ": " \ + + message); \ + } while (false) + +#define NVTE_CHECK(expr, ...) \ + do { \ + if (!(expr)) { \ + NVTE_ERROR(std::string("Assertion failed: " #x ". ") \ + + std::string(__VA_ARGS__)); \ + } \ + } while (false) + +#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_LOGGING_H_ From 39fcdf4e32217d7ffa91a4de131790721b3cee95 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 15 Aug 2023 20:20:38 -0700 Subject: [PATCH 2/7] Debug logging macros Signed-off-by: Tim Moon --- transformer_engine/jax/csrc/logging.h | 2 +- transformer_engine/paddle/csrc/logging.h | 2 +- transformer_engine/pytorch/csrc/logging.h | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/csrc/logging.h b/transformer_engine/jax/csrc/logging.h index 621e8a4f78..b2faad58df 100644 --- a/transformer_engine/jax/csrc/logging.h +++ b/transformer_engine/jax/csrc/logging.h @@ -23,7 +23,7 @@ #define NVTE_CHECK(expr, ...) \ do { \ if (!(expr)) { \ - NVTE_ERROR(std::string("Assertion failed: " #x ". ") \ + NVTE_ERROR(std::string("Assertion failed: " #expr ". ") \ + std::string(__VA_ARGS__)); \ } \ } while (false) diff --git a/transformer_engine/paddle/csrc/logging.h b/transformer_engine/paddle/csrc/logging.h index 8214f1d90b..15f4b1214d 100644 --- a/transformer_engine/paddle/csrc/logging.h +++ b/transformer_engine/paddle/csrc/logging.h @@ -21,7 +21,7 @@ #define NVTE_CHECK(expr, ...) \ do { \ if (!(expr)) { \ - NVTE_ERROR(std::string("Assertion failed: " #x ". ") \ + NVTE_ERROR(std::string("Assertion failed: " #expr ". ") \ + std::string(__VA_ARGS__)); \ } \ } while (false) diff --git a/transformer_engine/pytorch/csrc/logging.h b/transformer_engine/pytorch/csrc/logging.h index 9ed7da1eba..6773959fcd 100644 --- a/transformer_engine/pytorch/csrc/logging.h +++ b/transformer_engine/pytorch/csrc/logging.h @@ -10,6 +10,7 @@ #include #include +#ifndef NVTE_ERROR #define NVTE_ERROR(message) \ do { \ throw std::runtime_error(std::string(__FILE__ ":") \ @@ -17,13 +18,16 @@ + " in function " + __func__ + ": " \ + message); \ } while (false) +#endif // NVTE_ERROR +#ifndef NVTE_CHECK #define NVTE_CHECK(expr, ...) \ do { \ if (!(expr)) { \ - NVTE_ERROR(std::string("Assertion failed: " #x ". ") \ + NVTE_ERROR(std::string("Assertion failed: " #expr ". ") \ + std::string(__VA_ARGS__)); \ } \ } while (false) +#endif // NVTE_CHECK #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_LOGGING_H_ From 9140a428559a963524ff651d5234e9879d92e544 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 16 Aug 2023 12:55:47 -0700 Subject: [PATCH 3/7] Debug C++ tests Use Google style for header includes. Signed-off-by: Tim Moon --- tests/cpp/operator/test_cast_transpose.cu | 17 ++++++++-------- .../cpp/operator/test_cast_transpose_dbias.cu | 17 ++++++++-------- .../test_cast_transpose_dbias_dgelu.cu | 17 ++++++++-------- .../operator/test_cast_transpose_dgeglu.cu | 11 +++++----- tests/cpp/operator/test_dgeglu.cu | 11 +++++----- tests/cpp/operator/test_geglu.cu | 11 +++++----- tests/cpp/operator/test_gelu.cu | 15 +++++++------- tests/cpp/operator/test_layernorm.cu | 18 +++++++++-------- .../cpp/operator/test_multi_cast_transpose.cu | 13 ++++++------ tests/cpp/operator/test_qdq.cu | 20 +++++++++---------- tests/cpp/operator/test_rmsnorm.cu | 12 ++++++----- tests/cpp/operator/test_transpose.cu | 17 ++++++++-------- tests/cpp/test_common.cu | 9 ++++++--- tests/cpp/test_common.h | 13 ++++++------ transformer_engine/common/common.h | 17 ++++++++-------- .../fused_softmax/scaled_masked_softmax.cu | 7 +++++-- .../scaled_upper_triang_masked_softmax.cu | 7 +++++-- .../common/gemm/cublaslt_gemm.cu | 5 +++-- transformer_engine/common/util/logging.h | 2 +- transformer_engine/jax/csrc/modules.h | 4 ++-- transformer_engine/jax/csrc/utils.h | 2 +- transformer_engine/paddle/csrc/common.h | 1 - transformer_engine/pytorch/csrc/common.h | 1 - 23 files changed, 135 insertions(+), 112 deletions(-) diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 9d89f32484..2a63ff50fb 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -4,16 +4,17 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include -#include -#include +#include #include +#include +#include #include -#include + +#include +#include +#include + +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu index 24b940ef0b..8f722a1aa9 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias.cu @@ -4,17 +4,18 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include +#include +#include #include -#include #include +#include #include -#include -#include + +#include +#include +#include + +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index 1f8d93cb7a..aef645ffc2 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -4,17 +4,18 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include +#include +#include #include -#include #include +#include #include -#include -#include + +#include +#include +#include + +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_cast_transpose_dgeglu.cu b/tests/cpp/operator/test_cast_transpose_dgeglu.cu index 263a353e49..e301d5eda0 100644 --- a/tests/cpp/operator/test_cast_transpose_dgeglu.cu +++ b/tests/cpp/operator/test_cast_transpose_dgeglu.cu @@ -4,17 +4,18 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include #include #include #include #include #include #include + +#include +#include +#include + +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_dgeglu.cu b/tests/cpp/operator/test_dgeglu.cu index c00f220618..aa9e8e870a 100644 --- a/tests/cpp/operator/test_dgeglu.cu +++ b/tests/cpp/operator/test_dgeglu.cu @@ -4,11 +4,6 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include #include #include #include @@ -16,6 +11,12 @@ #include #include #include + +#include +#include +#include + +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_geglu.cu b/tests/cpp/operator/test_geglu.cu index 12c32a65cf..aa02cb0e73 100644 --- a/tests/cpp/operator/test_geglu.cu +++ b/tests/cpp/operator/test_geglu.cu @@ -4,11 +4,6 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include #include #include #include @@ -16,6 +11,12 @@ #include #include #include + +#include +#include +#include + +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_gelu.cu b/tests/cpp/operator/test_gelu.cu index 709b61d7a9..ba32da46a0 100644 --- a/tests/cpp/operator/test_gelu.cu +++ b/tests/cpp/operator/test_gelu.cu @@ -4,18 +4,19 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include #include +#include #include -#include #include +#include #include -#include #include + +#include +#include +#include + +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_layernorm.cu b/tests/cpp/operator/test_layernorm.cu index b869c8cd80..b6da6205fd 100644 --- a/tests/cpp/operator/test_layernorm.cu +++ b/tests/cpp/operator/test_layernorm.cu @@ -4,17 +4,19 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include +#include +#include #include -#include #include +#include #include -#include -#include + +#include +#include +#include + +#include +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index a783111101..6527da90bb 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -4,17 +4,18 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include #include -#include #include +#include #include #include #include + +#include +#include +#include + +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_qdq.cu b/tests/cpp/operator/test_qdq.cu index 1dee213a00..f814e73cfa 100644 --- a/tests/cpp/operator/test_qdq.cu +++ b/tests/cpp/operator/test_qdq.cu @@ -4,19 +4,19 @@ * See LICENSE for license information. ************************************************************************/ -#include "gtest/gtest.h" -#include -#include -#include -#include -#include -#include -#include +#include #include +#include +#include #include -#include + +#include +#include +#include + +#include +#include #include "../test_common.h" -#include "transformer_engine/transformer_engine.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_rmsnorm.cu b/tests/cpp/operator/test_rmsnorm.cu index a44bf33d5d..4c01b39c0d 100644 --- a/tests/cpp/operator/test_rmsnorm.cu +++ b/tests/cpp/operator/test_rmsnorm.cu @@ -4,17 +4,19 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include #include #include #include #include #include #include + +#include +#include +#include + +#include +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/operator/test_transpose.cu b/tests/cpp/operator/test_transpose.cu index 7f631b6fa0..f7536dd067 100644 --- a/tests/cpp/operator/test_transpose.cu +++ b/tests/cpp/operator/test_transpose.cu @@ -4,16 +4,17 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include -#include -#include -#include -#include -#include +#include #include +#include +#include #include -#include + +#include +#include +#include + +#include #include "../test_common.h" using namespace transformer_engine; diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index bbb25bb2fc..ee579da987 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -6,13 +6,16 @@ #include "test_common.h" -#include "transformer_engine/logging.h" -#include "transformer_engine/transformer_engine.h" -#include + #include #include #include +#include + +#include +#include "util/logging.h" + namespace test { std::vector all_fp_types = {DType::kFloat32, diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 7278f1827b..087ba29931 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -6,15 +6,17 @@ #pragma once +#include #include -#include -#include -#include +#include + #include +#include #include #include -#include -#include + +#include +#include "util/logging.h" namespace test { using namespace transformer_engine; @@ -252,4 +254,3 @@ bool isFp8Type(DType type); default: \ NVTE_ERROR("Invalid type."); \ } - diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 0aed8ab79e..6b43330d6f 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -7,20 +7,21 @@ #ifndef TRANSFORMER_ENGINE_COMMON_COMMON_H_ #define TRANSFORMER_ENGINE_COMMON_COMMON_H_ -#include -#include -#include -#include -#include -#include -#include #include #include #include #include +#include +#include #include -#include "nvtx.h" +#include +#include +#include +#include + +#include +#include "./nvtx.h" #include "./util/logging.h" namespace transformer_engine { diff --git a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu index c8b1662a40..87666bc0cc 100644 --- a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu @@ -4,15 +4,18 @@ * See LICENSE for license information. ************************************************************************/ -#include #include #include + #include #include + #include -#include #include #include +#include + +#include #include "../common.h" #include "../utils.cuh" #include "../util/logging.h" diff --git a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu index 3fe46d8d13..4235b2231b 100644 --- a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -4,15 +4,18 @@ * See LICENSE for license information. ************************************************************************/ -#include #include #include + #include #include + #include -#include #include #include +#include + +#include #include "../common.h" #include "../utils.cuh" #include "../util/logging.h" diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index a3e80326ff..0f6f945006 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -4,10 +4,11 @@ * See LICENSE for license information. ************************************************************************/ -#include -#include #include #include + +#include +#include #include "../common.h" #include "../util/logging.h" diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index e95ca80830..0e1d47963b 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -9,8 +9,8 @@ #include -#include #include +#include #include #include diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index 1c02690fb6..a068fa637d 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -16,8 +16,8 @@ #include #include -#include "transformer_engine/fused_attn.h" -#include "transformer_engine/transformer_engine.h" +#include +#include #include "logging.h" namespace transformer_engine { diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index b508200b99..46f85eeeb8 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -14,7 +14,7 @@ #include -#include "transformer_engine/fused_attn.h" +#include #include "logging.h" namespace transformer_engine { diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index 455bc42344..c0a84f7641 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -20,7 +20,6 @@ #include #include #include - #include "logging.h" namespace transformer_engine { diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 6c646a6a10..605455f540 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -39,7 +39,6 @@ #include #include #include - #include "logging.h" namespace transformer_engine { From 01ce05c674a1f7a02d335000743c5c377b16724e Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 18 Aug 2023 14:58:25 -0700 Subject: [PATCH 4/7] Update CUDA driver macros Incorporating changes from #389. Co-authored-by: Tim Moon Co-authored-by: Jan Bielak Signed-off-by: Tim Moon --- .../common/gemm/cublaslt_gemm.cu | 9 ++-- .../include/transformer_engine/transpose.h | 3 -- transformer_engine/common/util/cuda_driver.h | 41 ++++++++----------- transformer_engine/common/util/logging.h | 2 +- 4 files changed, 23 insertions(+), 32 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 0f6f945006..8a6ede2d5f 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -230,9 +230,12 @@ void cublas_gemm(const Tensor *inputA, preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); - NVTE_CHECK_CUBLAS(cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, - Ddesc, preference, 1, &heuristicResult, - &returnedResults)); + const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, + Ddesc, preference, 1, &heuristicResult, + &returnedResults); + NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED, + "Unable to find suitable cuBLAS GEMM algorithm"); + NVTE_CHECK_CUBLAS(status); if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index b12e3f8096..6eb653a359 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -146,9 +146,6 @@ void nvte_multi_cast_transpose(size_t num_tensors, * - `cast_output` is the result of the cast * - `transposed_output` is the transposed result of the cast. * - * Calling this function with workspace being an empty tensor will not perform the operation, - * but instead set the shape and type of the workspace tensor to the required values. - * * \param[in] input Input tensor of shape [N, H]. * \param[in] geglu_input Tensor used as input to the forward of GeGLU operation. * Shape [N, H * 2]. diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 5d07e7a641..6b460a1335 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -43,30 +43,21 @@ inline CUresult call(const char *symbol, ArgTs... args) { } // namespace transformer_engine -namespace { - -/*! \brief Throw exception if CUDA driver call has failed */ -inline void check_cuda_driver_(CUresult status) { - if (status != CUDA_SUCCESS) { - const char *description; - transformer_engine::cuda_driver::call("cuGetErrorString", &description); - NVTE_ERROR(transformer_engine::concat_strings("CUDA Error: ", description)); - } -} - -/*! \brief Call CUDA driver function and throw exception if it fails */ -template -inline void call_and_check_cuda_driver_(const char *symbol, - ArgTs &&... args) { - check_cuda_driver_(transformer_engine::cuda_driver::call(symbol, - std::forward(args)...)); -} - -} // namespace - -#define NVTE_CHECK_CUDA_DRIVER(ans) { check_cuda_driver_(ans); } - -#define NVTE_CALL_CHECK_CUDA_DRIVER(func, ...) \ - { call_and_check_cuda_driver_(#func, __VA_ARGS__); } +#define NVTE_CHECK_CUDA_DRIVER(expr) \ + do { \ + const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr); \ + if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) { \ + const char *desc_NVTE_CHECK_CUDA_DRIVER; \ + ::transformer_engine::cuda_driver::call("cuGetErrorString", \ + &desc_NVTE_CHECK_CUDA_DRIVER); \ + NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); \ + } \ + } while (false) + +#define NVTE_CALL_CHECK_CUDA_DRIVER(symbol, ...) \ + do { \ + NVTE_CHECK_CUDA_DRIVER( \ + ::transformer_engine::cuda_driver::call(#symbol, __VA_ARGS__)); \ + } while (false) #endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_DRIVER_H_ diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 0e1d47963b..b08a096888 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -54,7 +54,7 @@ #define NVTE_CHECK_CUDNN(expr) \ do { \ const cudnnStatus_t status_NVTE_CHECK_CUDNN = (expr); \ - if ( status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS ) { \ + if (status_NVTE_CHECK_CUDNN != CUDNN_STATUS_SUCCESS) { \ NVTE_ERROR("cuDNN Error: ", \ cudnnGetErrorString(status_NVTE_CHECK_CUDNN), \ ". " \ From 546bcfc1d01fe180740ce325f2afd2d6739dfb35 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 18 Aug 2023 15:21:55 -0700 Subject: [PATCH 5/7] Use core error checking macros in PyTorch extensions Hack to get around macro redefinition warning. Signed-off-by: Tim Moon --- .../pytorch/csrc/comm_gemm_overlap.h | 7 ++++-- transformer_engine/pytorch/csrc/logging.h | 23 +------------------ 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 5dd71e4758..50839f2512 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -4,17 +4,20 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include + #include #include #include #include -#include -#include #include #include #include #include + #include "userbuffers/userbuffers.h" +#include "logging.h" #define HALF_BYTES 2 #define UB_MAX_SM 32 diff --git a/transformer_engine/pytorch/csrc/logging.h b/transformer_engine/pytorch/csrc/logging.h index 6773959fcd..ac6e182940 100644 --- a/transformer_engine/pytorch/csrc/logging.h +++ b/transformer_engine/pytorch/csrc/logging.h @@ -7,27 +7,6 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_LOGGING_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_LOGGING_H_ -#include -#include - -#ifndef NVTE_ERROR -#define NVTE_ERROR(message) \ - do { \ - throw std::runtime_error(std::string(__FILE__ ":") \ - + std::to_string(__LINE__) \ - + " in function " + __func__ + ": " \ - + message); \ - } while (false) -#endif // NVTE_ERROR - -#ifndef NVTE_CHECK -#define NVTE_CHECK(expr, ...) \ - do { \ - if (!(expr)) { \ - NVTE_ERROR(std::string("Assertion failed: " #expr ". ") \ - + std::string(__VA_ARGS__)); \ - } \ - } while (false) -#endif // NVTE_CHECK +#include "../util/logging.h" #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_LOGGING_H_ From 62b429ef155bd2085b0cfb151acbdd954885af33 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 18 Aug 2023 15:29:11 -0700 Subject: [PATCH 6/7] Fix missing arg when getting CUDA driver error string Signed-off-by: Tim Moon --- transformer_engine/common/util/cuda_driver.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index 6b460a1335..805bb48317 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -48,8 +48,10 @@ inline CUresult call(const char *symbol, ArgTs... args) { const CUresult status_NVTE_CHECK_CUDA_DRIVER = (expr); \ if (status_NVTE_CHECK_CUDA_DRIVER != CUDA_SUCCESS) { \ const char *desc_NVTE_CHECK_CUDA_DRIVER; \ - ::transformer_engine::cuda_driver::call("cuGetErrorString", \ - &desc_NVTE_CHECK_CUDA_DRIVER); \ + ::transformer_engine::cuda_driver::call( \ + "cuGetErrorString", \ + status_NVTE_CHECK_CUDA_DRIVER, \ + &desc_NVTE_CHECK_CUDA_DRIVER); \ NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER); \ } \ } while (false) From 598f5b50b09f1dcc2f280cc45f1bad9e011c2c7c Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 1 Sep 2023 17:27:32 -0700 Subject: [PATCH 7/7] Reuse logging header in frameworks Signed-off-by: Tim Moon --- setup.py | 2 + transformer_engine/jax/csrc/logging.h | 40 ------------------- transformer_engine/jax/csrc/modules.h | 2 +- transformer_engine/jax/csrc/utils.h | 2 +- transformer_engine/paddle/csrc/common.h | 2 +- transformer_engine/paddle/csrc/custom_ops.cu | 3 +- transformer_engine/paddle/csrc/logging.h | 36 ----------------- .../pytorch/csrc/comm_gemm_overlap.h | 2 +- transformer_engine/pytorch/csrc/common.h | 2 +- transformer_engine/pytorch/csrc/extensions.h | 2 +- transformer_engine/pytorch/csrc/logging.h | 12 ------ 11 files changed, 10 insertions(+), 95 deletions(-) delete mode 100644 transformer_engine/jax/csrc/logging.h delete mode 100644 transformer_engine/paddle/csrc/logging.h delete mode 100644 transformer_engine/pytorch/csrc/logging.h diff --git a/setup.py b/setup.py index 5959c2b941..740d6cb604 100644 --- a/setup.py +++ b/setup.py @@ -480,6 +480,7 @@ def setup_pytorch_extension() -> setuptools.Extension: include_dirs = [ root_path / "transformer_engine" / "common" / "include", root_path / "transformer_engine" / "pytorch" / "csrc", + root_path / "transformer_engine", root_path / "3rdparty" / "cudnn-frontend" / "include", ] @@ -552,6 +553,7 @@ def setup_paddle_extension() -> setuptools.Extension: include_dirs = [ root_path / "transformer_engine" / "common" / "include", root_path / "transformer_engine" / "paddle" / "csrc", + root_path / "transformer_engine", ] # Compiler flags diff --git a/transformer_engine/jax/csrc/logging.h b/transformer_engine/jax/csrc/logging.h deleted file mode 100644 index b2faad58df..0000000000 --- a/transformer_engine/jax/csrc/logging.h +++ /dev/null @@ -1,40 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_JAX_CSRC_LOGGING_H_ -#define TRANSFORMER_ENGINE_JAX_CSRC_LOGGING_H_ - -#include -#include - -#include - -#define NVTE_ERROR(message) \ - do { \ - throw std::runtime_error(std::string(__FILE__ ":") \ - + std::to_string(__LINE__) \ - + " in function " + __func__ + ": " \ - + message); \ - } while (false) - -#define NVTE_CHECK(expr, ...) \ - do { \ - if (!(expr)) { \ - NVTE_ERROR(std::string("Assertion failed: " #expr ". ") \ - + std::string(__VA_ARGS__)); \ - } \ - } while (false) - -#define NVTE_CHECK_CUDA(expr) \ - do { \ - const cudaError_t status_NVTE_CHECK_CUDA = (expr); \ - if (status_NVTE_CHECK_CUDA != cudaSuccess) { \ - NVTE_ERROR(std::string("CUDA Error: ") \ - + cudaGetErrorString(status_NVTE_CHECK_CUDA)); \ - } \ - } while (false) - -#endif // TRANSFORMER_ENGINE_JAX_CSRC_LOGGING_H_ diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index f4ee841221..aa475f73b4 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -16,9 +16,9 @@ #include #include +#include "common/util/logging.h" #include #include -#include "logging.h" namespace transformer_engine { namespace jax { diff --git a/transformer_engine/jax/csrc/utils.h b/transformer_engine/jax/csrc/utils.h index eb1fb14181..c18832429b 100644 --- a/transformer_engine/jax/csrc/utils.h +++ b/transformer_engine/jax/csrc/utils.h @@ -15,8 +15,8 @@ #include +#include "common/util/logging.h" #include -#include "logging.h" namespace transformer_engine { namespace jax { diff --git a/transformer_engine/paddle/csrc/common.h b/transformer_engine/paddle/csrc/common.h index c0a84f7641..4b9391853e 100644 --- a/transformer_engine/paddle/csrc/common.h +++ b/transformer_engine/paddle/csrc/common.h @@ -11,6 +11,7 @@ #include #include "paddle/extension.h" +#include "common/util/logging.h" #include #include #include @@ -20,7 +21,6 @@ #include #include #include -#include "logging.h" namespace transformer_engine { namespace paddle_ext { diff --git a/transformer_engine/paddle/csrc/custom_ops.cu b/transformer_engine/paddle/csrc/custom_ops.cu index 76f8987306..7c76878607 100644 --- a/transformer_engine/paddle/csrc/custom_ops.cu +++ b/transformer_engine/paddle/csrc/custom_ops.cu @@ -5,8 +5,9 @@ ************************************************************************/ #include -#include "../common.h" + #include "common.h" +#include "common/common.h" namespace transformer_engine { namespace paddle_ext { diff --git a/transformer_engine/paddle/csrc/logging.h b/transformer_engine/paddle/csrc/logging.h deleted file mode 100644 index 15f4b1214d..0000000000 --- a/transformer_engine/paddle/csrc/logging.h +++ /dev/null @@ -1,36 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ -#pragma once - -#include -#include - -#include - -#define NVTE_ERROR(message) \ - do { \ - throw std::runtime_error(std::string(__FILE__ ":") \ - + std::to_string(__LINE__) \ - + " in function " + __func__ + ": " \ - + message); \ - } while (false) - -#define NVTE_CHECK(expr, ...) \ - do { \ - if (!(expr)) { \ - NVTE_ERROR(std::string("Assertion failed: " #expr ". ") \ - + std::string(__VA_ARGS__)); \ - } \ - } while (false) - -#define NVTE_CHECK_CUDA(expr) \ - do { \ - const cudaError_t status_NVTE_CHECK_CUDA = (expr); \ - if (status_NVTE_CHECK_CUDA != cudaSuccess) { \ - NVTE_ERROR(std::string("CUDA Error: ") \ - + cudaGetErrorString(status_NVTE_CHECK_CUDA)); \ - } \ - } while (false) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 50839f2512..2b4cdaefab 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -16,8 +16,8 @@ #include #include +#include "common/util/logging.h" #include "userbuffers/userbuffers.h" -#include "logging.h" #define HALF_BYTES 2 #define UB_MAX_SM 32 diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 605455f540..836be00167 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -30,6 +30,7 @@ #include #include +#include "common/util/logging.h" #include #include #include @@ -39,7 +40,6 @@ #include #include #include -#include "logging.h" namespace transformer_engine { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d06906b5a2..32f293646f 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -5,7 +5,7 @@ ************************************************************************/ #include "common.h" -#include "../common.h" +#include "common/common.h" NVTE_Fused_Attn_Backend get_fused_attn_backend( const transformer_engine::DType q_dtype, diff --git a/transformer_engine/pytorch/csrc/logging.h b/transformer_engine/pytorch/csrc/logging.h deleted file mode 100644 index ac6e182940..0000000000 --- a/transformer_engine/pytorch/csrc/logging.h +++ /dev/null @@ -1,12 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_LOGGING_H_ -#define TRANSFORMER_ENGINE_PYTORCH_CSRC_LOGGING_H_ - -#include "../util/logging.h" - -#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_LOGGING_H_