Skip to content

Commit f34c614

Browse files
committed
Update attention_unfused required memory size
1 parent 8748c95 commit f34c614

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

csrc/transformer/inference/csrc/pt_binding.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -354,10 +354,7 @@ void attention_unfused(T* prev_key_cont,
354354
float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0;
355355
float alpha = norm_factor * norm_factor / layer_scale;
356356
float gemm_beta = 0.0;
357-
// Always use the tail workspace
358-
T* scratch = (T*)Context::Instance().GetWorkSpace();
359-
T *workspace = scratch + ((Context::Instance().get_workspace_size() / sizeof(T)) -
360-
bsz * heads * seq_len * soft_len);
357+
T* workspace = (T*)Context::Instance().GetAttentionUnfusedWorkspace();
361358

362359
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
363360
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),

csrc/transformer/inference/includes/inference_context.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ class Context {
5555
_curr_offset(0),
5656
_stream(0),
5757
_free_memory_size(0),
58-
_num_tokens(1)
58+
_num_tokens(1),
59+
_attention_unfused_workspace_offset(0)
5960
{
6061
if (cublasCreate(&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
6162
auto message = std::string("Fail to create cublas handle.");
@@ -101,7 +102,7 @@ class Context {
101102
if (!_free_memory_size) { cudaMemGetInfo(&_free_memory_size, &total_size); }
102103

103104
size_t activation_size = 16 * hidden_dim * batch_size;
104-
size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size;
105+
size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size / mp_size;
105106
size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2;
106107
size_t minimal_requirements = temp_size + (_free_memory_size > GIGABYTE ? 500 : 100) * MEGABYTE;
107108
if (_free_memory_size < minimal_requirements) {
@@ -139,13 +140,15 @@ class Context {
139140
throw std::runtime_error("Workspace is null.");
140141
}
141142
_workSpaceSize = workSpaceSize;
143+
_attention_unfused_workspace_offset = workSpaceSize - temp_size;
142144
}
143145
inline size_t GetMaxTokenLenght() const { return _max_seq_len; }
144146

145147
cudaEvent_t GetCompEvent(int id) { return id == 1 ? _comp1_event : _comp2_event; }
146148

147149
size_t get_workspace_size() const { return _workSpaceSize; }
148150
void* GetWorkSpace() { return _workspace; }
151+
void* GetAttentionUnfusedWorkspace() { return _workspace + _attention_unfused_workspace_offset; }
149152

150153
inline unsigned new_token(unsigned layer_id)
151154
{
@@ -211,6 +214,8 @@ class Context {
211214
cudaEvent_t _comm_event;
212215

213216
void* _workspace;
217+
// offset from _workspace for uttention unfused memory
218+
size_t _attention_unfused_workspace_offset;
214219
uint64_t _seed;
215220
uint64_t _curr_offset;
216221

0 commit comments

Comments
 (0)