diff --git a/csrc/transformer/ds_transformer_cuda.cpp b/csrc/transformer/ds_transformer_cuda.cpp old mode 100755 new mode 100644 index b4e28915011b..8e605d7558d5 --- a/csrc/transformer/ds_transformer_cuda.cpp +++ b/csrc/transformer/ds_transformer_cuda.cpp @@ -31,7 +31,8 @@ size_t get_workspace_size(int maxBatchSize, if (training) { workSpacesize += ((std::max)((size_t(maxBatchSize) * seq_len * intermediate_size), 2 * (size_t(maxBatchSize) * heads * seq_len * seq_len))); - if (gelu_checkpoint) workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * hidden_size); + if (gelu_checkpoint) + workSpacesize += 2 * (size_t(maxBatchSize) * seq_len * intermediate_size); } return workSpacesize; // * sizeof(T); } @@ -178,9 +179,17 @@ void BertTransformerLayer::Forward(int bsz, size_t small_buf_size = bsz * _seq_length * _hidden_size; T* buf_0 = workspace; T* buf_1 = buf_0 + small_buf_size; + T* buf_2 = buf_1; - if (_normalize_invertible) add_res_ptr = buf_1 + 3 * small_buf_size; - if (_attn_dropout_checkpoint) ctx_bufB_ptr = buf_1 + 4 * small_buf_size; + if (_normalize_invertible) { + add_res_ptr = buf_1 + 3 * small_buf_size; + buf_2 = add_res_ptr; + } + if (_gelu_checkpoint) buf_2 += small_buf_size; + if (_attn_dropout_checkpoint) + ctx_bufB_ptr = + (_gelu_checkpoint ? (buf_2 + (_intermediate_size / _hidden_size) * small_buf_size) + : (buf_1 + 4 * small_buf_size)); int bsz_seq = bsz * _seq_length; @@ -257,14 +266,11 @@ void BertTransformerLayer::Forward(int bsz, _gelu.ForwardWithBiasAdd(bsz_seq, (_gelu_checkpoint ? ff2_inp_ptr : gelu_inp_ptr), inter_b_ptr, - (_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr), + (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), _stream); - _ff2.Forward(bsz_seq, - (_gelu_checkpoint ? ctx_bufB_ptr : ff2_inp_ptr), - output_w_ptr, - out_ptr, - _cublasHandle); + _ff2.Forward( + bsz_seq, (_gelu_checkpoint ? buf_2 : ff2_inp_ptr), output_w_ptr, out_ptr, _cublasHandle); // layer output dropout. if (_pre_or_postLayerNorm) @@ -336,7 +342,7 @@ void BertTransformerLayer::Backward(int bsz, T* buf_2 = buf_1 + small_buf_size; T* buf_3 = buf_2 + small_buf_size; - T* ff2_buf = (_gelu_checkpoint ? buf_2 + (bsz * _seq_length * _intermediate_size) + T* ff2_buf = (_gelu_checkpoint ? buf_3 + (bsz * _seq_length * _intermediate_size) : buf_3 + small_buf_size); T* ctx_bufB_ptr_recomp = ff2_buf + (_seq_length * _seq_length * bsz * _heads);