From 99a844cc7b434e1175277c7e438cbf1c348dca89 Mon Sep 17 00:00:00 2001 From: yandai Date: Sun, 9 Jul 2023 00:47:58 +0800 Subject: [PATCH 1/2] fix llama --- src/fastertransformer/models/llama/Llama.cc | 17 ++++++++++++----- .../models/llama/LlamaWeight.cc | 2 ++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/fastertransformer/models/llama/Llama.cc b/src/fastertransformer/models/llama/Llama.cc index 1eac9fd20..3c1bfc9e3 100644 --- a/src/fastertransformer/models/llama/Llama.cc +++ b/src/fastertransformer/models/llama/Llama.cc @@ -820,11 +820,18 @@ void Llama::forward(std::unordered_map* output_ten sizeof(T) * vocab_size_ * hidden_units_, cudaMemcpyDeviceToDevice, stream_); - cudaMemcpyAsync(padded_embedding_bias_, - gpt_weights->post_decoder_embedding.bias, - sizeof(T) * vocab_size_, - cudaMemcpyDeviceToDevice, - stream_); + if (gpt_weights->post_decoder_embedding.bias) { + cudaMemcpyAsync(padded_embedding_bias_, + gpt_weights->post_decoder_embedding.bias, + sizeof(T) * vocab_size_, + cudaMemcpyDeviceToDevice, + stream_); + } else { + cudaMemsetAsync(padded_embedding_bias_, + 0, + sizeof(T) * vocab_size_, + stream_); + } sync_check_cuda_error(); } diff --git a/src/fastertransformer/models/llama/LlamaWeight.cc b/src/fastertransformer/models/llama/LlamaWeight.cc index e9e11b6a1..84c4aaa08 100644 --- a/src/fastertransformer/models/llama/LlamaWeight.cc +++ b/src/fastertransformer/models/llama/LlamaWeight.cc @@ -89,6 +89,7 @@ LlamaWeight::~LlamaWeight() post_decoder_layernorm.beta = nullptr; post_decoder_layernorm.gamma = nullptr; post_decoder_embedding.kernel = nullptr; + post_decoder_embedding.bias = nullptr; is_maintain_buffer = false; } } @@ -196,6 +197,7 @@ void LlamaWeight::setWeightPtr() post_decoder_layernorm.beta = weights_ptr[1]; post_decoder_layernorm.gamma = weights_ptr[2]; post_decoder_embedding.kernel = weights_ptr[3]; + post_decoder_embedding.bias = nullptr; // prompt learning tables: set weight ptr if (malloc_load_prompt_weights_) { From f1dd8fb4bd4952eadef90585daf6d5d1705c84ac Mon Sep 17 00:00:00 2001 From: yandai Date: Sun, 16 Jul 2023 17:09:38 +0800 Subject: [PATCH 2/2] remove padded_embedding_bias_ --- src/fastertransformer/models/llama/Llama.cc | 16 ---------------- src/fastertransformer/models/llama/Llama.h | 1 - 2 files changed, 17 deletions(-) diff --git a/src/fastertransformer/models/llama/Llama.cc b/src/fastertransformer/models/llama/Llama.cc index 3c1bfc9e3..4cbe593e6 100644 --- a/src/fastertransformer/models/llama/Llama.cc +++ b/src/fastertransformer/models/llama/Llama.cc @@ -108,9 +108,6 @@ void Llama::allocateBuffer( padded_embedding_kernel_ = (T*)(allocator_->reMalloc(padded_embedding_kernel_, sizeof(T) * hidden_units_ * vocab_size_padded_, true)); padded_embedding_kernel_ptr_ = padded_embedding_kernel_; - - padded_embedding_bias_ = - (T*)(allocator_->reMalloc(padded_embedding_bias_, sizeof(T) * vocab_size_padded_, true)); } input_attention_mask_ = (T*)(allocator_->reMalloc( @@ -184,7 +181,6 @@ void Llama::freeBuffer() if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ptr_ = nullptr; allocator_->free((void**)(&padded_embedding_kernel_)); - allocator_->free((void**)(&padded_embedding_bias_)); } allocator_->free((void**)(&input_attention_mask_)); @@ -820,18 +816,6 @@ void Llama::forward(std::unordered_map* output_ten sizeof(T) * vocab_size_ * hidden_units_, cudaMemcpyDeviceToDevice, stream_); - if (gpt_weights->post_decoder_embedding.bias) { - cudaMemcpyAsync(padded_embedding_bias_, - gpt_weights->post_decoder_embedding.bias, - sizeof(T) * vocab_size_, - cudaMemcpyDeviceToDevice, - stream_); - } else { - cudaMemsetAsync(padded_embedding_bias_, - 0, - sizeof(T) * vocab_size_, - stream_); - } sync_check_cuda_error(); } diff --git a/src/fastertransformer/models/llama/Llama.h b/src/fastertransformer/models/llama/Llama.h index a0958280e..df621f009 100644 --- a/src/fastertransformer/models/llama/Llama.h +++ b/src/fastertransformer/models/llama/Llama.h @@ -84,7 +84,6 @@ class Llama: public BaseLayer { protected: T* padded_embedding_kernel_; - T* padded_embedding_bias_; const T* padded_embedding_kernel_ptr_; T* input_attention_mask_;