Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 190 additions & 2 deletions tests/cpp/operator/test_act.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <gtest/gtest.h>

#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
#include "../test_common.h"

using namespace transformer_engine;
Expand Down Expand Up @@ -48,15 +49,16 @@ void compute_ref_dact_cast(const IT *input_h,
const IT *grad_h,
OT *output_h,
const size_t N,
const size_t H) {
const size_t H,
float scale = 1.f) {
using CT = float;
#pragma omp parallel for schedule(static) proc_bind(spread)
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < H; j++) {
CT elt = static_cast<CT>(input_h[i * H + j]);
elt = dact(elt);
CT grad = static_cast<CT>(grad_h[i * H + j]);
output_h[i * H + j] = static_cast<OT>(grad * elt);
output_h[i * H + j] = static_cast<OT>(grad * elt * scale);
}
}
}
Expand Down Expand Up @@ -160,6 +162,123 @@ void performTest(const size_t N, const size_t H) {
}
}

std::vector<size_t> getDBiasWorkspaceShape(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype) {
auto input_shape = std::vector<size_t>{batch_size, hidden_size};
auto dact_input_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_shape = std::vector<size_t>{batch_size, hidden_size};
auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
auto dbias_shape = std::vector<size_t>{hidden_size};

// Evil hack to specify TE impl
// Note: nvte_quantize_dbias_dgelu chooses its internal impl based
// on what pointers are allocated, e.g. whether to output with
// column-wise data. However, we don't have access to any allocated
// buffers in this function. We pass a dummy pointer as a
// workaround.
int temp = 0;

auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
auto dact_input_tensor =
TensorWrapper(reinterpret_cast<void *>(&temp), dact_input_shape, in_dtype);
auto output_tensor = TensorWrapper();
output_tensor.set_rowwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_shape);
output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);

TensorWrapper dummy_workspace;

// For now, all dbias_dact(-s) have the same workspace size
nvte_quantize_dbias_dgelu(input_tensor.data(), dact_input_tensor.data(), output_tensor.data(),
dbias_tensor.data(), dummy_workspace.data(), nullptr);

auto work_shape = std::vector<size_t>(dummy_workspace.shape().data, dummy_workspace.shape().data + dummy_workspace.shape().ndim);
return work_shape;
}

template <float (*ref_dact)(const float),
void (*nvte_dact)(const NVTETensor, const NVTETensor,
NVTETensor, NVTETensor, NVTETensor,
cudaStream_t),
typename IType, typename OType>
void performTestDActZeroGradInput(const size_t N, const size_t H) {
using namespace test;

DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;

Tensor input({ N, H }, itype);
Tensor igrad({ N, H }, otype);
Tensor ograd({ N, H }, itype);
Tensor dbias({ H }, itype);
auto workspace_shape = getDBiasWorkspaceShape(N, H, itype, otype);
Tensor workspace(workspace_shape, DType::kFloat32);

fillUniform(&input);
fillCase<IType>(&ograd, zeros);
setRandomScale(&igrad);
float iGradScale = igrad.scale();

std::unique_ptr<OType[]> ref_igrad = std::make_unique<OType[]>(N*H);

nvte_dact(ograd.data(), input.data(), igrad.data(), dbias.data(), workspace.data(), 0);

compute_ref_dact_cast<ref_dact>(input.rowwise_cpu_dptr<IType>(), ograd.rowwise_cpu_dptr<IType>(),
ref_igrad.get(), N, H, iGradScale);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

{
auto [atol, rtol] = getTolerances(DType::kFloat32);
compareResults("scale_inv", *igrad.rowwise_cpu_scale_inv_ptr<fp32>(), 1.f / iGradScale, atol, rtol);
}
}

template <float (*ref_dact)(const float),
void (*nvte_dact)(const NVTETensor, const NVTETensor,
NVTETensor, NVTETensor, NVTETensor,
cudaStream_t),
typename IType, typename OType>
void performTestDAct(const size_t N, const size_t H) {
using namespace test;

DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;

Tensor input({ N, H }, itype);
Tensor igrad({ N, H }, otype);
Tensor ograd({ N, H }, itype);
Tensor dbias({ H }, itype);
auto workspace_shape = getDBiasWorkspaceShape(N, H, itype, otype);
Tensor workspace(workspace_shape, DType::kFloat32);

fillUniform(&input);
fillUniform(&ograd);
setRandomScale(&igrad);
float iGradScale = igrad.scale();

std::unique_ptr<OType[]> ref_igrad = std::make_unique<OType[]>(N*H);

nvte_dact(ograd.data(), input.data(), igrad.data(), dbias.data(), workspace.data(), 0);

compute_ref_dact_cast<ref_dact>(input.rowwise_cpu_dptr<IType>(), ograd.rowwise_cpu_dptr<IType>(),
ref_igrad.get(), N, H, iGradScale);

cudaDeviceSynchronize();
auto err = cudaGetLastError();
ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);

{
auto [atol, rtol] = getTolerances(otype);
compareResults("igrad_act", igrad, ref_igrad.get(), atol, rtol);
}
{
auto [atol, rtol] = getTolerances(DType::kFloat32);
compareResults("scale_inv", *igrad.rowwise_cpu_scale_inv_ptr<fp32>(), 1.f / iGradScale, atol, rtol);
}
}

template <float (*ref_act)(const float),
float (*ref_dact)(const float),
void (*nvte_act)(const NVTETensor, NVTETensor, cudaStream_t),
Expand Down Expand Up @@ -384,6 +503,45 @@ TEST_P(ActTestSuite, TestSReGLU) {
OutputType>(size.first, size.second);););
}

class DActZeroGradTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};

TEST_P(DActZeroGradTestSuite, TestDGELUDBias) {
using namespace transformer_engine;
using namespace test;

const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());

TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTestDActZeroGradInput<dgelu, nvte_quantize_dbias_dgelu, InputType, OutputType>(size.first, size.second);
);
);
}

class DActTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
transformer_engine::DType,
std::pair<size_t, size_t>>> {};

TEST_P(DActTestSuite, TestDGELUDBias) {
using namespace transformer_engine;
using namespace test;

const DType input_type = std::get<0>(GetParam());
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());

TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
performTestDAct<dgelu, nvte_quantize_dbias_dgelu, InputType, OutputType>(size.first, size.second);
);
);
}


namespace {

std::vector<std::pair<size_t, size_t>> act_test_cases = {{2048, 12288},
Expand All @@ -410,3 +568,33 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<2>(info.param).second);
return name;
});

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
DActZeroGradTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
::testing::Values(std::make_pair<size_t, size_t>(128, 128))),
[](const testing::TestParamInfo<DActZeroGradTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
DActTestSuite,
::testing::Combine(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
::testing::Values(std::make_pair<size_t, size_t>(128, 128))),
[](const testing::TestParamInfo<DActTestSuite::ParamType>& info) {
std::string name = test::typeName(std::get<0>(info.param)) + "X" +
test::typeName(std::get<1>(info.param)) + "X" +
std::to_string(std::get<2>(info.param).first) + "X" +
std::to_string(std::get<2>(info.param).second);
return name;
});
4 changes: 2 additions & 2 deletions tests/cpp/operator/test_cast_mxfp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ void performTest_x1(const ProcessingMethod processing_method,
break;
}
case ProcessingMethod::CAST_DACT: {
nvte_dgelu(act_input.data(), input.data(), output_c.data(), 0);
nvte_dgelu(input.data(), act_input.data(), output_c.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
Expand Down Expand Up @@ -381,7 +381,7 @@ void performTest_x2(const ProcessingMethod processing_method,
break;
}
case ProcessingMethod::CAST_DACT: {
nvte_dgelu(act_input.data(), input.data(), output.data(), 0);
nvte_dgelu(input.data(), act_input.data(), output.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
Expand Down
5 changes: 5 additions & 0 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,11 @@ void fillCase(Tensor *t, const InputsFillCase fill_case) {
template void fillCase<fp8e4m3>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp8e5m2>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp32>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<bf16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<fp16>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int64_t>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<int32_t>(Tensor *t, const InputsFillCase fill_case);
template void fillCase<uint8_t>(Tensor *t, const InputsFillCase fill_case);

void setRandomScale(Tensor *t) {
static std::mt19937 gen(12345);
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/activation/activation_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
constexpr NVTETensor dbias = nullptr;
constexpr NVTETensor workspace = nullptr;

quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(input, grad, nullptr, output, dbias,
quantize_helper<IS_DBIAS, IS_DACT, IS_ACT, Empty, OP>(grad, input, nullptr, output, dbias,
workspace, stream);
}

Expand Down
7 changes: 4 additions & 3 deletions transformer_engine/common/util/cast_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -563,10 +563,10 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
const bool row_out_of_bounds = row >= rows;
const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds;

float elt = static_cast<float>(in_sh[buff][shmem_offset_y][shmem_offset_x]);
float elt = static_cast<float>(act_in_sh[buff][shmem_offset_y][shmem_offset_x]);
if constexpr (IS_DACT) {
elt = OP(elt, {});
elt *= static_cast<float>(act_in_sh[buff][shmem_offset_y][shmem_offset_x]);
elt *= static_cast<float>(in_sh[buff][shmem_offset_y][shmem_offset_x]);
}
if constexpr (IS_DBIAS) {
if constexpr (IS_DACT) {
Expand Down Expand Up @@ -1122,7 +1122,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons
stream);
} else {
// Unaligned
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(act_input, input, output, stream);
NVTE_CHECK(act_input != nullptr, "Activation input must be provided for DAct.");
CastVectorizedUnaryGradKernelLauncher<ParamOP, OP>(&input, *act_input, output, stream);
}
} else {
cast_fp8_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
Expand Down