From 16aaa78344aa24e5534ee3963358b1c15631ed30 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Wed, 20 Sep 2023 15:56:36 -0700 Subject: [PATCH 1/7] Add update_scale_hysteresis --- csrc/amp_C_frontend.cpp | 12 ++++++ csrc/update_scale_hysteresis.cu | 70 +++++++++++++++++++++++++++++++++ setup.py | 1 + 3 files changed, 83 insertions(+) create mode 100644 csrc/update_scale_hysteresis.cu diff --git a/csrc/amp_C_frontend.cpp b/csrc/amp_C_frontend.cpp index 74d36487e..d39be60c1 100644 --- a/csrc/amp_C_frontend.cpp +++ b/csrc/amp_C_frontend.cpp @@ -178,6 +178,16 @@ void multi_tensor_lamb_mp_cuda( at::Tensor found_inf, at::Tensor inv_scale); +at::Tensor update_scale_hysteresis_cuda( + at::Tensor current_scale, + at::Tensor growth_tracker, + at::Tensor hysteresis_tracker, + at::Tensor found_inf, + const double growth_factor, + const double backoff_factor, + const int64_t growth_interval, + const int hysteresis); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("multi_tensor_scale", &multi_tensor_scale_cuda, "Fused overflow check + scale for a list of contiguous tensors"); @@ -211,4 +221,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Computes and apply update for LAMB optimizer"); m.def("multi_tensor_lamb_mp", &multi_tensor_lamb_mp_cuda, "Computes and apply update for LAMB optimizer"); + m.def("update_scale_hysteresis", &update_scale_hysteresis_cuda, + "Updates scale while accounting for hysteresis"); } diff --git a/csrc/update_scale_hysteresis.cu b/csrc/update_scale_hysteresis.cu new file mode 100644 index 000000000..6209792ad --- /dev/null +++ b/csrc/update_scale_hysteresis.cu @@ -0,0 +1,70 @@ +#include +#include + +__global__ void update_scale_hysteresis_cuda_kernel(float* current_scale, + int* growth_tracker, + int* hysteresis_tracker, + const float* found_inf, + double growth_factor, + double backoff_factor, + int growth_interval, + int hysteresis) +{ + if (*found_inf > 0) { + *hysteresis_tracker -= 1; + + // Only reset the growth tracker when hysteresis is larger than zero + if (*hysteresis_tracker > 0) { + *growth_tracker = 0; + return; + } + } + + if (*found_inf) { + *current_scale = (*current_scale)*backoff_factor; + *growth_tracker = 0; + } else { + // Entering this branch means we just carried out a successful step, + // so growth_tracker is incremented before comparing to growth_interval. + auto successful = (*growth_tracker) + 1; + if (successful == growth_interval) { + auto new_scale = static_cast((*current_scale)*growth_factor); + // Do not grow the scale past fp32 bounds to inf. + if (isfinite_ensure_cuda_math(new_scale)) { + *current_scale = new_scale; + } + *growth_tracker = 0; + } else { + *growth_tracker = successful; + } + } + + // Reset the hysteresis tracker if no infs are found + if (*found_inf <= 0) { + *hysteresis_tracker = hysteresis; + } +} + +at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale, + at::Tensor growth_tracker, + at::Tensor hysteresis_tracker, + at::Tensor found_inf, + const double growth_factor, + const double backoff_factor, + const int64_t growth_interval, + const int hysteresis) +{ + update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>( + current_scale.mutable_data_ptr(), + growth_tracker.mutable_data_ptr(), + hysteresis_tracker.mutable_data_ptr(), + found_inf.const_data_ptr(), + growth_factor, + backoff_factor, + growth_interval, + hysteresis); + + AT_CUDA_CHECK(cudaGetLastError()); + + return current_scale; +} diff --git a/setup.py b/setup.py index fd1602300..329f85646 100644 --- a/setup.py +++ b/setup.py @@ -195,6 +195,7 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int "csrc/multi_tensor_novograd.cu", "csrc/multi_tensor_lamb.cu", "csrc/multi_tensor_lamb_mp.cu", + "csrc/update_scale_hysteresis.cu", ], extra_compile_args={ "cxx": ["-O3"] + version_dependent_macros, From 79b7f4f256c33890af625e8e69f8c92a1c71405e Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Wed, 20 Sep 2023 23:42:14 -0700 Subject: [PATCH 2/7] Fix compile errors --- csrc/update_scale_hysteresis.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/update_scale_hysteresis.cu b/csrc/update_scale_hysteresis.cu index 6209792ad..2405130af 100644 --- a/csrc/update_scale_hysteresis.cu +++ b/csrc/update_scale_hysteresis.cu @@ -1,5 +1,6 @@ #include #include +#include __global__ void update_scale_hysteresis_cuda_kernel(float* current_scale, int* growth_tracker, @@ -30,7 +31,7 @@ __global__ void update_scale_hysteresis_cuda_kernel(float* current_scale, if (successful == growth_interval) { auto new_scale = static_cast((*current_scale)*growth_factor); // Do not grow the scale past fp32 bounds to inf. - if (isfinite_ensure_cuda_math(new_scale)) { + if (isfinite(new_scale)) { *current_scale = new_scale; } *growth_tracker = 0; From 97ae82eb5d3af4199659c054d28bb0c152370cda Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Thu, 28 Sep 2023 09:23:39 +0800 Subject: [PATCH 3/7] Massively reduce LayerNorm/RMSNorm GPU memory usage in modern networks by tricking torch autograd (#1715) * input grad checks out * adding clamp gamma * Both old and proposed implementation checks out * 2 tests not yet passed due to numerical issues * mem_eff works * fast-layer-norm done * Moving mem-eff to templates * Relax tolerance for memory efficient backward * Fix backward api of python --- apex/contrib/csrc/layer_norm/ln.h | 33 +- apex/contrib/csrc/layer_norm/ln_api.cpp | 53 +-- .../csrc/layer_norm/ln_bwd_kernels.cuh | 23 +- apex/contrib/layer_norm/layer_norm.py | 32 +- .../test/layer_norm/test_fast_layer_norm.py | 26 +- apex/normalization/fused_layer_norm.py | 142 +++++--- csrc/layer_norm_cuda.cpp | 93 +++-- csrc/layer_norm_cuda_kernel.cu | 342 ++++++++++++------ csrc/static_switch.h | 25 ++ .../test_fused_layer_norm.py | 127 ++++--- 10 files changed, 560 insertions(+), 336 deletions(-) create mode 100644 csrc/static_switch.h diff --git a/apex/contrib/csrc/layer_norm/ln.h b/apex/contrib/csrc/layer_norm/ln.h index 6ab709b09..cf0355c07 100644 --- a/apex/contrib/csrc/layer_norm/ln.h +++ b/apex/contrib/csrc/layer_norm/ln.h @@ -10,7 +10,7 @@ namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct LaunchParams{ size_t workspace_bytes; @@ -26,17 +26,20 @@ struct LaunchParams{ //////////////////////////////////////////////////////////////////////////////////////////////////// -struct ParamsBase { - ParamsBase() +struct FwdParams{ + FwdParams() : ctas_per_col(0) , rows(0) , cols(0) , x(nullptr) + , z(nullptr) , mu(nullptr) , rs(nullptr) , gamma(nullptr) + , beta(nullptr) , workspace(nullptr) , barrier(nullptr) + , epsilon(0.f) { } @@ -49,9 +52,11 @@ struct ParamsBase { // Common data pointers. void *x; + void *z; void *mu; void *rs; void *gamma; + void *beta; // Multi-CTA workspace in gmem. void *workspace; @@ -59,31 +64,15 @@ struct ParamsBase { // Multi-CTA sync barriers in gmem. int *barrier; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -struct FwdParams : public ParamsBase { - FwdParams() - : ParamsBase() - , z(nullptr) - , beta(nullptr) - , epsilon(0.f) - { - } - // Output of LN FWD. - void *z; - void *beta; float epsilon; - }; //////////////////////////////////////////////////////////////////////////////////////////////////// -struct BwdParams : public ParamsBase { +struct BwdParams : public FwdParams{ BwdParams() - : ParamsBase() + : FwdParams() , dz(nullptr) , dbeta_part(nullptr) , dgamma_part(nullptr) @@ -92,7 +81,6 @@ struct BwdParams : public ParamsBase { , dgamma(nullptr) { } - // Input: gradient wrt. LN FWD output. void *dz; @@ -200,3 +188,4 @@ struct BwdRegistrar{ //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm + diff --git a/apex/contrib/csrc/layer_norm/ln_api.cpp b/apex/contrib/csrc/layer_norm/ln_api.cpp index 30e4a5fec..54f04e201 100644 --- a/apex/contrib/csrc/layer_norm/ln_api.cpp +++ b/apex/contrib/csrc/layer_norm/ln_api.cpp @@ -130,12 +130,12 @@ std::vector ln_fwd(const at::Tensor &x, // BxSxhidden_size layer_norm::FwdParams ¶ms = launch_params.params; params.rows = rows; params.cols = cols; - params.x = x.data_ptr(); + params.z = z.data_ptr(); params.mu = mu.data_ptr(); params.rs = rsigma.data_ptr(); params.gamma = gamma.data_ptr(); params.beta = beta.data_ptr(); - params.z = z.data_ptr(); + params.x = x.data_ptr(); params.epsilon = epsilon; if( launch_params.barrier_size > 0 ) { @@ -153,33 +153,39 @@ std::vector ln_fwd(const at::Tensor &x, // BxSxhidden_size } //////////////////////////////////////////////////////////////////////////////////////////////////// - -std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size - const at::Tensor &x, // BxSxhidden_size - const at::Tensor &mu, // BxS, FP32! - const at::Tensor &rsigma, // BxS, FP32! - const at::Tensor &gamma // hidden_size +std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size + const at::Tensor &x_or_z, // BxSxhidden_size + c10::optional &mu_, // BxS, FP32! + const at::Tensor &rsigma, // BxS, FP32! + const at::Tensor &gamma, // hidden_size + c10::optional&beta_, // hidden_size + bool memory_efficient ) { - auto itype = x.scalar_type(); + auto itype = x_or_z.scalar_type(); auto wtype = gamma.scalar_type(); auto otype = wtype; auto ctype = torch::kFloat32; TORCH_CHECK(dz.dtype() == otype); - TORCH_CHECK(mu.dtype() == ctype); TORCH_CHECK(rsigma.dtype() == ctype); + if (mu_.has_value()) { + TORCH_CHECK(mu_.value().dtype() == ctype); + } - TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(x_or_z.is_cuda()); TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(mu.is_cuda()); TORCH_CHECK(rsigma.is_cuda()); TORCH_CHECK(gamma.is_cuda()); + if (beta_.has_value()) { + TORCH_CHECK(beta_.value().is_cuda()); + TORCH_CHECK(beta_.value().dtype() == wtype); + } - TORCH_CHECK(x.is_contiguous()); + TORCH_CHECK(x_or_z.is_contiguous()); TORCH_CHECK(dz.is_contiguous()); - auto sizes = x.sizes(); + auto sizes = x_or_z.sizes(); TORCH_CHECK(sizes.size() == 2); TORCH_CHECK(dz.sizes() == sizes); auto rows = sizes[0]; @@ -187,14 +193,14 @@ std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size auto hidden_size = gamma.numel(); - TORCH_CHECK(mu.numel() == rows); - TORCH_CHECK(mu.sizes() == rsigma.sizes()); - TORCH_CHECK(gamma.numel() == cols); + if (beta_.has_value()) { + TORCH_CHECK(beta_.value().numel() == cols); + } - auto options = x.options(); + auto options = x_or_z.options(); - auto dx = torch::empty_like(x); + auto dx = torch::empty_like(x_or_z); auto dgamma = torch::empty_like(gamma); auto dbeta = torch::empty_like(gamma); @@ -213,8 +219,13 @@ std::vector ln_bwd(const at::Tensor &dz, // BxSxhidden_size layer_norm::BwdParams ¶ms = launch_params.params; params.rows = rows; params.cols = cols; - params.x = x.data_ptr(); - params.mu = mu.data_ptr(); + if (memory_efficient) { + params.z = x_or_z.data_ptr(); + params.beta = beta_.value().data_ptr(); + } else { + params.x = x_or_z.data_ptr(); + params.mu = mu_.value().data_ptr(); + } params.rs = rsigma.data_ptr(); params.gamma = gamma.data_ptr(); params.dz = dz.data_ptr(); diff --git a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh index 8595f5ed4..019764a38 100644 --- a/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh +++ b/apex/contrib/csrc/layer_norm/ln_bwd_kernels.cuh @@ -57,10 +57,14 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { constexpr float rn = 1.f / float(COLS); Wvec gamma[LDGS]; + Wvec beta[LDGS]; index_t idx = c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { gamma[it].load_from(params.gamma, idx); + if (params.z != nullptr) { + beta[it].load_from(params.beta, idx); + } idx += Ktraits::VEC_COLS_PER_LDG; } // TODO if ROWS_PER_CTA does not divide rows, we might get divergence in the @@ -68,15 +72,19 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { // grid stride over rows #pragma unroll 1 for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { - const compute_t mu_r = static_cast(params.mu)[row]; + const compute_t mu_r = params.z == nullptr ? static_cast(params.mu)[row] : 0.f; const compute_t rs_r = static_cast(params.rs)[row]; - Ivec x[LDGS]; + Ivec x_or_z[LDGS]; Ovec dz[LDGS]; index_t idx = row * Ktraits::VEC_COLS + c; #pragma unroll for( int it = 0; it < LDGS; it++ ) { dz[it].load_from(params.dz, idx); - x[it].load_from(params.x, idx); + if (params.z != nullptr) { + x_or_z[it].load_from(params.z, idx); + } else { + x_or_z[it].load_from(params.x, idx); + } idx += Ktraits::VEC_COLS_PER_LDG; } @@ -89,10 +97,11 @@ void ln_bwd_kernel(layer_norm::BwdParams params) { for( int it = 0; it < LDGS; it++ ) { #pragma unroll for( int jt = 0; jt < NUM_ELTS; jt++ ) { - compute_t x_tmp = x[it].data.elt[jt]; - compute_t y_tmp = rs_r * (x_tmp - mu_r); - compute_t dy_tmp = compute_t(gamma[it].data.elt[jt]); - dy_tmp *= compute_t(dz[it].data.elt[jt]); + compute_t gamma_tmp = compute_t(gamma[it].data.elt[jt]); + compute_t beta_tmp = compute_t(beta[it].data.elt[jt]); + compute_t x_or_z_tmp = compute_t(x_or_z[it].data.elt[jt]); + compute_t y_tmp = params.z != nullptr ? (x_or_z_tmp - beta_tmp) / gamma_tmp : rs_r * (x_or_z_tmp - mu_r); + compute_t dy_tmp = compute_t(dz[it].data.elt[jt]) * gamma_tmp; compute_t dz_tmp = dz[it].data.elt[jt]; mdy_local += dy_tmp; diff --git a/apex/contrib/layer_norm/layer_norm.py b/apex/contrib/layer_norm/layer_norm.py index b084b1ace..1d79c561b 100644 --- a/apex/contrib/layer_norm/layer_norm.py +++ b/apex/contrib/layer_norm/layer_norm.py @@ -7,40 +7,44 @@ class FastLayerNormFN(torch.autograd.Function): @staticmethod - def forward(ctx, x, gamma, beta, epsilon): + def forward(ctx, x, gamma, beta, epsilon, memory_efficient): + ctx.x_shape = x.shape + ctx.memory_efficient = memory_efficient + x = x.contiguous() gamma = gamma.contiguous() beta = beta.contiguous() hidden_size = gamma.numel() xmat = x.view((-1, hidden_size)) ymat, mu, rsigma = fast_layer_norm.ln_fwd(xmat, gamma, beta, epsilon) - ctx.save_for_backward(x, gamma, mu, rsigma) + if ctx.memory_efficient: + ctx.save_for_backward(ymat, gamma, None, rsigma, beta) + else: + ctx.save_for_backward(xmat, gamma, mu, rsigma, None) return ymat.view(x.shape) @staticmethod def backward(ctx, dy): # assert dy.is_contiguous() dy = dy.contiguous() # this happens! - x, gamma, mu, rsigma = ctx.saved_tensors - - hidden_size = gamma.numel() - xmat = x.view((-1, hidden_size)) - dymat = dy.view(xmat.shape) - dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, xmat, mu, rsigma, gamma) - dx = dxmat.view(x.shape) - return dx, dgamma, dbeta, None + x_or_y_mat, gamma, mu, rsigma, beta = ctx.saved_tensors + dymat = dy.view(x_or_y_mat.shape) + dxmat, dgamma, dbeta, _, _ = fast_layer_norm.ln_bwd(dymat, x_or_y_mat, mu, rsigma, gamma, beta, ctx.memory_efficient) + dx = dxmat.view(ctx.x_shape) + return dx, dgamma, dbeta, None, None -def _fast_layer_norm(x, weight, bias, epsilon): - args = _cast_if_autocast_enabled(x, weight, bias, epsilon) +def _fast_layer_norm(x, weight, bias, epsilon, memory_efficient): + args = _cast_if_autocast_enabled(x, weight, bias, epsilon, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FastLayerNormFN.apply(*args) class FastLayerNorm(torch.nn.Module): - def __init__(self, hidden_size, eps=1e-5): + def __init__(self, hidden_size, eps=1e-5, memory_efficient=False): super().__init__() self.epsilon = eps + self.memory_efficient = memory_efficient self.weight = torch.nn.Parameter(torch.empty(hidden_size)) self.bias = torch.nn.Parameter(torch.empty(hidden_size)) self.reset_parameters() @@ -50,4 +54,4 @@ def reset_parameters(self): init.zeros_(self.bias) def forward(self, x): - return _fast_layer_norm(x, self.weight, self.bias, self.epsilon) + return _fast_layer_norm(x, self.weight, self.bias, self.epsilon, self.memory_efficient) diff --git a/apex/contrib/test/layer_norm/test_fast_layer_norm.py b/apex/contrib/test/layer_norm/test_fast_layer_norm.py index 9f6ee7980..fede67e90 100644 --- a/apex/contrib/test/layer_norm/test_fast_layer_norm.py +++ b/apex/contrib/test/layer_norm/test_fast_layer_norm.py @@ -1,3 +1,4 @@ +import itertools import unittest import torch @@ -106,7 +107,7 @@ def benchmark_(S, B, hidden_size, itype, wtype, runs=100): timer.start() for r in range(runs): - dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, x, mu, rsigma, gamma) + dx, dgamma, dbeta, dbp, dgp = fln.ln_bwd(dz, z, mu, rsigma, gamma, beta, True) timer.stop() timer.sync() @@ -126,7 +127,7 @@ def benchmark_(S, B, hidden_size, itype, wtype, runs=100): ) -def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32): +def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32, mem_eff=False): seed = 1243 torch.manual_seed(seed) @@ -134,7 +135,7 @@ def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32): otype = wtype print("========================================================") - print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype}") + print(f"S={S} B={B} Hidden={hidden_size} {itype} {wtype} Mem_Eff={mem_eff}") print("--------------------------------------------------------") x = torch.randn(S * B, hidden_size, dtype=itype, device=device) @@ -165,7 +166,10 @@ def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32): dx_ref, dg_ref, db_ref = backward_(dz, x, mu_ref, rs_ref, gamma) z, mu, rs = fln.ln_fwd(x, gamma, beta, epsilon) - dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma) + if mem_eff: + dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, z, mu, rs, gamma, beta, True) + else: + dx, dg, db, dg_part, db_part = fln.ln_bwd(dz, x, mu, rs, gamma, beta, False) re_z, mse_z = metrics(z_ref, z) re_mu, mse_mu = metrics(mu_ref, mu) @@ -184,7 +188,7 @@ def _test_impl(S, B, hidden_size, itype, wtype, ctype=fp32): print(f"db: relerr={re_db:.4e} mse={mse_db:.4e}") def check_err(x, relerr): - tol = 1e-3 if x.dtype == torch.float16 else 5e-6 + tol = 2e-2 if x.dtype in (torch.float16, torch.bfloat16) else 5e-6 return relerr < tol return [ @@ -233,13 +237,13 @@ def test_all_configs(self): 65536, ] - for h in hidden_sizes: + for (h, mem_eff) in itertools.product(hidden_sizes, (True, False)): with self.subTest(f"hidden_size={h}"): - self.assertAll(_test_impl(256, 2, h, fp32, fp32)) - self.assertAll(_test_impl(256, 2, h, fp16, fp16)) - self.assertAll(_test_impl(256, 2, h, fp32, fp16)) - self.assertAll(_test_impl(256, 2, h, bf16, bf16)) - self.assertAll(_test_impl(256, 2, h, fp32, bf16)) + self.assertAll(_test_impl(256, 2, h, fp32, fp32, mem_eff=mem_eff)) + self.assertAll(_test_impl(256, 2, h, fp16, fp16, mem_eff=mem_eff)) + self.assertAll(_test_impl(256, 2, h, fp32, fp16, mem_eff=mem_eff)) + self.assertAll(_test_impl(256, 2, h, bf16, bf16, mem_eff=mem_eff)) + self.assertAll(_test_impl(256, 2, h, fp32, bf16, mem_eff=mem_eff)) def test_run_benchmark(self): for (S, B, hidden_size, runs) in ( diff --git a/apex/normalization/fused_layer_norm.py b/apex/normalization/fused_layer_norm.py index d99e232ae..571b8b456 100644 --- a/apex/normalization/fused_layer_norm.py +++ b/apex/normalization/fused_layer_norm.py @@ -31,172 +31,198 @@ def manual_rms_norm(input, normalized_shape, weight, eps): class FusedLayerNormAffineFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): + def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward_affine( input_, ctx.normalized_shape, weight_, bias_, ctx.eps ) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, weight_, bias_, mean, invvar = ctx.saved_tensors + input_or_output, weight_, bias_, mean, invvar = ctx.saved_tensors grad_input = grad_weight = grad_bias = None grad_input, grad_weight, grad_bias = fused_layer_norm_cuda.backward_affine( - grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, weight_, bias_, ctx.eps + grad_output.contiguous(), mean, invvar, input_or_output, + ctx.normalized_shape, weight_, bias_, ctx.eps, ctx.memory_efficient ) - return grad_input, grad_weight, grad_bias, None, None + return grad_input, grad_weight, grad_bias, None, None, None class FusedRMSNormAffineFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): + def forward(ctx, input, weight, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward_affine( input_, ctx.normalized_shape, weight_, ctx.eps) - ctx.save_for_backward(input_, weight_, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, weight_, invvar = ctx.saved_tensors + input_or_output, weight_, invvar = ctx.saved_tensors grad_input = grad_weight = None grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( - grad_output.contiguous(), invvar, input_, ctx.normalized_shape, weight_, ctx.eps + grad_output.contiguous(), invvar, input_or_output, + ctx.normalized_shape, weight_, ctx.eps, ctx.memory_efficient ) - return grad_input, grad_weight, None, None + return grad_input, grad_weight, None, None, None class FusedLayerNormAffineMixedDtypesFunction(FusedLayerNormAffineFunction): @staticmethod - def forward(ctx, input, weight, bias, normalized_shape, eps): + def forward(ctx, input, weight, bias, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() bias_ = bias.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward_affine_mixed_dtypes( input_, ctx.normalized_shape, weight_, bias_, ctx.eps ) - ctx.save_for_backward(input_, weight_, bias_, mean, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, bias_, None, invvar) + else: + ctx.save_for_backward(input_, weight_, bias_, mean, invvar) return output class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction): @staticmethod - def forward(ctx, input, weight, normalized_shape, eps): + def forward(ctx, input, weight, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() weight_ = weight.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes( input_, ctx.normalized_shape, weight_, ctx.eps ) - - ctx.save_for_backward(input_, weight_, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) return output class FusedLayerNormFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, normalized_shape, eps): + def forward(ctx, input, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() output, mean, invvar = fused_layer_norm_cuda.forward(input_, ctx.normalized_shape, ctx.eps) - ctx.save_for_backward(input_, mean, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, None, invvar) + else: + ctx.save_for_backward(input_, mean, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, mean, invvar = ctx.saved_tensors - grad_input = None + input_or_output, mean, invvar = ctx.saved_tensors grad_input = fused_layer_norm_cuda.backward( - grad_output.contiguous(), mean, invvar, input_, ctx.normalized_shape, ctx.eps + grad_output.contiguous(), mean, invvar, input_or_output, + ctx.normalized_shape, ctx.eps, ctx.memory_efficient ) - return grad_input, None, None + return grad_input, None, None, None class FusedRMSNormFunction(torch.autograd.Function): @staticmethod - def forward(ctx, input, normalized_shape, eps): + def forward(ctx, input, normalized_shape, eps, memory_efficient): global fused_layer_norm_cuda if fused_layer_norm_cuda is None: fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") ctx.normalized_shape = normalized_shape ctx.eps = eps + ctx.memory_efficient = memory_efficient input_ = input.contiguous() output, invvar = fused_layer_norm_cuda.rms_forward(input_, ctx.normalized_shape, ctx.eps) - ctx.save_for_backward(input_, invvar) + if ctx.memory_efficient: + ctx.save_for_backward(output, invvar) + else: + ctx.save_for_backward(input_, invvar) return output @staticmethod def backward(ctx, grad_output): - input_, invvar = ctx.saved_tensors + input_or_output, invvar = ctx.saved_tensors grad_input = None grad_input = fused_layer_norm_cuda.rms_backward( - grad_output.contiguous(), invvar, input_, ctx.normalized_shape, ctx.eps + grad_output.contiguous(), invvar, input_or_output, + ctx.normalized_shape, ctx.eps, ctx.memory_efficient ) - return grad_input, None, None + return grad_input, None, None, None -def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) +def fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedLayerNormAffineFunction.apply(*args) -def fused_layer_norm(input, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, normalized_shape, eps) +def fused_layer_norm(input, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedLayerNormFunction.apply(*args) -def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps) +def mixed_dtype_fused_layer_norm_affine(input, weight, bias, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, bias, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedLayerNormAffineMixedDtypesFunction.apply(*args) -def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) +def fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedRMSNormAffineFunction.apply(*args) -def fused_rms_norm(input, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, normalized_shape, eps) +def fused_rms_norm(input, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedRMSNormFunction.apply(*args) -def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6): - args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps) +def mixed_dtype_fused_rms_norm_affine(input, weight, normalized_shape, eps=1e-6, memory_efficient=False): + args = _cast_if_autocast_enabled(input, weight, normalized_shape, eps, memory_efficient) with torch.cuda.amp.autocast(enabled=False): return FusedRMSNormAffineMixedDtypesFunction.apply(*args) @@ -261,7 +287,7 @@ class FusedLayerNorm(torch.nn.Module): .. _`Layer Normalization`: https://arxiv.org/abs/1607.06450 """ - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, memory_efficient=False): super().__init__() global fused_layer_norm_cuda @@ -272,6 +298,7 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine + self.memory_efficient = memory_efficient if self.elementwise_affine: self.weight = Parameter(torch.empty(*normalized_shape)) self.bias = Parameter(torch.empty(*normalized_shape)) @@ -289,9 +316,11 @@ def forward(self, input): if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) if self.elementwise_affine: - return fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) + return fused_layer_norm_affine( + input, self.weight, self.bias, self.normalized_shape, self.eps, self.memory_efficient + ) else: - return fused_layer_norm(input, self.normalized_shape, self.eps) + return fused_layer_norm(input, self.normalized_shape, self.eps, self.memory_efficient) def extra_repr(self): return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) @@ -357,7 +386,7 @@ class FusedRMSNorm(torch.nn.Module): .. _`Root Mean Square Layer Normalization`: https://arxiv.org/pdf/1910.07467.pdf """ - def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, memory_efficient=False): super().__init__() global fused_layer_norm_cuda @@ -368,6 +397,7 @@ def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine + self.memory_efficient = memory_efficient if self.elementwise_affine: self.weight = Parameter(torch.empty(*normalized_shape)) else: @@ -383,9 +413,11 @@ def forward(self, input): return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) if self.elementwise_affine: - return fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) + return fused_rms_norm_affine( + input, self.weight, self.normalized_shape, self.eps, self.memory_efficient + ) else: - return fused_rms_norm(input, self.normalized_shape, self.eps) + return fused_rms_norm(input, self.normalized_shape, self.eps, self.memory_efficient) def extra_repr(self): return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__) @@ -397,7 +429,7 @@ def extra_repr(self): # See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" class MixedFusedLayerNorm(FusedLayerNorm): - def __init__(self, normalized_shape, eps=1e-5, **kwargs): + def __init__(self, normalized_shape, eps=1e-5, *, memory_efficient=False, **kwargs): if "elementwise_affine" in kwargs: import warnings warnings.warn("MixedFusedLayerNorm does not support `elementwise_affine` argument") @@ -405,13 +437,16 @@ def __init__(self, normalized_shape, eps=1e-5, **kwargs): if not elementwise_affine: raise RuntimeError("MixedFusedLayerNorm does not support `elementwise_affine = False`") - super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) - + super().__init__( + normalized_shape=normalized_shape, eps=eps, elementwise_affine=True, memory_efficient=memory_efficient + ) def forward(self, input: torch.Tensor): # NOTE (mkozuki): CPU path is here mainly for unittest sake. if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps) - return mixed_dtype_fused_layer_norm_affine(input, self.weight, self.bias, self.normalized_shape, self.eps) + return mixed_dtype_fused_layer_norm_affine( + input, self.weight, self.bias, self.normalized_shape, self.eps, self.memory_efficient + ) # MixedFusedLayerNorm differs from FusedLayerNorm in that this layer norm uses parameter's dtype @@ -419,7 +454,7 @@ def forward(self, input: torch.Tensor): # See: `layer_norm_affine` and `layer_norm_affine_mixed_dtypes` in "csrc/layer_norm_cuda.cpp" class MixedFusedRMSNorm(FusedRMSNorm): - def __init__(self, normalized_shape, eps=1e-5, **kwargs): + def __init__(self, normalized_shape, eps=1e-5, *, memory_efficient=False, **kwargs): if "elementwise_affine" in kwargs: import warnings warnings.warn("MixedFusedRMSNorm does not support `elementwise_affine` argument") @@ -427,11 +462,14 @@ def __init__(self, normalized_shape, eps=1e-5, **kwargs): if not elementwise_affine: raise RuntimeError("MixedFusedRMSNorm does not support `elementwise_affine = False`") - super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=True) - + super().__init__( + normalized_shape=normalized_shape, eps=eps, elementwise_affine=True, memory_efficient=memory_efficient + ) def forward(self, input: torch.Tensor): # NOTE (mkozuki): CPU path is here mainly for unittest sake. # TODO Manual RMS Norm Implementation Here if torch.jit.is_tracing() or torch.jit.is_scripting() or not input.is_cuda: return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) - return mixed_dtype_fused_rms_norm_affine(input, self.weight, self.normalized_shape, self.eps) + return mixed_dtype_fused_rms_norm_affine( + input, self.weight, self.normalized_shape, self.eps, self.memory_efficient + ) diff --git a/csrc/layer_norm_cuda.cpp b/csrc/layer_norm_cuda.cpp index 005906103..588375f6f 100644 --- a/csrc/layer_norm_cuda.cpp +++ b/csrc/layer_norm_cuda.cpp @@ -214,7 +214,7 @@ void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -227,38 +227,45 @@ void cuda_layer_norm_gradient( double epsilon, at::Tensor* grad_input, at::Tensor* grad_gamma, - at::Tensor* grad_beta + at::Tensor* grad_beta, + bool memory_efficient ); at::Tensor layer_norm_gradient( at::Tensor dout, - at::Tensor mean, + c10::optional mean_, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else at::IntList normalized_shape, #endif - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); - CHECK_INPUT(mean); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); int n1,n2; - check_args(input,normalized_shape,n1,n2); - at::Tensor grad_input = at::empty_like(input); - cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, - normalized_shape,NULL,NULL,epsilon, - &grad_input,NULL,NULL); + check_args(input_or_output,normalized_shape,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); + if (mean_.has_value()) { + cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2, + normalized_shape,NULL,NULL,epsilon, + &grad_input,NULL,NULL,memory_efficient); + } else { + cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2, + normalized_shape,NULL,NULL,epsilon, + &grad_input,NULL,NULL,memory_efficient); + } return grad_input; } std::vector layer_norm_gradient_affine( at::Tensor dout, - at::Tensor mean, + c10::optional mean_, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else @@ -266,21 +273,28 @@ std::vector layer_norm_gradient_affine( #endif at::Tensor gamma, at::Tensor beta, - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); - CHECK_INPUT(mean); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); CHECK_INPUT(gamma); CHECK_INPUT(beta); int n1,n2; - check_args(input,normalized_shape,gamma,beta,n1,n2); - at::Tensor grad_input = at::empty_like(input); + check_args(input_or_output,normalized_shape,gamma,beta,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); at::Tensor grad_gamma = at::empty_like(gamma); at::Tensor grad_beta = at::empty_like(beta); - cuda_layer_norm_gradient(&dout,&mean,&invvar,&input,n1,n2, - normalized_shape,&gamma,&beta,epsilon, - &grad_input,&grad_gamma,&grad_beta); +// at::Tensor *mean = mean_.has_value() ? &mean_.value() : NULL; + if (mean_.has_value()) { + cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2, + normalized_shape,&gamma,&beta,epsilon, + &grad_input,&grad_gamma,&grad_beta,memory_efficient); + } else { + cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2, + normalized_shape,&gamma,&beta,epsilon, + &grad_input,&grad_gamma,&grad_beta,memory_efficient); + } return {grad_input, grad_gamma, grad_beta}; } @@ -364,7 +378,7 @@ std::vector rms_norm_affine_mixed_dtypes( void cuda_rms_norm_gradient( at::Tensor* dout, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -375,52 +389,55 @@ void cuda_rms_norm_gradient( at::Tensor* gamma, double epsilon, at::Tensor* grad_input, - at::Tensor* grad_gamma); + at::Tensor* grad_gamma, + bool memory_efficient); at::Tensor rms_norm_gradient( at::Tensor dout, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else at::IntList normalized_shape, #endif - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); int n1,n2; - check_args(input,normalized_shape,n1,n2); - at::Tensor grad_input = at::empty_like(input); - cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2, + check_args(input_or_output,normalized_shape,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); + cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2, normalized_shape,NULL,epsilon, - &grad_input,NULL); + &grad_input,NULL,memory_efficient); return grad_input; } std::vector rms_norm_gradient_affine( at::Tensor dout, at::Tensor invvar, - at::Tensor input, + at::Tensor input_or_output, #ifdef VERSION_GE_1_1 at::IntArrayRef normalized_shape, #else at::IntList normalized_shape, #endif at::Tensor gamma, - double epsilon) { + double epsilon, + bool memory_efficient) { CHECK_INPUT(dout); CHECK_INPUT(invvar); - CHECK_INPUT(input); + CHECK_INPUT(input_or_output); CHECK_INPUT(gamma); int n1,n2; - check_args(input,normalized_shape,gamma,n1,n2); - at::Tensor grad_input = at::empty_like(input); + check_args(input_or_output,normalized_shape,gamma,n1,n2); + at::Tensor grad_input = at::empty_like(input_or_output); at::Tensor grad_gamma = at::empty_like(gamma); - cuda_rms_norm_gradient(&dout,&invvar,&input,n1,n2, + cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2, normalized_shape,&gamma,epsilon, - &grad_input,&grad_gamma); + &grad_input,&grad_gamma,memory_efficient); return {grad_input, grad_gamma}; } diff --git a/csrc/layer_norm_cuda_kernel.cu b/csrc/layer_norm_cuda_kernel.cu index 21366772c..4e80e057a 100644 --- a/csrc/layer_norm_cuda_kernel.cu +++ b/csrc/layer_norm_cuda_kernel.cu @@ -7,6 +7,7 @@ #include #include "type_shim.h" +#include "static_switch.h" template __device__ void cuWelfordOnlineSum( @@ -437,7 +438,28 @@ void cuApplyRMSNorm( cuApplyLayerNorm_(output_vals, NULL, invvar, vals, n1, n2, epsilon, gamma, NULL, true); } -template __device__ + +template __device__ +V clamp_by_magnitude(V curr_gamma, double eps) +{ + const V kMinGamma = V(eps); + if (curr_gamma >= 0) { + if (curr_gamma < kMinGamma) { + return kMinGamma; + } else { + return curr_gamma; + } + } else { + if (curr_gamma > -kMinGamma) { + return -kMinGamma; + } else { + return curr_gamma; + } + } +} + + +template __device__ void cuLoadWriteStridedInputs( const int i1_block, const int thr_load_row_off, @@ -446,34 +468,41 @@ void cuLoadWriteStridedInputs( const int row_stride, U* warp_buf1, U* warp_buf2, - const T* input, + const T* input_or_output, const V* dout, const int i1_end, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, + const V* __restrict__ gamma, + const V* __restrict__ beta, + const double eps, bool rms_only ) { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1*n2+i2; int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; if (i2(input[load_idx]); + U c_h = static_cast(input_or_output[load_idx]); U curr_dout = static_cast(dout[load_idx]); if (!rms_only) { warp_buf1[write_idx] = curr_dout; - warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + if (MemoryEfficient) { + U curr_beta = static_cast(beta[i2]); + warp_buf2[write_idx] = curr_dout * (c_h - curr_beta) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] = curr_dout * (c_h - mean[i1]) * invvar[i1]; + } } else { - warp_buf2[write_idx] = curr_dout * (curr_input) * curr_invvar; + if (MemoryEfficient) { + warp_buf2[write_idx] = curr_dout * (c_h) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] = curr_dout * (c_h) * invvar[i1]; + } } } else { if (!rms_only) { @@ -493,7 +522,7 @@ void cuLoadWriteStridedInputs( } } -template __device__ +template __device__ void cuLoadAddStridedInputs( const int i1_block, const int thr_load_row_off, @@ -502,34 +531,41 @@ void cuLoadAddStridedInputs( const int row_stride, U* warp_buf1, U* warp_buf2, - const T* input, + const T* input_or_output, const V* dout, const int i1_end, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, + const V* __restrict__ gamma, + const V* __restrict__ beta, + const double eps, bool rms_only ) { int i1 = i1_block+thr_load_row_off; if (i1 < i1_end) { - U curr_mean; - if (!rms_only) { - curr_mean = mean[i1]; - } - U curr_invvar = invvar[i1]; for (int k = 0; k < blockDim.y; ++k) { int i2 = i2_off + k; int load_idx = i1*n2+i2; int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; if (i2(input[load_idx]); + U c_h = static_cast(input_or_output[load_idx]); U curr_dout = static_cast(dout[load_idx]); if (!rms_only) { + U curr_beta = static_cast(beta[i2]); warp_buf1[write_idx] += curr_dout; - warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + if (MemoryEfficient) { + warp_buf2[write_idx] += curr_dout * (c_h - curr_beta) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] += curr_dout * (c_h - mean[i1]) * invvar[i1]; + } } else { - warp_buf2[write_idx] += curr_dout * (curr_input) * curr_invvar; + if (MemoryEfficient) { + warp_buf2[write_idx] += curr_dout * (c_h) / static_cast(clamp_by_magnitude(gamma[i2], eps)); + } else { + warp_buf2[write_idx] += curr_dout * (c_h) * invvar[i1]; + } } } } @@ -537,17 +573,20 @@ void cuLoadAddStridedInputs( } -template __global__ +template __global__ void cuComputePartGradGammaBeta( const V* __restrict__ dout, - const T* __restrict__ input, + const T* __restrict__ input_or_output, const int n1, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta, U* part_grad_gamma, U* part_grad_beta, + const double eps, bool rms_only) { const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); @@ -565,9 +604,9 @@ void cuComputePartGradGammaBeta( U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; // compute partial sums from strided inputs // do this to increase number of loads in flight - cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only); + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps, rms_only); for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { - cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar, rms_only); + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input_or_output,dout,i1_end,n2,mean,invvar,gamma,beta,eps, rms_only); } __syncthreads(); // inter-warp reductions @@ -675,78 +714,108 @@ void cuComputeGradGammaBeta( } -template __global__ +template __global__ void cuComputeGradInput( const V* __restrict__ dout, - const T* __restrict__ input, + const T* __restrict__ input_or_output, const int n1, const int n2, const U* __restrict__ mean, const U* __restrict__ invvar, U epsilon, const V* gamma, + const V* beta, T* grad_input, + const double eps, bool rms_only) { for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { U sum_loss1 = U(0); U sum_loss2 = U(0); - U c_mean; - if (!rms_only) { - c_mean = mean[i1]; - } - const U c_invvar = invvar[i1]; - const T* k_input = input + i1*n2; + const T* k_h = input_or_output + i1*n2; const V* k_dout = dout + i1*n2; + const U c_invvar = invvar[i1]; + const U c_mean = !MemoryEfficient ? mean[i1] : 0.; const int numx = blockDim.x * blockDim.y; const int thrx = threadIdx.x + threadIdx.y * blockDim.x; if (gamma != NULL) { int l = 4*thrx; for (; l+3 < n2; l+=4*numx) { for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l+k]); + const U c_h = static_cast(k_h[l+k]); const U c_loss = static_cast(k_dout[l+k]); if (!rms_only) { sum_loss1 += c_loss * gamma[l+k]; - sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * (c_h - beta[l+k]); + } else { + sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * gamma[l+k] * (c_h) * c_invvar; + } } } } for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); if (!rms_only) { sum_loss1 += c_loss * gamma[l]; - sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * (c_h - beta[l]); + } else { + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * gamma[l] * (c_h) * c_invvar; + } } - } } else { int l = 4*thrx; for (; l+3 < n2; l+=4*numx) { for (int k = 0; k < 4; ++k) { - const U c_h = static_cast(k_input[l+k]); + const U c_h = static_cast(k_h[l+k]); const U c_loss = static_cast(k_dout[l+k]); if (!rms_only) { sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h) * c_invvar; + } } } } for (; l < n2; ++l) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); if (!rms_only) { sum_loss1 += c_loss; - sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } } else { - sum_loss2 += c_loss * (c_h) * c_invvar; + if (MemoryEfficient) { + sum_loss2 += c_loss * c_h; + } else { + sum_loss2 += c_loss * (c_h) * c_invvar; + } } } } @@ -801,28 +870,46 @@ void cuComputeGradInput( T* k_grad_input = grad_input + i1*n2; if (gamma != NULL) { for (int l = thrx; l < n2; l+=numx) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); - U f_grad_input = fH * c_loss * gamma[l]; + const U k_gamma = static_cast(clamp_by_magnitude(gamma[l], eps)); + U f_grad_input = fH * c_loss * k_gamma; if (!rms_only) { + const U k_beta = beta[l]; f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= (c_h - k_beta) / k_gamma * sum_loss2; + } else { + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } } else { - f_grad_input -= (c_h) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= c_h / k_gamma * sum_loss2; + } else { + f_grad_input -= c_h * c_invvar * sum_loss2; + } } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); } } else { for (int l = thrx; l < n2; l+=numx) { - const U c_h = static_cast(k_input[l]); + const U c_h = static_cast(k_h[l]); const U c_loss = static_cast(k_dout[l]); U f_grad_input = fH * c_loss; if (!rms_only) { f_grad_input -= sum_loss1; - f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= c_h * sum_loss2; + } else { + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + } } else { - f_grad_input -= (c_h) * c_invvar * sum_loss2; + if (MemoryEfficient) { + f_grad_input -= c_h * sum_loss2; + } else { + f_grad_input -= c_h * c_invvar * sum_loss2; + } } f_grad_input *= term1; k_grad_input[l] = static_cast(f_grad_input); @@ -947,7 +1034,7 @@ void HostLayerNormGradient( const V* dout, const U* mean, const U* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, const V* gamma, @@ -955,7 +1042,8 @@ void HostLayerNormGradient( double epsilon, T* grad_input, V* grad_gamma, - V* grad_beta + V* grad_beta, + bool memory_efficient ) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -971,21 +1059,27 @@ void HostLayerNormGradient( // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that // the `cuda_layer_norm_gradient` doesn't support double. const auto part_grad_dtype = - (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + (input_or_output->scalar_type() == at::ScalarType::Half || input_or_output->scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : - input->scalar_type(); - at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); + input_or_output->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size,n2}, input_or_output->options().dtype(part_grad_dtype)); at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); - cuComputePartGradGammaBeta<<>>( - dout, - input->DATA_PTR(), - n1,n2, - mean, - invvar, - U(epsilon), - part_grad_gamma.DATA_PTR(), - part_grad_beta.DATA_PTR(), - false); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{ + auto kernel = &cuComputePartGradGammaBeta; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + beta, + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + epsilon, + false); + }); const dim3 threads3(32,8,1); const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); @@ -1008,29 +1102,35 @@ void HostLayerNormGradient( threads1.y > 1 ? threads1.y*threads1.x*sizeof(U) : 0; - cuComputeGradInput<<>>( - dout, - input->DATA_PTR(), - n1,n2, - mean, - invvar, - U(epsilon), - gamma, - grad_input, - false); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] { + auto kernel = cuComputeGradInput; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + beta, + grad_input, + epsilon, + false); + }); } template void HostRMSNormGradient( const V* dout, const U* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, const V* gamma, double epsilon, T* grad_input, - V* grad_gamma) + V* grad_gamma, + bool memory_efficient) { auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -1044,20 +1144,27 @@ void HostRMSNormGradient( // note (mkozuki): I can hard code part_grad_gamma's dtype as float given that // the `cuda_layer_norm_gradient` doesn't support double. const auto part_grad_dtype = - (input->scalar_type() == at::ScalarType::Half || input->scalar_type() == at::ScalarType::BFloat16) ? + (input_or_output->scalar_type() == at::ScalarType::Half || input_or_output->scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : - input->scalar_type(); - at::Tensor part_grad_gamma = at::empty({part_size,n2}, input->options().dtype(part_grad_dtype)); - cuComputePartGradGammaBeta<<>>( - dout, - input->DATA_PTR(), - n1,n2, - invvar, // unused - invvar, - U(epsilon), - part_grad_gamma.DATA_PTR(), - part_grad_gamma.DATA_PTR(), /* unused */ - true); + input_or_output->scalar_type(); + at::Tensor part_grad_gamma = at::empty({part_size,n2}, input_or_output->options().dtype(part_grad_dtype)); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&]{ + auto kernel = &cuComputePartGradGammaBeta; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + gamma, /* unused */ + part_grad_gamma.DATA_PTR(), + part_grad_gamma.DATA_PTR(), /* unused */ + epsilon, + true); + }); + const dim3 threads3(32,8,1); const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); @@ -1080,23 +1187,28 @@ void HostRMSNormGradient( threads1.y > 1 ? threads1.y*threads1.x*sizeof(U) : 0; - cuComputeGradInput<<>>( - dout, - input->DATA_PTR(), - n1,n2, - invvar, /* unused */ - invvar, - U(epsilon), - gamma, - grad_input, - true); + BOOL_SWITCH(memory_efficient, MemoryEfficient, [&] { + auto kernel = cuComputeGradInput; + kernel<<>>( + dout, + input_or_output->DATA_PTR(), + n1,n2, + invvar, /* unused */ + invvar, + U(epsilon), + gamma, + gamma, /* unused */ + grad_input, + epsilon, + true); + }); } void cuda_layer_norm_gradient( at::Tensor* dout, at::Tensor* mean, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -1109,18 +1221,19 @@ void cuda_layer_norm_gradient( double epsilon, at::Tensor* grad_input, at::Tensor* grad_gamma, - at::Tensor* grad_beta) + at::Tensor* grad_beta, + bool memory_efficient) { using namespace at; // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInput", + input_or_output->scalar_type(), gamma == NULL ? input_or_output->scalar_type() : gamma->scalar_type(), "cuComputeGradInput", using accscalar_t = at::acc_type; HostLayerNormGradient( dout->DATA_PTR(), - mean->DATA_PTR(), + mean != NULL ? mean->DATA_PTR() : NULL, invvar->DATA_PTR(), - input, + input_or_output, n1,n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. @@ -1129,14 +1242,15 @@ void cuda_layer_norm_gradient( epsilon, grad_input->DATA_PTR(), gamma != NULL ? grad_gamma->DATA_PTR() : NULL, - gamma != NULL ? grad_beta->DATA_PTR() : NULL); + gamma != NULL ? grad_beta->DATA_PTR() : NULL, + memory_efficient); ) } void cuda_rms_norm_gradient( at::Tensor* dout, at::Tensor* invvar, - at::Tensor* input, + at::Tensor* input_or_output, int n1, int n2, #ifdef VERSION_GE_1_1 @@ -1147,24 +1261,26 @@ void cuda_rms_norm_gradient( at::Tensor* gamma, double epsilon, at::Tensor* grad_input, - at::Tensor* grad_gamma) + at::Tensor* grad_gamma, + bool memory_efficient) { using namespace at; // we can do away with `accscalar_t` as there're only three dtypes: fp32, fp16, bf16 // DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( - input->scalar_type(), gamma == NULL ? input->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS", + input_or_output->scalar_type(), gamma == NULL ? input_or_output->scalar_type() : gamma->scalar_type(), "cuComputeGradInputRMS", using accscalar_t = at::acc_type; HostRMSNormGradient( dout->DATA_PTR(), invvar->DATA_PTR(), - input, + input_or_output, n1,n2, // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta // if gamma Tensor is NULL on input. gamma != NULL ? gamma->DATA_PTR() : NULL, epsilon, grad_input->DATA_PTR(), - gamma != NULL ? grad_gamma->DATA_PTR() : NULL); + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + memory_efficient); ) } diff --git a/csrc/static_switch.h b/csrc/static_switch.h new file mode 100644 index 000000000..1ba09857b --- /dev/null +++ b/csrc/static_switch.h @@ -0,0 +1,25 @@ +// From +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py index 13dee874b..94c30057f 100644 --- a/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py +++ b/tests/L0/run_fused_layer_norm/test_fused_layer_norm.py @@ -21,7 +21,7 @@ def _prep_inputs(batch_size, normalized_shape, dtype): class TestFusedLayerNorm(common_utils.TestCase): def _test_fused_layer_norm( - self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=None, atol=None), bwd_thresholds=dict(rtol=None, atol=None) ): @@ -29,15 +29,19 @@ def _test_fused_layer_norm( if not mixed_fused: module_cpu_ = FusedLayerNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine).cpu() + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).cpu() module_cuda_ = FusedLayerNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine).to(device="cuda", dtype=dtype) + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).to(device="cuda", dtype=dtype) else: assert elementwise_affine module_cpu_ = MixedFusedLayerNorm( - normalized_shape=normalized_shape).cpu() + normalized_shape=normalized_shape, memory_efficient=memory_efficient + ).cpu() module_cuda_ = MixedFusedLayerNorm( - normalized_shape=normalized_shape).to(device="cuda", dtype=dtype) + normalized_shape=normalized_shape, memory_efficient=memory_efficient + ).to(device="cuda", dtype=dtype) torch.cuda.manual_seed(42) if contiguous: @@ -70,7 +74,7 @@ def _test_fused_layer_norm( input_.grad.to(device="cuda", dtype=dtype), input_cuda_.grad, **bwd_thresholds) def _test_fused_rms_norm( - self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=None, atol=None), bwd_thresholds=dict(rtol=None, atol=None) ): @@ -78,9 +82,11 @@ def _test_fused_rms_norm( if not mixed_fused: module_cpu_ = FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine).cpu() + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).cpu() module_cuda_ = FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine).to(device="cuda", dtype=dtype) + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient + ).to(device="cuda", dtype=dtype) else: assert elementwise_affine module_cpu_ = MixedFusedRMSNorm( @@ -123,87 +129,87 @@ def _test_fused_rms_norm( # layer norm tests @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (False,), (False,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (False,), (False,), (torch.float,), (True, False))) ) - def test_layer_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype) + def test_layer_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (True,), (False,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (False,), (torch.float,), (True, False))) ) - def test_layer_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype) + def test_layer_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (True,), (True,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (True,), (torch.float,), (True, False))) ) - def test_layer_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype) + def test_layer_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16,), (True, False), (True,), (False,), (torch.half,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.half,), (True, False))) ) - def test_layer_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_layer_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=1e-3, atol=1e-3), bwd_thresholds=dict(rtol=1e-3, atol=1e-3)) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,), (True, False))) ) - def test_layer_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_layer_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_layer_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=1.6e-2, atol=3e-4), bwd_thresholds=dict(rtol=1.6e-2, atol=3e-3)) # rms norm tests @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (False,), (False,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (False,), (False,), (torch.float,), (True, False))) ) - def test_rms_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype) + def test_rms_norm_regular(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (True,), (False,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (False,), (torch.float,), (True, False))) ) - def test_rms_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_rms_norm_elemwise(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, bwd_thresholds=dict(rtol=2e-3, atol=2e-4)) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16, 65536), (True, False), (True,), (True,), (torch.float,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16, 65536), (True, False), (True,), (True,), (torch.float,), (True, False))) ) - def test_rms_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_rms_norm_mixed(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, bwd_thresholds=dict(rtol=2e-3, atol=2e-4)) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16,), (True, False), (True,), (False,), (torch.half,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.half,), (True, False))) ) - def test_rms_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_rms_norm_half(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3)) @common_utils.parametrize( - "batch_size, contiguous, elementwise_affine, mixed_fused, dtype", - list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,))) + "batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient", + list(product((16,), (True, False), (True,), (False,), (torch.bfloat16,), (True, False))) ) - def test_rms_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype): - self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, + def test_rms_norm_bfloat16(self, batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient): + self._test_fused_rms_norm(batch_size, contiguous, elementwise_affine, mixed_fused, dtype, memory_efficient, fwd_thresholds=dict(rtol=1.6e-2, atol=3e-4), bwd_thresholds=dict(rtol=1.6e-2, atol=3e-2)) @common_utils.parametrize( - "dtype, elementwise_affine", - list(product(autocast_dtypes, (True, False))) + "dtype, elementwise_affine, memory_efficient", + list(product(autocast_dtypes, (True, False), (True, False))) ) - def test_autocast_fused_layer_norm(self, dtype, elementwise_affine): + def test_autocast_fused_layer_norm(self, dtype, elementwise_affine, memory_efficient): bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) batch_size = 16 @@ -212,7 +218,7 @@ def test_autocast_fused_layer_norm(self, dtype, elementwise_affine): normalized_shape=normalized_shape, elementwise_affine=elementwise_affine ).to(device="cuda", dtype=dtype) fused = FusedLayerNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient ).cuda() native_x, fused_x = _prep_inputs(batch_size, normalized_shape, dtype) @@ -230,22 +236,27 @@ def test_autocast_fused_layer_norm(self, dtype, elementwise_affine): expected.backward(g_native) actual.backward(g_fused) - tols = {'rtol': None, 'atol': None} if dtype == torch.half else bf16_bwd_thresholds + if dtype != torch.half: + tols = bf16_bwd_thresholds + elif memory_efficient: + tols = {'rtol': 1e-3, 'atol': 1e-4} + else: + tols = {'rtol': None, 'atol': None} torch.testing.assert_close(native_x.grad, fused_x.grad, **tols, check_dtype=False) @common_utils.parametrize( - "dtype, elementwise_affine", - list(product(autocast_dtypes, (True, False))) + "dtype, elementwise_affine, memory_efficient", + list(product(autocast_dtypes, (True, False), (True, False))) ) - def test_autocast_fused_rms_norm(self, dtype, elementwise_affine): + def test_autocast_fused_rms_norm(self, dtype, elementwise_affine, memory_efficient): bf16_fwd_thresholds = dict(rtol=1.6e-2, atol=3e-4) bf16_bwd_thresholds = dict(rtol=1.6e-2, atol=3e-3) batch_size = 16 normalized_shape = [32, 16] native = FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient, ).to(dtype=dtype) fused = FusedRMSNorm( - normalized_shape=normalized_shape, elementwise_affine=elementwise_affine + normalized_shape=normalized_shape, elementwise_affine=elementwise_affine, memory_efficient=memory_efficient, ).cuda() native_x, fused_x = _prep_inputs(batch_size, normalized_shape, dtype) From 82b7195e0b97a76b1bf57b4b06d222a2c0858aff Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 29 Sep 2023 09:40:46 -0700 Subject: [PATCH 4/7] Distributed optimizer infrastructure for FP8 parameters (#1723) * Add distopt support for param syncs with non-floating-point dtypes Signed-off-by: Tim Moon * Update apex/contrib/optimizers/distributed_fused_adam.py Co-authored-by: Masaki Kozuki --------- Signed-off-by: Tim Moon Co-authored-by: Masaki Kozuki --- .../optimizers/distributed_fused_adam.py | 348 +++++++++++++----- .../contrib/test/optimizers/test_dist_adam.py | 15 + 2 files changed, 280 insertions(+), 83 deletions(-) diff --git a/apex/contrib/optimizers/distributed_fused_adam.py b/apex/contrib/optimizers/distributed_fused_adam.py index f91a71a5c..565bdc1fe 100644 --- a/apex/contrib/optimizers/distributed_fused_adam.py +++ b/apex/contrib/optimizers/distributed_fused_adam.py @@ -6,7 +6,17 @@ import io import itertools import threading -from typing import Any, Callable, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import warnings import torch @@ -170,18 +180,28 @@ def _multi_tensor_copy( # Just copy bytes if dtypes are same buf_in = buf_in.view(torch.uint8) buf_out = buf_out.view(torch.uint8) - key = (buf_in.is_cuda, buf_in.dtype, buf_out.is_cuda, buf_out.dtype) + is_cuda = ( + _devices_match(buf_in.device, "cuda") + and _devices_match(buf_out.device, "cuda") + ) + is_contiguous = buf_in.is_contiguous() and buf_out.is_contiguous() + key = ( + buf_in.dtype, + buf_out.dtype, + is_cuda, + is_contiguous, + ) buffer_groups[key].append((buf_in, buf_out)) # Copy each group of buffers for key, buffers in buffer_groups.items(): # Check if buffers support fused kernel - is_cuda_in, dtype_in, is_cuda_out, dtype_out = key + dtype_in, dtype_out, is_cuda, is_contiguous = key supported_dtypes = (torch.float32, torch.float16) use_fused_kernel = ( dtype_in in supported_dtypes and dtype_out in supported_dtypes ) or (dtype_in == torch.uint8 and dtype_out == torch.uint8) - use_fused_kernel = use_fused_kernel and is_cuda_in and is_cuda_out + use_fused_kernel = use_fused_kernel and is_cuda and is_contiguous # Copy buffers if use_fused_kernel and _FOUND_DEPRECATED_FUSED_ADAM: @@ -437,6 +457,7 @@ def __init__( ) def dtypes(self) -> Tuple[torch.dtype, torch.dtype, torch.dtype]: + """Datatypes for the bucket's compute and communication""" return ( self.dtype, self.grad_sync_dtype, @@ -549,9 +570,8 @@ def __init__( if ( dtype not in supported_dtypes or grad_sync_dtype not in supported_dtypes - or param_sync_dtype not in supported_dtypes ): - raise RuntimeError( + raise ValueError( "Unsupported dtypes for DistributedFusedAdam " f"(dtype={dtype}, " f"grad_sync_dtype={grad_sync_dtype}, " @@ -1032,6 +1052,39 @@ def parameters(self) -> Iterable[torch.nn.Parameter]: group["params"] for group in self.param_groups ) + def parameter( + self, + *args: Union[int, ParameterFragment], + ) -> torch.nn.Parameter: + """Get optimizer parameter + + Can either accept two ints or one + DistributedFusedAdam.ParameterFragment. + + Arguments: + param_group_id (int): Parameter group index + param_id (int): Parameter index within parameter group + + """ + if ( + len(args) == 2 + and isinstance(args[0], int) + and isinstance(args[1], int) + ): + param_group_id = args[0] + param_id = args[1] + elif len(args) == 1 and isinstance(args[0], self.ParameterFragment): + fragment = args[0] + param_group_id = fragment.param_group_id + param_id = fragment.param_id + else: + raise TypeError( + "Expected input types are " + "[int, int] or [DistributedFusedAdam.ParameterFragment], " + f"but found {[type(arg).__name__ for arg in args]}" + ) + return self.param_groups[param_group_id]["params"][param_id] + def init_params( self, params: Optional[Iterable[torch.nn.Parameter]] = None, @@ -1189,9 +1242,23 @@ def _init_param_state( grad_sync_dtype = self.grad_sync_dtype if param_sync_dtype is None: param_sync_dtype = self.param_sync_dtype - assert ( - dtype == self.dtype - ), "Optimizer states with non-default dtypes are not supported" + if dtype != self.dtype: + raise ValueError( + "Optimizer states with non-default dtypes are not supported" + ) + supported_dtypes = (torch.float32, torch.float16, torch.bfloat16) + if ( + dtype not in supported_dtypes + or grad_sync_dtype not in supported_dtypes + ): + raise ValueError( + "Unsupported dtypes for DistributedFusedAdam " + f"(dtype={dtype}, " + f"grad_sync_dtype={grad_sync_dtype}, " + f"param_sync_dtype={param_sync_dtype}))" + ) + + # Store params or param remainders store_params = ( self.store_params or dtype != self.dtype @@ -1370,7 +1437,12 @@ def zero_grad(self, set_to_none: bool = False) -> None: ) def _grad_copy(self, param: torch.nn.Parameter) -> None: - """Copy parameter gradients to buckets""" + """Copy parameter gradients to gradient buckets + + Initializes gradient buckets if needed. The original parameter + gradient is set to None. + + """ # Initialize parameter if needed if "fragments" not in self.state[param]: @@ -1431,8 +1503,15 @@ def _grad_copy(self, param: torch.nn.Parameter) -> None: # Free param grad buffer param.grad = None - def _param_copy(self, params: torch.nn.Parameter) -> None: - """Update parameters with values from parameter buckets""" + def _param_copy( + self, + params: Union[torch.nn.Parameter, Iterable[torch.nn.Parameter]], + ) -> None: + """Update parameters with values from parameter buckets + + Synchronizes and deletes parameter buckets as needed. + + """ # Get parameter fragments to be synchronized if isinstance(params, torch.Tensor): @@ -1446,6 +1525,10 @@ def _param_copy(self, params: torch.nn.Parameter) -> None: if fragment.bucket_id in self._params_buckets ) + # Return immediately if no fragments need to be synchronized + if not fragments: + return + # Make sure all needed buckets have been synchronized buckets = collections.OrderedDict() for fragment in fragments: @@ -1459,37 +1542,67 @@ def _param_copy(self, params: torch.nn.Parameter) -> None: self._finish_bucket_param_sync() # Copy values from bucket buffers to params - params_in = [] - params_out = [] - for fragment in fragments: - bucket_id = fragment.bucket_id - param_group_id = fragment.param_group_id - param_id = fragment.param_id - bucket_start, bucket_end = fragment.bucket_range - param_start, param_end = fragment.param_range - if param_end > param_start: - bucket = self._params_buckets[bucket_id] - param = self.param_groups[param_group_id]["params"][param_id] - params_in.append(bucket.params_bucket[bucket_start:bucket_end]) - params_out.append(param.detach().view(-1)[param_start:param_end]) - _multi_tensor_copy( - params_in, - params_out, - dummy_overflow_buf=self._dummy_overflow_buf, - ) + self._param_copy_fragments(fragments) # Delete buckets if possible for fragment in fragments: bucket_id = fragment.bucket_id bucket = self._params_buckets[bucket_id] + bucket.params_updated.add(self.parameter(fragment)) bucket_fragments = self.state["buckets"][bucket_id].fragments - param_group_id = fragment.param_group_id - param_id = fragment.param_id - param = self.param_groups[param_group_id]["params"][param_id] - bucket.params_updated.add(param) if len(bucket.params_updated) == len(bucket_fragments): del self._params_buckets[bucket_id] + def _param_copy_fragments( + self, + fragments: Iterable[ParameterFragment], + ) -> None: + """Update parameter fragments with values from parameter buckets""" + + # Figure out corresponding positions in param buckets and params + buffers_in = [] + buffers_out = [] + for fragment in fragments: + + # Check if fragment needs to be updated + bucket_id = fragment.bucket_id + bucket_start, bucket_end = fragment.bucket_range + param_start, param_end = fragment.param_range + if param_end <= param_start or bucket_id not in self._params_buckets: + continue + + # Corresponding positions in param bucket and param + bucket = self._params_buckets[bucket_id] + param = self.parameter(fragment) + buffer_in = bucket.params_bucket[bucket_start:bucket_end] + buffer_out = param.detach().view(-1)[param_start:param_end] + + if ( + torch.is_floating_point(buffer_in) + and torch.is_floating_point(buffer_out) + ): + # Cast between floating-point dtypes + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + else: + # Copy most significant bytes for non-floating-point + # dtypes + # Note: Assume dtypes are little-endian + in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8) + out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8) + copy_size = min(in_bytes.size(-1), out_bytes.size(-1)) + buffers_in.append(in_bytes[..., -copy_size:]) + buffers_out.append(out_bytes[..., -copy_size:]) + if copy_size < out_bytes.size(-1): + out_bytes[..., :-copy_size].zero_() + + # Copy data from parameter buckets to parameters + _multi_tensor_copy( + buffers_in, + buffers_out, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + def grad_buffer_view(self, param: torch.nn.Parameter) -> torch.Tensor: """Construct view into grad buffer corresponding to param @@ -1725,11 +1838,11 @@ def _try_start_bucket_param_sync( return for bucket_id, bucket in self._params_buckets.items(): if bucket.status == self.ParameterStatus.SHARDED: - fragment = self.state["buckets"][bucket_id].fragments[-1] - param_group_id = fragment.param_group_id - param_id = fragment.param_id - param = self.param_groups[param_group_id]["params"][param_id] - params.append(param) + params.append( + self.parameter( + self.state["buckets"][bucket_id].fragments[-1] + ) + ) break # Find buckets corresponding to params @@ -1773,9 +1886,11 @@ def _start_bucket_param_sync(self, buckets: List[ParameterBucket]) -> None: ] for bucket in buckets: bucket.status = self.ParameterStatus.SYNCING - if self.distributed_size == 1: + if bucket.params_bucket is not None: + pass + elif self.distributed_size == 1: bucket.params_bucket = bucket.params_shard - elif bucket.params_bucket is None: + else: shard_size = bucket.params_shard.numel() bucket_size = shard_size * self.distributed_size bucket.params_bucket = torch.empty( @@ -1850,9 +1965,7 @@ def grad_sync(self) -> None: """Ensure that all gradients are synchronized""" for bucket in self.state["buckets"]: for fragment in bucket.fragments: - param_group_id = fragment.param_group_id - param_id = fragment.param_id - param = self.param_groups[param_group_id]["params"][param_id] + param = self.parameter(fragment) if param.grad is not None: self._grad_copy(param) if not self.contiguous_grad_buffer: @@ -1870,10 +1983,7 @@ def param_sync(self) -> None: while self._params_buckets: bucket_id, bucket = next(iter((self._params_buckets.items()))) for fragment in reversed(self.state["buckets"][bucket_id].fragments): - param_id = fragment.param_id - param_group_id = fragment.param_group_id - param = self.param_groups[param_group_id]["params"][param_id] - self._param_copy(param) + self._param_copy(self.parameter(fragment)) self._params_buckets.clear() @torch.no_grad() @@ -1904,11 +2014,7 @@ def _local_grad_norm( all_params_set = set() for bucket in self.state["buckets"]: for fragment in bucket.fragments: - param_group_id = fragment.param_group_id - param_id = fragment.param_id - all_params_set.add( - self.param_groups[param_group_id]["params"][param_id] - ) + all_params_set.add(self.parameter(fragment)) if not params_set.issubset(all_params_set): raise RuntimeError( "Attempted to compute gradient norm for a parameter " @@ -2095,11 +2201,18 @@ def step( self._grad_scale = self._grad_scale.to(dtype=torch.float32, device=self.device) # Initialize param shard buffers + overlap_first_bucket = ( + self.distributed_size > 1 + and self.overlap_param_sync + and self.state["buckets"] + ) for bucket_id in reversed(range(len(self.state["buckets"]))): params_bucket = self.ParameterBucket() state_bucket = self.state["buckets"][bucket_id] shard_size = state_bucket.shard_size + param_sync_dtype = state_bucket.param_sync_dtype if self.contiguous_param_buffer: + # Construct view into contiguous param buffer if not self._param_buffers: self.init_param_buffer() bucket_size = state_bucket.bucket_size @@ -2112,32 +2225,39 @@ def step( params_bucket.params_shard = params_bucket.params_bucket[ bucket_start:bucket_end ] + elif not param_sync_dtype.is_floating_point: + # Allocate temporary buffer for param shard + # Note: Adam kernel only supports floating-point + # dtypes. + params_bucket.params_shard = torch.empty( + [shard_size], + dtype=self.dtype, + device=self.device, + ) + overlap_first_bucket = False else: + # Allocate param shard buffer params_bucket.params_shard = torch.empty( [shard_size], - dtype=state_bucket.param_sync_dtype, + dtype=param_sync_dtype, device=self.device, ) self._params_buckets[bucket_id] = params_bucket - # Apply optimizer step and synchronize params + # Apply optimizer step self.state["step"] += 1 - if ( - self.distributed_size > 1 - and self.overlap_param_sync - and self.state["buckets"] - ): + if overlap_first_bucket: # Local step and non-blocking param sync # Note: Overlap param sync of first buckets with optimizer # step of remaining buckets. # Get buckets containing "first" parameter - fragment = self.state["buckets"][-1].fragments[-1] - param_group_id = fragment.param_group_id - param_id = fragment.param_id - param = self.param_groups[param_group_id]["params"][param_id] + first_param = self.parameter( + self.state["buckets"][-1].fragments[-1] + ) first_bucket_ids = sorted( - fragment.bucket_id for fragment in self.state[param]["fragments"] + fragment.bucket_id + for fragment in self.state[first_param]["fragments"] ) # Local step and launch param sync for first buckets @@ -2154,13 +2274,19 @@ def step( if bucket_id not in first_bucket_ids ) - # Enable pre-forward hook + else: + # Local step + self._local_step(list(range(len(self.state["buckets"])))) + self._check_params_shard_dtypes(self._params_buckets) + + # Synchronize params + if self.distributed_size > 1 and self.overlap_param_sync: + # Asynchronous param sync + self._try_start_bucket_param_sync() for param in self.parameters(): param._pre_forward_hook_is_enabled = True - else: - # Local step and blocking param sync - self._local_step(list(range(len(self.state["buckets"])))) + # Blocking param sync self.param_sync() return loss @@ -2213,10 +2339,8 @@ def _local_step(self, bucket_ids: Iterable[int]) -> None: if shard_end <= shard_start: continue shard_range = slice(shard_start, shard_end) - param_group_id = fragment.param_group_id - param_id = fragment.param_id if state_bucket.params_shard is None: - param = self.param_groups[param_group_id]["params"][param_id] + param = self.parameter(fragment) param_range = slice(*fragment.shard_param_range) param_fragment = param.detach().view(-1)[param_range] param_fragment = param_fragment.to( @@ -2226,7 +2350,7 @@ def _local_step(self, bucket_ids: Iterable[int]) -> None: params_shard = state_bucket.params_shard param_fragment = params_shard[shard_range] buffers_key = ( - param_group_id, + fragment.param_group_id, state_bucket.dtype, state_bucket.grad_sync_dtype, state_bucket.param_sync_dtype, @@ -2299,10 +2423,11 @@ def _local_step_with_param_remainders( if shard_end <= shard_start: continue shard_range = slice(shard_start, shard_end) - param_group_id = fragment.param_group_id - param_id = fragment.param_id - buffers_key = (param_group_id, state_bucket.grad_sync_dtype) - param = self.param_groups[param_group_id]["params"][param_id] + buffers_key = ( + fragment.param_group_id, + state_bucket.grad_sync_dtype, + ) + param = self.parameter(fragment) param_range = slice(*fragment.shard_param_range) param_fragment = param.detach().view(-1)[param_range] param_fragment = param_fragment.to( @@ -2338,6 +2463,67 @@ def _local_step_with_param_remainders( group["weight_decay"], ) + @torch.no_grad() + def _check_params_shard_dtypes( + self, + params_buckets: Dict[int, ParameterBucket], + ) -> None: + """Make sure local shards of parameters are in expected datatypes + + The Adam kernel only supports floating-point datatypes. If we + want to perform parameter synchronization with + non-floating-point dtypes, we need to allocate temporary + buffers that can accommodate the Adam kernel. This function is + responsible for converting these temporary buffers to the + parameter synchronization datatype. + + """ + + # Find param shards that require dtype conversion + buffers_in = [] + buffers_out = [] + for bucket_id, param_bucket in params_buckets.items(): + + # Check if param shard is already in expected dtype + state_bucket = self.state["buckets"][bucket_id] + param_sync_dtype = state_bucket.param_sync_dtype + if param_bucket.params_shard.dtype == param_sync_dtype: + continue + + # Allocate buffer with required dtype + buffer_in = param_bucket.params_shard + buffer_out = torch.empty_like( + param_bucket.params_shard, + dtype=param_sync_dtype, + ) + param_bucket.params_shard = buffer_out + + if ( + torch.is_floating_point(buffer_in) + and torch.is_floating_point(buffer_out) + ): + # Cast between floating-point dtypes + buffers_in.append(buffer_in) + buffers_out.append(buffer_out) + else: + # Copy most significant bytes for non-floating-point + # dtypes + # Note: Assume dtypes are little-endian + in_bytes = buffer_in.unsqueeze(-1).view(torch.uint8) + out_bytes = buffer_out.unsqueeze(-1).view(torch.uint8) + copy_size = min(in_bytes.size(-1), out_bytes.size(-1)) + buffers_in.append(in_bytes[..., -copy_size:]) + buffers_out.append(out_bytes[..., -copy_size:]) + if copy_size < out_bytes.size(-1): + out_bytes[..., :-copy_size].zero_() + + # Perform dtype conversions + _multi_tensor_copy( + buffers_in, + buffers_out, + dummy_overflow_buf=self._dummy_overflow_buf, + ) + def state_dict( self, *, @@ -2625,11 +2811,9 @@ def pack_param_shard(bucket_id: int) -> torch.Tensor: for fragment in bucket.fragments: if not fragment.in_local_shard: continue - param_id = fragment.param_id - param_group_id = fragment.param_group_id param_range = slice(*fragment.shard_param_range) shard_range = slice(*fragment.shard_range) - param = self.param_groups[param_group_id]["params"][param_id] + param = self.parameter(fragment) buffers_in.append(param.view(-1)[param_range]) buffers_out.append(shard_bf16[shard_range]) _multi_tensor_copy( @@ -2663,11 +2847,9 @@ def pack_param_shard(bucket_id: int) -> torch.Tensor: for fragment in bucket.fragments: if not fragment.in_local_shard: continue - param_id = fragment.param_id - param_group_id = fragment.param_group_id param_range = slice(*fragment.shard_param_range) shard_range = slice(*fragment.shard_range) - param = self.param_groups[param_group_id]["params"][param_id] + param = self.parameter(fragment) buffers_in.append(param.view(-1)[param_range]) buffers_out.append(shard[shard_range]) _multi_tensor_copy( diff --git a/apex/contrib/test/optimizers/test_dist_adam.py b/apex/contrib/test/optimizers/test_dist_adam.py index 4346be61f..298e5ec16 100644 --- a/apex/contrib/test/optimizers/test_dist_adam.py +++ b/apex/contrib/test/optimizers/test_dist_adam.py @@ -291,6 +291,21 @@ def init_optim(optim: DistributedFusedAdam): init_optim_func=init_optim, ) + def test_matches_pytorch_int64_param_sync(self): + self.test_matches_pytorch( + param_sync_dtype=torch.int64, + ) + + def test_matches_pytorch_uint8_param_sync(self): + self.test_matches_pytorch( + rtol=0.5, + atol=0.05, + model_dtype=torch.float16, + optim_dtype=torch.float16, + micro_batch_steps=1, + param_sync_dtype=torch.uint8, + ) + def test_raises_on_mismatch(self): torch.manual_seed(self.seed + self.rank) From 26577664645273cd3d84aca3cb55ccef4bf342f2 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Fri, 29 Sep 2023 15:30:15 -0700 Subject: [PATCH 5/7] Add unit test --- .../run_amp/test_update_scale_hysteresis.py | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 tests/L0/run_amp/test_update_scale_hysteresis.py diff --git a/tests/L0/run_amp/test_update_scale_hysteresis.py b/tests/L0/run_amp/test_update_scale_hysteresis.py new file mode 100644 index 000000000..470974d53 --- /dev/null +++ b/tests/L0/run_amp/test_update_scale_hysteresis.py @@ -0,0 +1,105 @@ +import unittest +import random +import math + +from apex import amp +import torch + +from utils import common_init + +try: + import amp_C + from amp_C import update_scale_hysteresis + disabled = False +except ImportError as err: + print("amp_C fused kernels unavailable, disabling TestUpdateScaleHysteresis. ImportError was ", err) + disabled = True + +def isfinite(val): + return ((val >= torch.finfo(torch.float32).smallest_normal) and (val <= torch.finfo(torch.float32).max)) + +class TestUpdateScaleHysteresis(unittest.TestCase): + + def setUp(self): + common_init(self) + + def tearDown(self): + pass + + def update_scale_hysteresis_body(self, init_scale, growth_factor, backoff_factor, + growth_interval, hysteresis): + scale_ref = float(init_scale) + grow_tracker_ref = 0 + hysteresis_tracker_ref = 0 + + scale = torch.tensor([init_scale], dtype=torch.float32, device='cuda') + growth_tracker = torch.tensor([0], dtype=torch.int32, device='cuda') + hysteresis_tracker = torch.tensor([hysteresis], dtype=torch.int32, device='cuda') + + # Infs appear for hysteresis-1 iterations, scale shouldn't change + found_inf = torch.tensor([1], dtype=torch.float32, device='cuda') + for i in range(hysteresis-1): + update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, + found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) + self.assertTrue(scale.item() == init_scale) + + # No infs for growth_interval-1 iterations, scale shouldn't change + found_inf.zero_() + for i in range(growth_interval-1): + update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, + found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) + self.assertTrue(scale.item() == init_scale) + + # Infs appear for more than hysteresis iterations, scale should be backed off + found_inf.fill_(1) + extra_iters = random.randint(0, 1000) + scale_before = scale.detach().item() + scale_ref = scale_before + for i in range(hysteresis + extra_iters): + update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, + found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) + for i in range(1 + extra_iters): + # Scale is continuously backed off for each iteration with an inf + scale_new = scale_ref * backoff_factor + if isfinite(scale_new): + scale_ref = scale_new + else: + scale_ref = 0 # Scale update kernel does not check for underflow when backing off, which results in zero + self.assertTrue(scale.item() == scale_ref) + + # No infs for growth_interval iterations, scale should be increased + found_inf.fill_(0) + extra_iters = random.randint(0, 1000) + scale_before = scale.detach().item() + scale_ref = scale_before + for i in range(growth_interval + extra_iters): + update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker, + found_inf, growth_factor, backoff_factor, growth_interval, hysteresis) + for i in range(1 + int(math.floor(extra_iters / growth_interval))): + # Scale is grown every growth_interval iterations + scale_new = scale_ref * growth_factor + if isfinite(scale_new): + scale_ref = scale_new + self.assertTrue(scale.item() == scale_ref) + + + @unittest.skipIf(disabled, "amp_C is unavailable") + def test_fuzz(self): + init_scale_list = [1, 1024, 65536] + growth_factor_list = [1.0, 2.0, 4.0] + backoff_factor_list = [0.5, 0.25] + growth_interval_list = [10, 100] + hysteresis_list = [10, 100] + + for init_scale in init_scale_list: + for growth_factor in growth_factor_list: + for backoff_factor in backoff_factor_list: + for growth_interval in growth_interval_list: + for hysteresis in hysteresis_list: + self.update_scale_hysteresis_body(init_scale, growth_factor, + backoff_factor, growth_interval, hysteresis) + + + +if __name__ == '__main__': + unittest.main() From 28e09867543d2393dba7553891dff447e22bc2a5 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Fri, 29 Sep 2023 16:55:17 -0700 Subject: [PATCH 6/7] Fix comment in unit test --- tests/L0/run_amp/test_update_scale_hysteresis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/L0/run_amp/test_update_scale_hysteresis.py b/tests/L0/run_amp/test_update_scale_hysteresis.py index 470974d53..c70680279 100644 --- a/tests/L0/run_amp/test_update_scale_hysteresis.py +++ b/tests/L0/run_amp/test_update_scale_hysteresis.py @@ -67,7 +67,7 @@ def update_scale_hysteresis_body(self, init_scale, growth_factor, backoff_factor scale_ref = 0 # Scale update kernel does not check for underflow when backing off, which results in zero self.assertTrue(scale.item() == scale_ref) - # No infs for growth_interval iterations, scale should be increased + # No infs for more than growth_interval iterations, scale should be increased found_inf.fill_(0) extra_iters = random.randint(0, 1000) scale_before = scale.detach().item() From 0992537e8038b3b3706abee2027a7a4713f91645 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Fri, 29 Sep 2023 18:04:00 -0700 Subject: [PATCH 7/7] Remove unnecessary bits --- tests/L0/run_amp/test_update_scale_hysteresis.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/L0/run_amp/test_update_scale_hysteresis.py b/tests/L0/run_amp/test_update_scale_hysteresis.py index c70680279..6bb524003 100644 --- a/tests/L0/run_amp/test_update_scale_hysteresis.py +++ b/tests/L0/run_amp/test_update_scale_hysteresis.py @@ -2,11 +2,8 @@ import random import math -from apex import amp import torch -from utils import common_init - try: import amp_C from amp_C import update_scale_hysteresis @@ -21,7 +18,7 @@ def isfinite(val): class TestUpdateScaleHysteresis(unittest.TestCase): def setUp(self): - common_init(self) + pass def tearDown(self): pass