From 62c486944d8c11da5d604cfeec50fe2373b151d6 Mon Sep 17 00:00:00 2001 From: Adrian Cole Date: Tue, 10 Feb 2026 18:12:11 +0800 Subject: [PATCH] feat(claude-code): dynamic model listing and mid-session model switching Before, the claude provider hard-coded 'claude-sonnet-4-20250514' instead of using "default", didn't support model switching mid-session, and didn't support listing available models dynamically. This uses the stream-json control protocol to fetch models via initialize and switch models via set_model. Also removes buggy ContextLengthExceeded handling as claude-code responses are unpredictable for oversized context. Signed-off-by: Adrian Cole --- crates/goose/src/providers/claude_code.rs | 408 ++++++++++++++++++---- crates/goose/tests/providers.rs | 127 ++++++- 2 files changed, 457 insertions(+), 78 deletions(-) diff --git a/crates/goose/src/providers/claude_code.rs b/crates/goose/src/providers/claude_code.rs index ebc89ae0309f..73e5205ad8b9 100644 --- a/crates/goose/src/providers/claude_code.rs +++ b/crates/goose/src/providers/claude_code.rs @@ -20,17 +20,102 @@ use crate::subprocess::configure_subprocess; use rmcp::model::Tool; const CLAUDE_CODE_PROVIDER_NAME: &str = "claude-code"; -pub const CLAUDE_CODE_DEFAULT_MODEL: &str = "claude-sonnet-4-20250514"; -pub const CLAUDE_CODE_KNOWN_MODELS: &[&str] = &["sonnet", "opus"]; +pub const CLAUDE_CODE_DEFAULT_MODEL: &str = "default"; pub const CLAUDE_CODE_DOC_URL: &str = "https://code.claude.com/docs/en/setup"; -#[derive(Debug)] struct CliProcess { child: tokio::process::Child, - stdin: tokio::process::ChildStdin, - reader: BufReader, + stdin: Box, + reader: BufReader>, #[allow(dead_code)] stderr_handle: tokio::task::JoinHandle, + current_model: String, + log_model_update: bool, + next_request_id: u64, +} + +impl std::fmt::Debug for CliProcess { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CliProcess") + .field("current_model", &self.current_model) + .field("next_request_id", &self.next_request_id) + .finish_non_exhaustive() + } +} + +impl CliProcess { + fn next_request_id(&mut self) -> String { + let id = self.next_request_id; + self.next_request_id += 1; + format!("req_{id}") + } + + /// Send a `set_model` control request and wait for the response before returning. + /// Skips the request if the model is already active. + async fn send_set_model(&mut self, model: &str) -> Result<(), ProviderError> { + if model == self.current_model { + return Ok(()); + } + + let request_id = self.next_request_id(); + let req = json!({ + "type": "control_request", + "request_id": request_id, + "request": {"subtype": "set_model", "model": model} + }); + let mut req_str = serde_json::to_string(&req).unwrap(); + req_str.push('\n'); + self.stdin + .write_all(req_str.as_bytes()) + .await + .map_err(|e| { + ProviderError::RequestFailed(format!("Failed to write set_model request: {e}")) + })?; + + // Read lines until we get the control_response for our request. + let mut line = String::new(); + loop { + line.clear(); + match self.reader.read_line(&mut line).await { + Ok(0) => { + return Err(ProviderError::RequestFailed( + "CLI process terminated while waiting for set_model response".to_string(), + )); + } + Ok(_) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + if let Ok(parsed) = serde_json::from_str::(trimmed) { + if parsed.get("type").and_then(|t| t.as_str()) == Some("control_response") { + let success = + parsed.pointer("/response/subtype").and_then(|s| s.as_str()) + == Some("success"); + if success { + self.current_model = model.to_string(); + self.log_model_update = true; + return Ok(()); + } else { + let err = parsed + .pointer("/response/error") + .and_then(|e| e.as_str()) + .unwrap_or("unknown"); + return Err(ProviderError::RequestFailed(format!( + "set_model failed: {err}" + ))); + } + } + } + } + Err(e) => { + return Err(ProviderError::RequestFailed(format!( + "Failed to read set_model response: {e}" + ))); + } + } + } + } } impl Drop for CliProcess { @@ -110,6 +195,20 @@ impl ClaudeCodeProvider { blocks } + fn build_stream_json_command(&self) -> Command { + let mut cmd = Command::new(&self.command); + configure_subprocess(&mut cmd); + cmd.arg("--input-format") + .arg("stream-json") + .arg("--output-format") + .arg("stream-json") + .arg("--verbose") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + cmd + } + fn apply_permission_flags(cmd: &mut Command) -> Result<(), ProviderError> { let config = Config::global(); let goose_mode = config.get_goose_mode().unwrap_or(GooseMode::Auto); @@ -208,11 +307,6 @@ impl ClaudeCodeProvider { .get("error") .and_then(|e| e.as_str()) .unwrap_or("Unknown error"); - if error_msg.contains("context") && error_msg.contains("exceeded") { - return Err(ProviderError::ContextLengthExceeded( - error_msg.to_string(), - )); - } return Err(ProviderError::RequestFailed(format!( "Claude CLI error: {}", error_msg @@ -226,9 +320,6 @@ impl ClaudeCodeProvider { // Combine all text content into a single message let combined_text = all_text_content.join("\n\n"); - if combined_text.contains("Prompt is too long") { - return Err(ProviderError::ContextLengthExceeded(combined_text)); - } if combined_text.is_empty() { return Err(ProviderError::RequestFailed( "No text content found in response".to_string(), @@ -252,6 +343,7 @@ impl ClaudeCodeProvider { messages: &[Message], _tools: &[Tool], session_id: &str, + model: &str, ) -> Result, ProviderError> { let filtered_system = filter_extensions_from_system_prompt(system); @@ -271,30 +363,14 @@ impl ClaudeCodeProvider { let process_mutex = self .cli_process .get_or_try_init(|| async { - let mut cmd = Command::new(&self.command); - // NO -p flag — persistent mode - configure_subprocess(&mut cmd); - cmd.arg("--input-format") - .arg("stream-json") - .arg("--output-format") - .arg("stream-json") - .arg("--verbose") - // System prompt is set once at process start. The provider - // instance is not reused across sessions with different prompts. - .arg("--system-prompt") - .arg(&filtered_system); - - // Only pass model parameter if it's in the known models list - if CLAUDE_CODE_KNOWN_MODELS.contains(&self.model.model_name.as_str()) { - cmd.arg("--model").arg(&self.model.model_name); - } + let mut cmd = self.build_stream_json_command(); + // System prompt is set once at process start and cannot be updated at runtime. + cmd.arg("--system-prompt").arg(&filtered_system); - // Add permission mode based on GOOSE_MODE setting - Self::apply_permission_flags(&mut cmd)?; + // The initial model can be updated later. + cmd.arg("--model").arg(&self.model.model_name); - cmd.stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .stderr(Stdio::piped()); + Self::apply_permission_flags(&mut cmd)?; let mut child = cmd.spawn().map_err(|e| { ProviderError::RequestFailed(format!( @@ -323,15 +399,21 @@ impl ClaudeCodeProvider { Ok::<_, ProviderError>(tokio::sync::Mutex::new(CliProcess { child, - stdin, - reader: BufReader::new(stdout), + stdin: Box::new(stdin), + reader: BufReader::new(Box::new(stdout)), stderr_handle, + current_model: String::new(), + log_model_update: false, + next_request_id: 0, })) }) .await?; let mut process = process_mutex.lock().await; + // Switch model if it differs from what the CLI is currently using. + process.send_set_model(model).await?; + let blocks = self.last_user_content_blocks(messages); // Write NDJSON line to stdin @@ -367,11 +449,26 @@ impl ClaudeCodeProvider { } lines.push(trimmed.to_string()); - // Check if this is a result event (end of turn) if let Ok(parsed) = serde_json::from_str::(trimmed) { match parsed.get("type").and_then(|t| t.as_str()) { Some("result") => break, Some("error") => break, + // The system init with the resolved model arrives here, + // not in send_set_model (which only sees control_response): + // send_set_model: {"type":"control_response",...} + // execute_command: {"type":"system",...,"model":"claude-sonnet-4-5-20250929",...} + Some("system") if process.log_model_update => { + if let Some(resolved) = parsed.get("model").and_then(|m| m.as_str()) + { + if std::env::var("GOOSE_CLAUDE_CODE_DEBUG").is_ok() { + println!( + "set_model: {} resolved to {}", + process.current_model, resolved + ); + } + } + process.log_model_update = false; + } _ => {} } } @@ -439,6 +536,32 @@ impl ClaudeCodeProvider { } } +/// Extract model aliases from the CLI's initialize control_response. +fn parse_models_from_lines(lines: &[String]) -> Vec { + for line in lines { + if let Ok(parsed) = serde_json::from_str::(line) { + if parsed.get("type").and_then(|t| t.as_str()) != Some("control_response") { + continue; + } + let success = + parsed.pointer("/response/subtype").and_then(|s| s.as_str()) == Some("success"); + if !success { + continue; + } + if let Some(models) = parsed + .pointer("/response/response/models") + .and_then(|m| m.as_array()) + { + return models + .iter() + .filter_map(|m| m.get("value").and_then(|v| v.as_str()).map(String::from)) + .collect(); + } + } + } + Vec::new() +} + fn build_stream_json_input(content_blocks: &[Value], session_id: &str) -> String { let msg = json!({"type":"user","session_id":session_id,"message":{"role":"user","content":content_blocks}}); serde_json::to_string(&msg).expect("serializing JSON content blocks cannot fail") @@ -454,10 +577,14 @@ impl ProviderDef for ClaudeCodeProvider { "Claude Code CLI", "Requires claude CLI installed, no MCPs. Use Anthropic provider for full features.", CLAUDE_CODE_DEFAULT_MODEL, - CLAUDE_CODE_KNOWN_MODELS.to_vec(), + // Only a few agentic choices; fetched dynamically via fetch_supported_models. + vec![], CLAUDE_CODE_DOC_URL, vec![ConfigKey::from_value_type::(true, false)], ) + // The model list only returns aliases the `claude` CLI uses, such as "default" + // and "haiku". There is no listing that includes full names like + // "claude-sonnet-4-5-20250929". However, they are permitted. .with_unlisted_models() } @@ -489,10 +616,62 @@ impl Provider for ClaudeCodeProvider { } async fn fetch_supported_models(&self) -> Result, ProviderError> { - Ok(CLAUDE_CODE_KNOWN_MODELS - .iter() - .map(|s| s.to_string()) - .collect()) + // Uses a separate short-lived process because --system-prompt is a CLI-only + // flag with no NDJSON equivalent. The persistent process needs it at spawn, + // but it's unavailable during model listing. + // See: https://code.claude.com/docs/en/cli-reference#system-prompt-flags + let mut cmd = self.build_stream_json_command(); + let mut child = cmd.spawn().map_err(|e| { + ProviderError::RequestFailed(format!("Failed to spawn CLI for model listing: {e}")) + })?; + + let mut stdin = child + .stdin + .take() + .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdin".to_string()))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| ProviderError::RequestFailed("Failed to capture stdout".to_string()))?; + + let request = json!({ + "type": "control_request", + "request_id": "model_list", + "request": {"subtype": "initialize"} + }); + let mut request_str = serde_json::to_string(&request).unwrap(); + request_str.push('\n'); + stdin.write_all(request_str.as_bytes()).await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to write initialize request: {e}")) + })?; + + let mut reader = BufReader::new(stdout); + let mut lines = Vec::new(); + let mut line = String::new(); + + // Read until we see a control_response or hit EOF + loop { + line.clear(); + match reader.read_line(&mut line).await { + Ok(0) => break, + Ok(_) => { + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + lines.push(trimmed.to_string()); + if let Ok(parsed) = serde_json::from_str::(trimmed) { + if parsed.get("type").and_then(|t| t.as_str()) == Some("control_response") { + break; + } + } + } + Err(_) => break, + } + } + + let _ = child.start_kill(); + Ok(parse_models_from_lines(&lines)) } #[tracing::instrument( @@ -514,7 +693,9 @@ impl Provider for ClaudeCodeProvider { // session_id is None before a session is created (e.g. model listing). let sid = session_id.unwrap_or("default"); - let json_lines = self.execute_command(system, messages, tools, sid).await?; + let json_lines = self + .execute_command(system, messages, tools, sid, &model_config.model_name) + .await?; let (message, usage) = self.parse_claude_response(&json_lines)?; @@ -648,11 +829,11 @@ mod tests { #[test_case( &[ - r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"The answer is 2."}],"usage":{"input_tokens":100,"output_tokens":20}}}"#, - r#"{"type":"result","subtype":"success","result":"The answer is 2.","session_id":"abc"}"#, + r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello! How can I help you today?"}],"usage":{"input_tokens":3,"output_tokens":3}}}"#, + r#"{"type":"result","usage":{"input_tokens":3,"output_tokens":16}}"#, ], - "The answer is 2.", - Some(100), Some(20) + "Hello! How can I help you today?", + Some(3), Some(3) ; "assistant_with_usage" )] #[test_case( @@ -665,12 +846,12 @@ mod tests { )] #[test_case( &[ - r#"{"type":"system","subtype":"init","session_id":"abc"}"#, - r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Hello"}]}}"#, - r#"{"type":"result","subtype":"success","result":"Hello","session_id":"abc"}"#, + r#"{"type":"system","model":"claude-opus-4-6"}"#, + r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello!"}],"usage":{"input_tokens":3,"output_tokens":3}}}"#, + r#"{"type":"result","usage":{"input_tokens":3,"output_tokens":16}}"#, ], - "Hello", - None, None + "Hello!", + Some(3), Some(3) ; "system_init_filtered" )] fn test_parse_claude_response_ok( @@ -699,7 +880,7 @@ mod tests { )] #[test_case( &[r#"{"type":"error","error":"context window exceeded"}"#], - ProviderError::ContextLengthExceeded("context window exceeded".into()) + ProviderError::RequestFailed("Claude CLI error: context window exceeded".into()) ; "context_length" )] #[test_case( @@ -707,11 +888,6 @@ mod tests { ProviderError::RequestFailed("Claude CLI error: Model not supported".into()) ; "generic_error" )] - #[test_case( - &[r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Prompt is too long"}]}}"#], - ProviderError::ContextLengthExceeded("Prompt is too long".into()) - ; "prompt_too_long_exact" - )] fn test_parse_claude_response_err(lines: &[&str], expected: ProviderError) { let provider = make_provider(); let lines: Vec = lines.iter().map(|s| s.to_string()).collect(); @@ -721,12 +897,124 @@ mod tests { ); } + #[test_case( + &[ + r#"{"type":"control_response","response":{"subtype":"success","request_id":"model_list","response":{"models":[{"value":"default","displayName":"Default (recommended)","description":"Opus 4.6 · Most capable for complex work"},{"value":"sonnet","displayName":"Sonnet","description":"Sonnet 4.5 · Best for everyday tasks"},{"value":"haiku","displayName":"Haiku","description":"Haiku 4.5 · Fastest for quick answers"}]}}}"#, + ], + &["default", "sonnet", "haiku"] + ; "success" + )] + #[test_case( + &[ + r#"{"type":"control_response","response":{"subtype":"success","request_id":"model_list","response":{"models":[{"value":"default","displayName":"Default","description":"..."},{"value":null,"displayName":"Bad","description":"..."}]}}}"#, + ], + &["default"] + ; "filters_null_values" + )] + #[test_case( + &[r#"{"type":"system","subtype":"init","session_id":"abc"}"#], + &[] + ; "no_control_response" + )] + #[test_case( + &[r#"{"type":"control_response","response":{"subtype":"error","request_id":"req_1","error":"fail"}}"#], + &[] + ; "error_response" + )] + fn test_parse_models_from_lines(lines: &[&str], expected: &[&str]) { + let lines: Vec = lines.iter().map(|s| s.to_string()).collect(); + let result = parse_models_from_lines(&lines); + let expected: Vec = expected.iter().map(|s| s.to_string()).collect(); + assert_eq!(result, expected); + } + fn make_provider() -> ClaudeCodeProvider { ClaudeCodeProvider { command: PathBuf::from("claude"), - model: ModelConfig::new("sonnet").unwrap(), + model: ModelConfig::new(CLAUDE_CODE_DEFAULT_MODEL).unwrap(), name: "claude-code".to_string(), cli_process: tokio::sync::OnceCell::new(), } } + + fn make_test_process(canned_stdout: &str) -> (CliProcess, tokio::io::DuplexStream) { + let child = tokio::process::Command::new("true") + .spawn() + .expect("failed to spawn `true`"); + let (stdin_writer, stdin_reader) = tokio::io::duplex(1024); + let process = CliProcess { + child, + stdin: Box::new(stdin_writer), + reader: BufReader::new(Box::new(std::io::Cursor::new( + canned_stdout.as_bytes().to_vec(), + ))), + stderr_handle: tokio::spawn(async { String::new() }), + current_model: String::new(), + log_model_update: false, + next_request_id: 0, + }; + (process, stdin_reader) + } + + #[test_case( + &[r#"{"type":"control_response","response":{"subtype":"success","request_id":"req_0"}}"#], + Some("default"), "sonnet", + Ok(()), + "{\"type\":\"control_request\",\"request_id\":\"req_0\",\"request\":{\"subtype\":\"set_model\",\"model\":\"sonnet\"}}\n" + ; "default_to_sonnet" + )] + #[test_case( + &[r#"{"type":"control_response","response":{"subtype":"success","request_id":"req_0"}}"#], + Some("sonnet"), "default", + Ok(()), + "{\"type\":\"control_request\",\"request_id\":\"req_0\",\"request\":{\"subtype\":\"set_model\",\"model\":\"default\"}}\n" + ; "sonnet_to_default" + )] + #[test_case( + &[r#"{"type":"control_response","response":{"subtype":"error","request_id":"req_0","error":"bad model"}}"#], + None, "bad", + Err(ProviderError::RequestFailed("set_model failed: bad model".into())), + "{\"type\":\"control_request\",\"request_id\":\"req_0\",\"request\":{\"subtype\":\"set_model\",\"model\":\"bad\"}}\n" + ; "failure" + )] + #[test_case( + &[], + Some("sonnet"), "sonnet", + Ok(()), "" + ; "skip_when_same_model" + )] + #[test_case( + &[], + None, "sonnet", + Err(ProviderError::RequestFailed("CLI process terminated while waiting for set_model response".into())), + "{\"type\":\"control_request\",\"request_id\":\"req_0\",\"request\":{\"subtype\":\"set_model\",\"model\":\"sonnet\"}}\n" + ; "eof" + )] + #[tokio::test] + async fn test_send_set_model( + lines: &[&str], + initial_model: Option<&str>, + target_model: &str, + expected: Result<(), ProviderError>, + expected_stdin: &str, + ) { + use tokio::io::AsyncReadExt; + + let stdout = lines.join("\n"); + let (mut process, mut stdin_reader) = make_test_process(&stdout); + if let Some(m) = initial_model { + process.current_model = m.to_string(); + } + + let result = process.send_set_model(target_model).await; + process.stdin = Box::new(tokio::io::sink()); + let mut stdin_bytes = Vec::new(); + stdin_reader.read_to_end(&mut stdin_bytes).await.unwrap(); + + assert_eq!(result, expected); + if expected.is_ok() { + assert_eq!(process.current_model, target_model); + } + assert_eq!(String::from_utf8(stdin_bytes).unwrap(), expected_stdin); + } } diff --git a/crates/goose/tests/providers.rs b/crates/goose/tests/providers.rs index a077a71e53b3..9feea377fb34 100644 --- a/crates/goose/tests/providers.rs +++ b/crates/goose/tests/providers.rs @@ -92,6 +92,7 @@ struct ProviderTester { name: String, extension_manager: Arc, is_cli_provider: bool, + model_switch_name: Option, } impl ProviderTester { @@ -100,12 +101,14 @@ impl ProviderTester { name: String, extension_manager: Arc, is_cli_provider: bool, + model_switch_name: Option, ) -> Self { Self { provider, name, extension_manager, is_cli_provider, + model_switch_name, } } @@ -220,17 +223,7 @@ impl ProviderTester { "hello ".repeat(300_000) }; - let messages = vec![ - Message::user().with_text("hi there. what is 2 + 2?"), - Message::assistant().with_text("hey! I think it's 4."), - Message::user().with_text(&large_message_content), - Message::assistant().with_text("heyy!!"), - Message::user().with_text("what's the meaning of life?"), - Message::assistant().with_text("the meaning of life is 42"), - Message::user().with_text( - "did I ask you what's 2+2 in this message history? just respond with 'yes' or 'no'", - ), - ]; + let messages = vec![Message::user().with_text(&large_message_content)]; let result = self .provider @@ -281,6 +274,42 @@ impl ProviderTester { Ok(()) } + async fn test_model_switch(&self, session_id: &str) -> Result<()> { + // The process is already running with the default model from test_basic_response. + // Switch to model_switch_name and call complete_with_model to exercise send_set_model. + let default = &self.provider.get_model_config().model_name; + let alt = self + .model_switch_name + .as_deref() + .expect("model_switch_name required for test_model_switch"); + let alt_config = goose::model::ModelConfig::new(alt)?; + + let message = Message::user().with_text("Just say hello!"); + let (response, _) = self + .provider + .complete_with_model( + Some(session_id), + &alt_config, + "You are a helpful assistant.", + &[message], + &[], + ) + .await?; + + assert!( + matches!(response.content[0], MessageContent::Text(_)), + "Expected text response after model switch" + ); + println!( + "=== {}::model_switch ({} -> {}) === {}", + self.name, + default, + alt, + response.as_concat_text() + ); + Ok(()) + } + async fn test_model_listing(&self) -> Result<()> { let models = self.provider.fetch_supported_models().await?; @@ -300,6 +329,15 @@ impl ProviderTester { "Expected model '{}' in supported models", model_name ); + if let Some(alt) = &self.model_switch_name { + assert!( + models + .iter() + .any(|m| m == alt || m.contains(alt.as_str()) || alt.contains(m.as_str())), + "Expected model_switch_name '{}' in supported models", + alt + ); + } Ok(()) } @@ -322,8 +360,16 @@ impl ProviderTester { self.test_image_content_support(&self.session_id_for_test("image_content")) .await?; } - self.test_context_length_exceeded_error(&self.session_id_for_test("context_length")) - .await?; + if self.model_switch_name.is_some() { + self.test_model_switch(&self.session_id_for_test("model_switch")) + .await?; + } + // claude-code responds unpredictably to oversized context: + // sometimes "no", sometimes "Prompt is too long". + if self.name != "claude-code" { + self.test_context_length_exceeded_error(&self.session_id_for_test("context_length")) + .await?; + } Ok(()) } } @@ -337,6 +383,7 @@ fn load_env() { async fn test_provider( name: &str, model_name: &str, + model_switch_name: Option<&str>, required_vars: &[&str], env_modifications: Option>>, // CLI providers cannot propagate the agent-session-id header to MCP servers. @@ -435,6 +482,7 @@ async fn test_provider( name.to_string(), extension_manager, is_cli_provider, + model_switch_name.map(String::from), ); let _mcp = mcp; let result = tester.run_test_suite().await; @@ -457,6 +505,7 @@ async fn test_openai_provider() -> Result<()> { test_provider( "openai", OPEN_AI_DEFAULT_MODEL, + None, &["OPENAI_API_KEY"], None, false, @@ -469,6 +518,7 @@ async fn test_azure_provider() -> Result<()> { test_provider( "Azure", AZURE_DEFAULT_MODEL, + None, &[ "AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", @@ -485,6 +535,7 @@ async fn test_bedrock_provider_long_term_credentials() -> Result<()> { test_provider( "aws_bedrock", BEDROCK_DEFAULT_MODEL, + None, &["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"], None, false, @@ -500,6 +551,7 @@ async fn test_bedrock_provider_aws_profile_credentials() -> Result<()> { test_provider( "aws_bedrock", BEDROCK_DEFAULT_MODEL, + None, &["AWS_PROFILE"], Some(env_mods), false, @@ -519,6 +571,7 @@ async fn test_bedrock_provider_bearer_token() -> Result<()> { test_provider( "aws_bedrock", BEDROCK_DEFAULT_MODEL, + None, &["AWS_BEARER_TOKEN_BEDROCK", "AWS_REGION"], Some(env_mods), false, @@ -531,6 +584,7 @@ async fn test_databricks_provider() -> Result<()> { test_provider( "Databricks", DATABRICKS_DEFAULT_MODEL, + None, &["DATABRICKS_HOST", "DATABRICKS_TOKEN"], None, false, @@ -541,7 +595,15 @@ async fn test_databricks_provider() -> Result<()> { #[tokio::test] async fn test_ollama_provider() -> Result<()> { // qwen3-vl supports text, tools, and vision (needed for image test) - test_provider("Ollama", "qwen3-vl", &["OLLAMA_HOST"], None, false).await + test_provider( + "Ollama", + "qwen3-vl", + Some("qwen3"), + &["OLLAMA_HOST"], + None, + false, + ) + .await } #[tokio::test] @@ -549,6 +611,7 @@ async fn test_anthropic_provider() -> Result<()> { test_provider( "Anthropic", ANTHROPIC_DEFAULT_MODEL, + None, &["ANTHROPIC_API_KEY"], None, false, @@ -561,6 +624,7 @@ async fn test_openrouter_provider() -> Result<()> { test_provider( "OpenRouter", OPEN_AI_DEFAULT_MODEL, + None, &["OPENROUTER_API_KEY"], None, false, @@ -573,6 +637,7 @@ async fn test_google_provider() -> Result<()> { test_provider( "Google", GOOGLE_DEFAULT_MODEL, + None, &["GOOGLE_API_KEY"], None, false, @@ -585,6 +650,7 @@ async fn test_snowflake_provider() -> Result<()> { test_provider( "Snowflake", SNOWFLAKE_DEFAULT_MODEL, + None, &["SNOWFLAKE_HOST", "SNOWFLAKE_TOKEN"], None, false, @@ -597,6 +663,7 @@ async fn test_sagemaker_tgi_provider() -> Result<()> { test_provider( "SageMakerTgi", SAGEMAKER_TGI_DEFAULT_MODEL, + None, &["SAGEMAKER_ENDPOINT_NAME"], None, false, @@ -617,12 +684,28 @@ async fn test_litellm_provider() -> Result<()> { ("LITELLM_API_KEY", Some("".to_string())), ]); - test_provider("LiteLLM", LITELLM_DEFAULT_MODEL, &[], Some(env_mods), false).await + test_provider( + "LiteLLM", + LITELLM_DEFAULT_MODEL, + None, + &[], + Some(env_mods), + false, + ) + .await } #[tokio::test] async fn test_xai_provider() -> Result<()> { - test_provider("Xai", XAI_DEFAULT_MODEL, &["XAI_API_KEY"], None, false).await + test_provider( + "Xai", + XAI_DEFAULT_MODEL, + None, + &["XAI_API_KEY"], + None, + false, + ) + .await } #[tokio::test] @@ -632,7 +715,15 @@ async fn test_claude_code_provider() -> Result<()> { TEST_REPORT.record_skip("claude-code"); return Ok(()); } - test_provider("claude-code", CLAUDE_CODE_DEFAULT_MODEL, &[], None, true).await + test_provider( + "claude-code", + CLAUDE_CODE_DEFAULT_MODEL, + Some("sonnet"), + &[], + None, + true, + ) + .await } #[tokio::test] @@ -642,7 +733,7 @@ async fn test_codex_provider() -> Result<()> { TEST_REPORT.record_skip("codex"); return Ok(()); } - test_provider("codex", CODEX_DEFAULT_MODEL, &[], None, true).await + test_provider("codex", CODEX_DEFAULT_MODEL, None, &[], None, true).await } #[ctor::dtor]