From be752b833d273c4006c7ba36cdb815ae5b3ecb7c Mon Sep 17 00:00:00 2001 From: Li Xiangtian <1193027052@qq.com> Date: Tue, 31 Mar 2026 02:07:59 +0800 Subject: [PATCH] Add runtime session header support for OpenAI configs --- core/src/agent_api.rs | 112 +++++++++++++---------- core/src/config.rs | 198 ++++++++++++++++++++++++++++++++++++++++ core/src/llm/factory.rs | 45 ++++++++- core/src/llm/openai.rs | 46 ++++++++-- core/src/llm/tests.rs | 63 +++++++++++++ 5 files changed, 411 insertions(+), 53 deletions(-) diff --git a/core/src/agent_api.rs b/core/src/agent_api.rs index 75a3421..49f5021 100644 --- a/core/src/agent_api.rs +++ b/core/src/agent_api.rs @@ -783,39 +783,9 @@ impl Agent { ) -> Result { let opts = options.unwrap_or_default(); - let llm_client = if let Some(ref model) = opts.model { - let (provider_name, model_id) = model - .split_once('/') - .context("model format must be 'provider/model' (e.g., 'openai/gpt-4o')")?; - - let mut llm_config = self - .code_config - .llm_config(provider_name, model_id) - .with_context(|| { - format!("provider '{provider_name}' or model '{model_id}' not found in config") - })?; - - if let Some(temp) = opts.temperature { - llm_config = llm_config.with_temperature(temp); - } - if let Some(budget) = opts.thinking_budget { - llm_config = llm_config.with_thinking_budget(budget); - } - - crate::llm::create_client_with_config(llm_config) - } else { - if opts.temperature.is_some() || opts.thinking_budget.is_some() { - tracing::warn!( - "temperature/thinking_budget set without model override — these will be ignored. \ - Use with_model() to apply LLM parameter overrides." - ); - } - self.llm_client.clone() - }; - // Merge global MCP manager with any session-level one from opts. // If both exist, session-level servers are added into the global manager. - let merged_opts = match (&self.global_mcp, &opts.mcp_manager) { + let mut merged_opts = match (&self.global_mcp, &opts.mcp_manager) { (Some(global), Some(session)) => { let global = Arc::clone(global); let session_mgr = Arc::clone(session); @@ -857,6 +827,13 @@ impl Agent { _ => opts, }; + let session_id = merged_opts + .session_id + .clone() + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); + merged_opts.session_id = Some(session_id.clone()); + let llm_client = self.resolve_session_llm_client(&merged_opts, Some(&session_id))?; + self.build_session(workspace.into(), llm_client, &merged_opts) } @@ -963,21 +940,7 @@ impl Agent { // Build session with the saved workspace let mut opts = options; opts.session_id = Some(data.id.clone()); - - let llm_client = if let Some(ref model) = opts.model { - let (provider_name, model_id) = model - .split_once('/') - .context("model format must be 'provider/model'")?; - let llm_config = self - .code_config - .llm_config(provider_name, model_id) - .with_context(|| { - format!("provider '{provider_name}' or model '{model_id}' not found") - })?; - crate::llm::create_client_with_config(llm_config) - } else { - self.llm_client.clone() - }; + let llm_client = self.resolve_session_llm_client(&opts, Some(&data.id))?; let session = self.build_session(data.config.workspace.clone(), llm_client, &opts)?; @@ -987,6 +950,53 @@ impl Agent { Ok(session) } + fn resolve_session_llm_client( + &self, + opts: &SessionOptions, + session_id: Option<&str>, + ) -> Result> { + let model_ref = if let Some(ref model) = opts.model { + model.as_str() + } else { + if opts.temperature.is_some() || opts.thinking_budget.is_some() { + tracing::warn!( + "temperature/thinking_budget set without model override — these will be ignored. \ + Use with_model() to apply LLM parameter overrides." + ); + } + self.code_config + .default_model + .as_deref() + .context("default_model must be set in 'provider/model' format")? + }; + + let (provider_name, model_id) = model_ref + .split_once('/') + .context("model format must be 'provider/model' (e.g., 'openai/gpt-4o')")?; + + let mut llm_config = self + .code_config + .llm_config(provider_name, model_id) + .with_context(|| { + format!("provider '{provider_name}' or model '{model_id}' not found in config") + })?; + + if opts.model.is_some() { + if let Some(temp) = opts.temperature { + llm_config = llm_config.with_temperature(temp); + } + if let Some(budget) = opts.thinking_budget { + llm_config = llm_config.with_thinking_budget(budget); + } + } + + if let Some(session_id) = session_id { + llm_config = llm_config.with_session_id(session_id); + } + + Ok(crate::llm::create_client_with_config(llm_config)) + } + fn build_session( &self, workspace: String, @@ -2710,12 +2720,16 @@ mod tests { name: "anthropic".to_string(), api_key: Some("test-key".to_string()), base_url: None, + headers: std::collections::HashMap::new(), + session_id_header: None, models: vec![ModelConfig { id: "claude-sonnet-4-20250514".to_string(), name: "Claude Sonnet 4".to_string(), family: "claude-sonnet".to_string(), api_key: None, base_url: None, + headers: std::collections::HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -2730,12 +2744,16 @@ mod tests { name: "openai".to_string(), api_key: Some("test-openai-key".to_string()), base_url: None, + headers: std::collections::HashMap::new(), + session_id_header: None, models: vec![ModelConfig { id: "gpt-4o".to_string(), name: "GPT-4o".to_string(), family: "gpt-4".to_string(), api_key: None, base_url: None, + headers: std::collections::HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -3387,6 +3405,8 @@ dir content name: "anthropic".to_string(), api_key: Some("test-key".to_string()), base_url: None, + headers: std::collections::HashMap::new(), + session_id_header: None, models: vec![], }], ..Default::default() diff --git a/core/src/config.rs b/core/src/config.rs index 3678811..af95b03 100644 --- a/core/src/config.rs +++ b/core/src/config.rs @@ -14,6 +14,7 @@ use crate::llm::LlmConfig; use crate::memory::MemoryConfig; use serde::{Deserialize, Serialize}; use serde_json::Value as JsonValue; +use std::collections::HashMap; use std::path::{Path, PathBuf}; // ============================================================================ @@ -78,6 +79,12 @@ pub struct ModelConfig { /// Per-model base URL override #[serde(default)] pub base_url: Option, + /// Static HTTP headers for this model + #[serde(default)] + pub headers: HashMap, + /// Header name to receive the runtime session ID + #[serde(default)] + pub session_id_header: Option, /// Supports file attachments #[serde(default)] pub attachment: bool, @@ -120,6 +127,12 @@ pub struct ProviderConfig { /// Base URL for the API #[serde(default)] pub base_url: Option, + /// Static HTTP headers for this provider + #[serde(default)] + pub headers: HashMap, + /// Header name to receive the runtime session ID + #[serde(default)] + pub session_id_header: Option, /// Available models #[serde(default)] pub models: Vec, @@ -171,6 +184,21 @@ impl ProviderConfig { pub fn get_base_url<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> { model.base_url.as_deref().or(self.base_url.as_deref()) } + + /// Get the effective static headers for a model (provider defaults with model overrides) + pub fn get_headers(&self, model: &ModelConfig) -> HashMap { + let mut headers = self.headers.clone(); + headers.extend(model.headers.clone()); + headers + } + + /// Get the header name that should carry the runtime session ID. + pub fn get_session_id_header<'a>(&'a self, model: &'a ModelConfig) -> Option<&'a str> { + model + .session_id_header + .as_deref() + .or(self.session_id_header.as_deref()) + } } // ============================================================================ @@ -600,11 +628,19 @@ impl CodeConfig { let (provider, model) = self.default_model_config()?; let api_key = provider.get_api_key(model)?; let base_url = provider.get_base_url(model); + let headers = provider.get_headers(model); + let session_id_header = provider.get_session_id_header(model); let mut config = LlmConfig::new(&provider.name, &model.id, api_key); if let Some(url) = base_url { config = config.with_base_url(url); } + if !headers.is_empty() { + config = config.with_headers(headers); + } + if let Some(header_name) = session_id_header { + config = config.with_session_id_header(header_name); + } config = apply_model_caps(config, model, self.thinking_budget); Some(config) } @@ -617,11 +653,19 @@ impl CodeConfig { let model = provider.find_model(model_id)?; let api_key = provider.get_api_key(model)?; let base_url = provider.get_base_url(model); + let headers = provider.get_headers(model); + let session_id_header = provider.get_session_id_header(model); let mut config = LlmConfig::new(&provider.name, &model.id, api_key); if let Some(url) = base_url { config = config.with_base_url(url); } + if !headers.is_empty() { + config = config.with_headers(headers); + } + if let Some(header_name) = session_id_header { + config = config.with_session_id_header(header_name); + } config = apply_model_caps(config, model, self.thinking_budget); Some(config) } @@ -907,12 +951,16 @@ mod tests { name: "anthropic".to_string(), api_key: Some("key1".to_string()), base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![], }, ProviderConfig { name: "openai".to_string(), api_key: Some("key2".to_string()), base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![], }, ], @@ -932,12 +980,16 @@ mod tests { name: "anthropic".to_string(), api_key: Some("test-api-key".to_string()), base_url: Some("https://api.anthropic.com".to_string()), + headers: HashMap::new(), + session_id_header: None, models: vec![ModelConfig { id: "claude-sonnet-4".to_string(), name: "Claude Sonnet 4".to_string(), family: "claude-sonnet".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -967,6 +1019,8 @@ mod tests { name: "openai".to_string(), api_key: Some("provider-key".to_string()), base_url: Some("https://api.openai.com".to_string()), + headers: HashMap::new(), + session_id_header: None, models: vec![ ModelConfig { id: "gpt-4".to_string(), @@ -974,6 +1028,8 @@ mod tests { family: "gpt".to_string(), api_key: None, // Uses provider key base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -989,6 +1045,8 @@ mod tests { family: "custom".to_string(), api_key: Some("model-specific-key".to_string()), // Override base_url: Some("https://custom.api.com".to_string()), // Override + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1026,6 +1084,8 @@ mod tests { name: "anthropic".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![ ModelConfig { id: "claude-1".to_string(), @@ -1033,6 +1093,8 @@ mod tests { family: "claude".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1048,6 +1110,8 @@ mod tests { family: "claude".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1063,12 +1127,16 @@ mod tests { name: "openai".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![ModelConfig { id: "gpt-4".to_string(), name: "GPT-4".to_string(), family: "gpt".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1115,6 +1183,8 @@ mod tests { name: "test".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![], }], ..Default::default() @@ -1233,6 +1303,8 @@ mod tests { family: "gpt-4".to_string(), api_key: Some("sk-test".to_string()), base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: true, reasoning: false, tool_call: true, @@ -1294,6 +1366,8 @@ mod tests { name: "anthropic".to_string(), api_key: Some("sk-test".to_string()), base_url: Some("https://api.anthropic.com".to_string()), + headers: HashMap::new(), + session_id_header: None, models: vec![], }; let json = serde_json::to_string(&provider).unwrap(); @@ -1317,12 +1391,16 @@ mod tests { name: "anthropic".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![ModelConfig { id: "claude-sonnet-4".to_string(), name: "Claude Sonnet 4".to_string(), family: "claude-sonnet".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1348,6 +1426,8 @@ mod tests { name: "anthropic".to_string(), api_key: Some("provider-key".to_string()), base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![], }; @@ -1357,6 +1437,8 @@ mod tests { family: "".to_string(), api_key: Some("model-key".to_string()), base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1373,6 +1455,8 @@ mod tests { family: "".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1390,6 +1474,98 @@ mod tests { ); } + #[test] + fn test_provider_config_get_headers_and_session_id_header() { + let mut provider_headers = HashMap::new(); + provider_headers.insert("X-Provider".to_string(), "provider".to_string()); + provider_headers.insert("X-Shared".to_string(), "provider".to_string()); + + let mut model_headers = HashMap::new(); + model_headers.insert("X-Model".to_string(), "model".to_string()); + model_headers.insert("X-Shared".to_string(), "model".to_string()); + + let provider = ProviderConfig { + name: "openai".to_string(), + api_key: Some("provider-key".to_string()), + base_url: None, + headers: provider_headers, + session_id_header: Some("X-Session-Id".to_string()), + models: vec![], + }; + + let model = ModelConfig { + id: "gpt-4o".to_string(), + name: "".to_string(), + family: "".to_string(), + api_key: None, + base_url: None, + headers: model_headers, + session_id_header: Some("X-Model-Session".to_string()), + attachment: false, + reasoning: false, + tool_call: true, + temperature: true, + release_date: None, + modalities: ModelModalities::default(), + cost: ModelCost::default(), + limit: ModelLimit::default(), + }; + + let headers = provider.get_headers(&model); + assert_eq!(headers.get("X-Provider"), Some(&"provider".to_string())); + assert_eq!(headers.get("X-Model"), Some(&"model".to_string())); + assert_eq!(headers.get("X-Shared"), Some(&"model".to_string())); + assert_eq!( + provider.get_session_id_header(&model), + Some("X-Model-Session") + ); + } + + #[test] + fn test_llm_config_includes_headers_and_runtime_session_header() { + let mut provider_headers = HashMap::new(); + provider_headers.insert("X-Provider".to_string(), "provider".to_string()); + + let config = CodeConfig { + default_model: Some("openai/gpt-4o".to_string()), + providers: vec![ProviderConfig { + name: "openai".to_string(), + api_key: Some("sk-test".to_string()), + base_url: Some("https://api.example.com".to_string()), + headers: provider_headers, + session_id_header: Some("X-Session-Id".to_string()), + models: vec![ModelConfig { + id: "gpt-4o".to_string(), + name: "".to_string(), + family: "".to_string(), + api_key: None, + base_url: None, + headers: HashMap::new(), + session_id_header: None, + attachment: false, + reasoning: false, + tool_call: true, + temperature: true, + release_date: None, + modalities: ModelModalities::default(), + cost: ModelCost::default(), + limit: ModelLimit::default(), + }], + }], + ..Default::default() + }; + + let llm_config = config.default_llm_config().unwrap(); + assert_eq!( + llm_config.headers.get("X-Provider"), + Some(&"provider".to_string()) + ); + assert_eq!( + llm_config.session_id_header.as_deref(), + Some("X-Session-Id") + ); + } + #[test] fn test_code_config_default_provider_config() { let config = CodeConfig { @@ -1398,6 +1574,8 @@ mod tests { name: "anthropic".to_string(), api_key: Some("sk-test".to_string()), base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![], }], ..Default::default() @@ -1416,12 +1594,16 @@ mod tests { name: "anthropic".to_string(), api_key: Some("sk-test".to_string()), base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![ModelConfig { id: "claude-sonnet-4".to_string(), name: "Claude Sonnet 4".to_string(), family: "claude-sonnet".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1450,12 +1632,16 @@ mod tests { name: "anthropic".to_string(), api_key: Some("sk-test".to_string()), base_url: Some("https://api.anthropic.com".to_string()), + headers: HashMap::new(), + session_id_header: None, models: vec![ModelConfig { id: "claude-sonnet-4".to_string(), name: "Claude Sonnet 4".to_string(), family: "claude-sonnet".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1481,12 +1667,16 @@ mod tests { name: "anthropic".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![ModelConfig { id: "claude-sonnet-4".to_string(), name: "".to_string(), family: "".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1501,12 +1691,16 @@ mod tests { name: "openai".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![ModelConfig { id: "gpt-4o".to_string(), name: "".to_string(), family: "".to_string(), api_key: None, base_url: None, + headers: HashMap::new(), + session_id_header: None, attachment: false, reasoning: false, tool_call: true, @@ -1538,6 +1732,8 @@ mod tests { name: "anthropic".to_string(), api_key: Some("sk-test".to_string()), base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![model], }], ..Default::default() @@ -1563,6 +1759,8 @@ mod tests { name: "anthropic".to_string(), api_key: Some("sk-test".to_string()), base_url: None, + headers: HashMap::new(), + session_id_header: None, models: vec![], }], ..Default::default() diff --git a/core/src/llm/factory.rs b/core/src/llm/factory.rs index 5ece7c1..4516bb5 100644 --- a/core/src/llm/factory.rs +++ b/core/src/llm/factory.rs @@ -6,6 +6,7 @@ use super::types::SecretString; use super::zhipu::ZhipuClient; use super::LlmClient; use crate::retry::RetryConfig; +use std::collections::HashMap; use std::sync::Arc; /// LLM client configuration @@ -15,6 +16,9 @@ pub struct LlmConfig { pub model: String, pub api_key: SecretString, pub base_url: Option, + pub headers: HashMap, + pub session_id_header: Option, + pub session_id: Option, pub retry_config: Option, /// Sampling temperature (0.0–1.0). None uses the provider default. pub temperature: Option, @@ -33,6 +37,12 @@ impl std::fmt::Debug for LlmConfig { .field("model", &self.model) .field("api_key", &"[REDACTED]") .field("base_url", &self.base_url) + .field("headers", &self.headers.keys().collect::>()) + .field("session_id_header", &self.session_id_header) + .field( + "session_id", + &self.session_id.as_ref().map(|_| "[REDACTED]"), + ) .field("retry_config", &self.retry_config) .field("temperature", &self.temperature) .field("max_tokens", &self.max_tokens) @@ -53,6 +63,9 @@ impl LlmConfig { model: model.into(), api_key: SecretString::new(api_key.into()), base_url: None, + headers: HashMap::new(), + session_id_header: None, + session_id: None, retry_config: None, temperature: None, max_tokens: None, @@ -66,6 +79,21 @@ impl LlmConfig { self } + pub fn with_headers(mut self, headers: HashMap) -> Self { + self.headers = headers; + self + } + + pub fn with_session_id_header(mut self, header_name: impl Into) -> Self { + self.session_id_header = Some(header_name.into()); + self + } + + pub fn with_session_id(mut self, session_id: impl Into) -> Self { + self.session_id = Some(session_id.into()); + self + } + pub fn with_retry_config(mut self, retry_config: RetryConfig) -> Self { self.retry_config = Some(retry_config); self @@ -85,12 +113,21 @@ impl LlmConfig { self.thinking_budget = Some(budget); self } + + pub(crate) fn resolved_headers(&self) -> HashMap { + let mut headers = self.headers.clone(); + if let (Some(header_name), Some(session_id)) = (&self.session_id_header, &self.session_id) { + headers.insert(header_name.clone(), session_id.clone()); + } + headers + } } /// Create LLM client with full configuration (supports custom base_url) pub fn create_client_with_config(config: LlmConfig) -> Arc { - let retry = config.retry_config.unwrap_or_default(); + let retry = config.retry_config.clone().unwrap_or_default(); let api_key = config.api_key.expose().to_string(); + let headers = config.resolved_headers(); match config.provider.as_str() { "anthropic" | "claude" => { @@ -120,6 +157,9 @@ pub fn create_client_with_config(config: LlmConfig) -> Arc { if let Some(base_url) = config.base_url { client = client.with_base_url(base_url); } + if !headers.is_empty() { + client = client.with_headers(headers.clone()); + } if !config.disable_temperature { if let Some(temp) = config.temperature { client = client.with_temperature(temp); @@ -157,6 +197,9 @@ pub fn create_client_with_config(config: LlmConfig) -> Arc { if let Some(base_url) = config.base_url { client = client.with_base_url(base_url); } + if !headers.is_empty() { + client = client.with_headers(headers.clone()); + } if !config.disable_temperature { if let Some(temp) = config.temperature { client = client.with_temperature(temp); diff --git a/core/src/llm/openai.rs b/core/src/llm/openai.rs index 54405ea..fbbc79e 100644 --- a/core/src/llm/openai.rs +++ b/core/src/llm/openai.rs @@ -9,6 +9,7 @@ use anyhow::{Context, Result}; use async_trait::async_trait; use futures::StreamExt; use serde::Deserialize; +use std::collections::HashMap; use std::sync::Arc; use std::time::Instant; use tokio::sync::mpsc; @@ -20,6 +21,7 @@ pub struct OpenAiClient { pub(crate) model: String, pub(crate) base_url: String, pub(crate) chat_completions_path: String, + pub(crate) headers: HashMap, pub(crate) temperature: Option, pub(crate) max_tokens: Option, pub(crate) http: Arc, @@ -76,6 +78,7 @@ impl OpenAiClient { model, base_url: "https://api.openai.com".to_string(), chat_completions_path: "/v1/chat/completions".to_string(), + headers: HashMap::new(), temperature: None, max_tokens: None, http: default_http_client(), @@ -108,6 +111,11 @@ impl OpenAiClient { self } + pub fn with_headers(mut self, headers: HashMap) -> Self { + self.headers = headers; + self + } + pub fn with_max_tokens(mut self, max_tokens: usize) -> Self { self.max_tokens = Some(max_tokens); self @@ -123,6 +131,26 @@ impl OpenAiClient { self } + pub(crate) fn request_headers(&self) -> Vec<(String, String)> { + let mut headers = Vec::with_capacity(self.headers.len() + 1); + let has_authorization = self + .headers + .keys() + .any(|key| key.eq_ignore_ascii_case("authorization")); + if !has_authorization { + headers.push(( + "Authorization".to_string(), + format!("Bearer {}", self.api_key.expose()), + )); + } + headers.extend( + self.headers + .iter() + .map(|(key, value)| (key.clone(), value.clone())), + ); + headers + } + pub(crate) fn convert_messages(&self, messages: &[Message]) -> Vec { messages .iter() @@ -282,15 +310,18 @@ impl LlmClient for OpenAiClient { } let url = format!("{}{}", self.base_url, self.chat_completions_path); - let auth_header = format!("Bearer {}", self.api_key.expose()); - let headers = vec![("Authorization", auth_header.as_str())]; + let request_headers = self.request_headers(); let response = crate::retry::with_retry(&self.retry_config, |_attempt| { let http = &self.http; let url = &url; - let headers = headers.clone(); + let request_headers = request_headers.clone(); let request = &request; async move { + let headers = request_headers + .iter() + .map(|(key, value)| (key.as_str(), value.as_str())) + .collect::>(); match http.post(url, headers, request).await { Ok(resp) => { let status = reqwest::StatusCode::from_u16(resp.status) @@ -430,15 +461,18 @@ impl LlmClient for OpenAiClient { } let url = format!("{}{}", self.base_url, self.chat_completions_path); - let auth_header = format!("Bearer {}", self.api_key.expose()); - let headers = vec![("Authorization", auth_header.as_str())]; + let request_headers = self.request_headers(); let streaming_resp = crate::retry::with_retry(&self.retry_config, |_attempt| { let http = &self.http; let url = &url; - let headers = headers.clone(); + let request_headers = request_headers.clone(); let request = &request; async move { + let headers = request_headers + .iter() + .map(|(key, value)| (key.as_str(), value.as_str())) + .collect::>(); match http.post_streaming(url, headers, request).await { Ok(resp) => { let status = reqwest::StatusCode::from_u16(resp.status) diff --git a/core/src/llm/tests.rs b/core/src/llm/tests.rs index 2fc4424..c006647 100644 --- a/core/src/llm/tests.rs +++ b/core/src/llm/tests.rs @@ -1537,6 +1537,51 @@ mod extra_llm_tests2 { assert_eq!(client.base_url, "https://custom.openai.com"); } + #[test] + fn test_openai_client_request_headers_with_custom_headers() { + let mut headers = std::collections::HashMap::new(); + headers.insert("X-Session-Id".to_string(), "sess-123".to_string()); + headers.insert("X-Test".to_string(), "value".to_string()); + + let client = + OpenAiClient::new("key".to_string(), "model".to_string()).with_headers(headers); + let request_headers = client.request_headers(); + + assert!(request_headers + .iter() + .any(|(key, value)| key == "Authorization" && value == "Bearer key")); + assert!(request_headers + .iter() + .any(|(key, value)| key == "X-Session-Id" && value == "sess-123")); + assert!(request_headers + .iter() + .any(|(key, value)| key == "X-Test" && value == "value")); + } + + #[test] + fn test_openai_client_request_headers_respects_custom_authorization() { + let mut headers = std::collections::HashMap::new(); + headers.insert( + "Authorization".to_string(), + "Bearer override-token".to_string(), + ); + + let client = + OpenAiClient::new("key".to_string(), "model".to_string()).with_headers(headers); + let request_headers = client.request_headers(); + + assert_eq!( + request_headers + .iter() + .filter(|(key, _)| key.eq_ignore_ascii_case("authorization")) + .count(), + 1 + ); + assert!(request_headers.iter().any(|(key, value)| { + key.eq_ignore_ascii_case("authorization") && value == "Bearer override-token" + })); + } + #[test] fn test_anthropic_client_with_base_url() { let client = AnthropicClient::new("key".to_string(), "model".to_string()) @@ -1907,6 +1952,24 @@ mod extra_llm_tests2 { assert_eq!(config.base_url, Some("https://custom.api.com".to_string())); } + #[test] + fn test_llm_config_resolved_headers_with_runtime_session() { + let mut headers = std::collections::HashMap::new(); + headers.insert("X-Test".to_string(), "value".to_string()); + + let config = LlmConfig::new("openai", "gpt-4", "key") + .with_headers(headers) + .with_session_id_header("X-Session-Id") + .with_session_id("sess-456"); + + let resolved_headers = config.resolved_headers(); + assert_eq!( + resolved_headers.get("X-Session-Id"), + Some(&"sess-456".to_string()) + ); + assert_eq!(resolved_headers.get("X-Test"), Some(&"value".to_string())); + } + #[test] fn test_llm_config_with_retry_config() { let retry = RetryConfig::default();