From d7d61cd84c140c3f6226537ff93916bebd91ab77 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 6 Feb 2025 14:32:46 -0800 Subject: [PATCH 01/11] Ensure that each tensor is seeded differently Signed-off-by: Przemek Tredak --- tests/cpp/operator/test_act.cu | 8 +-- tests/cpp/operator/test_cast.cu | 4 +- tests/cpp/operator/test_cast_dbias.cu | 8 +-- tests/cpp/operator/test_cast_dbias_dgelu.cu | 10 ++-- tests/cpp/operator/test_cast_gated_swiglu.cu | 6 +-- tests/cpp/operator/test_cast_mxfp8.cu | 24 ++++----- .../operator/test_cast_mxfp8_gated_swiglu.cu | 14 ++--- tests/cpp/operator/test_cast_transpose.cu | 4 +- .../cpp/operator/test_cast_transpose_dbias.cu | 8 +-- .../test_cast_transpose_dbias_dgelu.cu | 10 ++-- .../operator/test_cast_transpose_dgeglu.cu | 6 +-- tests/cpp/operator/test_causal_softmax.cu | 10 ++-- tests/cpp/operator/test_dequantize_mxfp8.cu | 14 ++--- .../cpp/operator/test_multi_cast_transpose.cu | 5 +- tests/cpp/operator/test_multi_padding.cu | 5 +- tests/cpp/operator/test_normalization.cu | 28 +++++----- .../cpp/operator/test_normalization_mxfp8.cu | 18 +++---- tests/cpp/operator/test_qdq.cu | 8 +-- tests/cpp/operator/test_swizzle.cu | 4 +- tests/cpp/operator/test_transpose.cu | 4 +- tests/cpp/test_common.cu | 53 ++++++++++--------- tests/cpp/test_common.h | 14 +++-- 22 files changed, 140 insertions(+), 125 deletions(-) diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index 7a6f389c40..fa12e3d824 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -116,10 +116,10 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N, H }, itype); - Tensor output({ N, H }, otype); - Tensor igrad({ N, H }, itype); - Tensor ograd({ N, H }, itype); + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype); + Tensor igrad("igrad", { N, H }, itype); + Tensor ograd("ograd", { N, H }, itype); fillUniform(&input); fillUniform(&ograd); diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu index be0b6acf04..f57d1f035d 100644 --- a/tests/cpp/operator/test_cast.cu +++ b/tests/cpp/operator/test_cast.cu @@ -44,8 +44,8 @@ void performTest(const std::vector& shape) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input(shape, itype); - Tensor output_c(shape, otype); + Tensor input("input", shape, itype); + Tensor output_c("output_c", shape, otype); std::unique_ptr ref_output_c = std::make_unique(full_size); diff --git a/tests/cpp/operator/test_cast_dbias.cu b/tests/cpp/operator/test_cast_dbias.cu index 20ae33e304..1f0a9305d8 100644 --- a/tests/cpp/operator/test_cast_dbias.cu +++ b/tests/cpp/operator/test_cast_dbias.cu @@ -66,11 +66,11 @@ void performTest(const std::vector& shape) { const size_t N = first_dimension(shape); const size_t H = last_dimension(shape); - Tensor input(shape, itype); + Tensor input("input", shape, itype); - Tensor output_c(shape, otype); + Tensor output_c("output_c", shape, otype); // dbias has the same data type with "output grad" - Tensor dbias({H}, itype); + Tensor dbias("dbias", {H}, itype); fillUniform(&input); setRandomScale(&output_c); @@ -94,7 +94,7 @@ void performTest(const std::vector& shape) { workspace.data(), 0); - workspace = Tensor(workspace.rowwise_shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_quantize_dbias(input.data(), output_c.data(), diff --git a/tests/cpp/operator/test_cast_dbias_dgelu.cu b/tests/cpp/operator/test_cast_dbias_dgelu.cu index 1fb6acf834..b951632ec5 100644 --- a/tests/cpp/operator/test_cast_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_dbias_dgelu.cu @@ -74,12 +74,12 @@ void performTest(const std::vector& shape) { const size_t N = first_dimension(shape); const size_t H = last_dimension(shape); - Tensor input(shape, itype); - Tensor gelu_input(shape, itype); + Tensor input("input", shape, itype); + Tensor gelu_input("gelu_input", shape, itype); - Tensor output_c(shape, otype); + Tensor output_c("output_c", shape, otype); // dbias has the same data type with "output grad" - Tensor dbias({H}, itype); + Tensor dbias("dbias", {H}, itype); fillUniform(&input); fillUniform(&gelu_input); @@ -106,7 +106,7 @@ void performTest(const std::vector& shape) { workspace.data(), 0); - workspace = Tensor(workspace.rowwise_shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_quantize_dbias_dgelu(input.data(), diff --git a/tests/cpp/operator/test_cast_gated_swiglu.cu b/tests/cpp/operator/test_cast_gated_swiglu.cu index 5129a8fd19..35ae462106 100644 --- a/tests/cpp/operator/test_cast_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_gated_swiglu.cu @@ -72,9 +72,9 @@ void performTest(const std::vector& shape) { const size_t rows = first_dimension(shape); const size_t cols = last_dimension(shape); - Tensor grad(shape, itype); - Tensor input(input_shape, itype); - Tensor output_c(input_shape, otype); + Tensor grad("grad", shape, itype); + Tensor input("input", input_shape, itype); + Tensor output_c("output_c", input_shape, otype); fillUniform(&grad); fillUniform(&input); diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 67f36b4f7e..5c03e59510 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -190,10 +190,10 @@ void performTest_x1(const ProcessingMethod processing_method, const size_t blocks_X = scale_dims[3]; const size_t scales_stride = blocks_X; - Tensor input(shape, itype); - Tensor act_input(shape, itype); - Tensor output_c(shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); - Tensor output_dbias({ cols }, itype); + Tensor input("input", shape, itype); + Tensor act_input("act_input", shape, itype); + Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + Tensor output_dbias("output_dbias", { cols }, itype); std::unique_ptr ref_output_c = std::make_unique(rows * cols); std::unique_ptr ref_output_dbias = std::make_unique(cols); @@ -214,7 +214,7 @@ void performTest_x1(const ProcessingMethod processing_method, output_dbias.data(), workspace.data(), 0); - workspace = Tensor(workspace.rowwise_shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_quantize_dbias(input.data(), output_c.data(), @@ -230,7 +230,7 @@ void performTest_x1(const ProcessingMethod processing_method, output_dbias.data(), workspace.data(), 0); - workspace = Tensor(workspace.rowwise_shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_quantize_dbias_dgelu(input.data(), act_input.data(), @@ -328,10 +328,10 @@ void performTest_x2(const ProcessingMethod processing_method, const size_t blocks_X_colwise = scale_dims_colwise[3]; const size_t scales_stride_colwise = blocks_X_colwise; - Tensor input(shape, itype); - Tensor act_input(shape, itype); - Tensor output(shape, otype, true, true, NVTE_MXFP8_1D_SCALING); - Tensor output_dbias({ cols }, itype); + Tensor input("input", shape, itype); + Tensor act_input("act_input", shape, itype); + Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor output_dbias("output_dbias", { cols }, itype); std::unique_ptr ref_output_c_rowwise = std::make_unique(rows * cols); std::unique_ptr ref_output_c_colwise = std::make_unique(rows * cols); @@ -354,7 +354,7 @@ void performTest_x2(const ProcessingMethod processing_method, output_dbias.data(), workspace.data(), 0); - workspace = Tensor(workspace.rowwise_shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_quantize_dbias(input.data(), output.data(), @@ -370,7 +370,7 @@ void performTest_x2(const ProcessingMethod processing_method, output_dbias.data(), workspace.data(), 0); - workspace = Tensor(workspace.rowwise_shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_quantize_dbias_dgelu(input.data(), act_input.data(), diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index e22a6d70ea..6acbdefeab 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -204,8 +204,8 @@ void performTest_x1(const size_t rows, // std::cout << "blocks_X: " << blocks_X << std::endl; // std::cout << "scales_stride: " << scales_stride << std::endl; - Tensor grad({ rows, cols }, itype); - Tensor input({ rows, cols * 2 }, itype); + Tensor grad("grad", { rows, cols }, itype); + Tensor input("input", { rows, cols * 2 }, itype); const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; @@ -218,7 +218,8 @@ void performTest_x1(const size_t rows, const size_t blocks_X = scale_dims[3]; const size_t scales_stride = blocks_X; - Tensor output(std::vector{ rows, output_cols }, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + Tensor output("output", std::vector{ rows, output_cols }, otype, + rowwise, colwise, NVTE_MXFP8_1D_SCALING); std::unique_ptr ref_output = std::make_unique(rows * output_cols); std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X); @@ -288,8 +289,8 @@ void performTest_x2(const size_t rows, DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor grad({ rows, cols }, itype); - Tensor input({ rows, cols * 2 }, itype); + Tensor grad("grad", { rows, cols }, itype); + Tensor input("input", { rows, cols * 2 }, itype); const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; @@ -308,7 +309,8 @@ void performTest_x2(const size_t rows, const size_t blocks_X_colwise = scale_dims_colwise[3]; const size_t scales_stride_colwise = blocks_X_colwise; - Tensor output(std::vector{ rows, output_cols }, otype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor output("output", std::vector{ rows, output_cols }, otype, + true, true, NVTE_MXFP8_1D_SCALING); std::unique_ptr ref_output_rowwise = std::make_unique(rows * output_cols); std::unique_ptr ref_output_colwise = std::make_unique(rows * output_cols); diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index e42671fe27..830682eec3 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -45,8 +45,8 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N, H }, itype); - Tensor output({ N, H }, otype, true, true); + Tensor input("input", { N, H }, itype); + Tensor output("output", { N, H }, otype, true, true); std::unique_ptr ref_output_c = std::make_unique(N * H); std::unique_ptr ref_output_t = std::make_unique(N * H); diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu index 68126a1ea0..53918e2699 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias.cu @@ -65,11 +65,11 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({N, H}, itype); + Tensor input("input", {N, H}, itype); - Tensor output({N, H}, otype, true, true); + Tensor output("output", {N, H}, otype, true, true); // dbias has the same data type with "output grad" - Tensor dbias({H}, itype); + Tensor dbias("dbias", {H}, itype); fillUniform(&input); setRandomScale(&output); @@ -95,7 +95,7 @@ void performTest(const size_t N, const size_t H) { workspace.data(), 0); - workspace = Tensor(workspace.rowwise_shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_quantize_dbias(input.data(), diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index ef38560418..e0e769257a 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -76,12 +76,12 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({N, H}, itype); - Tensor gelu_input({N, H}, itype); + Tensor input("input"{N, H}, itype); + Tensor gelu_input("gelu_input", {N, H}, itype); - Tensor output({N, H}, otype, true, true); + Tensor output("output", {N, H}, otype, true, true); // dbias has the same data type with "output grad" - Tensor dbias({H}, itype); + Tensor dbias("dbias", {H}, itype); fillUniform(&input); fillUniform(&gelu_input); @@ -110,7 +110,7 @@ void performTest(const size_t N, const size_t H) { workspace.data(), 0); - workspace = Tensor(workspace.rowwise_shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_cast_transpose_dbias_dgelu(input.data(), diff --git a/tests/cpp/operator/test_cast_transpose_dgeglu.cu b/tests/cpp/operator/test_cast_transpose_dgeglu.cu index f107829e0f..ae2da7bad2 100644 --- a/tests/cpp/operator/test_cast_transpose_dgeglu.cu +++ b/tests/cpp/operator/test_cast_transpose_dgeglu.cu @@ -74,9 +74,9 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor grad({N, H}, itype); - Tensor input({N, H * 2}, itype); - Tensor output({N, H * 2}, otype, true, true); + Tensor grad("grad", {N, H}, itype); + Tensor input("input", {N, H * 2}, itype); + Tensor output("output", {N, H * 2}, otype, true, true); fillUniform(&grad); fillUniform(&input); diff --git a/tests/cpp/operator/test_causal_softmax.cu b/tests/cpp/operator/test_causal_softmax.cu index d4c4154c17..2fdc0a524d 100644 --- a/tests/cpp/operator/test_causal_softmax.cu +++ b/tests/cpp/operator/test_causal_softmax.cu @@ -153,11 +153,11 @@ void performTest( DType itype = TypeInfo::dtype; - Tensor data_in({ batches, heads, rows, cols }, itype); - Tensor softmax_out({ batches, heads, rows, cols }, itype); - Tensor softmax_in({ batches, heads, rows, cols }, itype); - Tensor grads_in({ batches, heads, rows, cols }, itype); - Tensor grads_out({ batches, heads, rows, cols }, itype); + Tensor data_in("data_in", { batches, heads, rows, cols }, itype); + Tensor softmax_out("softmax_out", { batches, heads, rows, cols }, itype); + Tensor softmax_in("softmax_in", { batches, heads, rows, cols }, itype); + Tensor grads_in("grads_in", { batches, heads, rows, cols }, itype); + Tensor grads_out("grads_out", { batches, heads, rows, cols }, itype); const size_t elements_total = batches * heads * rows * cols; std::unique_ptr softmax_out_ref = std::make_unique(elements_total); diff --git a/tests/cpp/operator/test_dequantize_mxfp8.cu b/tests/cpp/operator/test_dequantize_mxfp8.cu index 1a090c3a5c..701deb38bb 100644 --- a/tests/cpp/operator/test_dequantize_mxfp8.cu +++ b/tests/cpp/operator/test_dequantize_mxfp8.cu @@ -194,10 +194,10 @@ void performTest_x1(const size_t rows, const size_t blocks_num = rowwise ? blocks_num_rowwise : blocks_num_colwise; const size_t scales_stride = rowwise ? blocks_X_rowwise : blocks_X_colwise; - Tensor input({ rows, cols }, itype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + Tensor input("input", { rows, cols }, itype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); // Output data are written to the rowwise ptr regardless of the scaling direction - Tensor output({ rows, cols }, otype, true, false); + Tensor output("output", { rows, cols }, otype, true, false); std::unique_ptr ref_output = std::make_unique(rows * cols); std::unique_ptr scales = std::make_unique(blocks_num); @@ -247,11 +247,11 @@ void performTest_quantize_then_dequantize(const size_t rows, // input --> quantized --> output (dequantized) // input == output - Tensor input({ rows, cols }, in_type); - Tensor quantized({ rows, cols }, intermed_type, rowwise, colwise, NVTE_MXFP8_1D_SCALING); + Tensor input("input", { rows, cols }, in_type); + Tensor quantized("quantized", { rows, cols }, intermed_type, rowwise, colwise, NVTE_MXFP8_1D_SCALING); // Output data are written to the rowwise ptr regardless of the scaling direction - Tensor output({ rows, cols }, out_type, true, false); + Tensor output("output", { rows, cols }, out_type, true, false); // fillCase(&input, InputsFillCase::minNorm_to_maxNorm); fillCase(&input, InputsFillCase::uniform); @@ -313,8 +313,8 @@ void performTest_x2(const size_t rows, const size_t blocks_num_rowwise = blocks_Y_rowwise * blocks_X_rowwise; const size_t blocks_num_colwise = blocks_Y_colwise * blocks_X_colwise; - Tensor input({ rows, cols }, itype, true, true, NVTE_MXFP8_1D_SCALING); - Tensor output({ rows, cols }, otype); + Tensor input("input", { rows, cols }, itype, true, true, NVTE_MXFP8_1D_SCALING); + Tensor output("output", { rows, cols }, otype); std::unique_ptr ref_output_rowwise = std::make_unique(rows * cols); std::unique_ptr ref_output_colwise = std::make_unique(rows * cols); diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index 3a3aae1846..f07138caca 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -81,8 +81,9 @@ void performTest() { for (size_t tensor_id = 0; tensor_id < num_tensors; ++tensor_id) { const size_t height = tensor_dims[tensor_id].first; const size_t width = tensor_dims[tensor_id].second; - input_list.emplace_back(Tensor({ height, width }, itype)); - output_list.emplace_back(Tensor({ height, width }, otype, true, true)); + input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype)); + output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id), + { height, width }, otype, true, true)); auto& input = input_list.back(); auto& output = output_list.back(); diff --git a/tests/cpp/operator/test_multi_padding.cu b/tests/cpp/operator/test_multi_padding.cu index f74c00e32a..b8475fe561 100644 --- a/tests/cpp/operator/test_multi_padding.cu +++ b/tests/cpp/operator/test_multi_padding.cu @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -84,8 +85,8 @@ void performTest() { const size_t height = tensor_dims[tensor_id].first; const size_t width = tensor_dims[tensor_id].second; const size_t padded_height = (height + align - 1) / align * align; - input_list.emplace_back(Tensor({ height, width }, itype)); - output_list.emplace_back(Tensor({ padded_height, width }, otype)); + input_list.emplace_back(Tensor("input_" + std::to_string(tensor_id), { height, width }, itype)); + output_list.emplace_back(Tensor("output_" + std::to_string(tensor_id), { padded_height, width }, otype)); auto& input = input_list.back(); auto& output = output_list.back(); diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index a8b142a603..0004c2ce74 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -191,16 +191,16 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, return; } - Tensor input({ N, H }, itype); - Tensor z({ N, H }, otype); - Tensor gamma({ H }, wtype); - Tensor beta({ H }, wtype); - Tensor mu({ N }, DType::kFloat32); - Tensor rsigma({ N }, DType::kFloat32); - Tensor dz({ N, H }, wtype); - Tensor dx({ N, H }, itype); - Tensor dgamma({ H }, wtype); - Tensor dbeta({ H }, wtype); + Tensor input("input", { N, H }, itype); + Tensor z("z", { N, H }, otype); + Tensor gamma("gamma", { H }, wtype); + Tensor beta("beta", { H }, wtype); + Tensor mu("mu", { N }, DType::kFloat32); + Tensor rsigma("rsigma", { N }, DType::kFloat32); + Tensor dz("dz", { N, H }, wtype); + Tensor dx("dx", { N, H }, itype); + Tensor dgamma("dgamma", { H }, wtype); + Tensor dbeta("dbeta", { H }, wtype); Tensor workspace_fwd, workspace_bwd; fillUniform(&input); @@ -230,7 +230,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_fwd = Tensor(workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); + workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); @@ -240,7 +240,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, dx.data(), dgamma.data(), dbeta.data(), workspace_bwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_bwd = Tensor(workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); nvte_layernorm_bwd(dz.data(), input.data(), mu.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), dbeta.data(), @@ -250,7 +250,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_fwd = Tensor(workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); + workspace_fwd = Tensor("workspace", workspace_fwd.rowwise_shape(), workspace_fwd.dtype()); nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); @@ -259,7 +259,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_bwd = Tensor(workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), dx.data(), dgamma.data(), workspace_bwd.data(), diff --git a/tests/cpp/operator/test_normalization_mxfp8.cu b/tests/cpp/operator/test_normalization_mxfp8.cu index 31fc430c11..d1bdb6203b 100644 --- a/tests/cpp/operator/test_normalization_mxfp8.cu +++ b/tests/cpp/operator/test_normalization_mxfp8.cu @@ -179,12 +179,12 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, DType wtype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N, H }, itype); - Tensor z({ N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING); - Tensor gamma({ H }, wtype); - Tensor beta({ H }, wtype); - Tensor mu({ N }, DType::kFloat32); - Tensor rsigma({ N }, DType::kFloat32); + Tensor input("input", { N, H }, itype); + Tensor z("z", { N, H }, otype, true, is_training, NVTE_MXFP8_1D_SCALING); + Tensor gamma("gamma", { H }, wtype); + Tensor beta("beta", { H }, wtype); + Tensor mu("mu", { N }, DType::kFloat32); + Tensor rsigma("rsigma", { N }, DType::kFloat32); Tensor workspace; @@ -199,7 +199,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, z.data(), mu.data(), rsigma.data(), workspace.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - workspace = Tensor(workspace.rowwise_shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_layernorm_fwd(input.data(), gamma.data(), beta.data(), epsilon, z.data(), mu.data(), rsigma.data(), workspace.data(), prop.multiProcessorCount, zero_centered_gamma, @@ -210,14 +210,14 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, prop.multiProcessorCount, zero_centered_gamma, 0); - workspace = Tensor(workspace.rowwise_shape(), workspace.dtype()); + workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_rmsnorm_fwd(input.data(), gamma.data(), epsilon, z.data(), rsigma.data(), workspace.data(), prop.multiProcessorCount, zero_centered_gamma, 0); } - Tensor dequantized_output({ N, H }, DType::kFloat32, true, true); + Tensor dequantized_output("dequantized_output", { N, H }, DType::kFloat32, true, true); dequantize_2x(z, dequantized_output, is_training); diff --git a/tests/cpp/operator/test_qdq.cu b/tests/cpp/operator/test_qdq.cu index cf73631c83..3c12cef865 100644 --- a/tests/cpp/operator/test_qdq.cu +++ b/tests/cpp/operator/test_qdq.cu @@ -58,8 +58,8 @@ void performTestQ(const size_t N) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N }, itype); - Tensor output({ N }, otype); + Tensor input("input", { N }, itype); + Tensor output("output", { N }, otype); std::unique_ptr ref_output = std::make_unique(N); @@ -89,8 +89,8 @@ void performTestDQ(const size_t N) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({ N }, itype); - Tensor output({ N }, otype); + Tensor input("input", { N }, itype); + Tensor output("output", { N }, otype); std::unique_ptr ref_output = std::make_unique(N); diff --git a/tests/cpp/operator/test_swizzle.cu b/tests/cpp/operator/test_swizzle.cu index 84f3f1a350..f6e0da057a 100644 --- a/tests/cpp/operator/test_swizzle.cu +++ b/tests/cpp/operator/test_swizzle.cu @@ -83,8 +83,8 @@ void performTestSwizzle1D(const int num_tiles_M, const int num_tiles_K, bool row const auto scale_shape = std::vector{data_shape[0] / SF_MODE_X, data_shape[1] /SF_MODE_Y}; std::vector scaling_mode = {SF_MODE_X, SF_MODE_Y, 0}; - Tensor input(data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); - Tensor output(data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + Tensor input("input", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); + Tensor output("output", data_shape, dtype, rowwise, columnwise, NVTE_MXFP8_1D_SCALING); fillUniform(&input); diff --git a/tests/cpp/operator/test_transpose.cu b/tests/cpp/operator/test_transpose.cu index 706091cde6..00dd241c92 100644 --- a/tests/cpp/operator/test_transpose.cu +++ b/tests/cpp/operator/test_transpose.cu @@ -37,8 +37,8 @@ void performTest(const size_t N, const size_t H) { DType dtype = TypeInfo::dtype; - Tensor input({ N, H }, dtype); - Tensor output({ H, N }, dtype); + Tensor input("input", { N, H }, dtype); + Tensor output("output", { H, N }, dtype); std::unique_ptr ref_output = std::make_unique(N * H); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index c03deb9a02..bf91d9f8a0 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -12,6 +12,7 @@ #include #include #include +#include #include #include @@ -21,6 +22,12 @@ namespace test { +size_t create_seed_from_tensor_name(const std::string& tensor_name) { + auto full_name = std::string(testing::UnitTest::GetInstance()->current_test_info()->name()) + + "/" + tensor_name; + return std::hash{}(full_name); +} + std::vector all_fp_types = {DType::kFloat32, DType::kFloat16, DType::kBFloat16, @@ -163,9 +170,13 @@ std::pair get_scales(const NVTEShape& shape, NVTE_ERROR("Invalid scaling mode!"); } -Tensor::Tensor(const NVTEShape &shape, const DType type, +Tensor::Tensor(const std::string& name, + const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, const NVTEScalingMode &scaling_mode) { + name_ = name; + const size_t seed = create_seed_from_tensor_name(name); + gen_.seed(seed); rowwise_ = rowwise; columnwise_ = columnwise; size_t s = typeToSize(type); @@ -371,11 +382,10 @@ void Tensor::set_scale_inv(float scale_inv) { if (num_scales == 1){ rowwise_cpu_scale_inv_ptr()[0] = scale_inv; } else{ - static std::mt19937 gen(12345); std::uniform_int_distribution dis(0, 127); auto* scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); for (size_t i = 0; i < num_scales; i++){ - scale_inv_ptr[i] = dis(gen); + scale_inv_ptr[i] = dis(gen_); } } } @@ -384,11 +394,10 @@ void Tensor::set_scale_inv(float scale_inv) { if (num_scales == 1){ columnwise_cpu_scale_inv_ptr()[0] = scale_inv; } else{ - static std::mt19937 gen(12345); std::uniform_int_distribution dis(0, 127); auto* scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); for (size_t i = 0; i < num_scales; i++){ - scale_inv_ptr[i] = dis(gen); + scale_inv_ptr[i] = dis(gen_); } } } @@ -632,18 +641,18 @@ std::pair getTolerances(const DType type) { } template -void generate_data_uniformly(T* data, const size_t size) { - const int seed = 12345; +void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { #pragma omp parallel proc_bind(spread) { - std::mt19937 gen(seed); - gen.discard(omp_get_thread_num() * 599); + std::mt19937 gen_local = *gen; + gen_local.discard(omp_get_thread_num() * 599); std::uniform_real_distribution<> dis(-2.0, 1.0); #pragma omp for schedule(static) for (size_t i = 0; i < size; ++i) { - data[i] = static_cast(dis(gen)); + data[i] = static_cast(dis(gen_local)); } } + gen->discard(size); } void fillUniform(Tensor *t) { @@ -652,7 +661,7 @@ void fillUniform(Tensor *t) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { T *data = t->rowwise_cpu_dptr(); - generate_data_uniformly(data, size); + generate_data_uniformly(data, size, &(t->gen())); } ); } else { @@ -660,18 +669,17 @@ void fillUniform(Tensor *t) { TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, { T *data = t->columnwise_cpu_dptr(); - generate_data_uniformly(data, size); + generate_data_uniformly(data, size, &(t->gen())); } ); } - static std::mt19937 gen(12345); std::uniform_real_distribution<> dis(-2.0, 1.0); - t->set_scale_inv(dis(gen)); + t->set_scale_inv(dis(t->gen())); t->from_cpu(); } template -void fillCase_special(Tensor *t) { +void fillCase_special(Tensor *t, const std::string& tensor_name) { const size_t size = product(t->rowwise_shape()); const size_t rows = t->rowwise_shape().data[0]; const size_t cols = t->rowwise_shape().data[1]; @@ -690,7 +698,6 @@ void fillCase_special(Tensor *t) { minAbs = Quantized_Limits::ranges[Case]; maxAbs = Quantized_Limits::ranges[Case + 1]; } - static std::mt19937 gen(12345); std::uniform_real_distribution<> dis(minAbs, maxAbs); std::uniform_real_distribution<> dis_sign(-1.0, 1.0); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(t->dtype(), InputType, { @@ -698,8 +705,8 @@ void fillCase_special(Tensor *t) { for (size_t i = 0; i < rows; ++i) { for (size_t j = 0; j < cols; ++j) { const size_t idx = i * cols + j; - const bool is_negative = (dis_sign(gen) < 0.0); - double val = dis(gen); + const bool is_negative = (dis_sign(t->gen()) < 0.0); + double val = dis(t->gen()); if (is_negative) { val = -val; } @@ -732,17 +739,15 @@ template void fillCase(Tensor *t, const InputsFillCase fill_case); template void fillCase(Tensor *t, const InputsFillCase fill_case); template void fillCase(Tensor *t, const InputsFillCase fill_case); -void setRandomScale(Tensor *t) { - static std::mt19937 gen(12345); +void setRandomScale(Tensor *t, const std::string& tensor_name) { std::uniform_real_distribution<> dis(-2.0, 1.0); - const float scale = dis(gen); + const float scale = dis(t->gen()); t->set_scale(scale); } -void setRandomScaleInv(Tensor *t) { - static std::mt19937 gen(12345); +void setRandomScaleInv(Tensor *t, const std::string& tensor_name) { std::uniform_real_distribution<> dis(-2.0, 1.0); - const float scale_inv = dis(gen); + const float scale_inv = dis(t->gen()); t->set_scale_inv(scale_inv); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index f03649c138..dc515ccb8e 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -6,10 +6,10 @@ #pragma once -#include #include #include #include +#include #include #include @@ -97,17 +97,19 @@ struct TypeInfo{ class Tensor { public: - Tensor(const NVTEShape &shape, const DType type, + Tensor(const std::string& name, + const NVTEShape &shape, const DType type, const bool rowwise = true, const bool columnwise = false, const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING); - Tensor(const std::vector &shape, + Tensor(const std::string& name, + const std::vector &shape, const DType type, const bool rowwise = true, const bool columnwise = false, const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : - Tensor(NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} + Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {} Tensor() {} @@ -260,6 +262,8 @@ class Tensor { void set_scale_inv(float scale_inv); void shareFP8Meta(const Tensor &other); + std::mt19937& gen() { return gen_; } + private: TensorWrapper tensor_; std::unique_ptr cpu_data_rowwise_; @@ -270,6 +274,8 @@ class Tensor { std::unique_ptr columnwise_scale_inv_cpu_data_; bool rowwise_; bool columnwise_; + std::string name_; + std::mt19937 gen_; }; constexpr uint32_t FP32_EXPONENT_BIAS = 127; From 5b34f013d69d8e1f7a9b0c61f66cd4acd5cbeb80 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 6 Feb 2025 14:45:28 -0800 Subject: [PATCH 02/11] Fix Signed-off-by: Przemek Tredak --- tests/cpp/test_common.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index bf91d9f8a0..f191992313 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -679,7 +679,7 @@ void fillUniform(Tensor *t) { } template -void fillCase_special(Tensor *t, const std::string& tensor_name) { +void fillCase_special(Tensor *t) { const size_t size = product(t->rowwise_shape()); const size_t rows = t->rowwise_shape().data[0]; const size_t cols = t->rowwise_shape().data[1]; From 736b242737762dd4d5b4a11e8234c35826006fc0 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 6 Feb 2025 14:46:38 -0800 Subject: [PATCH 03/11] Fix Signed-off-by: Przemek Tredak --- tests/cpp/operator/test_act.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index fa12e3d824..4224f199f4 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -171,10 +171,10 @@ void performTestGLU(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input({N, H * 2}, itype); - Tensor output({N, H}, otype); - Tensor igrad({ N, H * 2 }, itype); - Tensor ograd({ N, H }, itype); + Tensor input("input", {N, H * 2}, itype); + Tensor output("output", {N, H}, otype); + Tensor igrad("igrad", { N, H * 2 }, itype); + Tensor ograd("ograd", { N, H }, itype); fillUniform(&input); fillUniform(&ograd); From 3ff436e91dc36aee5edd4d04dba07b7f5f9174cd Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 6 Feb 2025 14:47:38 -0800 Subject: [PATCH 04/11] Fix Signed-off-by: Przemek Tredak --- tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index e0e769257a..15c7d8d665 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -76,7 +76,7 @@ void performTest(const size_t N, const size_t H) { DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; - Tensor input("input"{N, H}, itype); + Tensor input("input", {N, H}, itype); Tensor gelu_input("gelu_input", {N, H}, itype); Tensor output("output", {N, H}, otype, true, true); From 1ab6b104cb649288894e98de20b704da45a97556 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 6 Feb 2025 14:49:13 -0800 Subject: [PATCH 05/11] Fix Signed-off-by: Przemek Tredak --- tests/cpp/test_common.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index f191992313..ec4a9bdbb7 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -739,13 +739,13 @@ template void fillCase(Tensor *t, const InputsFillCase fill_case); template void fillCase(Tensor *t, const InputsFillCase fill_case); template void fillCase(Tensor *t, const InputsFillCase fill_case); -void setRandomScale(Tensor *t, const std::string& tensor_name) { +void setRandomScale(Tensor *t) { std::uniform_real_distribution<> dis(-2.0, 1.0); const float scale = dis(t->gen()); t->set_scale(scale); } -void setRandomScaleInv(Tensor *t, const std::string& tensor_name) { +void setRandomScaleInv(Tensor *t) { std::uniform_real_distribution<> dis(-2.0, 1.0); const float scale_inv = dis(t->gen()); t->set_scale_inv(scale_inv); From d7254bf0caafdcf2090f0efa47af3d7c966c70c3 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 6 Feb 2025 15:04:04 -0800 Subject: [PATCH 06/11] Disambiguate (and fix) the C++ unit tests for dact Signed-off-by: Przemek Tredak --- tests/cpp/operator/test_cast_dbias_dgelu.cu | 20 ++++---- tests/cpp/operator/test_cast_mxfp8.cu | 56 ++++++++++----------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/tests/cpp/operator/test_cast_dbias_dgelu.cu b/tests/cpp/operator/test_cast_dbias_dgelu.cu index b951632ec5..20ea5c31f1 100644 --- a/tests/cpp/operator/test_cast_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_dbias_dgelu.cu @@ -25,7 +25,7 @@ namespace { template void compute_ref_cast_dbias_dgelu(const IT *input, - const IT *gelu_input, + const IT *grad, const CT scale, OT *output_c, CT *amax_h, @@ -39,9 +39,9 @@ void compute_ref_cast_dbias_dgelu(const IT *input, for (size_t i = 0; i < N; i++) { for (size_t j = 0; j < H; j++) { CT in_elt = static_cast(input[i * H + j]); - const CT gelu_in = static_cast(gelu_input[i * H + j]); + const CT in_grad = static_cast(grad[i * H + j]); - const CT elt = in_elt * static_cast(dgelu(static_cast(gelu_in))); + const CT elt = in_grad * static_cast(dgelu(static_cast(in_elt))); const CT elt_abs = std::abs(elt); // update amax @@ -75,14 +75,14 @@ void performTest(const std::vector& shape) { const size_t H = last_dimension(shape); Tensor input("input", shape, itype); - Tensor gelu_input("gelu_input", shape, itype); + Tensor grad("grad", shape, itype); Tensor output_c("output_c", shape, otype); // dbias has the same data type with "output grad" Tensor dbias("dbias", {H}, itype); fillUniform(&input); - fillUniform(&gelu_input); + fillUniform(&grad); setRandomScale(&output_c); std::unique_ptr ref_output_c = std::make_unique(N*H); @@ -90,7 +90,7 @@ void performTest(const std::vector& shape) { CType ref_amax; compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr(), - gelu_input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), output_c.scale(), ref_output_c.get(), &ref_amax, @@ -99,8 +99,8 @@ void performTest(const std::vector& shape) { Tensor workspace; - nvte_quantize_dbias_dgelu(input.data(), - gelu_input.data(), + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), output_c.data(), dbias.data(), workspace.data(), @@ -109,8 +109,8 @@ void performTest(const std::vector& shape) { workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(input.data(), - gelu_input.data(), + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), output_c.data(), dbias.data(), workspace.data(), diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 5c03e59510..ef9f787f9f 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -39,7 +39,7 @@ enum ActivationType { template void scale_block(const ProcessingMethod processing_method, const InputType* input, - const InputType* act_input, + const InputType* grad, OutputType* output_c, float* dbias, fp8e8m0* output_scales, @@ -62,7 +62,7 @@ void scale_block(const ProcessingMethod processing_method, } if (processing_method == ProcessingMethod::CAST_DACT || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - elt *= static_cast(act_input[idx]); + elt *= static_cast(grad[idx]); } dbias[j] += elt; if (isinf(elt) || isnan(elt)) { @@ -87,7 +87,7 @@ void scale_block(const ProcessingMethod processing_method, } if (processing_method == ProcessingMethod::CAST_DACT || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - elt *= static_cast(act_input[idx]); + elt *= static_cast(grad[idx]); } output_c[idx] = static_cast(elt * scale_reciprocal); } @@ -97,7 +97,7 @@ void scale_block(const ProcessingMethod processing_method, template void compute_ref_x1(const ProcessingMethod processing_method, const InputType* input, - const InputType* act_input, + const InputType* grad, OutputType* output_c, fp8e8m0* output_scales, InputType* output_dbias, @@ -120,7 +120,7 @@ void compute_ref_x1(const ProcessingMethod processing_method, const size_t j_max = std::min((jj + 1) * block_size_X, cols); const size_t scale_idx = ii * scales_stride + jj; scale_block( - processing_method, input, act_input, output_c, output_dbias_fp32.data(), + processing_method, input, grad, output_c, output_dbias_fp32.data(), output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); } } @@ -132,7 +132,7 @@ void compute_ref_x1(const ProcessingMethod processing_method, template void compute_ref_x2(const ProcessingMethod processing_method, const InputType* input, - const InputType* act_input, + const InputType* grad, OutputType* output_rowwise, OutputType* output_colwise, fp8e8m0* scales_rowwise, @@ -145,10 +145,10 @@ void compute_ref_x2(const ProcessingMethod processing_method, const size_t scales_stride_rowwise, const size_t scales_stride_colwise) { compute_ref_x1( - processing_method, input, act_input, output_rowwise, scales_rowwise, output_dbias, + processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias, rows, cols, 1, block_size_X, scales_stride_rowwise); compute_ref_x1( - processing_method, input, act_input, output_colwise, scales_colwise, output_dbias, + processing_method, input, grad, output_colwise, scales_colwise, output_dbias, rows, cols, block_size_Y, 1, scales_stride_colwise); } @@ -191,7 +191,7 @@ void performTest_x1(const ProcessingMethod processing_method, const size_t scales_stride = blocks_X; Tensor input("input", shape, itype); - Tensor act_input("act_input", shape, itype); + Tensor grad("grad", shape, itype); Tensor output_c("output_c", shape, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING); Tensor output_dbias("output_dbias", { cols }, itype); @@ -200,7 +200,7 @@ void performTest_x1(const ProcessingMethod processing_method, std::unique_ptr ref_output_scales = std::make_unique(blocks_Y * blocks_X); fillCase(&input, fill_case); - fillUniform(&act_input); + fillUniform(&grad); Tensor workspace; switch (processing_method) { @@ -209,14 +209,14 @@ void performTest_x1(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS: { - nvte_quantize_dbias(input.data(), + nvte_quantize_dbias(grad.data(), output_c.data(), output_dbias.data(), workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias(input.data(), + nvte_quantize_dbias(grad.data(), output_c.data(), output_dbias.data(), workspace.data(), @@ -224,16 +224,16 @@ void performTest_x1(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS_DACT: { - nvte_quantize_dbias_dgelu(input.data(), - act_input.data(), + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), output_c.data(), output_dbias.data(), workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(input.data(), - act_input.data(), + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), output_c.data(), output_dbias.data(), workspace.data(), @@ -241,7 +241,7 @@ void performTest_x1(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(act_input.data(), input.data(), output_c.data(), 0); + nvte_dgelu(grad.data(), input.data(), output_c.data(), 0); break; } case ProcessingMethod::CAST_ACT: { @@ -256,7 +256,7 @@ void performTest_x1(const ProcessingMethod processing_method, compute_ref_x1(processing_method, input.rowwise_cpu_dptr(), - act_input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), ref_output_c.get(), ref_output_scales.get(), ref_output_dbias.get(), @@ -329,7 +329,7 @@ void performTest_x2(const ProcessingMethod processing_method, const size_t scales_stride_colwise = blocks_X_colwise; Tensor input("input", shape, itype); - Tensor act_input("act_input", shape, itype); + Tensor grad("grad", shape, itype); Tensor output("output", shape, otype, true, true, NVTE_MXFP8_1D_SCALING); Tensor output_dbias("output_dbias", { cols }, itype); @@ -340,7 +340,7 @@ void performTest_x2(const ProcessingMethod processing_method, std::unique_ptr ref_output_dbias = std::make_unique(cols); fillCase(&input, fill_case); - fillUniform(&act_input); + fillUniform(&grad); Tensor workspace; switch (processing_method) { @@ -349,14 +349,14 @@ void performTest_x2(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS: { - nvte_quantize_dbias(input.data(), + nvte_quantize_dbias(grad.data(), output.data(), output_dbias.data(), workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias(input.data(), + nvte_quantize_dbias(grad.data(), output.data(), output_dbias.data(), workspace.data(), @@ -364,16 +364,16 @@ void performTest_x2(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS_DACT: { - nvte_quantize_dbias_dgelu(input.data(), - act_input.data(), + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), output.data(), output_dbias.data(), workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(input.data(), - act_input.data(), + nvte_quantize_dbias_dgelu(grad.data(), + input.data(), output.data(), output_dbias.data(), workspace.data(), @@ -381,7 +381,7 @@ void performTest_x2(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(act_input.data(), input.data(), output.data(), 0); + nvte_dgelu(grad.data(), input.data(), output.data(), 0); break; } case ProcessingMethod::CAST_ACT: { @@ -396,7 +396,7 @@ void performTest_x2(const ProcessingMethod processing_method, compute_ref_x2(processing_method, input.rowwise_cpu_dptr(), - act_input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), ref_output_c_rowwise.get(), ref_output_c_colwise.get(), ref_scales_rowwise.get(), From e35f3016b08c736e7bc5f0ea8d177f002e830d55 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 6 Feb 2025 15:44:55 -0800 Subject: [PATCH 07/11] Fix tests Signed-off-by: Przemek Tredak --- .../common/activation/activation_template.h | 4 +- transformer_engine/common/util/cast.cu | 20 +++++----- .../common/util/cast_kernels.cuh | 39 ++++++++++++------- 3 files changed, 37 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 438c546a9a..ac70c5c161 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -30,9 +30,9 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { constexpr bool IS_ACT = true; constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor activation_input = nullptr; + constexpr const NVTETensor grad = nullptr; - quantize_helper(input, activation_input, nullptr, output, + quantize_helper(input, grad, nullptr, output, dbias, workspace, stream); } diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 2a80c82ef3..7c770328f8 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -33,10 +33,10 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea constexpr bool IS_ACT = false; constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor activation_input = nullptr; + constexpr const NVTETensor grad = nullptr; detail::quantize_helper( - input, activation_input, nullptr, output, dbias, workspace, stream); + input, grad, nullptr, output, dbias, workspace, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -49,10 +49,10 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no constexpr bool IS_ACT = false; constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor activation_input = nullptr; + constexpr const NVTETensor grad = nullptr; detail::quantize_helper( - input, activation_input, noop, output, dbias, workspace, stream); + input, grad, noop, output, dbias, workspace, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, @@ -66,7 +66,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d constexpr const NVTETensor activation_input = nullptr; detail::quantize_helper( - input, activation_input, nullptr, output, dbias, workspace, stream); + activation_input, input, nullptr, output, dbias, workspace, stream); } void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, @@ -80,7 +80,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - input, activation_input, nullptr, output, dbias, workspace, stream); + activation_input, input, nullptr, output, dbias, workspace, stream); } void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, @@ -94,7 +94,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - input, activation_input, nullptr, output, dbias, workspace, stream); + activation_input, input, nullptr, output, dbias, workspace, stream); } void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, @@ -108,7 +108,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati constexpr bool IS_ACT = false; detail::quantize_helper>( - input, activation_input, nullptr, output, dbias, workspace, stream); + activation_input, input, nullptr, output, dbias, workspace, stream); } void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, @@ -122,7 +122,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - input, activation_input, nullptr, output, dbias, workspace, stream); + activation_input, input, nullptr, output, dbias, workspace, stream); } void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, @@ -136,7 +136,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat constexpr bool IS_ACT = false; detail::quantize_helper>( - input, activation_input, nullptr, output, dbias, workspace, stream); + activation_input, input, nullptr, output, dbias, workspace, stream); } void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 36387f8357..3844ca9a57 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -248,11 +248,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); float elt = static_cast(in.data.elt[j]); - if constexpr (IS_ACT || IS_DACT) { + if constexpr (IS_ACT) { elt = OP(elt, {}); } if constexpr (IS_DACT) { - elt *= static_cast(act_in.data.elt[j]); + float act_in_elt = static_cast(act_in.data.elt[j]); + elt *= OP(act_in_elt, {}); } if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { if (!out_of_bounds) { @@ -306,11 +307,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); float elt = static_cast(in_sh[buff][i][tid_colwise_X]); - if constexpr (IS_ACT || IS_DACT) { + if constexpr (IS_ACT) { elt = OP(elt, {}); } if constexpr (IS_DACT) { - elt *= static_cast(act_in_sh[buff][i][tid_colwise_X]); + float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); + elt *= OP(act_in_elt, {}); } if constexpr (IS_DBIAS) { if (!out_of_bounds) { @@ -565,8 +567,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); if constexpr (IS_DACT) { - elt = OP(elt, {}); - elt *= static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= OP(act_in_elt, {}); } if constexpr (IS_DBIAS) { if constexpr (IS_DACT) { @@ -1153,7 +1155,7 @@ void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const if (!IS_DACT) { CastVectorizedUnaryKernelLauncher(input, noop, output, stream); } else { - CastVectorizedUnaryGradKernelLauncher(act_input, input, output, stream); + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); } } @@ -1194,12 +1196,21 @@ namespace detail { template -void quantize_helper(const NVTETensor input, const NVTETensor activation_input, +void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, NVTETensor output, NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { - const auto &input_tensor = *(reinterpret_cast(input)); + const Tensor *input_tensor; + const Tensor *activation_input_tensor; + if constexpr (IS_DBIAS || IS_DACT) { + // backward - input is incoming gradient + input_tensor = reinterpret_cast(grad); + activation_input_tensor = reinterpret_cast(input); + } else { + // forward = input is activation input + input_tensor = reinterpret_cast(input); + activation_input_tensor = nullptr; + } auto output_tensor = reinterpret_cast(output); - const auto activation_tensor = reinterpret_cast(activation_input); auto dbias_tensor = reinterpret_cast(dbias); auto workspace_tensor = reinterpret_cast(workspace); const auto noop_tensor = noop != nullptr ? *(reinterpret_cast(noop)) : Tensor(); @@ -1210,22 +1221,22 @@ void quantize_helper(const NVTETensor input, const NVTETensor activation_input, NVTE_CHECK(output_tensor->has_data(), "Quantizing in only the columnwise direction not supported yet!"); if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(input_tensor, noop_tensor, output_tensor, stream); + cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); } else { cast_transpose_fused( - input_tensor, activation_tensor, output_tensor, dbias_tensor, workspace_tensor, + *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); } } else if (output_tensor->has_data()) { fp8_quantize( - input_tensor, activation_tensor, &noop_tensor, output_tensor, dbias_tensor, + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); } break; } case NVTE_MXFP8_1D_SCALING: { mxfp8_quantize( - input_tensor, activation_tensor, &noop_tensor, output_tensor, dbias_tensor, + *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); break; } From 71426189b3782b55ea4e98a120d7d27c840035de Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 6 Feb 2025 15:50:42 -0800 Subject: [PATCH 08/11] Fix Signed-off-by: Przemek Tredak --- transformer_engine/common/util/cast_kernels.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 3844ca9a57..1af313001b 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1060,7 +1060,7 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, } template -void CastVectorizedUnaryGradKernelLauncher(const Tensor *grad, const Tensor &input, Tensor *output, +void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, cudaStream_t stream) { constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; const size_t N = product(input.data.shape); @@ -1072,8 +1072,8 @@ void CastVectorizedUnaryGradKernelLauncher(const Tensor *grad, const Tensor &inp is_delayed_tensor_scaling(output->scaling_mode)) { constexpr int nvec = 32 / sizeof(IType); VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad->data.dptr), - reinterpret_cast(input.data.dptr), + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), reinterpret_cast(output->data.dptr), reinterpret_cast(output->scale.dptr), reinterpret_cast(output->amax.dptr), @@ -1124,7 +1124,7 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons stream); } else { // Unaligned - CastVectorizedUnaryGradKernelLauncher(act_input, input, output, stream); + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); } } else { cast_fp8_2D(input, act_input, output, dbias, workspace, From 30dc30a5db87e1a92b45bfbde9912204d7476e0c Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 6 Feb 2025 15:52:26 -0800 Subject: [PATCH 09/11] Fix Signed-off-by: Przemek Tredak --- transformer_engine/common/util/cast_kernels.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 1af313001b..1a8acca6ca 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1063,9 +1063,9 @@ template void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, cudaStream_t stream) { constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; - const size_t N = product(input.data.shape); + const size_t N = product(input->data.shape); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, + input->data.dtype, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( output->data.dtype, OType, if (!is_fp8_dtype(output->data.dtype) || From ed31060492ab001d6e986f4d957ead2300d89bff Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 6 Feb 2025 16:02:12 -0800 Subject: [PATCH 10/11] Fix MXFP8 dbias tests Signed-off-by: Przemek Tredak --- tests/cpp/operator/test_cast_mxfp8.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index ef9f787f9f..cb38a5a74a 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -56,6 +56,10 @@ void scale_block(const ProcessingMethod processing_method, for (size_t j = j_min; j < j_max; ++j) { const size_t idx = i * cols + j; float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } if (processing_method != ProcessingMethod::CAST_ONLY && processing_method != ProcessingMethod::CAST_DBIAS) { elt = OP(elt); @@ -81,6 +85,10 @@ void scale_block(const ProcessingMethod processing_method, for (size_t j = j_min; j < j_max; ++j) { const size_t idx = i * cols + j; float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } if (processing_method != ProcessingMethod::CAST_ONLY && processing_method != ProcessingMethod::CAST_DBIAS) { elt = OP(elt); From efde3a226ffc14113dbf4117a1a8473af7ac3df8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Feb 2025 00:28:02 +0000 Subject: [PATCH 11/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/activation/activation_template.h | 4 ++-- transformer_engine/common/util/cast.cu | 8 ++++---- transformer_engine/common/util/cast_kernels.cuh | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index ac70c5c161..708403f911 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -32,8 +32,8 @@ void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - quantize_helper(input, grad, nullptr, output, - dbias, workspace, stream); + quantize_helper(input, grad, nullptr, output, dbias, + workspace, stream); } template diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu index 7c770328f8..22a50025df 100644 --- a/transformer_engine/common/util/cast.cu +++ b/transformer_engine/common/util/cast.cu @@ -35,8 +35,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper( - input, grad, nullptr, output, dbias, workspace, stream); + detail::quantize_helper(input, grad, nullptr, output, + dbias, workspace, stream); } void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, @@ -51,8 +51,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no constexpr NVTETensor workspace = nullptr; constexpr const NVTETensor grad = nullptr; - detail::quantize_helper( - input, grad, noop, output, dbias, workspace, stream); + detail::quantize_helper(input, grad, noop, output, + dbias, workspace, stream); } void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 1a8acca6ca..404babc745 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1196,9 +1196,9 @@ namespace detail { template -void quantize_helper(const NVTETensor input, const NVTETensor grad, - const NVTETensor noop, NVTETensor output, NVTETensor dbias, - NVTETensor workspace, cudaStream_t stream) { +void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETensor noop, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { const Tensor *input_tensor; const Tensor *activation_input_tensor; if constexpr (IS_DBIAS || IS_DACT) {