diff --git a/src/fastertransformer/models/llama/Llama.cc b/src/fastertransformer/models/llama/Llama.cc index 1eac9fd20..4cbe593e6 100644 --- a/src/fastertransformer/models/llama/Llama.cc +++ b/src/fastertransformer/models/llama/Llama.cc @@ -108,9 +108,6 @@ void Llama::allocateBuffer( padded_embedding_kernel_ = (T*)(allocator_->reMalloc(padded_embedding_kernel_, sizeof(T) * hidden_units_ * vocab_size_padded_, true)); padded_embedding_kernel_ptr_ = padded_embedding_kernel_; - - padded_embedding_bias_ = - (T*)(allocator_->reMalloc(padded_embedding_bias_, sizeof(T) * vocab_size_padded_, true)); } input_attention_mask_ = (T*)(allocator_->reMalloc( @@ -184,7 +181,6 @@ void Llama::freeBuffer() if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ptr_ = nullptr; allocator_->free((void**)(&padded_embedding_kernel_)); - allocator_->free((void**)(&padded_embedding_bias_)); } allocator_->free((void**)(&input_attention_mask_)); @@ -820,11 +816,6 @@ void Llama::forward(std::unordered_map* output_ten sizeof(T) * vocab_size_ * hidden_units_, cudaMemcpyDeviceToDevice, stream_); - cudaMemcpyAsync(padded_embedding_bias_, - gpt_weights->post_decoder_embedding.bias, - sizeof(T) * vocab_size_, - cudaMemcpyDeviceToDevice, - stream_); sync_check_cuda_error(); } diff --git a/src/fastertransformer/models/llama/Llama.h b/src/fastertransformer/models/llama/Llama.h index a0958280e..df621f009 100644 --- a/src/fastertransformer/models/llama/Llama.h +++ b/src/fastertransformer/models/llama/Llama.h @@ -84,7 +84,6 @@ class Llama: public BaseLayer { protected: T* padded_embedding_kernel_; - T* padded_embedding_bias_; const T* padded_embedding_kernel_ptr_; T* input_attention_mask_; diff --git a/src/fastertransformer/models/llama/LlamaWeight.cc b/src/fastertransformer/models/llama/LlamaWeight.cc index e9e11b6a1..84c4aaa08 100644 --- a/src/fastertransformer/models/llama/LlamaWeight.cc +++ b/src/fastertransformer/models/llama/LlamaWeight.cc @@ -89,6 +89,7 @@ LlamaWeight::~LlamaWeight() post_decoder_layernorm.beta = nullptr; post_decoder_layernorm.gamma = nullptr; post_decoder_embedding.kernel = nullptr; + post_decoder_embedding.bias = nullptr; is_maintain_buffer = false; } } @@ -196,6 +197,7 @@ void LlamaWeight::setWeightPtr() post_decoder_layernorm.beta = weights_ptr[1]; post_decoder_layernorm.gamma = weights_ptr[2]; post_decoder_embedding.kernel = weights_ptr[3]; + post_decoder_embedding.bias = nullptr; // prompt learning tables: set weight ptr if (malloc_load_prompt_weights_) {