@@ -55,7 +55,8 @@ class Context {
5555 _curr_offset(0 ),
5656 _stream(0 ),
5757 _free_memory_size(0 ),
58- _num_tokens(1 )
58+ _num_tokens(1 ),
59+ _attention_unfused_workspace_offset(0 )
5960 {
6061 if (cublasCreate (&_cublasHandle) != CUBLAS_STATUS_SUCCESS) {
6162 auto message = std::string (" Fail to create cublas handle." );
@@ -101,7 +102,7 @@ class Context {
101102 if (!_free_memory_size) { cudaMemGetInfo (&_free_memory_size, &total_size); }
102103
103104 size_t activation_size = 16 * hidden_dim * batch_size;
104- size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size;
105+ size_t temp_size = batch_size * num_heads * prompt_len * prompt_len * elem_size / mp_size ;
105106 size_t cache_size = num_layers * batch_size * (hidden_dim / mp_size) * 2 ;
106107 size_t minimal_requirements = temp_size + (_free_memory_size > GIGABYTE ? 500 : 100 ) * MEGABYTE;
107108 if (_free_memory_size < minimal_requirements) {
@@ -139,13 +140,15 @@ class Context {
139140 throw std::runtime_error (" Workspace is null." );
140141 }
141142 _workSpaceSize = workSpaceSize;
143+ _attention_unfused_workspace_offset = workSpaceSize - temp_size;
142144 }
143145 inline size_t GetMaxTokenLenght () const { return _max_seq_len; }
144146
145147 cudaEvent_t GetCompEvent (int id) { return id == 1 ? _comp1_event : _comp2_event; }
146148
147149 size_t get_workspace_size () const { return _workSpaceSize; }
148150 void * GetWorkSpace () { return _workspace; }
151+ void * GetAttentionUnfusedWorkspace () { return _workspace + _attention_unfused_workspace_offset; }
149152
150153 inline unsigned new_token (unsigned layer_id)
151154 {
@@ -211,6 +214,8 @@ class Context {
211214 cudaEvent_t _comm_event;
212215
213216 void * _workspace;
217+ // offset from _workspace for uttention unfused memory
218+ size_t _attention_unfused_workspace_offset;
214219 uint64_t _seed;
215220 uint64_t _curr_offset;
216221
0 commit comments