Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ __pycache__/
/models
/notebooks
**/.ipynb_checkpoints/
.DS_Store

/3rdparty/NeMo/
/3rdparty/apex/
/3rdparty/apex/

19 changes: 16 additions & 3 deletions examples/cpp/llama/llama_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ void llama_example(const INIReader reader)

int tensor_para_size = reader.GetInteger("ft_instance_hyperparameter", "tensor_para_size");
int pipeline_para_size = reader.GetInteger("ft_instance_hyperparameter", "pipeline_para_size");
int int8_mode = reader.GetInteger("ft_instance_hyperparameter", "int8_mode", 0);

const size_t head_num = reader.GetInteger(model_name, "head_num");
const size_t size_per_head = reader.GetInteger(model_name, "size_per_head");
Expand Down Expand Up @@ -177,6 +178,7 @@ void llama_example(const INIReader reader)
tiled_stop_words.insert(tiled_stop_words.end(), stop_words.begin(), stop_words.end());
}


int* d_stop_words = nullptr;
deviceMalloc(&d_stop_words, tiled_stop_words.size(), false);
cudaH2Dcpy(d_stop_words, tiled_stop_words.data(), tiled_stop_words.size());
Expand All @@ -193,6 +195,7 @@ void llama_example(const INIReader reader)
1,
"../examples/cpp/llama/start_ids.csv");


int* d_input_ids;
int* d_input_lengths;
if (max_input_len == 0) {
Expand Down Expand Up @@ -285,6 +288,7 @@ void llama_example(const INIReader reader)
pipeline_para.world_size_,
pipeline_para.rank_,
use_gptj_residual,
int8_mode,
prompt_learning_type,
prefix_prompt_table_pair);

Expand Down Expand Up @@ -331,12 +335,19 @@ void llama_example(const INIReader reader)
&allocator,
false,
&prop,
attention_type);
attention_type,
int8_mode,
nullptr,
0,
1.0f);

int* d_output_ids;
int* d_sequence_lengths;


deviceMalloc(&d_output_ids, request_batch_size * beam_width * total_output_len, false);
deviceMalloc(&d_sequence_lengths, request_batch_size * beam_width, false);

std::vector<uint32_t> output_seq_len(request_batch_size, total_output_len);
std::unordered_map<std::string, Tensor> input_tensors = std::unordered_map<std::string, Tensor>{
{"input_ids",
Expand Down Expand Up @@ -411,15 +422,18 @@ void llama_example(const INIReader reader)
ite = 1;
ft_nvtx::setScope("warmup_time");
PUSH_RANGE("warmup time")

for (int i = 0; i < ite; ++i) {
gpt.forward(&output_tensors, &input_tensors, &gpt_weights);
}

cudaDeviceSynchronize();
mpi::barrier();

POP_RANGE;
ft_nvtx::resetScope();


if (rank == 0) {

std::string fName = "out";
Expand All @@ -430,6 +444,7 @@ void llama_example(const INIReader reader)
else {
size_t outCount = total_output_len * request_batch_size * beam_width;
int* hBuf = new int[outCount];

cudaD2Hcpy(hBuf, d_output_ids, outCount);

{
Expand Down Expand Up @@ -468,7 +483,6 @@ void llama_example(const INIReader reader)
for (int i = 0; i < ite; ++i) {
gpt.forward(&output_tensors, &input_tensors, &gpt_weights);
}

cudaDeviceSynchronize();
mpi::barrier();

Expand Down Expand Up @@ -509,6 +523,5 @@ void llama_example(const INIReader reader)
if (d_sequence_lengths != nullptr) {
deviceFree(d_sequence_lengths);
}

return;
}
3 changes: 2 additions & 1 deletion src/fastertransformer/layers/FfnLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,7 @@ SiluFfnLayer<T>::SiluFfnLayer(size_t max_batch_size,
IAllocator* allocator,
bool is_free_buffer_after_forward,
bool sparse,
int int8_mode,
bool use_gated_activation):
FfnLayer<T>(max_batch_size,
max_seq_len,
Expand All @@ -696,7 +697,7 @@ SiluFfnLayer<T>::SiluFfnLayer(size_t max_batch_size,
allocator,
is_free_buffer_after_forward,
sparse,
0,
int8_mode,
use_gated_activation)
{
}
Expand Down
1 change: 1 addition & 0 deletions src/fastertransformer/layers/FfnLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class SiluFfnLayer: public FfnLayer<T> {
IAllocator* allocator,
bool is_free_buffer_after_forward,
bool sparse = false,
int int8_mode = 0,
bool use_gated_activation = false);

SiluFfnLayer(SiluFfnLayer<T> const& ffn_layer);
Expand Down
4 changes: 3 additions & 1 deletion src/fastertransformer/layers/TensorParallelSiluFfnLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ TensorParallelSiluFfnLayer<T>::TensorParallelSiluFfnLayer(size_t max_b
bool is_sparse,
bool use_gated_activation,
std::shared_ptr<AbstractCustomComm> custom_all_reduce_comm,
int enable_custom_all_reduce):
int enable_custom_all_reduce,
int int8_mode):
SiluFfnLayer<T>(max_batch_size,
max_seq_len,
head_num,
Expand All @@ -88,6 +89,7 @@ TensorParallelSiluFfnLayer<T>::TensorParallelSiluFfnLayer(size_t max_b
allocator,
is_free_buffer_after_forward,
is_sparse,
int8_mode,
use_gated_activation),
tensor_para_(tensor_para),
custom_all_reduce_comm_(custom_all_reduce_comm),
Expand Down
3 changes: 2 additions & 1 deletion src/fastertransformer/layers/TensorParallelSiluFfnLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class TensorParallelSiluFfnLayer: public SiluFfnLayer<T> {
bool is_sparse,
bool use_gated_activation = false,
std::shared_ptr<AbstractCustomComm> custom_all_reduce_comm = nullptr,
int enable_custom_all_reduce = 0);
int enable_custom_all_reduce = 0,
int int8_mode = 0);

TensorParallelSiluFfnLayer(TensorParallelSiluFfnLayer<T> const& ffn_layer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ LinearAdapterLayer<T>::LinearAdapterLayer(LinearAdapterConfig const& co
is_sparse,
false,
custom_all_reduce_comm,
enable_custom_all_reduce)},
enable_custom_all_reduce,
0)},
layer_norm_type_{config.layerNormType()},
layer_norm_eps_{layer_norm_eps},
max_token_size_{max_batch_size * max_seq_len},
Expand Down
66 changes: 59 additions & 7 deletions src/fastertransformer/models/llama/Llama.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ void Llama<T>::initialize()
is_free_buffer_after_forward_,
is_context_qk_buf_float_,
attention_type_,
int8_mode_,
custom_all_reduce_comm_,
enable_custom_all_reduce_);

Expand All @@ -59,6 +60,7 @@ void Llama<T>::initialize()
cublas_wrapper_,
allocator_,
is_free_buffer_after_forward_,
int8_mode_,
custom_all_reduce_comm_,
enable_custom_all_reduce_);

Expand Down Expand Up @@ -165,6 +167,13 @@ void Llama<T>::allocateBuffer(

generation_should_stop_ = (bool*)allocator_->reMalloc(generation_should_stop_, sizeof(bool), true, true);

if (shared_contexts_ratio_ > 0.0f) {
shared_contexts_idx_ = (int*)allocator_->reMalloc(shared_contexts_idx_, batch_size * sizeof(int), false);
batch_to_compact_idx_ = (int*)allocator_->reMalloc(batch_to_compact_idx_, batchxbeam * sizeof(int), false);
compact_idx_ = (int*)allocator_->reMalloc(compact_idx_, batch_size * sizeof(int), false);
compact_size_ = (int*)allocator_->reMalloc(compact_size_, sizeof(int), false);
}

is_allocate_buffer_ = true;
}

Expand Down Expand Up @@ -216,6 +225,11 @@ void Llama<T>::freeBuffer()

allocator_->free((void**)(&generation_should_stop_), true);

if (shared_contexts_ratio_ > 0.0f) {
allocator_->free((void**)(&shared_contexts_idx_));
allocator_->free((void**)(&compact_size_));
}

is_allocate_buffer_ = false;
}
}
Expand Down Expand Up @@ -246,8 +260,10 @@ Llama<T>::Llama(size_t head_num,
bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop,
AttentionType attention_type,
int int8_mode,
std::shared_ptr<AbstractCustomComm> custom_all_reduce_comm,
int enable_custom_all_reduce):
int enable_custom_all_reduce,
float shared_contexts_ratio):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop),
head_num_(head_num),
size_per_head_(size_per_head),
Expand All @@ -263,7 +279,9 @@ Llama<T>::Llama(size_t head_num,
use_gptj_residual_(use_gptj_residual),
hidden_units_(head_num * size_per_head),
local_head_num_(head_num / 1),
attention_type_(attention_type)
attention_type_(attention_type),
int8_mode_(int8_mode),
shared_contexts_ratio_(shared_contexts_ratio)
{
tensor_para_.world_size_ = 1;
tensor_para_.rank_ = 0;
Expand Down Expand Up @@ -310,8 +328,10 @@ Llama<T>::Llama(size_t head_num,
bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop,
AttentionType attention_type,
int int8_mode,
std::shared_ptr<AbstractCustomComm> custom_all_reduce_comm,
int enable_custom_all_reduce):
int enable_custom_all_reduce,
float shared_contexts_ratio):
BaseLayer(stream, cublas_wrapper, allocator, is_free_buffer_after_forward, cuda_device_prop),
head_num_(head_num),
size_per_head_(size_per_head),
Expand All @@ -331,7 +351,9 @@ Llama<T>::Llama(size_t head_num,
local_head_num_(head_num / tensor_para.world_size_),
custom_all_reduce_comm_(custom_all_reduce_comm),
enable_custom_all_reduce_(enable_custom_all_reduce),
attention_type_(attention_type)
attention_type_(attention_type),
int8_mode_(int8_mode),
shared_contexts_ratio_(shared_contexts_ratio)
{
int local_vacab_size = ceil(vocab_size_ / 1.f / tensor_para_.world_size_);
if (std::is_same<half, T>::value) {
Expand Down Expand Up @@ -363,7 +385,9 @@ Llama<T>::Llama(Llama<T> const& gpt):
vocab_size_padded_(gpt.vocab_size_padded_),
custom_all_reduce_comm_(gpt.custom_all_reduce_comm_),
enable_custom_all_reduce_(gpt.enable_custom_all_reduce_),
attention_type_(gpt.attention_type_)
attention_type_(gpt.attention_type_),
int8_mode_(gpt.int8_mode_),
shared_contexts_ratio_(gpt.shared_contexts_ratio_)
{
initialize();
}
Expand Down Expand Up @@ -585,6 +609,23 @@ void Llama<T>::forward(std::unordered_map<std::string, Tensor>* output_ten
cudaMemsetAsync(cache_indirections_[0], 0, 2 * sizeof(int) * batch_size * beam_width * max_seq_len, stream_);
}

int compact_size;
bool use_shared_contexts = (shared_contexts_ratio_ > 0.0f) && (max_input_length >= 1) && (batch_size > 1);
if (use_shared_contexts) {
invokeFindContextDups(shared_contexts_idx_,
batch_to_compact_idx_,
compact_idx_,
compact_size_,
input_tensors->at("input_ids").getPtr<int>(),
batch_size,
beam_width,
max_input_length,
stream_);
cudaD2Hcpy(&compact_size, compact_size_, 1);
use_shared_contexts = compact_size <= shared_contexts_ratio_ * batch_size;
sync_check_cuda_error();
}

// Prefix prompts
if (has_prefix_prompt_) {
cudaMemcpyAsync(prompt_learning_weight_batch_,
Expand Down Expand Up @@ -686,6 +727,14 @@ void Llama<T>::forward(std::unordered_map<std::string, Tensor>* output_ten
{batch_size * beam_width},
has_prefix_prompt_ ? tiled_prompt_lengths_buf_ : nullptr}}};

if (use_shared_contexts) {
decoder_input_tensors.insert(
{"compact_idx", Tensor(MEMORY_GPU, TYPE_INT32, {(size_t)compact_size}, compact_idx_)});
decoder_input_tensors.insert(
{"batch_to_compact_idx",
Tensor(MEMORY_GPU, TYPE_INT32, {batch_size * beam_width}, batch_to_compact_idx_)});
}

std::unordered_map<std::string, Tensor> decoder_output_tensors{
{"decoder_output",
Tensor{MEMORY_GPU,
Expand Down Expand Up @@ -877,6 +926,7 @@ void Llama<T>::forward(std::unordered_map<std::string, Tensor>* output_ten
stream_);
sync_check_cuda_error();


if (tensor_para_.world_size_ == 1) {
float alpha = 1.0f;
float beta = 0.0f;
Expand Down Expand Up @@ -924,6 +974,8 @@ void Llama<T>::forward(std::unordered_map<std::string, Tensor>* output_ten
local_vocab_size, /* n */
CUDA_R_32F,
cublasGemmAlgo_t(-1));


ftNcclAllGather(nccl_logits_buf_ + vocab_size_units_offset,
nccl_logits_buf_ + vocab_size_units_offset,
local_batch_size * beam_width * local_vocab_size,
Expand All @@ -937,7 +989,8 @@ void Llama<T>::forward(std::unordered_map<std::string, Tensor>* output_ten
local_vocab_size,
stream_);
}



int tmp_local_batch_size = local_batch_size;
bool is_initialize_random_table = step == max_input_length;
std::unordered_map<std::string, Tensor> dynamic_decode_input_tensors{
Expand Down Expand Up @@ -1229,5 +1282,4 @@ template class Llama<half>;
#ifdef ENABLE_BF16
template class Llama<__nv_bfloat16>;
#endif

} // namespace fastertransformer
15 changes: 13 additions & 2 deletions src/fastertransformer/models/llama/Llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class Llama: public BaseLayer {
float layernorm_eps_;

static constexpr bool neox_rotary_style_ = true;
float shared_contexts_ratio_;

int start_id_;
int end_id_;
Expand All @@ -54,6 +55,7 @@ class Llama: public BaseLayer {
int enable_custom_all_reduce_;

AttentionType attention_type_;
const int int8_mode_ = 0;

size_t vocab_size_padded_;
const bool is_context_qk_buf_float_ =
Expand Down Expand Up @@ -120,6 +122,11 @@ class Llama: public BaseLayer {

bool* generation_should_stop_ = nullptr;

int* shared_contexts_idx_ = nullptr;
int* compact_idx_ = nullptr;
int* batch_to_compact_idx_ = nullptr;
int* compact_size_ = nullptr;

T* context_decoder_input_buf_;
T* context_decoder_output_buf_;
float* output_log_probs_buf_;
Expand Down Expand Up @@ -165,8 +172,10 @@ class Llama: public BaseLayer {
bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop = nullptr,
AttentionType attention_type = AttentionType::UNFUSED_MHA,
int int8_mode = 0,
std::shared_ptr<AbstractCustomComm> custom_all_reduce_comm = nullptr,
int enable_custom_all_reduce = 0);
int enable_custom_all_reduce = 0,
float shared_contexts_ratio = 1.0f);

Llama(size_t head_num,
size_t size_per_head,
Expand Down Expand Up @@ -195,8 +204,10 @@ class Llama: public BaseLayer {
bool is_free_buffer_after_forward,
cudaDeviceProp* cuda_device_prop = nullptr,
AttentionType attention_type = AttentionType::UNFUSED_MHA,
int int8_mode = 0,
std::shared_ptr<AbstractCustomComm> custom_all_reduce_comm = nullptr,
int enable_custom_all_reduce = 0);
int enable_custom_all_reduce = 0,
float shared_contexts_ratio = 1.0f);

Llama(Llama<T> const& Llama);

Expand Down
Loading