From 455ec7226d3f58ac622b10a694156f2aff482022 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 2 Mar 2021 23:03:55 +0000 Subject: [PATCH 1/3] fixing buffers in transformer kernel when gelu-checkpoint is enabled --- csrc/transformer/ds_transformer_cuda.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) mode change 100755 => 100644 csrc/transformer/ds_transformer_cuda.cpp diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp old mode 100755 new mode 100644 index b4e28915011b..2e0962a53227 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -31,7 +31,8 @@ size_t get_workspace_size(int maxBatchSize, if (training) { workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size), 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len))); - if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size); + if (gelu_checkpoint) + workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size); } return workSpacesize; // * sizeof(T); } @@ -178,9 +179,12 @@ void BertTransformerLayer::Forward(int bsz, size_t small_buf_size = bsz * _seq_length * _hidden_size; T* buf_0 = workspace; T* buf_1 = buf_0 + small_buf_size; + T* buf_2 = buf_1; if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size; - if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size; + if (_gelu_checkpoint) buf_2 += small_buf_size; + if (_attn_dropout_checkpoint) + ctx_bufB_ptr = buf_2 + (_intermediate_size / _hidden_size) * small_buf_size; int bsz_seq = bsz * _seq_length; @@ -257,14 +261,11 @@ void BertTransformerLayer::Forward(int bsz, _gelu.ForwardWithBiasAdd(bsz_seq, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, - (_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr), + (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), _stream); - _ff2.Forward(bsz_seq, - (_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr), - output_w_ptr, - out_ptr, - _cublasHandle); + _ff2.Forward( + bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle); // layer output dropout. if (_pre_or_postLayerNorm) @@ -336,7 +337,7 @@ void BertTransformerLayer::Backward(int bsz, T* buf_2 = buf_1 + small_buf_size; T* buf_3 = buf_2 + small_buf_size; - T* ff2_buf = (_gelu_checkpoint ? buf_2 + (bsz * _seq_length * _intermediate_size) + T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size) : buf_3 + small_buf_size); T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads); From 9e5ca614c427f1d2e6589a16adf488b1a770cd34 Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 2 Mar 2021 23:25:27 +0000 Subject: [PATCH 2/3] fixing the test issue for other memory optimization flags --- csrc/transformer/ds_transformer_cuda.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) mode change 100644 => 100755 csrc/transformer/ds_transformer_cuda.cpp diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp old mode 100644 new mode 100755 index 2e0962a53227..d791d4d3a07a --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -181,7 +181,10 @@ void BertTransformerLayer::Forward(int bsz, T* buf_1 = buf_0 + small_buf_size; T* buf_2 = buf_1; - if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size; + if (_normalize_invertible) { + add_res_ptr = buf_1 + 3 * small_buf_size; + buf_2 = add_res_ptr; + } if (_gelu_checkpoint) buf_2 += small_buf_size; if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_2 + (_intermediate_size / _hidden_size) * small_buf_size; From 83acfad85cec212eba62e1d0dfc52fd4755b040c Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Tue, 2 Mar 2021 23:53:04 +0000 Subject: [PATCH 3/3] fixing a bug for when attn_dropout_checkpoint is enabled --- csrc/transformer/ds_transformer_cuda.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) mode change 100755 => 100644 csrc/transformer/ds_transformer_cuda.cpp diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp old mode 100755 new mode 100644 index d791d4d3a07a..8e605d7558d5 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -187,7 +187,9 @@ void BertTransformerLayer::Forward(int bsz, } if (_gelu_checkpoint) buf_2 += small_buf_size; if (_attn_dropout_checkpoint) - ctx_bufB_ptr = buf_2 + (_intermediate_size / _hidden_size) * small_buf_size; + ctx_bufB_ptr = + (_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size) + : (buf_1 + 4 * small_buf_size)); int bsz_seq = bsz * _seq_length;