Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions colossalai/kernel/cuda_native/csrc/multihead_attention_1d.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,25 @@
template <typename T>
class MultiHeadAttention {
public:
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len, int hidden_size,
int num_heads, float attn_dropout_ratio, float hidden_output_dropout_ratio,
MultiHeadAttention(int layer_id, int max_batch_tokens, int _max_seq_len,
int hidden_size, int num_heads, float attn_dropout_ratio,
float hidden_output_dropout_ratio,
bool pre_or_postLayerNorm);

virtual ~MultiHeadAttention();

void Forward(const T *input_ptr, const T *input_mask_ptr, T *out_ptr);

void Backward(const T *grad_output_ptr, const T *input_ptr, const T *output_ptr,
const T *input_mask_ptr, T *grad_input_ptr);
void Backward(const T *grad_output_ptr, const T *input_ptr,
const T *output_ptr, const T *input_mask_ptr,
T *grad_input_ptr);

void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr, T *buffer);
void attn_layer_fw(const T *input_ptr, const T *input_mask_ptr, T *output_ptr,
T *buffer);

void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr, const T *output_ptr,
const T *grad_output_ptr, T *grad_input_attn_layer_bwptr, T *buffer);
void attn_layer_bw(const T *input_ptr, const T *input_mask_ptr,
const T *output_ptr, const T *grad_output_ptr,
T *grad_input_attn_layer_bwptr, T *buffer);

void set_cur_batch_shape(int batch_size, int seq_len) {
_batch_size = batch_size;
Expand Down Expand Up @@ -83,14 +87,17 @@ class MultiHeadAttention {
}

_qkv_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size * 3);
_soft_out_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_ctx_bufB_ptr = cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_soft_out_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_ctx_bufB_ptr =
cuda_malloc<T>(_max_batch_tokens * _heads / pg_size * _max_seq_len);
_attn_o_inp_ptr = cuda_malloc<T>(_max_batch_tokens * _hidden_size);

// buffer size needed by attn bw
size_t smem_size = 4 * _max_batch_tokens * _hidden_size / pg_size +
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
_max_batch_tokens * _heads / pg_size * _max_seq_len);
size_t smem_size =
4 * _max_batch_tokens * _hidden_size / pg_size +
std::max(3 * _max_batch_tokens * _hidden_size / pg_size,
_max_batch_tokens * _heads / pg_size * _max_seq_len);

if (!_shared_mem_ptr) {
cuda_free(_shared_mem_ptr);
Expand Down