From ede8630053145aa76744705a9ee3c273828522f3 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 5 Feb 2025 19:02:43 -0800 Subject: [PATCH 1/2] Fix DAct input ordering of gradient input and activation input Signed-off-by: Jeremy Berchtold --- tests/cpp/operator/test_act.cu | 195 ++++++++++++++++++ tests/cpp/operator/test_cast_mxfp8.cu | 4 +- tests/cpp/test_common.cu | 5 + .../common/activation/activation_template.h | 2 +- .../common/util/cast_kernels.cuh | 7 +- 5 files changed, 207 insertions(+), 6 deletions(-) diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index 7a6f389c40..370c02a974 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -17,6 +17,7 @@ #include #include +#include #include "../test_common.h" using namespace transformer_engine; @@ -160,6 +161,131 @@ void performTest(const size_t N, const size_t H) { } } +std::vector getDBiasWorkspaceShape(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) { + auto input_shape = std::vector{batch_size, hidden_size}; + auto dact_input_shape = std::vector{batch_size, hidden_size}; + auto output_shape = std::vector{batch_size, hidden_size}; + auto output_trans_shape = std::vector{hidden_size, batch_size}; + auto dbias_shape = std::vector{hidden_size}; + + // Evil hack to specify TE impl + // Note: nvte_quantize_dbias_dgelu chooses its internal impl based + // on what pointers are allocated, e.g. whether to output with + // column-wise data. However, we don't have access to any allocated + // buffers in this function. We pass a dummy pointer as a + // workaround. + int temp = 0; + + auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); + auto dact_input_tensor = + TensorWrapper(reinterpret_cast(&temp), dact_input_shape, in_dtype); + auto output_tensor = TensorWrapper(); + output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); + output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, output_trans_shape); + auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); + + TensorWrapper dummy_workspace; + + // For now, all dbias_dact(-s) have the same workspace size + nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(), + dbias_tensor.data(), dummy_workspace.data(), nullptr); + + auto work_shape = std::vector(dummy_workspace.shape().data, dummy_workspace.shape().data + dummy_workspace.shape().ndim); + return work_shape; + // return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); +} + +template +void performTestDActZeroGradInput(const size_t N, const size_t H) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + // const NVTETensor input, const NVTETensor activation_input, + // NVTETensor output, NVTETensor dbias, NVTETensor workspace, + // cudaStream_t stream + + Tensor input({ N, H }, itype); + Tensor igrad({ N, H }, otype); + Tensor ograd({ N, H }, itype); + Tensor dbias({ H }, itype); + auto workspace_shape = getDBiasWorkspaceShape(N, H, itype, otype); + Tensor workspace(workspace_shape, DType::kFloat32); + + fillUniform(&input); + // fillUniform(&ograd); + igrad.set_scale(1.f); + fillCase(&ograd, zeros); + + std::unique_ptr ref_igrad = std::make_unique(N*H); + + nvte_dact(ograd.data(), input.data(), igrad.data(), dbias.data(), workspace.data(), 0); + + compute_ref_dact_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(), + ref_igrad.get(), N, H); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + { + auto [atol, rtol] = getTolerances(otype); + compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol); + + // TODO compare amax, scale_inv + } +} + +template +void performTestDAct(const size_t N, const size_t H) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = TypeInfo::dtype; + + // const NVTETensor input, const NVTETensor activation_input, + // NVTETensor output, NVTETensor dbias, NVTETensor workspace, + // cudaStream_t stream + + Tensor input({ N, H }, itype); + Tensor igrad({ N, H }, otype); + Tensor ograd({ N, H }, itype); + Tensor dbias({ H }, itype); + auto workspace_shape = getDBiasWorkspaceShape(N, H, itype, otype); + Tensor workspace(workspace_shape, DType::kFloat32); + + fillUniform(&input); + fillUniform(&ograd); + igrad.set_scale(1.f); + + std::unique_ptr ref_igrad = std::make_unique(N*H); + + nvte_dact(ograd.data(), input.data(), igrad.data(), dbias.data(), workspace.data(), 0); + + compute_ref_dact_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(), + ref_igrad.get(), N, H); + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + { + auto [atol, rtol] = getTolerances(otype); + compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol); + + // TODO compare amax, scale_inv + } +} + template (size.first, size.second););); } +class DActZeroGradTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(DActZeroGradTestSuite, TestDGELUDBias) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTestDActZeroGradInput(size.first, size.second); + ); + ); +} + +class DActTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(DActTestSuite, TestDGELUDBias) { + using namespace transformer_engine; + using namespace test; + + const DType input_type = std::get<0>(GetParam()); + const DType output_type = std::get<1>(GetParam()); + const auto size = std::get<2>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTestDAct(size.first, size.second); + ); + ); +} + + namespace { std::vector> act_test_cases = {{2048, 12288}, @@ -410,3 +575,33 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<2>(info.param).second); return name; }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DActZeroGradTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3), + ::testing::Values(std::make_pair(128, 128))), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)) + "X" + + std::to_string(std::get<2>(info.param).first) + "X" + + std::to_string(std::get<2>(info.param).second); + return name; + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + DActTestSuite, + ::testing::Combine( + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3), + ::testing::Values(std::make_pair(128, 128))), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)) + "X" + + test::typeName(std::get<1>(info.param)) + "X" + + std::to_string(std::get<2>(info.param).first) + "X" + + std::to_string(std::get<2>(info.param).second); + return name; + }); diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 67f36b4f7e..6e05b6c228 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -241,7 +241,7 @@ void performTest_x1(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(act_input.data(), input.data(), output_c.data(), 0); + nvte_dgelu(input.data(), act_input.data(), output_c.data(), 0); break; } case ProcessingMethod::CAST_ACT: { @@ -381,7 +381,7 @@ void performTest_x2(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(act_input.data(), input.data(), output.data(), 0); + nvte_dgelu(input.data(), act_input.data(), output.data(), 0); break; } case ProcessingMethod::CAST_ACT: { diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index c03deb9a02..d07e7a8079 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -731,6 +731,11 @@ void fillCase(Tensor *t, const InputsFillCase fill_case) { template void fillCase(Tensor *t, const InputsFillCase fill_case); template void fillCase(Tensor *t, const InputsFillCase fill_case); template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); void setRandomScale(Tensor *t) { static std::mt19937 gen(12345); diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 438c546a9a..040a435048 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -46,7 +46,7 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, nullptr, output, dbias, + quantize_helper(grad, input, nullptr, output, dbias, workspace, stream); } diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 36387f8357..998cc1ab63 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -563,10 +563,10 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const bool row_out_of_bounds = row >= rows; const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; - float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); + float elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); if constexpr (IS_DACT) { elt = OP(elt, {}); - elt *= static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); } if constexpr (IS_DBIAS) { if constexpr (IS_DACT) { @@ -1122,7 +1122,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons stream); } else { // Unaligned - CastVectorizedUnaryGradKernelLauncher(act_input, input, output, stream); + NVTE_CHECK(act_input != nullptr, "Activation input must be provided for DAct."); + CastVectorizedUnaryGradKernelLauncher(&input, *act_input, output, stream); } } else { cast_fp8_2D(input, act_input, output, dbias, workspace, From 42f35609208f1e69698ced9e3348b3f3a5892447 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 6 Feb 2025 09:37:11 -0800 Subject: [PATCH 2/2] Use random scale in new tests Signed-off-by: Jeremy Berchtold --- tests/cpp/operator/test_act.cu | 37 ++++++++++++++-------------------- 1 file changed, 15 insertions(+), 22 deletions(-) diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index 370c02a974..13d026799b 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -49,7 +49,8 @@ void compute_ref_dact_cast(const IT *input_h, const IT *grad_h, OT *output_h, const size_t N, - const size_t H) { + const size_t H, + float scale = 1.f) { using CT = float; #pragma omp parallel for schedule(static) proc_bind(spread) for (size_t i = 0; i < N; i++) { @@ -57,7 +58,7 @@ void compute_ref_dact_cast(const IT *input_h, CT elt = static_cast(input_h[i * H + j]); elt = dact(elt); CT grad = static_cast(grad_h[i * H + j]); - output_h[i * H + j] = static_cast(grad * elt); + output_h[i * H + j] = static_cast(grad * elt * scale); } } } @@ -192,7 +193,6 @@ std::vector getDBiasWorkspaceShape(size_t batch_size, size_t hidden_size auto work_shape = std::vector(dummy_workspace.shape().data, dummy_workspace.shape().data + dummy_workspace.shape().ndim); return work_shape; - // return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); } template ::dtype; DType otype = TypeInfo::dtype; - // const NVTETensor input, const NVTETensor activation_input, - // NVTETensor output, NVTETensor dbias, NVTETensor workspace, - // cudaStream_t stream - Tensor input({ N, H }, itype); Tensor igrad({ N, H }, otype); Tensor ograd({ N, H }, itype); @@ -218,26 +214,24 @@ void performTestDActZeroGradInput(const size_t N, const size_t H) { Tensor workspace(workspace_shape, DType::kFloat32); fillUniform(&input); - // fillUniform(&ograd); - igrad.set_scale(1.f); fillCase(&ograd, zeros); + setRandomScale(&igrad); + float iGradScale = igrad.scale(); std::unique_ptr ref_igrad = std::make_unique(N*H); nvte_dact(ograd.data(), input.data(), igrad.data(), dbias.data(), workspace.data(), 0); compute_ref_dact_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(), - ref_igrad.get(), N, H); + ref_igrad.get(), N, H, iGradScale); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); { - auto [atol, rtol] = getTolerances(otype); - compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol); - - // TODO compare amax, scale_inv + auto [atol, rtol] = getTolerances(DType::kFloat32); + compareResults("scale_inv", *igrad.rowwise_cpu_scale_inv_ptr(), 1.f / iGradScale, atol, rtol); } } @@ -252,10 +246,6 @@ void performTestDAct(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - // const NVTETensor input, const NVTETensor activation_input, - // NVTETensor output, NVTETensor dbias, NVTETensor workspace, - // cudaStream_t stream - Tensor input({ N, H }, itype); Tensor igrad({ N, H }, otype); Tensor ograd({ N, H }, itype); @@ -265,14 +255,15 @@ void performTestDAct(const size_t N, const size_t H) { fillUniform(&input); fillUniform(&ograd); - igrad.set_scale(1.f); + setRandomScale(&igrad); + float iGradScale = igrad.scale(); std::unique_ptr ref_igrad = std::make_unique(N*H); nvte_dact(ograd.data(), input.data(), igrad.data(), dbias.data(), workspace.data(), 0); compute_ref_dact_cast(input.rowwise_cpu_dptr(), ograd.rowwise_cpu_dptr(), - ref_igrad.get(), N, H); + ref_igrad.get(), N, H, iGradScale); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -281,8 +272,10 @@ void performTestDAct(const size_t N, const size_t H) { { auto [atol, rtol] = getTolerances(otype); compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol); - - // TODO compare amax, scale_inv + } + { + auto [atol, rtol] = getTolerances(DType::kFloat32); + compareResults("scale_inv", *igrad.rowwise_cpu_scale_inv_ptr(), 1.f / iGradScale, atol, rtol); } }