diff --git a/client/src/components/Chat/index.tsx b/client/src/components/Chat/index.tsx index 814117d0de..df714a58f3 100644 --- a/client/src/components/Chat/index.tsx +++ b/client/src/components/Chat/index.tsx @@ -94,8 +94,10 @@ const Chat = () => { queryIdToEdit ? `&parent_query_id=${queryIdToEdit}` : '' }` : '' - }&model=${ - preferredAnswerSpeed === 'normal' ? 'gpt-4' : 'gpt-3.5-turbo-finetuned' + }&answer_model=${ + preferredAnswerSpeed === 'normal' + ? 'gpt-4-turbo-24k' + : 'gpt-3.5-turbo-finetuned' }`; console.log(url); const eventSource = new EventSource(url); diff --git a/client/src/consts/codeStudio.ts b/client/src/consts/codeStudio.ts index 29de4c9550..a4a742edf8 100644 --- a/client/src/consts/codeStudio.ts +++ b/client/src/consts/codeStudio.ts @@ -1 +1 @@ -export const TOKEN_LIMIT = 7000; +export const TOKEN_LIMIT = 21000; diff --git a/server/bleep/src/agent.rs b/server/bleep/src/agent.rs index 133e365089..e2cdc40f8a 100644 --- a/server/bleep/src/agent.rs +++ b/server/bleep/src/agent.rs @@ -57,7 +57,8 @@ pub struct Agent { pub thread_id: uuid::Uuid, pub query_id: uuid::Uuid, - pub model: model::AnswerModel, + pub answer_model: model::LLMModel, + pub agent_model: model::LLMModel, /// Indicate whether the request was answered. /// @@ -220,7 +221,7 @@ impl Agent { ))]; history.extend(self.history()?); - let trimmed_history = trim_history(history.clone(), self.model)?; + let trimmed_history = trim_history(history.clone(), self.agent_model)?; let raw_response = self .llm_gateway @@ -484,7 +485,7 @@ impl Agent { fn trim_history( mut history: Vec, - model: model::AnswerModel, + model: model::LLMModel, ) -> Result> { const HIDDEN: &str = "[HIDDEN]"; diff --git a/server/bleep/src/agent/model.rs b/server/bleep/src/agent/model.rs index 70ad7177bb..917941e931 100644 --- a/server/bleep/src/agent/model.rs +++ b/server/bleep/src/agent/model.rs @@ -2,7 +2,7 @@ use crate::agent::prompts; use std::str::FromStr; #[derive(Debug, Copy, Clone)] -pub struct AnswerModel { +pub struct LLMModel { /// The name of this model according to tiktoken pub tokenizer: &'static str, @@ -22,7 +22,7 @@ pub struct AnswerModel { pub system_prompt: fn(&str) -> String, } -pub const GPT_3_5_TURBO_FINETUNED: AnswerModel = AnswerModel { +pub const GPT_3_5_TURBO_FINETUNED: LLMModel = LLMModel { tokenizer: "gpt-3.5-turbo-0613", model_name: "gpt-3.5-turbo-finetuned", answer_headroom: 512, @@ -31,7 +31,24 @@ pub const GPT_3_5_TURBO_FINETUNED: AnswerModel = AnswerModel { system_prompt: prompts::answer_article_prompt_finetuned, }; -pub const GPT_4: AnswerModel = AnswerModel { +// GPT-4 turbo has a context window of 128k tokens +const GPT_4_TURBO_MAX_TOKENS: usize = 128_000; +// We want to use only 24k tokens +const ACTUAL_MAX_TOKENS: usize = 24_000; +// 104k tokens should be left unused. This is done by adding 104k to the headrooms +// (tokens left unused for other purposes answer, prompt...) +const HEADROOM_CORRECTION: usize = GPT_4_TURBO_MAX_TOKENS - ACTUAL_MAX_TOKENS; +// PS: when we want to fully utilize the model max context window, the correction is 0 +pub const GPT_4_TURBO_24K: LLMModel = LLMModel { + tokenizer: "gpt-4-1106-preview", + model_name: "gpt-4-1106-preview", + answer_headroom: 1024 + HEADROOM_CORRECTION, + prompt_headroom: 2500 + HEADROOM_CORRECTION, + history_headroom: 2048 + HEADROOM_CORRECTION, + system_prompt: prompts::answer_article_prompt, +}; + +pub const GPT_4: LLMModel = LLMModel { tokenizer: "gpt-4-0613", model_name: "gpt-4-0613", answer_headroom: 1024, @@ -40,24 +57,25 @@ pub const GPT_4: AnswerModel = AnswerModel { system_prompt: prompts::answer_article_prompt, }; -impl FromStr for AnswerModel { +impl FromStr for LLMModel { type Err = (); fn from_str(s: &str) -> Result { #[allow(clippy::wildcard_in_or_patterns)] match s { "gpt-4" => Ok(GPT_4), + "gpt-4-turbo-24k" => Ok(GPT_4_TURBO_24K), "gpt-3.5-turbo-finetuned" | _ => Ok(GPT_3_5_TURBO_FINETUNED), } } } -impl<'de> serde::Deserialize<'de> for AnswerModel { +impl<'de> serde::Deserialize<'de> for LLMModel { fn deserialize(deserializer: D) -> Result where D: serde::Deserializer<'de>, { let s = String::deserialize(deserializer)?; - s.parse::() + s.parse::() .map_err(|_| serde::de::Error::custom("failed to deserialize")) } } diff --git a/server/bleep/src/agent/tools/answer.rs b/server/bleep/src/agent/tools/answer.rs index 249228d2af..7d40b13f48 100644 --- a/server/bleep/src/agent/tools/answer.rs +++ b/server/bleep/src/agent/tools/answer.rs @@ -41,16 +41,16 @@ impl Agent { } let context = self.answer_context(aliases).await?; - let system_prompt = (self.model.system_prompt)(&context); + let system_prompt = (self.answer_model.system_prompt)(&context); let system_message = llm_gateway::api::Message::system(&system_prompt); let history = { let h = self.utter_history().collect::>(); let system_headroom = tiktoken_rs::num_tokens_from_messages( - self.model.tokenizer, + self.answer_model.tokenizer, &[(&system_message).into()], )?; - let headroom = self.model.answer_headroom + system_headroom; - trim_utter_history(h, headroom, self.model)? + let headroom = self.answer_model.answer_headroom + system_headroom; + trim_utter_history(h, headroom, self.answer_model)? }; let messages = Some(system_message) .into_iter() @@ -60,12 +60,14 @@ impl Agent { let mut stream = pin!( self.llm_gateway .clone() - .model(self.model.model_name) - .frequency_penalty(if self.model.model_name == "gpt-3.5-turbo-finetuned" { - Some(0.2) - } else { - Some(0.0) - }) + .model(self.answer_model.model_name) + .frequency_penalty( + if self.answer_model.model_name == "gpt-3.5-turbo-finetuned" { + Some(0.2) + } else { + Some(0.0) + } + ) .chat_stream(&messages, None) .await? ); @@ -108,7 +110,7 @@ impl Agent { .with_payload("query_history", &history) .with_payload("response", &response) .with_payload("raw_prompt", &system_prompt) - .with_payload("model", self.model.model_name), + .with_payload("model", self.answer_model.model_name), ); Ok(()) @@ -145,9 +147,9 @@ impl Agent { // Sometimes, there are just too many code chunks in the context, and deduplication still // doesn't trim enough chunks. So, we enforce a hard limit here that stops adding tokens // early if we reach a heuristic limit. - let bpe = tiktoken_rs::get_bpe_from_model(self.model.tokenizer)?; + let bpe = tiktoken_rs::get_bpe_from_model(self.answer_model.tokenizer)?; let mut remaining_prompt_tokens = - tiktoken_rs::get_completion_max_tokens(self.model.tokenizer, &s)?; + tiktoken_rs::get_completion_max_tokens(self.answer_model.tokenizer, &s)?; // Select as many recent chunks as possible let mut recent_chunks = Vec::new(); @@ -166,7 +168,7 @@ impl Agent { let snippet_tokens = bpe.encode_ordinary(&formatted_snippet).len(); - if snippet_tokens >= remaining_prompt_tokens - self.model.prompt_headroom { + if snippet_tokens >= remaining_prompt_tokens - self.answer_model.prompt_headroom { info!("breaking at {} tokens", remaining_prompt_tokens); break; } @@ -251,8 +253,8 @@ impl Agent { /// Making this closure to 1 means that more of the context is taken up by source code. const CONTEXT_CODE_RATIO: f32 = 0.5; - let bpe = tiktoken_rs::get_bpe_from_model(self.model.tokenizer).unwrap(); - let context_size = tiktoken_rs::model::get_context_size(self.model.tokenizer); + let bpe = tiktoken_rs::get_bpe_from_model(self.answer_model.tokenizer).unwrap(); + let context_size = tiktoken_rs::model::get_context_size(self.answer_model.tokenizer); let max_tokens = (context_size as f32 * CONTEXT_CODE_RATIO) as usize; // Note: The end line number here is *not* inclusive. @@ -412,7 +414,7 @@ impl Agent { fn trim_utter_history( mut history: Vec, headroom: usize, - model: model::AnswerModel, + model: model::LLMModel, ) -> Result> { let mut tiktoken_msgs: Vec = history.iter().map(|m| m.into()).collect::>(); diff --git a/server/bleep/src/webserver/answer.rs b/server/bleep/src/webserver/answer.rs index 5b7b343cdf..8dce6b5a39 100644 --- a/server/bleep/src/webserver/answer.rs +++ b/server/bleep/src/webserver/answer.rs @@ -69,8 +69,10 @@ pub(super) async fn vote( pub struct Answer { pub q: String, pub repo_ref: RepoRef, - #[serde(default = "default_model")] - pub model: agent::model::AnswerModel, + #[serde(default = "default_answer_model")] + pub answer_model: agent::model::LLMModel, + #[serde(default = "default_agent_model")] + pub agent_model: agent::model::LLMModel, #[serde(default = "default_thread_id")] pub thread_id: uuid::Uuid, /// Optional id of the parent of the exchange to overwrite @@ -82,8 +84,12 @@ fn default_thread_id() -> uuid::Uuid { uuid::Uuid::new_v4() } -fn default_model() -> agent::model::AnswerModel { - agent::model::GPT_3_5_TURBO_FINETUNED +fn default_answer_model() -> agent::model::LLMModel { + agent::model::GPT_4_TURBO_24K +} + +fn default_agent_model() -> agent::model::LLMModel { + agent::model::GPT_4 } pub(super) async fn answer( @@ -207,12 +213,20 @@ async fn try_execute_agent( Sse> + Send>>>, > { QueryLog::new(&app.sql).insert(¶ms.q).await?; + let Answer { + thread_id, + repo_ref, + answer_model, + agent_model, + .. + } = params.clone(); let llm_gateway = user .llm_gateway(&app) .await? .temperature(0.0) - .session_reference_id(conversation_id.to_string()); + .session_reference_id(conversation_id.to_string()) + .model(agent_model.model_name); // confirm client compatibility with answer-api match llm_gateway @@ -243,12 +257,6 @@ async fn try_execute_agent( } }; - let Answer { - thread_id, - repo_ref, - model, - .. - } = params.clone(); let stream = async_stream::try_stream! { let (exchange_tx, exchange_rx) = tokio::sync::mpsc::channel(10); @@ -262,7 +270,8 @@ async fn try_execute_agent( thread_id, query_id, exchange_state: ExchangeState::Pending, - model, + answer_model, + agent_model }; let mut exchange_rx = tokio_stream::wrappers::ReceiverStream::new(exchange_rx); @@ -339,7 +348,7 @@ async fn try_execute_agent( Ok(sse::Event::default() .json_data(json!({ "thread_id": params.thread_id.to_string(), - "query_id": query_id + "query_id": query_id, })) // This should never happen, so we force an unwrap. .expect("failed to serialize initialization object")) @@ -391,7 +400,8 @@ pub async fn explain( repo_ref: params.repo_ref, thread_id: params.thread_id, parent_exchange_id: None, - model: agent::model::GPT_4, + answer_model: agent::model::GPT_4_TURBO_24K, + agent_model: agent::model::GPT_4, }; let conversation_id = ConversationId { diff --git a/server/bleep/src/webserver/studio.rs b/server/bleep/src/webserver/studio.rs index 12edeed4a8..be9c838a98 100644 --- a/server/bleep/src/webserver/studio.rs +++ b/server/bleep/src/webserver/studio.rs @@ -32,7 +32,7 @@ use crate::{ mod diff; -const LLM_GATEWAY_MODEL: &str = "gpt-4-0613"; +const LLM_GATEWAY_MODEL: &str = "gpt-4-1106-preview"; fn no_user_id() -> Error { Error::user("didn't have user ID") @@ -472,7 +472,7 @@ async fn token_counts( }) .collect::>(); - let core_bpe = tiktoken_rs::get_bpe_from_model("gpt-4-0613").unwrap(); + let core_bpe = tiktoken_rs::get_bpe_from_model("gpt-4-1106-preview").unwrap(); let per_doc_file = stream::iter(doc_context) .map(|file| async { if file.hidden { @@ -633,14 +633,14 @@ pub async fn get_doc_file_token_count( .map(|sr| sr.text) .collect::(); - let core_bpe = tiktoken_rs::get_bpe_from_model("gpt-4-0613").unwrap(); + let core_bpe = tiktoken_rs::get_bpe_from_model("gpt-4-1106-preview").unwrap(); let token_count = core_bpe.encode_ordinary(&content).len(); Ok(Json(token_count)) } fn count_tokens_for_file(path: &str, body: &str, ranges: &[Range]) -> usize { - let core_bpe = tiktoken_rs::get_bpe_from_model("gpt-4-0613").unwrap(); + let core_bpe = tiktoken_rs::get_bpe_from_model("gpt-4-1106-preview").unwrap(); let mut chunks = Vec::new();