From b8313ea8e4c622f16ff4d362d3077e71e0ec67f6 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Mon, 9 Mar 2026 17:09:04 +0000 Subject: [PATCH 01/25] rusk sdk v2 init --- .github/workflows/build-rust-steps.yml | 63 +++ .github/workflows/foundry-local-sdk-build.yml | 44 ++ samples/rust/Cargo.toml | 5 +- samples/rust/README.md | 20 +- .../audio-transcription-example/Cargo.toml | 10 + .../audio-transcription-example/README.md | 25 + .../audio-transcription-example/src/main.rs | 74 +++ .../rust/foundry-local-webserver/Cargo.toml | 11 + .../rust/foundry-local-webserver/README.md | 25 + .../rust/foundry-local-webserver/src/main.rs | 105 ++++ samples/rust/hello-foundry-local/Cargo.toml | 13 - samples/rust/hello-foundry-local/README.md | 30 -- samples/rust/hello-foundry-local/src/main.rs | 85 ---- .../rust/native-chat-completions/Cargo.toml | 10 + .../rust/native-chat-completions/README.md | 25 + .../rust/native-chat-completions/src/main.rs | 92 ++++ .../tool-calling-foundry-local/Cargo.toml | 11 + .../rust/tool-calling-foundry-local/README.md | 25 + .../tool-calling-foundry-local/src/main.rs | 226 +++++++++ sdk_v2/rust/.clippy.toml | 2 + sdk_v2/rust/.rustfmt.toml | 3 + sdk_v2/rust/Cargo.toml | 44 ++ sdk_v2/rust/GENERATE-DOCS.md | 41 ++ sdk_v2/rust/README.md | 113 +++++ sdk_v2/rust/build.rs | 295 +++++++++++ sdk_v2/rust/examples/chat_completion.rs | 99 ++++ sdk_v2/rust/examples/interactive_chat.rs | 119 +++++ sdk_v2/rust/examples/tool_calling.rs | 205 ++++++++ sdk_v2/rust/src/catalog.rs | 184 +++++++ sdk_v2/rust/src/configuration.rs | 137 ++++++ sdk_v2/rust/src/detail/core_interop.rs | 459 ++++++++++++++++++ sdk_v2/rust/src/detail/mod.rs | 4 + sdk_v2/rust/src/detail/model_load_manager.rs | 73 +++ sdk_v2/rust/src/error.rs | 33 ++ sdk_v2/rust/src/foundry_local_manager.rs | 114 +++++ sdk_v2/rust/src/lib.rs | 56 +++ sdk_v2/rust/src/model.rs | 135 ++++++ sdk_v2/rust/src/model_variant.rs | 128 +++++ sdk_v2/rust/src/openai/audio_client.rs | 201 ++++++++ sdk_v2/rust/src/openai/chat_client.rs | 330 +++++++++++++ sdk_v2/rust/src/openai/mod.rs | 5 + sdk_v2/rust/src/types.rs | 124 +++++ sdk_v2/rust/tests/audio_client_test.rs | 138 ++++++ sdk_v2/rust/tests/catalog_test.rs | 137 ++++++ sdk_v2/rust/tests/chat_client_test.rs | 350 +++++++++++++ sdk_v2/rust/tests/common/mod.rs | 127 +++++ sdk_v2/rust/tests/manager_test.rs | 33 ++ sdk_v2/rust/tests/model_load_manager_test.rs | 149 ++++++ sdk_v2/rust/tests/model_test.rs | 70 +++ 49 files changed, 4672 insertions(+), 135 deletions(-) create mode 100644 .github/workflows/build-rust-steps.yml create mode 100644 samples/rust/audio-transcription-example/Cargo.toml create mode 100644 samples/rust/audio-transcription-example/README.md create mode 100644 samples/rust/audio-transcription-example/src/main.rs create mode 100644 samples/rust/foundry-local-webserver/Cargo.toml create mode 100644 samples/rust/foundry-local-webserver/README.md create mode 100644 samples/rust/foundry-local-webserver/src/main.rs delete mode 100644 samples/rust/hello-foundry-local/Cargo.toml delete mode 100644 samples/rust/hello-foundry-local/README.md delete mode 100644 samples/rust/hello-foundry-local/src/main.rs create mode 100644 samples/rust/native-chat-completions/Cargo.toml create mode 100644 samples/rust/native-chat-completions/README.md create mode 100644 samples/rust/native-chat-completions/src/main.rs create mode 100644 samples/rust/tool-calling-foundry-local/Cargo.toml create mode 100644 samples/rust/tool-calling-foundry-local/README.md create mode 100644 samples/rust/tool-calling-foundry-local/src/main.rs create mode 100644 sdk_v2/rust/.clippy.toml create mode 100644 sdk_v2/rust/.rustfmt.toml create mode 100644 sdk_v2/rust/Cargo.toml create mode 100644 sdk_v2/rust/GENERATE-DOCS.md create mode 100644 sdk_v2/rust/README.md create mode 100644 sdk_v2/rust/build.rs create mode 100644 sdk_v2/rust/examples/chat_completion.rs create mode 100644 sdk_v2/rust/examples/interactive_chat.rs create mode 100644 sdk_v2/rust/examples/tool_calling.rs create mode 100644 sdk_v2/rust/src/catalog.rs create mode 100644 sdk_v2/rust/src/configuration.rs create mode 100644 sdk_v2/rust/src/detail/core_interop.rs create mode 100644 sdk_v2/rust/src/detail/mod.rs create mode 100644 sdk_v2/rust/src/detail/model_load_manager.rs create mode 100644 sdk_v2/rust/src/error.rs create mode 100644 sdk_v2/rust/src/foundry_local_manager.rs create mode 100644 sdk_v2/rust/src/lib.rs create mode 100644 sdk_v2/rust/src/model.rs create mode 100644 sdk_v2/rust/src/model_variant.rs create mode 100644 sdk_v2/rust/src/openai/audio_client.rs create mode 100644 sdk_v2/rust/src/openai/chat_client.rs create mode 100644 sdk_v2/rust/src/openai/mod.rs create mode 100644 sdk_v2/rust/src/types.rs create mode 100644 sdk_v2/rust/tests/audio_client_test.rs create mode 100644 sdk_v2/rust/tests/catalog_test.rs create mode 100644 sdk_v2/rust/tests/chat_client_test.rs create mode 100644 sdk_v2/rust/tests/common/mod.rs create mode 100644 sdk_v2/rust/tests/manager_test.rs create mode 100644 sdk_v2/rust/tests/model_load_manager_test.rs create mode 100644 sdk_v2/rust/tests/model_test.rs diff --git a/.github/workflows/build-rust-steps.yml b/.github/workflows/build-rust-steps.yml new file mode 100644 index 00000000..4ee92aa6 --- /dev/null +++ b/.github/workflows/build-rust-steps.yml @@ -0,0 +1,63 @@ +name: Build Rust SDK + +on: + workflow_call: + inputs: + platform: + required: false + type: string + default: 'ubuntu' # or 'windows' or 'macos' + useWinML: + required: false + type: boolean + default: false + run-integration-tests: + required: false + type: boolean + default: false + +permissions: + contents: read + +jobs: + build: + runs-on: ${{ inputs.platform }}-latest + + defaults: + run: + working-directory: sdk_v2/rust + + env: + CARGO_FEATURES: ${{ inputs.useWinML && '--features winml' || '' }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + clean: true + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: clippy, rustfmt + + - name: Cache cargo dependencies + uses: Swatinem/rust-cache@v2 + with: + workspaces: sdk_v2/rust -> target + + - name: Check formatting + run: cargo fmt --all -- --check + + - name: Run clippy + run: cargo clippy --all-targets ${{ env.CARGO_FEATURES }} -- -D warnings + + - name: Build + run: cargo build ${{ env.CARGO_FEATURES }} + + - name: Run unit tests + run: cargo test --lib ${{ env.CARGO_FEATURES }} + + - name: Run integration tests + if: ${{ inputs.run-integration-tests }} + run: cargo test --test '*' ${{ env.CARGO_FEATURES }} diff --git a/.github/workflows/foundry-local-sdk-build.yml b/.github/workflows/foundry-local-sdk-build.yml index 1190ac90..7639091b 100644 --- a/.github/workflows/foundry-local-sdk-build.yml +++ b/.github/workflows/foundry-local-sdk-build.yml @@ -56,4 +56,48 @@ jobs: with: version: '0.9.0.${{ github.run_number }}' platform: 'macos' + secrets: inherit + + build-rust-windows: + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'windows' + secrets: inherit + build-rust-windows-WinML: + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'windows' + useWinML: true + secrets: inherit + build-rust-ubuntu: + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'ubuntu' + secrets: inherit + build-rust-macos: + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'macos' + secrets: inherit + + integration-test-rust-ubuntu: + if: github.event_name == 'workflow_dispatch' || github.ref == 'refs/heads/main' + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'ubuntu' + run-integration-tests: true + secrets: inherit + integration-test-rust-windows: + if: github.event_name == 'workflow_dispatch' || github.ref == 'refs/heads/main' + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'windows' + run-integration-tests: true + secrets: inherit + integration-test-rust-macos: + if: github.event_name == 'workflow_dispatch' || github.ref == 'refs/heads/main' + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'macos' + run-integration-tests: true secrets: inherit \ No newline at end of file diff --git a/samples/rust/Cargo.toml b/samples/rust/Cargo.toml index 97a5f824..bdc9ee44 100644 --- a/samples/rust/Cargo.toml +++ b/samples/rust/Cargo.toml @@ -1,5 +1,8 @@ [workspace] members = [ - "hello-foundry-local" + "foundry-local-webserver", + "tool-calling-foundry-local", + "native-chat-completions", + "audio-transcription-example", ] resolver = "2" diff --git a/samples/rust/README.md b/samples/rust/README.md index 9abb172d..3e824369 100644 --- a/samples/rust/README.md +++ b/samples/rust/README.md @@ -9,10 +9,18 @@ This directory contains samples demonstrating how to use the Foundry Local Rust ## Samples -### [Hello Foundry Local](./hello-foundry-local) +### [Foundry Local Web Server](./foundry-local-webserver) -A simple example that demonstrates how to: -- Start the Foundry Local service -- Download and load a model -- Send a prompt to the model using the OpenAI-compatible API -- Display the response from the model \ No newline at end of file +Demonstrates how to start a local OpenAI-compatible web server using the SDK, then call it with a standard HTTP client. + +### [Native Chat Completions](./native-chat-completions) + +Shows both non-streaming and streaming chat completions using the SDK's native chat client. + +### [Tool Calling with Foundry Local](./tool-calling-foundry-local) + +Demonstrates tool calling with streaming responses, multi-turn conversation, and local tool execution. + +### [Audio Transcription](./audio-transcription-example) + +Demonstrates audio transcription (non-streaming and streaming) using the `whisper` model. \ No newline at end of file diff --git a/samples/rust/audio-transcription-example/Cargo.toml b/samples/rust/audio-transcription-example/Cargo.toml new file mode 100644 index 00000000..2fa535b3 --- /dev/null +++ b/samples/rust/audio-transcription-example/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "audio-transcription-example" +version = "0.1.0" +edition = "2021" +description = "Audio transcription example using the Foundry Local Rust SDK" + +[dependencies] +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-stream = "0.1" diff --git a/samples/rust/audio-transcription-example/README.md b/samples/rust/audio-transcription-example/README.md new file mode 100644 index 00000000..240bd7df --- /dev/null +++ b/samples/rust/audio-transcription-example/README.md @@ -0,0 +1,25 @@ +# Sample: Audio Transcription + +This example demonstrates audio transcription (non-streaming and streaming) using the Foundry Local Rust SDK. It uses the `whisper` model to transcribe a WAV audio file. + +The `foundry-local-sdk` dependency is referenced via a local path. No crates.io publish is required: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +``` + +Run the application with a path to a WAV file: + +```bash +cargo run -- path/to/audio.wav +``` + +## Using WinML (Windows only) + +To use the WinML backend, enable the `winml` feature in `Cargo.toml`: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust", features = ["winml"] } +``` + +No code changes are needed — same API, different backend. diff --git a/samples/rust/audio-transcription-example/src/main.rs b/samples/rust/audio-transcription-example/src/main.rs new file mode 100644 index 00000000..4a602196 --- /dev/null +++ b/samples/rust/audio-transcription-example/src/main.rs @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use std::env; +use std::io::{self, Write}; + +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; +use tokio_stream::StreamExt; + +const ALIAS: &str = "whisper-tiny"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("Audio Transcription Example"); + println!("===========================\n"); + + // Accept an audio file path as a CLI argument. + let audio_path = env::args().nth(1).unwrap_or_else(|| { + eprintln!("Usage: cargo run -- "); + std::process::exit(1); + }); + + // ── 1. Initialise the manager ──────────────────────────────────────── + let manager = FoundryLocalManager::create(FoundryLocalConfig { + app_name: "foundry_local_samples".into(), + ..Default::default() + })?; + + // ── 2. Pick the whisper model and ensure it is downloaded ──────────── + let model = manager.catalog().get_model(ALIAS).await?; + println!("Model: {} (id: {})", model.alias(), model.id()); + + if !model.is_cached().await? { + println!("Downloading model..."); + model + .download(Some(|progress: &str| { + print!("\r {progress:.1}%"); + io::stdout().flush().ok(); + })) + .await?; + println!(); + } + + println!("Loading model..."); + model.load().await?; + println!("✓ Model loaded\n"); + + // ── 3. Create an audio client ──────────────────────────────────────── + let audio_client = model.create_audio_client(); + + // ── 4. Non-streaming transcription ─────────────────────────────────── + println!("--- Non-streaming transcription ---"); + let result = audio_client.transcribe(&audio_path).await?; + println!("Transcription: {}", result.text); + + // ── 5. Streaming transcription ─────────────────────────────────────── + println!("--- Streaming transcription ---"); + print!("Transcription: "); + let mut stream = audio_client.transcribe_streaming(&audio_path).await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + print!("{}", chunk.text); + io::stdout().flush().ok(); + } + stream.close().await?; + println!("\n"); + + // ── 6. Unload the model ────────────────────────────────────────────── + println!("Unloading model..."); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/samples/rust/foundry-local-webserver/Cargo.toml b/samples/rust/foundry-local-webserver/Cargo.toml new file mode 100644 index 00000000..4a20c1f8 --- /dev/null +++ b/samples/rust/foundry-local-webserver/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "foundry-local-webserver" +version = "0.1.0" +edition = "2021" +description = "Example of using the Foundry Local SDK with a local OpenAI-compatible web server" + +[dependencies] +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +serde_json = "1" +reqwest = { version = "0.12", features = ["json"] } diff --git a/samples/rust/foundry-local-webserver/README.md b/samples/rust/foundry-local-webserver/README.md new file mode 100644 index 00000000..f034f2a0 --- /dev/null +++ b/samples/rust/foundry-local-webserver/README.md @@ -0,0 +1,25 @@ +# Sample: Foundry Local Web Server + +This example demonstrates how to start a local OpenAI-compatible web server using the Foundry Local SDK, then call it with a standard HTTP client. This is useful when you want to use the OpenAI REST API directly or integrate with tools that expect an OpenAI-compatible endpoint. + +The `foundry-local-sdk` dependency is referenced via a local path. No crates.io publish is required: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +``` + +Run the application: + +```bash +cargo run +``` + +## Using WinML (Windows only) + +To use the WinML backend, enable the `winml` feature in `Cargo.toml`: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust", features = ["winml"] } +``` + +No code changes are needed — same API, different backend. diff --git a/samples/rust/foundry-local-webserver/src/main.rs b/samples/rust/foundry-local-webserver/src/main.rs new file mode 100644 index 00000000..3c9dbb75 --- /dev/null +++ b/samples/rust/foundry-local-webserver/src/main.rs @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Foundry Local Web Server example. +//! +//! Demonstrates how to start a local OpenAI-compatible web server using the +//! Foundry Local SDK, then call it with a standard HTTP client. This is useful +//! when you want to use the OpenAI REST API directly or integrate with tools +//! that expect an OpenAI-compatible endpoint. + +use std::io::{self, Write}; + +use serde_json::json; + +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // ── 1. Initialise the SDK ──────────────────────────────────────────── + println!("Initializing Foundry Local SDK..."); + let manager = FoundryLocalManager::create(FoundryLocalConfig { + app_name: "foundry_local_samples".into(), + ..Default::default() + })?; + println!("✓ SDK initialized"); + + // ── 2. Download and load a model ───────────────────────────────────── + let model_alias = "qwen2.5-0.5b"; + let model = manager.catalog().get_model(model_alias).await?; + + if !model.is_cached().await? { + print!("Downloading model {model_alias}..."); + model + .download(Some(move |progress: &str| { + print!("\rDownloading model... {progress:.1}%"); + io::stdout().flush().ok(); + })) + .await?; + println!(); + } + + print!("Loading model {model_alias}..."); + model.load().await?; + println!("done."); + + // ── 3. Start the web service ───────────────────────────────────────── + print!("Starting web service..."); + let urls = manager.start_web_service().await?; + println!("done."); + + let endpoint = urls + .first() + .expect("Web service did not return an endpoint"); + println!("Web service listening on: {endpoint}"); + + // ── 4. Use the OpenAI-compatible REST API with streaming ──────────── + // Any HTTP client (or OpenAI SDK) can now talk to this endpoint. + let client = reqwest::Client::new(); + let base_url = endpoint.trim_end_matches('/'); + + let mut response = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&json!({ + "model": model.id(), + "messages": [ + { "role": "user", "content": "Why is the sky blue?" } + ], + "stream": true + })) + .send() + .await?; + + print!("[ASSISTANT]: "); + while let Some(chunk) = response.chunk().await? { + let text = String::from_utf8_lossy(&chunk); + for line in text.lines() { + let line = line.trim(); + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + break; + } + if let Ok(parsed) = serde_json::from_str::(data) { + if let Some(content) = parsed + .pointer("/choices/0/delta/content") + .and_then(|v| v.as_str()) + { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + } + } + println!(); + + // ── 5. Clean up ────────────────────────────────────────────────────── + println!("\nStopping web service..."); + manager.stop_web_service().await?; + + println!("Unloading model..."); + model.unload().await?; + + println!("✓ Done."); + Ok(()) +} diff --git a/samples/rust/hello-foundry-local/Cargo.toml b/samples/rust/hello-foundry-local/Cargo.toml deleted file mode 100644 index 2d0d58f1..00000000 --- a/samples/rust/hello-foundry-local/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "hello-foundry-local" -version = "0.1.0" -edition = "2021" -description = "A simple example of using the Foundry Local Rust SDK" - -[dependencies] -foundry-local = { path = "../../../sdk/rust" } -tokio = { version = "1", features = ["full"] } -anyhow = "1.0" -reqwest = { version = "0.11", features = ["json"] } -serde_json = "1.0" -env_logger = "0.10" \ No newline at end of file diff --git a/samples/rust/hello-foundry-local/README.md b/samples/rust/hello-foundry-local/README.md deleted file mode 100644 index 49d27d48..00000000 --- a/samples/rust/hello-foundry-local/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Hello Foundry Local (Rust) - -A simple example that demonstrates using the Foundry Local Rust SDK to interact with AI models locally. - -## Prerequisites - -- Rust 1.70.0 or later -- Foundry Local installed and available on PATH - -## Running the Sample - -1. Make sure Foundry Local is installed -2. Run the sample: - -```bash -cargo run -``` - -## What This Sample Does - -1. Creates a FoundryLocalManager instance -2. Starts the Foundry Local service if it's not already running -3. Downloads and loads the phi-3-mini-4k model -4. Sends a prompt to the model using the OpenAI-compatible API -5. Displays the response from the model - -## Code Structure - -- `src/main.rs` - The main application code -- `Cargo.toml` - Project configuration and dependencies \ No newline at end of file diff --git a/samples/rust/hello-foundry-local/src/main.rs b/samples/rust/hello-foundry-local/src/main.rs deleted file mode 100644 index 7320a2db..00000000 --- a/samples/rust/hello-foundry-local/src/main.rs +++ /dev/null @@ -1,85 +0,0 @@ -use anyhow::Result; -use foundry_local::FoundryLocalManager; - -#[tokio::main] -async fn main() -> Result<()> { - // Set up logging - env_logger::init_from_env( - env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "info"), - ); - - println!("Hello Foundry Local!"); - println!("==================="); - - // For this example, we will use the "phi-3-mini-4k" model which is 2.181 GB in size. - let model_to_use: &str = "phi-3-mini-4k"; - - // Create a FoundryLocalManager instance using the builder pattern - println!("\nInitializing Foundry Local manager..."); - let mut manager = FoundryLocalManager::builder() - // Alternatively to the checks below, you can specify the model to use directly during bootstrapping - // .alias_or_model_id(model_to_use) - .bootstrap(true) // Start the service if not running - .build() - .await?; - - // List all the models in the catalog - println!("\nAvailable models in catalog:"); - let models = manager.list_catalog_models().await?; - let model_in_catalog = models.iter().any(|m| m.alias == model_to_use); - for model in models { - println!("- {model}"); - } - // Check if the model is in the catalog - if !model_in_catalog { - println!("Model '{model_to_use}' not found in catalog. Exiting."); - return Ok(()); - } - - // List available models in the local cache - println!("\nAvailable models in local cache:"); - let models = manager.list_cached_models().await?; - let model_in_cache = models.iter().any(|m| m.alias == model_to_use); - for model in models { - println!("- {model}"); - } - - // Check if the model is already cached and download if not - if !model_in_cache { - println!("Model '{model_to_use}' not found in local cache. Downloading..."); - // Download the model if not in cache - // NOTE if you've bootstrapped with `alias_or_model_id`, you can use that directly and skip this check - manager.download_model(model_to_use, None, false).await?; - println!("Model '{model_to_use}' downloaded successfully."); - } - - // Get the model information - let model_info = manager.get_model_info(model_to_use, true).await?; - println!("\nUsing model: {model_info}"); - - // Build the prompt - let prompt = "What is the golden ratio?"; - println!("\nPrompt: {prompt}"); - - // Use the OpenAI compatible API to interact with the model - let client = reqwest::Client::new(); - let response = client - .post(format!("{}/chat/completions", manager.endpoint()?)) - .json(&serde_json::json!({ - "model": model_info.id, - "messages": [{"role": "user", "content": prompt}], - })) - .send() - .await?; - - // Parse and display the response - let result = response.json::().await?; - if let Some(content) = result["choices"][0]["message"]["content"].as_str() { - println!("\nResponse:\n{content}"); - } else { - println!("\nError: Failed to extract response content from API result"); - println!("Full API response: {result}"); - } - - Ok(()) -} diff --git a/samples/rust/native-chat-completions/Cargo.toml b/samples/rust/native-chat-completions/Cargo.toml new file mode 100644 index 00000000..183b99dd --- /dev/null +++ b/samples/rust/native-chat-completions/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "native-chat-completions" +version = "0.1.0" +edition = "2021" +description = "Native SDK chat completions (non-streaming and streaming) using the Foundry Local Rust SDK" + +[dependencies] +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-stream = "0.1" diff --git a/samples/rust/native-chat-completions/README.md b/samples/rust/native-chat-completions/README.md new file mode 100644 index 00000000..7645f642 --- /dev/null +++ b/samples/rust/native-chat-completions/README.md @@ -0,0 +1,25 @@ +# Sample: Native Chat Completions + +This example demonstrates both non-streaming and streaming chat completions using the Foundry Local Rust SDK's native chat client — no external HTTP libraries needed. + +The `foundry-local-sdk` dependency is referenced via a local path. No crates.io publish is required: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +``` + +Run the application: + +```bash +cargo run +``` + +## Using WinML (Windows only) + +To use the WinML backend, enable the `winml` feature in `Cargo.toml`: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust", features = ["winml"] } +``` + +No code changes are needed — same API, different backend. diff --git a/samples/rust/native-chat-completions/src/main.rs b/samples/rust/native-chat-completions/src/main.rs new file mode 100644 index 00000000..904c3934 --- /dev/null +++ b/samples/rust/native-chat-completions/src/main.rs @@ -0,0 +1,92 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use std::io::{self, Write}; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, FoundryLocalConfig, FoundryLocalManager, +}; +use tokio_stream::StreamExt; + +const ALIAS: &str = "qwen2.5-0.5b"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("Native Chat Completions"); + println!("=======================\n"); + + // ── 1. Initialise the manager ──────────────────────────────────────── + let manager = FoundryLocalManager::create(FoundryLocalConfig { + app_name: "foundry_local_samples".into(), + ..Default::default() + })?; + + // ── 2. Pick a model and ensure it is downloaded ────────────────────── + let model = manager.catalog().get_model(ALIAS).await?; + println!("Model: {} (id: {})", model.alias(), model.id()); + + if !model.is_cached().await? { + println!("Downloading model..."); + model + .download(Some(|progress: &str| { + print!("\r {progress:.1}%"); + io::stdout().flush().ok(); + })) + .await?; + println!(); + } + + println!("Loading model..."); + model.load().await?; + println!("✓ Model loaded\n"); + + // ── 3. Create a chat client ────────────────────────────────────────── + let mut client = model.create_chat_client(); + client.temperature(0.7).max_tokens(256); + + // ── 4. Non-streaming chat completion ───────────────────────────────── + let messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("What is Rust's ownership model?").into(), + ]; + + println!("--- Non-streaming completion ---"); + let response = client.complete_chat(&messages, None).await?; + if let Some(choice) = response.choices.first() { + if let Some(ref content) = choice.message.content { + println!("Assistant: {content}"); + } + } + + // ── 5. Streaming chat completion ───────────────────────────────────── + let stream_messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("Explain the borrow checker in two sentences.") + .into(), + ]; + + println!("\n--- Streaming completion ---"); + print!("Assistant: "); + let mut stream = client + .complete_streaming_chat(&stream_messages, None) + .await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + stream.close().await?; + println!("\n"); + + // ── 6. Unload the model ────────────────────────────────────────────── + println!("Unloading model..."); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/samples/rust/tool-calling-foundry-local/Cargo.toml b/samples/rust/tool-calling-foundry-local/Cargo.toml new file mode 100644 index 00000000..73e59316 --- /dev/null +++ b/samples/rust/tool-calling-foundry-local/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "tool-calling-foundry-local" +version = "0.1.0" +edition = "2021" +description = "Tool calling example using the Foundry Local Rust SDK" + +[dependencies] +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-stream = "0.1" +serde_json = "1" diff --git a/samples/rust/tool-calling-foundry-local/README.md b/samples/rust/tool-calling-foundry-local/README.md new file mode 100644 index 00000000..70534c82 --- /dev/null +++ b/samples/rust/tool-calling-foundry-local/README.md @@ -0,0 +1,25 @@ +# Sample: Tool Calling with Foundry Local + +This is a simple example of how to use the Foundry Local Rust SDK to run a model locally and perform tool calling with it. The example demonstrates how to set up the SDK, initialize a model, and perform a generated tool call. + +The `foundry-local-sdk` dependency is referenced via a local path. No crates.io publish is required: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +``` + +Run the application: + +```bash +cargo run +``` + +## Using WinML (Windows only) + +To use the WinML backend, enable the `winml` feature in `Cargo.toml`: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust", features = ["winml"] } +``` + +No code changes are needed — same API, different backend. diff --git a/samples/rust/tool-calling-foundry-local/src/main.rs b/samples/rust/tool-calling-foundry-local/src/main.rs new file mode 100644 index 00000000..919fe33e --- /dev/null +++ b/samples/rust/tool-calling-foundry-local/src/main.rs @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use std::io::{self, Write}; + +use serde_json::{json, Value}; +use tokio_stream::StreamExt; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestToolMessage, ChatCompletionRequestUserMessage, ChatCompletionTools, + ChatToolChoice, FinishReason, FoundryLocalConfig, FoundryLocalManager, +}; + +// By using an alias, the most suitable model variant will be downloaded +// to your end-user's device. +const ALIAS: &str = "qwen2.5-0.5b"; + +/// A simple tool that multiplies two numbers. +fn multiply_numbers(first: f64, second: f64) -> f64 { + first * second +} + +/// Dispatch a tool call by name and parsed arguments. +fn invoke_tool(name: &str, args: &Value) -> String { + match name { + "multiply_numbers" => { + let first = args.get("first").and_then(|v| v.as_f64()).unwrap_or(0.0); + let second = args.get("second").and_then(|v| v.as_f64()).unwrap_or(0.0); + let result = multiply_numbers(first, second); + result.to_string() + } + _ => format!("Unknown tool: {name}"), + } +} + +/// Accumulated state from a streaming response that contains tool calls. +#[derive(Default)] +struct ToolCallState { + tool_calls: Vec, + current_tool_id: String, + current_tool_name: String, + current_tool_args: String, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("Tool Calling with Foundry Local"); + println!("===============================\n"); + + // ── 1. Initialise the manager ──────────────────────────────────────── + let manager = FoundryLocalManager::create(FoundryLocalConfig { + app_name: "foundry_local_samples".into(), + ..Default::default() + })?; + + // ── 2. Load a model ────────────────────────────────────────────────── + let model = manager.catalog().get_model(ALIAS).await?; + println!("Model: {} (id: {})", model.alias(), model.id()); + + if !model.is_cached().await? { + println!("Downloading model..."); + model + .download(Some(|progress: &str| { + print!("\r {progress:.1}%"); + io::stdout().flush().ok(); + })) + .await?; + println!(); + } + + println!("Loading model..."); + model.load().await?; + println!("✓ Model loaded\n"); + + // ── 3. Create a chat client with tool_choice = required ────────────── + let mut client = model.create_chat_client(); + client + .max_tokens(512) + .tool_choice(ChatToolChoice::Required); + + // Define the multiply_numbers tool. + let tools: Vec = serde_json::from_value(json!([{ + "type": "function", + "function": { + "name": "multiply_numbers", + "description": "A tool for multiplying two numbers.", + "parameters": { + "type": "object", + "properties": { + "first": { + "type": "integer", + "description": "The first number in the operation" + }, + "second": { + "type": "integer", + "description": "The second number in the operation" + } + }, + "required": ["first", "second"] + } + } + }]))?; + + // Prepare the initial conversation. + let mut messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from( + "You are a helpful AI assistant. If necessary, you can use any provided tools to answer the question.", + ) + .into(), + ChatCompletionRequestUserMessage::from("What is the answer to 7 multiplied by 6?").into(), + ]; + + // ── 4. First streaming call – expect tool_calls ────────────────────── + println!("Chat completion response:"); + + let mut state = ToolCallState::default(); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await?; + + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + // Accumulate streamed content (if any). + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + + // Accumulate tool call fragments. + if let Some(ref tool_calls) = choice.delta.tool_calls { + for tc in tool_calls { + if let Some(ref id) = tc.id { + state.current_tool_id = id.clone(); + } + if let Some(ref func) = tc.function { + if let Some(ref name) = func.name { + state.current_tool_name = name.clone(); + } + if let Some(ref args) = func.arguments { + state.current_tool_args.push_str(args); + } + } + } + } + + // When the model signals finish_reason = ToolCalls, finalise. + if choice.finish_reason == Some(FinishReason::ToolCalls) { + let tc = json!({ + "id": state.current_tool_id.clone(), + "type": "function", + "function": { + "name": state.current_tool_name.clone(), + "arguments": state.current_tool_args.clone(), + } + }); + state.tool_calls.push(tc); + } + } + } + stream.close().await?; + println!(); + + // ── 5. Execute the tool(s) and append results ──────────────────────── + for tc in &state.tool_calls { + let func = &tc["function"]; + let name = func["name"].as_str().unwrap_or_default(); + let args_str = func["arguments"].as_str().unwrap_or("{}"); + let args: Value = serde_json::from_str(args_str).unwrap_or(json!({})); + + println!("\nInvoking tool: {name} with arguments {args}"); + let result = invoke_tool(name, &args); + println!("Tool response: {result}"); + + // Append the assistant's tool_calls message and the tool result. + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [tc], + }))?; + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: result.into(), + tool_call_id: tc["id"].as_str().unwrap_or_default().to_string(), + } + .into(), + ); + } + + // ── 6. Continue the conversation with auto tool_choice ─────────────── + println!("\nTool calls completed. Prompting model to continue conversation...\n"); + + messages.push( + ChatCompletionRequestSystemMessage::from( + "Respond only with the answer generated by the tool.", + ) + .into(), + ); + + client.tool_choice(ChatToolChoice::Auto); + + print!("Chat completion response: "); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + stream.close().await?; + println!("\n"); + + // ── 7. Clean up ────────────────────────────────────────────────────── + println!("Unloading model..."); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/sdk_v2/rust/.clippy.toml b/sdk_v2/rust/.clippy.toml new file mode 100644 index 00000000..1d42f2f1 --- /dev/null +++ b/sdk_v2/rust/.clippy.toml @@ -0,0 +1,2 @@ +# Clippy configuration for Foundry Local Rust SDK +msrv = "1.70" diff --git a/sdk_v2/rust/.rustfmt.toml b/sdk_v2/rust/.rustfmt.toml new file mode 100644 index 00000000..dce363ed --- /dev/null +++ b/sdk_v2/rust/.rustfmt.toml @@ -0,0 +1,3 @@ +edition = "2021" +max_width = 100 +use_field_init_shorthand = true diff --git a/sdk_v2/rust/Cargo.toml b/sdk_v2/rust/Cargo.toml new file mode 100644 index 00000000..fc035329 --- /dev/null +++ b/sdk_v2/rust/Cargo.toml @@ -0,0 +1,44 @@ +[package] +name = "foundry-local-sdk" +version = "0.1.0" +edition = "2021" +description = "Foundry Local Rust SDK - Local AI model inference" +license = "MIT" +readme = "README.md" + +[features] +default = [] +winml = [] +nightly = [] + +[dependencies] +libloading = "0.8" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "2" +tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync"] } +tokio-stream = "0.1" +futures-core = "0.3" +reqwest = { version = "0.12", features = ["json"] } +async-openai = { version = "0.33", default-features = false, features = ["chat-completion-types"] } + +[build-dependencies] +ureq = "3" +zip = "2" +serde_json = "1" +serde = { version = "1", features = ["derive"] } + +[[example]] +name = "chat_completion" +path = "examples/chat_completion.rs" + +[[example]] +name = "tool_calling" +path = "examples/tool_calling.rs" + +[[example]] +name = "interactive_chat" +path = "examples/interactive_chat.rs" + +[lints.clippy] +all = { level = "warn", priority = -1 } diff --git a/sdk_v2/rust/GENERATE-DOCS.md b/sdk_v2/rust/GENERATE-DOCS.md new file mode 100644 index 00000000..00ba2e0f --- /dev/null +++ b/sdk_v2/rust/GENERATE-DOCS.md @@ -0,0 +1,41 @@ +# Generating API Reference Docs + +The Rust SDK uses `cargo doc` to generate API documentation from `///` doc comments in the source code. + +## Viewing Docs Locally + +To generate and open the API docs in your browser: + +```bash +cd sdk_v2/rust +cargo doc --no-deps --open +``` + +This generates HTML documentation at `target/doc/foundry_local_sdk/index.html`. + +## Public API Surface + +The SDK re-exports all public types from the crate root. Key modules: + +| Module / Type | Description | +|---|---| +| `FoundryLocalManager` | Singleton entry point — SDK initialisation, web service lifecycle | +| `FoundryLocalConfig` | Configuration (app name, log level, service endpoint) | +| `Catalog` | Model discovery and lookup | +| `Model` | Grouped model (alias → best variant) | +| `ModelVariant` | Single variant — download, load, unload | +| `ChatClient` | OpenAI-compatible chat completions (sync + streaming) | +| `AudioClient` | OpenAI-compatible audio transcription (sync + streaming) | +| `CreateChatCompletionResponse` | Typed chat completion response (from `async-openai`) | +| `CreateChatCompletionStreamResponse` | Typed streaming chat chunk (from `async-openai`) | +| `AudioTranscriptionResponse` | Typed audio transcription response | +| `FoundryLocalError` | Error enum with variants for all failure modes | + +## Notes + +- Unlike the C# and JS SDKs which commit generated markdown docs, Rust's convention is to generate HTML docs on demand with `cargo doc`. +- Once the crate is published to crates.io, docs will be automatically hosted at [docs.rs](https://docs.rs). +- Use `--document-private-items` to include internal/private API in the generated docs: + ```bash + cargo doc --no-deps --document-private-items --open + ``` diff --git a/sdk_v2/rust/README.md b/sdk_v2/rust/README.md new file mode 100644 index 00000000..bf8fa508 --- /dev/null +++ b/sdk_v2/rust/README.md @@ -0,0 +1,113 @@ +# Foundry Local Rust SDK + +Rust bindings for [Foundry Local](https://github.com/microsoft/Foundry-Local) — run AI models locally with a simple API. + +The SDK dynamically loads the Foundry Local Core native engine and exposes a safe Rust interface for model management, chat completions, and audio processing. + +## Prerequisites + +- **Rust** 1.70+ (stable toolchain) +- An internet connection during first build (to download native libraries) + +## Installation + +```sh +cargo add foundry-local-sdk +``` + +Or add to your `Cargo.toml`: + +```toml +[dependencies] +foundry-local-sdk = "0.1" +``` + +## Feature Flags + +| Feature | Description | +|-----------|-------------| +| `winml` | Use the WinML backend (Windows only). Selects different ONNX Runtime and GenAI packages. | +| `nightly` | Resolve the latest nightly build of the Core package from the ORT-Nightly feed. | + +Enable features in `Cargo.toml`: + +```toml +[dependencies] +foundry-local-sdk = { version = "0.1", features = ["winml"] } +``` + +## Quick Start + +```rust +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, FoundryLocalConfig, FoundryLocalManager, +}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Initialize the manager — loads native libraries and starts the engine + let manager = FoundryLocalManager::create(FoundryLocalConfig { + app_name: "my_app".into(), + ..Default::default() + })?; + + // List available models + let models = manager.catalog().get_models().await?; + for model in &models { + println!("{} (id: {})", model.alias(), model.id()); + } + + // Pick a model and ensure it is loaded + let model = manager.catalog().get_model("phi-3.5-mini").await?; + model.load().await?; + + // Create a chat client and send a completion request + let mut client = model.create_chat_client(); + client.temperature(0.7).max_tokens(256); + + let messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("What is the capital of France?").into(), + ]; + + let response = client.complete_chat(&messages, None).await?; + if let Some(choice) = response.choices.first() { + if let Some(ref content) = choice.message.content { + println!("{content}"); + } + } + + Ok(()) +} +``` + +## How It Works + +### Native Library Download + +The `build.rs` build script automatically downloads the required native libraries at compile time: + +1. Queries NuGet/ORT-Nightly feeds for package metadata +2. Downloads `.nupkg` packages (zip archives) +3. Extracts platform-specific native libraries (`.dll`, `.so`, or `.dylib`) +4. Places them in Cargo's `OUT_DIR` for runtime discovery + +Downloaded libraries are cached — subsequent builds skip the download step. + +### Runtime Loading + +At runtime, the SDK uses `libloading` to dynamically load the Foundry Local Core library and resolve function pointers. No static linking or system-wide installation is required. + +## Platform Support + +| Platform | RID | Status | +|-----------------|------------|--------| +| Windows x64 | `win-x64` | ✅ | +| Windows ARM64 | `win-arm64`| ✅ | +| Linux x64 | `linux-x64`| ✅ | +| macOS ARM64 | `osx-arm64`| ✅ | + +## License + +MIT — see [LICENSE](../../LICENSE) for details. diff --git a/sdk_v2/rust/build.rs b/sdk_v2/rust/build.rs new file mode 100644 index 00000000..7e4db4ca --- /dev/null +++ b/sdk_v2/rust/build.rs @@ -0,0 +1,295 @@ +use std::env; +use std::fs; +use std::io::{self, Read}; +use std::path::{Path, PathBuf}; + +const NUGET_FEED: &str = "https://api.nuget.org/v3/index.json"; +const ORT_NIGHTLY_FEED: &str = "https://pkgs.dev.azure.com/aiinfra/PublicPackages/_packaging/ORT-Nightly/nuget/v3/index.json"; + +const CORE_VERSION: &str = "0.9.0.8-rc3"; +const ORT_VERSION: &str = "1.24.3"; +const GENAI_VERSION: &str = "0.12.2"; + +const WINML_ORT_VERSION: &str = "1.23.2.3"; + +struct NuGetPackage { + name: &'static str, + version: String, + feed_url: &'static str, +} + +fn get_rid() -> Option<&'static str> { + let os = env::consts::OS; + let arch = env::consts::ARCH; + match (os, arch) { + ("windows", "x86_64") => Some("win-x64"), + ("windows", "aarch64") => Some("win-arm64"), + ("linux", "x86_64") => Some("linux-x64"), + ("macos", "aarch64") => Some("osx-arm64"), + _ => None, + } +} + +fn native_lib_extension() -> &'static str { + match env::consts::OS { + "windows" => "dll", + "linux" => "so", + "macos" => "dylib", + _ => "so", + } +} + +fn get_packages(rid: &str) -> Vec { + let winml = env::var("CARGO_FEATURE_WINML").is_ok(); + let nightly = env::var("CARGO_FEATURE_NIGHTLY").is_ok(); + let is_linux = rid.starts_with("linux"); + + let core_version = if nightly { + resolve_latest_version("Microsoft.AI.Foundry.Local.Core", ORT_NIGHTLY_FEED) + .unwrap_or_else(|| CORE_VERSION.to_string()) + } else { + CORE_VERSION.to_string() + }; + + let mut packages = Vec::new(); + + if winml { + let winml_core_version = if nightly { + resolve_latest_version("Microsoft.AI.Foundry.Local.Core.WinML", ORT_NIGHTLY_FEED) + .unwrap_or_else(|| CORE_VERSION.to_string()) + } else { + CORE_VERSION.to_string() + }; + + packages.push(NuGetPackage { + name: "Microsoft.AI.Foundry.Local.Core.WinML", + version: winml_core_version, + feed_url: ORT_NIGHTLY_FEED, + }); + packages.push(NuGetPackage { + name: "Microsoft.ML.OnnxRuntime.Foundry", + version: WINML_ORT_VERSION.to_string(), + feed_url: NUGET_FEED, + }); + packages.push(NuGetPackage { + name: "Microsoft.ML.OnnxRuntimeGenAI.WinML", + version: GENAI_VERSION.to_string(), + feed_url: NUGET_FEED, + }); + } else { + packages.push(NuGetPackage { + name: "Microsoft.AI.Foundry.Local.Core", + version: core_version, + feed_url: ORT_NIGHTLY_FEED, + }); + + if is_linux { + packages.push(NuGetPackage { + name: "Microsoft.ML.OnnxRuntime.Gpu.Linux", + version: ORT_VERSION.to_string(), + feed_url: NUGET_FEED, + }); + } else { + packages.push(NuGetPackage { + name: "Microsoft.ML.OnnxRuntime.Foundry", + version: ORT_VERSION.to_string(), + feed_url: NUGET_FEED, + }); + } + + packages.push(NuGetPackage { + name: "Microsoft.ML.OnnxRuntimeGenAI.Foundry", + version: GENAI_VERSION.to_string(), + feed_url: NUGET_FEED, + }); + } + + packages +} + +/// Resolve the PackageBaseAddress from a NuGet v3 service index. +fn resolve_base_address(feed_url: &str) -> Result { + let body: String = ureq::get(feed_url) + .call() + .map_err(|e| format!("Failed to fetch NuGet feed index at {feed_url}: {e}"))? + .body_mut() + .read_to_string() + .map_err(|e| format!("Failed to read feed index response: {e}"))?; + + let index: serde_json::Value = serde_json::from_str(&body) + .map_err(|e| format!("Failed to parse feed index JSON: {e}"))?; + + let resources = index["resources"] + .as_array() + .ok_or("Feed index missing 'resources' array")?; + + for resource in resources { + let rtype = resource["@type"].as_str().unwrap_or(""); + if rtype == "PackageBaseAddress/3.0.0" { + if let Some(id) = resource["@id"].as_str() { + let base = if id.ends_with('/') { + id.to_string() + } else { + format!("{id}/") + }; + return Ok(base); + } + } + } + + Err(format!( + "Could not find PackageBaseAddress/3.0.0 in feed {feed_url}" + )) +} + +/// Resolve the latest version of a package from a NuGet feed. +fn resolve_latest_version(package_name: &str, feed_url: &str) -> Option { + let base_address = resolve_base_address(feed_url).ok()?; + let lower_name = package_name.to_lowercase(); + let index_url = format!("{base_address}{lower_name}/index.json"); + + let body: String = ureq::get(&index_url) + .call() + .ok()? + .body_mut() + .read_to_string() + .ok()?; + + let index: serde_json::Value = serde_json::from_str(&body).ok()?; + let versions = index["versions"].as_array()?; + versions.last()?.as_str().map(|s| s.to_string()) +} + +/// Download a .nupkg and extract native libraries for the given RID into `out_dir`. +fn download_and_extract( + pkg: &NuGetPackage, + rid: &str, + out_dir: &Path, +) -> Result<(), String> { + let base_address = resolve_base_address(pkg.feed_url)?; + let lower_name = pkg.name.to_lowercase(); + let lower_version = pkg.version.to_lowercase(); + let url = format!( + "{base_address}{lower_name}/{lower_version}/{lower_name}.{lower_version}.nupkg" + ); + + println!("cargo:warning=Downloading {name} {ver} from {feed}", + name = pkg.name, + ver = pkg.version, + feed = if pkg.feed_url == NUGET_FEED { "NuGet.org" } else { "ORT-Nightly" }, + ); + + let mut response = ureq::get(&url) + .call() + .map_err(|e| format!("Failed to download {}: {e}", pkg.name))?; + + let mut bytes = Vec::new(); + response + .body_mut() + .as_reader() + .read_to_end(&mut bytes) + .map_err(|e| format!("Failed to read response body for {}: {e}", pkg.name))?; + + let ext = native_lib_extension(); + let prefix = format!("runtimes/{rid}/native/"); + + let cursor = io::Cursor::new(&bytes); + let mut archive = zip::ZipArchive::new(cursor) + .map_err(|e| format!("Failed to open nupkg as zip for {}: {e}", pkg.name))?; + + let mut extracted = 0usize; + for i in 0..archive.len() { + let mut file = archive + .by_index(i) + .map_err(|e| format!("Failed to read zip entry: {e}"))?; + + let name = file.name().to_string(); + if !name.starts_with(&prefix) { + continue; + } + if !name.ends_with(&format!(".{ext}")) { + continue; + } + + let file_name = Path::new(&name) + .file_name() + .map(|n| n.to_string_lossy().to_string()) + .unwrap_or_default(); + + if file_name.is_empty() { + continue; + } + + let dest = out_dir.join(&file_name); + let mut out_file = fs::File::create(&dest) + .map_err(|e| format!("Failed to create {}: {e}", dest.display()))?; + io::copy(&mut file, &mut out_file) + .map_err(|e| format!("Failed to write {}: {e}", dest.display()))?; + + println!("cargo:warning= Extracted {file_name}"); + extracted += 1; + } + + if extracted == 0 { + println!( + "cargo:warning= No native libraries found for RID '{rid}' in {} {}", + pkg.name, pkg.version + ); + } + + Ok(()) +} + +/// Check whether we already have at least one native library in `out_dir`. +fn libs_already_present(out_dir: &Path) -> bool { + let ext = native_lib_extension(); + if let Ok(entries) = fs::read_dir(out_dir) { + for entry in entries.flatten() { + if let Some(name) = entry.file_name().to_str() { + if name.ends_with(&format!(".{ext}")) { + return true; + } + } + } + } + false +} + +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + + let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set")); + + let rid = match get_rid() { + Some(r) => r, + None => { + println!( + "cargo:warning=Unsupported platform: {} {}. Native libraries will not be downloaded.", + env::consts::OS, + env::consts::ARCH, + ); + return; + } + }; + + // Skip download if libraries already exist + if libs_already_present(&out_dir) { + println!("cargo:warning=Native libraries already present in OUT_DIR, skipping download."); + println!("cargo:rustc-link-search=native={}", out_dir.display()); + println!("cargo:rustc-env=FOUNDRY_NATIVE_DIR={}", out_dir.display()); + return; + } + + let packages = get_packages(rid); + + for pkg in &packages { + if let Err(e) = download_and_extract(pkg, rid, &out_dir) { + println!("cargo:warning=Error downloading {}: {e}", pkg.name); + println!("cargo:warning=Build will continue, but runtime loading may fail."); + println!("cargo:warning=You can manually place native libraries in the output directory."); + } + } + + println!("cargo:rustc-link-search=native={}", out_dir.display()); + println!("cargo:rustc-env=FOUNDRY_NATIVE_DIR={}", out_dir.display()); +} diff --git a/sdk_v2/rust/examples/chat_completion.rs b/sdk_v2/rust/examples/chat_completion.rs new file mode 100644 index 00000000..2d604417 --- /dev/null +++ b/sdk_v2/rust/examples/chat_completion.rs @@ -0,0 +1,99 @@ +//! Basic chat completion example demonstrating synchronous and streaming +//! usage of the Foundry Local SDK. + +use std::io::{self, Write}; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, FoundryLocalConfig, FoundryLocalError, FoundryLocalManager, +}; +use tokio_stream::StreamExt; + +/// Convenience alias matching the SDK's internal Result type. +type Result = std::result::Result; + +#[tokio::main] +async fn main() -> Result<()> { + // ── 1. Initialise the manager ──────────────────────────────────────── + let config = FoundryLocalConfig { + app_name: "foundry_local_samples".into(), + ..Default::default() + }; + + let manager = FoundryLocalManager::create(config)?; + + // ── 2. List available models ───────────────────────────────────────── + let models = manager.catalog().get_models().await?; + println!("Available models:"); + for model in &models { + println!(" • {} (id: {})", model.alias(), model.id()); + } + + // ── 3. Pick a model and ensure it is loaded ────────────────────────── + let model_alias = models + .first() + .map(|m| m.alias().to_string()) + .expect("No models available in the catalog"); + + let model = manager.catalog().get_model(&model_alias).await?; + + if !model.is_cached().await? { + println!("Downloading model '{}'…", model.alias()); + model + .download(Some(|progress: &str| { + println!(" {progress}"); + })) + .await?; + } + + println!("Loading model '{}'…", model.alias()); + model.load().await?; + + // ── 4. Synchronous chat completion ─────────────────────────────────── + let mut client = model.create_chat_client(); + client.temperature(0.7).max_tokens(256); + + let messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("What is Rust's ownership model?").into(), + ]; + + println!("\n--- Synchronous completion ---"); + let response = client.complete_chat(&messages, None).await?; + if let Some(choice) = response.choices.first() { + if let Some(ref content) = choice.message.content { + println!("Assistant: {content}"); + } + } + + // ── 5. Streaming chat completion ───────────────────────────────────── + println!("\n--- Streaming completion ---"); + let stream_messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("Explain the borrow checker in two sentences.") + .into(), + ]; + + print!("Assistant: "); + let mut stream = client + .complete_streaming_chat(&stream_messages, None) + .await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + stream.close().await?; + println!(); + + // ── 6. Unload the model ────────────────────────────────────────────── + println!("\nUnloading model…"); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/sdk_v2/rust/examples/interactive_chat.rs b/sdk_v2/rust/examples/interactive_chat.rs new file mode 100644 index 00000000..aabdae35 --- /dev/null +++ b/sdk_v2/rust/examples/interactive_chat.rs @@ -0,0 +1,119 @@ +//! Interactive chat example — a simple terminal chatbot powered by Foundry Local. +//! +//! Run with: `cargo run --example interactive_chat` + +use std::io::{self, Write}; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestAssistantMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, + FoundryLocalConfig, FoundryLocalManager, +}; +use tokio_stream::StreamExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // ── Initialise ─────────────────────────────────────────────────────── + let manager = FoundryLocalManager::create(FoundryLocalConfig { + app_name: "foundry_local_samples".into(), + ..Default::default() + })?; + + // Pick the first available model (or change this to a specific alias) + let catalog = manager.catalog(); + let models = catalog.get_models().await?; + + println!("Available models:"); + for (i, m) in models.iter().enumerate() { + println!(" [{i}] {}", m.alias()); + } + + print!("\nSelect a model number (default 0): "); + io::stdout().flush()?; + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + let idx: usize = choice.trim().parse().unwrap_or(0); + + let alias = models + .get(idx) + .map(|m| m.alias().to_string()) + .unwrap_or_else(|| models[0].alias().to_string()); + + let model = catalog.get_model(&alias).await?; + + // Download if needed + if !model.is_cached().await? { + println!("Downloading '{alias}'…"); + model + .download(Some(|p: &str| print!("\r {p}%"))) + .await?; + println!(); + } + + println!("Loading '{alias}'…"); + model.load().await?; + println!("Ready! Type your messages below. Press Ctrl-D (or type 'quit') to exit.\n"); + + // ── Chat loop ──────────────────────────────────────────────────────── + let mut client = model.create_chat_client(); + client.temperature(0.7).max_tokens(512); + + let mut messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful, concise assistant.").into(), + ]; + + loop { + print!("You: "); + io::stdout().flush()?; + + let mut input = String::new(); + if io::stdin().read_line(&mut input)? == 0 { + break; // EOF (Ctrl-D) + } + + let input = input.trim(); + if input.is_empty() { + continue; + } + if input.eq_ignore_ascii_case("quit") || input.eq_ignore_ascii_case("exit") { + break; + } + + messages.push(ChatCompletionRequestUserMessage::from(input).into()); + + // Stream the response token by token + print!("Assistant: "); + io::stdout().flush()?; + + let mut full_response = String::new(); + let mut stream = client.complete_streaming_chat(&messages, None).await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + full_response.push_str(content); + } + } + } + stream.close().await?; + println!("\n"); + + // Add assistant reply to history for multi-turn conversation + messages.push( + ChatCompletionRequestAssistantMessage { + content: Some(full_response.into()), + ..Default::default() + } + .into(), + ); + } + + // ── Cleanup ────────────────────────────────────────────────────────── + println!("\nUnloading model…"); + model.unload().await?; + println!("Goodbye!"); + + Ok(()) +} diff --git a/sdk_v2/rust/examples/tool_calling.rs b/sdk_v2/rust/examples/tool_calling.rs new file mode 100644 index 00000000..5fd8bf51 --- /dev/null +++ b/sdk_v2/rust/examples/tool_calling.rs @@ -0,0 +1,205 @@ +//! Tool-calling example demonstrating how to define tools, handle +//! `tool_calls` in streaming responses, execute the tool locally, +//! and feed results back for a multi-turn conversation. + +use std::io::{self, Write}; + +use serde_json::{json, Value}; +use tokio_stream::StreamExt; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestToolMessage, ChatCompletionRequestUserMessage, ChatCompletionTools, + ChatToolChoice, FinishReason, FoundryLocalConfig, FoundryLocalError, FoundryLocalManager, +}; + +/// Convenience alias matching the SDK's internal Result type. +type Result = std::result::Result; + +/// A trivial tool that multiplies two numbers. +fn multiply(a: f64, b: f64) -> f64 { + a * b +} + +/// Dispatch a tool call by name and arguments. +fn invoke_tool(name: &str, arguments: &Value) -> Result { + match name { + "multiply" => { + let a = arguments + .get("a") + .and_then(|v| v.as_f64()) + .unwrap_or(0.0); + let b = arguments + .get("b") + .and_then(|v| v.as_f64()) + .unwrap_or(0.0); + let result = multiply(a, b); + Ok(result.to_string()) + } + _ => Ok(format!("Unknown tool: {name}")), + } +} + +#[derive(Default)] +struct ToolCallState { + tool_calls: Vec, + tool_call_args: String, + current_tool_name: String, + current_tool_id: String, +} + +#[tokio::main] +async fn main() -> Result<()> { + // ── 1. Initialise ──────────────────────────────────────────────────── + let config = FoundryLocalConfig { + app_name: "foundry_local_samples".into(), + ..Default::default() + }; + + let manager = FoundryLocalManager::create(config)?; + + // ── 2. Load a model ────────────────────────────────────────────────── + let models = manager.catalog().get_models().await?; + let model = models + .iter() + .find(|m| m.selected_variant().info().supports_tool_calling == Some(true)) + .or_else(|| models.first()) + .expect("No models available"); + + if !model.is_cached().await? { + println!("Downloading model '{}'…", model.alias()); + model + .download(Some(|p: &str| println!(" {p}"))) + .await?; + } + println!("Loading model '{}'…", model.alias()); + model.load().await?; + + // ── 3. Create a chat client with tool_choice = required ────────────── + let mut client = model.create_chat_client(); + client + .tool_choice(ChatToolChoice::Required) + .max_tokens(512); + + let tools: Vec = serde_json::from_value(json!([{ + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two numbers together.", + "parameters": { + "type": "object", + "properties": { + "a": { "type": "number", "description": "First operand" }, + "b": { "type": "number", "description": "Second operand" } + }, + "required": ["a", "b"] + } + } + }])) + .expect("Failed to parse tool definitions"); + + let mut messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from( + "You are a helpful calculator assistant. Use the multiply tool when asked to multiply.", + ) + .into(), + ChatCompletionRequestUserMessage::from("What is 6 times 7?").into(), + ]; + + // ── 4. First streaming call – expect tool_calls ────────────────────── + println!("Sending initial request…"); + + let mut state = ToolCallState::default(); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await?; + + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref tool_calls) = choice.delta.tool_calls { + for tc in tool_calls { + if let Some(ref func) = tc.function { + if let Some(ref name) = func.name { + state.current_tool_name = name.clone(); + } + if let Some(ref args) = func.arguments { + state.tool_call_args.push_str(args); + } + } + if let Some(ref id) = tc.id { + state.current_tool_id = id.clone(); + } + } + } + + if choice.finish_reason == Some(FinishReason::ToolCalls) { + let tc = json!({ + "id": state.current_tool_id.clone(), + "type": "function", + "function": { + "name": state.current_tool_name.clone(), + "arguments": state.tool_call_args.clone(), + } + }); + state.tool_calls.push(tc); + } + } + } + stream.close().await?; + + // ── 5. Execute the tool(s) ─────────────────────────────────────────── + for tc in &state.tool_calls { + let func = &tc["function"]; + let name = func["name"].as_str().unwrap_or_default(); + let args_str = func["arguments"].as_str().unwrap_or("{}"); + let args: Value = serde_json::from_str(args_str).unwrap_or(json!({})); + + println!("Tool call: {name}({args})"); + let result = invoke_tool(name, &args)?; + println!("Tool result: {result}"); + + // Append the assistant's tool_calls message and the tool result. + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [tc], + })) + .expect("Failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: result.into(), + tool_call_id: tc["id"].as_str().unwrap_or_default().to_string(), + } + .into(), + ); + } + + // ── 6. Continue the conversation with auto tool_choice ─────────────── + client.tool_choice(ChatToolChoice::Auto); + + println!("\nContinuing conversation…"); + print!("Assistant: "); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + stream.close().await?; + println!(); + + // ── 7. Clean up ────────────────────────────────────────────────────── + println!("\nUnloading model…"); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/sdk_v2/rust/src/catalog.rs b/sdk_v2/rust/src/catalog.rs new file mode 100644 index 00000000..7d253471 --- /dev/null +++ b/sdk_v2/rust/src/catalog.rs @@ -0,0 +1,184 @@ +//! Model catalog – discovers, caches, and looks up available models. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use crate::detail::core_interop::CoreInterop; +use crate::detail::ModelLoadManager; +use crate::error::{FoundryLocalError, Result}; +use crate::model::Model; +use crate::model_variant::ModelVariant; +use crate::types::ModelInfo; + +/// How long the catalog cache remains valid before a refresh. +const CACHE_TTL: Duration = Duration::from_secs(6 * 60 * 60); // 6 hours + +/// The model catalog provides discovery and lookup for all available models. +pub struct Catalog { + core: Arc, + model_load_manager: Arc, + name: String, + models_by_alias: Mutex>, + variants_by_id: Mutex>, + last_refresh: Mutex>, +} + +impl Catalog { + pub(crate) fn new(core: Arc, model_load_manager: Arc) -> Result { + let name = core + .execute_command("get_catalog_name", None) + .unwrap_or_else(|_| "default".into()); + + let catalog = Self { + core, + model_load_manager, + name, + models_by_alias: Mutex::new(HashMap::new()), + variants_by_id: Mutex::new(HashMap::new()), + last_refresh: Mutex::new(None), + }; + + // Perform initial synchronous refresh during construction. + catalog.force_refresh_sync()?; + Ok(catalog) + } + + /// Catalog name as reported by the native core. + pub fn name(&self) -> &str { + &self.name + } + + /// Refresh the catalog from the native core if the cache has expired. + pub async fn update_models(&self) -> Result<()> { + { + let last = self.last_refresh.lock().unwrap(); + if let Some(ts) = *last { + if ts.elapsed() < CACHE_TTL { + return Ok(()); + } + } + } + + self.force_refresh().await + } + + /// Return all known models keyed by alias. + pub async fn get_models(&self) -> Result> { + self.update_models().await?; + let map = self.models_by_alias.lock().unwrap(); + Ok(map.values().cloned().collect()) + } + + /// Look up a model by its alias. + pub async fn get_model(&self, alias: &str) -> Result { + if alias.trim().is_empty() { + return Err(FoundryLocalError::Validation( + "Model alias must be a non-empty string".into(), + )); + } + self.update_models().await?; + let map = self.models_by_alias.lock().unwrap(); + map.get(alias).cloned().ok_or_else(|| { + let available: Vec<&String> = map.keys().collect(); + FoundryLocalError::ModelOperation(format!( + "Unknown model alias '{alias}'. Available: {available:?}" + )) + }) + } + + /// Look up a specific model variant by its unique id. + pub async fn get_model_variant(&self, id: &str) -> Result { + if id.trim().is_empty() { + return Err(FoundryLocalError::Validation( + "Variant id must be a non-empty string".into(), + )); + } + self.update_models().await?; + let map = self.variants_by_id.lock().unwrap(); + map.get(id).cloned().ok_or_else(|| { + let available: Vec<&String> = map.keys().collect(); + FoundryLocalError::ModelOperation(format!( + "Unknown variant id '{id}'. Available: {available:?}" + )) + }) + } + + /// Return only the model variants that are currently cached on disk. + /// + /// The native core returns a list of variant IDs. This method resolves + /// them against the internal cache, matching the JS SDK behaviour. + pub async fn get_cached_models(&self) -> Result> { + self.update_models().await?; + let raw = self + .core + .execute_command_async("get_cached_models".into(), None) + .await?; + if raw.trim().is_empty() { + return Ok(Vec::new()); + } + let cached_ids: Vec = serde_json::from_str(&raw)?; + let id_map = self.variants_by_id.lock().unwrap(); + Ok(cached_ids + .iter() + .filter_map(|id| id_map.get(id).cloned()) + .collect()) + } + + /// Return identifiers of models that are currently loaded into memory. + pub async fn get_loaded_models(&self) -> Result> { + self.model_load_manager.list_loaded().await + } + + async fn force_refresh(&self) -> Result<()> { + let raw = self + .core + .execute_command_async("get_model_list".into(), None) + .await?; + self.apply_model_list(&raw) + } + + /// Synchronous refresh used only during construction (before a tokio + /// runtime may be available). + fn force_refresh_sync(&self) -> Result<()> { + let raw = self.core.execute_command("get_model_list", None)?; + self.apply_model_list(&raw) + } + + fn apply_model_list(&self, raw: &str) -> Result<()> { + let infos: Vec = if raw.trim().is_empty() { + Vec::new() + } else { + serde_json::from_str(raw)? + }; + + let mut alias_map: HashMap = HashMap::new(); + let mut id_map: HashMap = HashMap::new(); + + for info in infos { + let variant = ModelVariant::new( + info.clone(), + Arc::clone(&self.core), + Arc::clone(&self.model_load_manager), + ); + id_map.insert(info.id.clone(), variant.clone()); + + alias_map + .entry(info.alias.clone()) + .or_insert_with(|| { + Model::new( + info.alias.clone(), + Arc::clone(&self.core), + Arc::clone(&self.model_load_manager), + ) + }) + .add_variant(variant); + } + + *self.models_by_alias.lock().unwrap() = alias_map; + *self.variants_by_id.lock().unwrap() = id_map; + *self.last_refresh.lock().unwrap() = Some(Instant::now()); + + Ok(()) + } +} diff --git a/sdk_v2/rust/src/configuration.rs b/sdk_v2/rust/src/configuration.rs new file mode 100644 index 00000000..0eb66e43 --- /dev/null +++ b/sdk_v2/rust/src/configuration.rs @@ -0,0 +1,137 @@ +use std::collections::HashMap; + +use crate::error::{FoundryLocalError, Result}; + +/// Log level for the Foundry Local service. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LogLevel { + Trace, + Debug, + Info, + Warn, + Error, + Fatal, +} + +impl LogLevel { + /// Returns the string value expected by the native core library. + fn as_core_str(&self) -> &'static str { + match self { + Self::Trace => "Verbose", + Self::Debug => "Debug", + Self::Info => "Information", + Self::Warn => "Warning", + Self::Error => "Error", + Self::Fatal => "Fatal", + } + } +} + +/// User-facing configuration for initializing the Foundry Local SDK. +#[derive(Debug, Clone, Default)] +pub struct FoundryLocalConfig { + pub app_name: String, + pub app_data_dir: Option, + pub model_cache_dir: Option, + pub logs_dir: Option, + pub log_level: Option, + pub web_service_urls: Option, + pub service_endpoint: Option, + pub library_path: Option, + pub additional_settings: Option>, +} + +/// Internal configuration object that converts [`FoundryLocalConfig`] into the +/// parameter map expected by the native core library. +#[derive(Debug, Clone)] +pub(crate) struct Configuration { + pub params: HashMap, +} + +impl Configuration { + /// Build a [`Configuration`] from the user-facing [`FoundryLocalConfig`]. + /// + /// # Errors + /// + /// Returns [`FoundryLocalError::InvalidConfiguration`] when `app_name` is + /// empty or blank. + pub fn new(config: FoundryLocalConfig) -> Result { + let app_name = config.app_name.trim().to_string(); + if app_name.is_empty() { + return Err(FoundryLocalError::InvalidConfiguration( + "app_name must be set and non-empty".into(), + )); + } + + let mut params = HashMap::new(); + params.insert("AppName".into(), app_name); + + if let Some(v) = config.app_data_dir { + params.insert("AppDataDir".into(), v); + } + if let Some(v) = config.model_cache_dir { + params.insert("ModelCacheDir".into(), v); + } + if let Some(v) = config.logs_dir { + params.insert("LogsDir".into(), v); + } + if let Some(level) = config.log_level { + params.insert("LogLevel".into(), level.as_core_str().into()); + } + if let Some(v) = config.web_service_urls { + params.insert("WebServiceUrls".into(), v); + } + if let Some(v) = config.service_endpoint { + params.insert("WebServiceExternalUrl".into(), v); + } + if let Some(v) = config.library_path { + params.insert("FoundryLocalCorePath".into(), v); + } + if let Some(extra) = config.additional_settings { + for (k, v) in extra { + params.insert(k, v); + } + } + + Ok(Self { params }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn valid_config() { + let cfg = FoundryLocalConfig { + app_name: "TestApp".into(), + app_data_dir: None, + model_cache_dir: None, + logs_dir: None, + log_level: Some(LogLevel::Debug), + web_service_urls: None, + service_endpoint: None, + library_path: None, + additional_settings: None, + }; + let c = Configuration::new(cfg).unwrap(); + assert_eq!(c.params["AppName"], "TestApp"); + assert_eq!(c.params["LogLevel"], "Debug"); + } + + #[test] + fn empty_app_name_fails() { + let cfg = FoundryLocalConfig { + app_name: " ".into(), + app_data_dir: None, + model_cache_dir: None, + logs_dir: None, + log_level: None, + web_service_urls: None, + service_endpoint: None, + library_path: None, + additional_settings: None, + }; + assert!(Configuration::new(cfg).is_err()); + } +} diff --git a/sdk_v2/rust/src/detail/core_interop.rs b/sdk_v2/rust/src/detail/core_interop.rs new file mode 100644 index 00000000..7223fc63 --- /dev/null +++ b/sdk_v2/rust/src/detail/core_interop.rs @@ -0,0 +1,459 @@ +//! FFI bridge to the `Microsoft.AI.Foundry.Local.Core` native library. +//! +//! Dynamically loads the shared library at runtime via [`libloading`] and +//! exposes two operations: +//! +//! * [`CoreInterop::execute_command`] – synchronous request/response. +//! * [`CoreInterop::execute_command_streaming`] – request with a streaming +//! callback that receives incremental chunks. + +use std::ffi::CString; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use libloading::{Library, Symbol}; +use serde_json::Value; + +use crate::configuration::Configuration; +use crate::error::{FoundryLocalError, Result}; + +// ── FFI types ──────────────────────────────────────────────────────────────── + +/// Request buffer passed to the native library. +#[repr(C)] +struct RequestBuffer { + command: *const i8, + command_length: i32, + data: *const i8, + data_length: i32, +} + +/// Response buffer filled by the native library. +#[repr(C)] +struct ResponseBuffer { + data: *mut u8, + data_length: i32, + error: *mut u8, + error_length: i32, +} + +impl ResponseBuffer { + fn new() -> Self { + Self { + data: std::ptr::null_mut(), + data_length: 0, + error: std::ptr::null_mut(), + error_length: 0, + } + } +} + +/// Signature for `execute_command`. +type ExecuteCommandFn = unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer); + +/// Signature for the streaming callback invoked by the native library. +type CallbackFn = unsafe extern "C" fn(*const u8, i32, *mut std::ffi::c_void); + +/// Signature for `execute_command_with_callback`. +type ExecuteCommandWithCallbackFn = + unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer, CallbackFn, *mut std::ffi::c_void); + +// ── Library name helpers ───────────────────────────────────────────────────── + +#[cfg(target_os = "windows")] +const CORE_LIB_NAME: &str = "Microsoft.AI.Foundry.Local.Core.dll"; +#[cfg(target_os = "macos")] +const CORE_LIB_NAME: &str = "Microsoft.AI.Foundry.Local.Core.dylib"; +#[cfg(target_os = "linux")] +const CORE_LIB_NAME: &str = "Microsoft.AI.Foundry.Local.Core.so"; + +// ── Native buffer deallocation ──────────────────────────────────────────────── + +/// Free a buffer allocated by the native core library. +/// +/// The .NET native core allocates response buffers with +/// `Marshal.AllocHGlobal` which maps to `malloc` on Unix and +/// `CoTaskMemAlloc` on Windows. +unsafe fn free_native_buffer(ptr: *mut u8) { + if ptr.is_null() { + return; + } + #[cfg(unix)] + { + extern "C" { + fn free(ptr: *mut std::ffi::c_void); + } + free(ptr as *mut std::ffi::c_void); + } + #[cfg(windows)] + { + extern "system" { + fn CoTaskMemFree(pv: *mut std::ffi::c_void); + } + CoTaskMemFree(ptr as *mut std::ffi::c_void); + } +} + +// ── Trampoline for streaming callback ──────────────────────────────────────── + +/// C-ABI trampoline that forwards chunks from the native library into a Rust +/// closure stored behind `user_data`. +unsafe extern "C" fn streaming_trampoline( + data: *const u8, + length: i32, + user_data: *mut std::ffi::c_void, +) { + if data.is_null() || length <= 0 { + return; + } + let closure = &mut *(user_data as *mut Box); + let slice = std::slice::from_raw_parts(data, length as usize); + if let Ok(chunk) = std::str::from_utf8(slice) { + closure(chunk); + } +} + +// ── CoreInterop ────────────────────────────────────────────────────────────── + +/// Handle to the loaded native core library. +/// +/// This type is `Send + Sync` because the underlying native library is +/// expected to be thread-safe for distinct request/response pairs. +pub(crate) struct CoreInterop { + _library: Library, + #[cfg(target_os = "windows")] + _dependency_libs: Vec, + execute_command: unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer), + execute_command_with_callback: + unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer, CallbackFn, *mut std::ffi::c_void), +} + +impl std::fmt::Debug for CoreInterop { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CoreInterop").finish_non_exhaustive() + } +} + + +impl CoreInterop { + /// Load the native core library using the provided configuration to locate + /// it on disk. + /// + /// Discovery order: + /// 1. `FoundryLocalCorePath` key in `config.params`. + /// 2. `FOUNDRY_NATIVE_DIR` environment variable. + /// 3. Sibling directory of the current executable. + pub fn new(config: &Configuration) -> Result { + let lib_path = Self::resolve_library_path(config)?; + + #[cfg(target_os = "windows")] + let _dependency_libs = Self::load_windows_dependencies(&lib_path)?; + + let library = unsafe { + Library::new(&lib_path).map_err(|e| { + FoundryLocalError::LibraryLoad(format!( + "Failed to load native library at {}: {e}", + lib_path.display() + )) + })? + }; + + let execute_command: ExecuteCommandFn = unsafe { + let sym: Symbol = + library.get(b"execute_command\0").map_err(|e| { + FoundryLocalError::LibraryLoad(format!( + "Symbol 'execute_command' not found: {e}" + )) + })?; + *sym + }; + + let execute_command_with_callback: ExecuteCommandWithCallbackFn = unsafe { + let sym: Symbol = + library.get(b"execute_command_with_callback\0").map_err(|e| { + FoundryLocalError::LibraryLoad(format!( + "Symbol 'execute_command_with_callback' not found: {e}" + )) + })?; + *sym + }; + + Ok(Self { + _library: library, + #[cfg(target_os = "windows")] + _dependency_libs, + execute_command, + execute_command_with_callback, + }) + } + + /// Execute a synchronous command against the native core. + /// + /// `command` is the operation name (e.g. `"initialize"`, `"load_model"`). + /// `params` is an optional JSON value that will be serialised and sent as + /// the data payload. + pub fn execute_command(&self, command: &str, params: Option<&Value>) -> Result { + let cmd = CString::new(command).map_err(|e| { + FoundryLocalError::CommandExecution(format!("Invalid command string: {e}")) + })?; + + let data_json = match params { + Some(v) => serde_json::to_string(v)?, + None => String::new(), + }; + let data_cstr = CString::new(data_json.as_str()).map_err(|e| { + FoundryLocalError::CommandExecution(format!("Invalid data string: {e}")) + })?; + + let request = RequestBuffer { + command: cmd.as_ptr(), + command_length: cmd.as_bytes().len() as i32, + data: data_cstr.as_ptr(), + data_length: data_cstr.as_bytes().len() as i32, + }; + + let mut response = ResponseBuffer::new(); + + unsafe { + (self.execute_command)(&request, &mut response); + } + + Self::process_response(&response) + } + + /// Execute a command that streams results back via `callback`. + /// + /// Each chunk delivered by the native library is decoded as UTF-8 and + /// forwarded to `callback`. After the native call returns, any error in + /// the response buffer is raised. + pub fn execute_command_streaming( + &self, + command: &str, + params: Option<&Value>, + mut callback: F, + ) -> Result + where + F: FnMut(&str), + { + let cmd = CString::new(command).map_err(|e| { + FoundryLocalError::CommandExecution(format!("Invalid command string: {e}")) + })?; + + let data_json = match params { + Some(v) => serde_json::to_string(v)?, + None => String::new(), + }; + let data_cstr = CString::new(data_json.as_str()).map_err(|e| { + FoundryLocalError::CommandExecution(format!("Invalid data string: {e}")) + })?; + + let request = RequestBuffer { + command: cmd.as_ptr(), + command_length: cmd.as_bytes().len() as i32, + data: data_cstr.as_ptr(), + data_length: data_cstr.as_bytes().len() as i32, + }; + + let mut response = ResponseBuffer::new(); + + // Box the closure so we can pass a stable pointer through FFI. + let mut boxed: Box = Box::new(|chunk: &str| callback(chunk)); + let user_data = &mut boxed as *mut Box as *mut std::ffi::c_void; + + unsafe { + (self.execute_command_with_callback)( + &request, + &mut response, + streaming_trampoline, + user_data, + ); + } + + Self::process_response(&response) + } + + /// Async version of [`Self::execute_command`]. + /// + /// Runs the blocking FFI call on a dedicated thread via + /// [`tokio::task::spawn_blocking`] so the async runtime is never blocked. + pub async fn execute_command_async( + self: &Arc, + command: String, + params: Option, + ) -> Result { + let this = Arc::clone(self); + tokio::task::spawn_blocking(move || { + this.execute_command(&command, params.as_ref()) + }) + .await + .map_err(|e| FoundryLocalError::CommandExecution(format!("task join error: {e}")))? + } + + /// Async version of [`Self::execute_command_streaming`]. + /// + /// The `callback` is invoked on the blocking thread – it must be + /// [`Send`] + `'static`. + pub async fn execute_command_streaming_async( + self: &Arc, + command: String, + params: Option, + callback: F, + ) -> Result + where + F: FnMut(&str) + Send + 'static, + { + let this = Arc::clone(self); + tokio::task::spawn_blocking(move || { + this.execute_command_streaming(&command, params.as_ref(), callback) + }) + .await + .map_err(|e| FoundryLocalError::CommandExecution(format!("task join error: {e}")))? + } + + /// Async streaming variant that bridges the FFI callback into a + /// [`tokio::sync::mpsc`] channel. + /// + /// Returns a `Receiver` that yields each chunk as it arrives. + /// The FFI call runs on a dedicated blocking thread; the receiver can + /// be wrapped with [`tokio_stream::wrappers::ReceiverStream`] to get a + /// `Stream`. + pub async fn execute_command_streaming_channel( + self: &Arc, + command: String, + params: Option, + ) -> Result<(tokio::sync::mpsc::UnboundedReceiver, tokio::task::JoinHandle>)> { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + let this = Arc::clone(self); + + let handle = tokio::task::spawn_blocking(move || { + this.execute_command_streaming(&command, params.as_ref(), move |chunk: &str| { + // Ignore send errors — the receiver was dropped. + let _ = tx.send(chunk.to_owned()); + }) + }); + + Ok((rx, handle)) + } + + /// Read the response buffer, free the native memory, and return the data + /// string or raise an error. + fn process_response(response: &ResponseBuffer) -> Result { + // Extract strings from the native pointers before freeing them. + let error_str = if !response.error.is_null() && response.error_length > 0 { + Some(unsafe { + let slice = + std::slice::from_raw_parts(response.error, response.error_length as usize); + String::from_utf8_lossy(slice).into_owned() + }) + } else { + None + }; + + let data_str = if !response.data.is_null() && response.data_length > 0 { + Some(unsafe { + let slice = + std::slice::from_raw_parts(response.data, response.data_length as usize); + String::from_utf8_lossy(slice).into_owned() + }) + } else { + None + }; + + // Free the heap-allocated response buffers (matches JS koffi.free() + // and C# Marshal.FreeHGlobal() behaviour). + unsafe { + free_native_buffer(response.data); + free_native_buffer(response.error); + } + + // Return error or data. + if let Some(err) = error_str { + Err(FoundryLocalError::CommandExecution(err)) + } else { + Ok(data_str.unwrap_or_default()) + } + } + + /// Resolve the path to the native core shared library. + fn resolve_library_path(config: &Configuration) -> Result { + // 1. Explicit path from configuration. + if let Some(dir) = config.params.get("FoundryLocalCorePath") { + let p = Path::new(dir).join(CORE_LIB_NAME); + if p.exists() { + return Ok(p); + } + // If the config value is the full path to the library itself. + let p = Path::new(dir); + if p.exists() && p.is_file() { + return Ok(p.to_path_buf()); + } + } + + // 2. Compile-time environment variable set by build.rs. + if let Some(dir) = option_env!("FOUNDRY_NATIVE_DIR") { + let p = Path::new(dir).join(CORE_LIB_NAME); + if p.exists() { + return Ok(p); + } + } + + // 3. Runtime environment variable (user override). + if let Ok(dir) = std::env::var("FOUNDRY_NATIVE_DIR") { + let p = Path::new(&dir).join(CORE_LIB_NAME); + if p.exists() { + return Ok(p); + } + } + + // 4. Next to the running executable. + if let Ok(exe) = std::env::current_exe() { + if let Some(dir) = exe.parent() { + let p = dir.join(CORE_LIB_NAME); + if p.exists() { + return Ok(p); + } + } + } + + Err(FoundryLocalError::LibraryLoad(format!( + "Could not locate native library '{CORE_LIB_NAME}'. \ + Set the FoundryLocalCorePath config option or the FOUNDRY_NATIVE_DIR \ + environment variable." + ))) + } + + /// On Windows, pre-load runtime dependencies so the core library can + /// resolve them. + #[cfg(target_os = "windows")] + fn load_windows_dependencies(core_lib_path: &Path) -> Result> { + let dir = core_lib_path + .parent() + .unwrap_or_else(|| Path::new(".")); + + let mut libs = Vec::new(); + + // Load WinML bootstrap if present. + let bootstrap = dir.join("Microsoft.WindowsAppRuntime.Bootstrap.dll"); + if bootstrap.exists() { + if let Ok(lib) = unsafe { Library::new(&bootstrap) } { + libs.push(lib); + } + } + + for dep in &["onnxruntime.dll", "onnxruntime-genai.dll"] { + let dep_path = dir.join(dep); + if dep_path.exists() { + let lib = unsafe { + Library::new(&dep_path).map_err(|e| { + FoundryLocalError::LibraryLoad(format!( + "Failed to load dependency {dep}: {e}" + )) + })? + }; + libs.push(lib); + } + } + + Ok(libs) + } +} diff --git a/sdk_v2/rust/src/detail/mod.rs b/sdk_v2/rust/src/detail/mod.rs new file mode 100644 index 00000000..3f7fd07c --- /dev/null +++ b/sdk_v2/rust/src/detail/mod.rs @@ -0,0 +1,4 @@ +pub(crate) mod core_interop; +mod model_load_manager; + +pub use model_load_manager::ModelLoadManager; diff --git a/sdk_v2/rust/src/detail/model_load_manager.rs b/sdk_v2/rust/src/detail/model_load_manager.rs new file mode 100644 index 00000000..639ec691 --- /dev/null +++ b/sdk_v2/rust/src/detail/model_load_manager.rs @@ -0,0 +1,73 @@ +//! Manages model loading and unloading. +//! +//! When an external service URL is configured the manager delegates to HTTP +//! endpoints (`models/load/{id}`, `models/unload/{id}`, `models/loaded`). +//! Otherwise it falls through to the native core library via [`CoreInterop`]. + +use std::sync::Arc; + +use serde_json::json; + +use crate::detail::core_interop::CoreInterop; +use crate::error::Result; + +/// Manages the lifecycle of loaded models. +#[derive(Debug)] +pub struct ModelLoadManager { + core: Arc, + external_service_url: Option, +} + +impl ModelLoadManager { + pub(crate) fn new(core: Arc, external_service_url: Option) -> Self { + Self { + core, + external_service_url, + } + } + + /// Load a model by its identifier. + pub async fn load(&self, model_id: &str) -> Result { + if let Some(base_url) = &self.external_service_url { + return Self::http_get(&format!("{base_url}/models/load/{model_id}")).await; + } + let params = json!({ "Params": { "Model": model_id } }); + self.core + .execute_command_async("load_model".into(), Some(params)) + .await + } + + /// Unload a previously loaded model. + pub async fn unload(&self, model_id: &str) -> Result { + if let Some(base_url) = &self.external_service_url { + return Self::http_get(&format!("{base_url}/models/unload/{model_id}")).await; + } + let params = json!({ "Params": { "Model": model_id } }); + self.core + .execute_command_async("unload_model".into(), Some(params)) + .await + } + + /// Return the list of currently loaded model identifiers. + pub async fn list_loaded(&self) -> Result> { + let raw = if let Some(base_url) = &self.external_service_url { + Self::http_get(&format!("{base_url}/models/loaded")).await? + } else { + self.core + .execute_command_async("list_loaded_models".into(), None) + .await? + }; + + if raw.trim().is_empty() { + return Ok(Vec::new()); + } + + let ids: Vec = serde_json::from_str(&raw)?; + Ok(ids) + } + + async fn http_get(url: &str) -> Result { + let body = reqwest::get(url).await?.text().await?; + Ok(body) + } +} diff --git a/sdk_v2/rust/src/error.rs b/sdk_v2/rust/src/error.rs new file mode 100644 index 00000000..a91453e5 --- /dev/null +++ b/sdk_v2/rust/src/error.rs @@ -0,0 +1,33 @@ +use thiserror::Error; + +/// Errors that can occur when using the Foundry Local SDK. +#[derive(Debug, Error)] +pub enum FoundryLocalError { + /// The native core library could not be loaded. + #[error("library load error: {0}")] + LibraryLoad(String), + /// A command executed against the native core returned an error. + #[error("command execution error: {0}")] + CommandExecution(String), + /// The provided configuration is invalid. + #[error("invalid configuration: {0}")] + InvalidConfiguration(String), + /// A model operation failed (load, unload, download, etc.). + #[error("model operation error: {0}")] + ModelOperation(String), + /// An HTTP request to the external service failed. + #[error("HTTP request error: {0}")] + HttpRequest(#[from] reqwest::Error), + /// Serialization or deserialization of JSON data failed. + #[error("serialization error: {0}")] + Serialization(#[from] serde_json::Error), + /// A validation check on user-supplied input failed. + #[error("validation error: {0}")] + Validation(String), + /// An I/O error occurred. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), +} + +/// Convenience alias used throughout the SDK. +pub type Result = std::result::Result; diff --git a/sdk_v2/rust/src/foundry_local_manager.rs b/sdk_v2/rust/src/foundry_local_manager.rs new file mode 100644 index 00000000..a3d38e23 --- /dev/null +++ b/sdk_v2/rust/src/foundry_local_manager.rs @@ -0,0 +1,114 @@ +//! Top-level entry point for the Foundry Local SDK. +//! +//! [`FoundryLocalManager`] is a singleton that initialises the native core +//! library, provides access to the model [`Catalog`], and can start / stop +//! the local web service. + +use std::sync::{Arc, OnceLock}; + +use serde_json::json; + +use crate::catalog::Catalog; +use crate::configuration::{Configuration, FoundryLocalConfig}; +use crate::detail::core_interop::CoreInterop; +use crate::detail::ModelLoadManager; +use crate::error::Result; + +/// Global singleton holder. +static INSTANCE: OnceLock = OnceLock::new(); + +/// Primary entry point for interacting with Foundry Local. +/// +/// Created once via [`FoundryLocalManager::create`]; subsequent calls return +/// the existing instance. +pub struct FoundryLocalManager { + _config: Configuration, + core: Arc, + _model_load_manager: Arc, + catalog: Catalog, + urls: std::sync::Mutex>, +} + +impl FoundryLocalManager { + /// Initialise the SDK. + /// + /// The first call creates the singleton, loads the native library, runs + /// the `initialize` command, and builds the model catalog. Subsequent + /// calls return a reference to the same instance (the provided config is + /// ignored after the first call). + pub fn create(config: FoundryLocalConfig) -> Result<&'static Self> { + // If already initialised, return the existing instance. + if let Some(mgr) = INSTANCE.get() { + return Ok(mgr); + } + + let internal_config = Configuration::new(config)?; + let core = Arc::new(CoreInterop::new(&internal_config)?); + + // Send the configuration map to the native core. + let init_params = json!({ "Params": internal_config.params }); + core.execute_command("initialize", Some(&init_params))?; + + let service_endpoint = internal_config + .params + .get("WebServiceExternalUrl") + .cloned(); + + let model_load_manager = Arc::new(ModelLoadManager::new( + Arc::clone(&core), + service_endpoint, + )); + + let catalog = Catalog::new(Arc::clone(&core), Arc::clone(&model_load_manager))?; + + let manager = Self { + _config: internal_config, + core, + _model_load_manager: model_load_manager, + catalog, + urls: std::sync::Mutex::new(Vec::new()), + }; + + // Attempt to store; if another thread raced us, return whichever won. + match INSTANCE.set(manager) { + Ok(()) => Ok(INSTANCE.get().unwrap()), + Err(_) => Ok(INSTANCE.get().unwrap()), + } + } + + /// Access the model catalog. + pub fn catalog(&self) -> &Catalog { + &self.catalog + } + + /// URLs that the local web service is listening on. + /// + /// Empty until [`Self::start_web_service`] has been called. + pub fn urls(&self) -> Vec { + self.urls.lock().unwrap().clone() + } + + /// Start the local web service and return the listening URLs. + pub async fn start_web_service(&self) -> Result> { + let raw = self + .core + .execute_command_async("start_service".into(), None) + .await?; + let parsed: Vec = if raw.trim().is_empty() { + Vec::new() + } else { + serde_json::from_str(&raw).unwrap_or_else(|_| vec![raw]) + }; + *self.urls.lock().unwrap() = parsed.clone(); + Ok(parsed) + } + + /// Stop the local web service. + pub async fn stop_web_service(&self) -> Result<()> { + self.core + .execute_command_async("stop_service".into(), None) + .await?; + self.urls.lock().unwrap().clear(); + Ok(()) + } +} diff --git a/sdk_v2/rust/src/lib.rs b/sdk_v2/rust/src/lib.rs new file mode 100644 index 00000000..a6472175 --- /dev/null +++ b/sdk_v2/rust/src/lib.rs @@ -0,0 +1,56 @@ +//! Foundry Local Rust SDK +//! +//! Local AI model inference powered by the Foundry Local Core engine. + +mod error; +mod types; +mod configuration; +mod foundry_local_manager; +mod catalog; +mod model; +mod model_variant; + +pub(crate) mod detail; +pub mod openai; + +pub use error::FoundryLocalError; +pub use types::*; +pub use configuration::{FoundryLocalConfig, LogLevel}; +pub use foundry_local_manager::FoundryLocalManager; +pub use catalog::Catalog; +pub use model::Model; +pub use model_variant::ModelVariant; +pub use detail::ModelLoadManager; + +// Re-export OpenAI request types so callers can construct typed messages. +pub use async_openai::types::chat::{ + ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionTools, + ChatCompletionToolChoiceOption, + ChatCompletionNamedToolChoice, + FunctionObject, +}; + +// Re-export OpenAI response types for convenience. +pub use async_openai::types::chat::{ + CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, + ChatChoice, + ChatChoiceStream, + ChatCompletionResponseMessage, + ChatCompletionStreamResponseDelta, + ChatCompletionMessageToolCall, + ChatCompletionMessageToolCalls, + ChatCompletionMessageToolCallChunk, + FunctionCall, + FunctionCallStream, + FinishReason, + CompletionUsage, +}; +pub use crate::openai::{ + AudioTranscriptionResponse, ChatCompletionStream, AudioTranscriptionStream, +}; diff --git a/sdk_v2/rust/src/model.rs b/sdk_v2/rust/src/model.rs new file mode 100644 index 00000000..54422c80 --- /dev/null +++ b/sdk_v2/rust/src/model.rs @@ -0,0 +1,135 @@ +//! High-level model abstraction that wraps one or more [`ModelVariant`]s +//! sharing the same alias. + +use std::sync::Arc; + +use crate::detail::core_interop::CoreInterop; +use crate::detail::ModelLoadManager; +use crate::error::{FoundryLocalError, Result}; +use crate::model_variant::ModelVariant; +use crate::openai::AudioClient; +use crate::openai::ChatClient; + +/// A model groups one or more [`ModelVariant`]s that share the same alias. +/// +/// By default the variant that is already cached locally is selected. You +/// can override the selection with [`Model::select_variant`]. +#[derive(Debug, Clone)] +pub struct Model { + alias: String, + core: Arc, + _model_load_manager: Arc, + variants: Vec, + selected_index: usize, +} + +impl Model { + pub(crate) fn new( + alias: String, + core: Arc, + model_load_manager: Arc, + ) -> Self { + Self { + alias, + core, + _model_load_manager: model_load_manager, + variants: Vec::new(), + selected_index: 0, + } + } + + /// Add a variant. If the new variant is cached and the current selection + /// is not, the new variant becomes the selected one. + pub(crate) fn add_variant(&mut self, variant: ModelVariant) { + self.variants.push(variant); + let new_idx = self.variants.len() - 1; + + // Prefer a cached variant over a non-cached one. + if self.variants[new_idx].info().cached && !self.variants[self.selected_index].info().cached + { + self.selected_index = new_idx; + } + } + + /// Select a variant by its unique id. + pub fn select_variant(&mut self, id: &str) -> Result<()> { + if let Some(pos) = self.variants.iter().position(|v| v.id() == id) { + self.selected_index = pos; + return Ok(()); + } + let available: Vec = self.variants.iter().map(|v| v.id().to_string()).collect(); + Err(FoundryLocalError::ModelOperation(format!( + "Variant '{id}' not found for model '{}'. Available: {available:?}", + self.alias + ))) + } + + /// Returns a reference to the currently selected variant. + pub fn selected_variant(&self) -> &ModelVariant { + &self.variants[self.selected_index] + } + + /// Returns all variants that belong to this model. + pub fn variants(&self) -> &[ModelVariant] { + &self.variants + } + + /// Alias shared by all variants in this model. + pub fn alias(&self) -> &str { + &self.alias + } + + /// Unique identifier of the selected variant. + pub fn id(&self) -> &str { + self.selected_variant().id() + } + + /// Whether the selected variant is cached on disk. + pub async fn is_cached(&self) -> Result { + self.selected_variant().is_cached().await + } + + /// Whether the selected variant is loaded into memory. + pub async fn is_loaded(&self) -> Result { + self.selected_variant().is_loaded().await + } + + /// Download the selected variant. If `progress` is provided, it receives + /// human-readable progress strings as they arrive from the native core. + pub async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(&str) + Send + 'static, + { + self.selected_variant().download(progress).await + } + + /// Return the local file-system path of the selected variant. + pub async fn path(&self) -> Result { + self.selected_variant().path().await + } + + /// Load the selected variant into memory. + pub async fn load(&self) -> Result { + self.selected_variant().load().await + } + + /// Unload the selected variant from memory. + pub async fn unload(&self) -> Result { + self.selected_variant().unload().await + } + + /// Remove the selected variant from the local cache. + pub async fn remove_from_cache(&self) -> Result { + self.selected_variant().remove_from_cache().await + } + + /// Create a [`ChatClient`] bound to the selected variant. + pub fn create_chat_client(&self) -> ChatClient { + ChatClient::new(self.id().to_string(), Arc::clone(&self.core)) + } + + /// Create an [`AudioClient`] bound to the selected variant. + pub fn create_audio_client(&self) -> AudioClient { + AudioClient::new(self.id().to_string(), Arc::clone(&self.core)) + } +} diff --git a/sdk_v2/rust/src/model_variant.rs b/sdk_v2/rust/src/model_variant.rs new file mode 100644 index 00000000..b3f73b28 --- /dev/null +++ b/sdk_v2/rust/src/model_variant.rs @@ -0,0 +1,128 @@ +//! A single model variant backed by [`ModelInfo`]. + +use std::sync::Arc; + +use serde_json::json; + +use crate::detail::core_interop::CoreInterop; +use crate::detail::ModelLoadManager; +use crate::error::Result; +use crate::openai::AudioClient; +use crate::openai::ChatClient; +use crate::types::ModelInfo; + +/// Represents one specific variant of a model (a particular id within an alias +/// group). +#[derive(Debug, Clone)] +pub struct ModelVariant { + info: ModelInfo, + core: Arc, + model_load_manager: Arc, +} + +impl ModelVariant { + pub(crate) fn new( + info: ModelInfo, + core: Arc, + model_load_manager: Arc, + ) -> Self { + Self { + info, + core, + model_load_manager, + } + } + + /// The full [`ModelInfo`] metadata for this variant. + pub fn info(&self) -> &ModelInfo { + &self.info + } + + /// Unique identifier. + pub fn id(&self) -> &str { + &self.info.id + } + + /// Alias shared with sibling variants. + pub fn alias(&self) -> &str { + &self.info.alias + } + + /// Check whether the variant is cached locally by querying the native + /// core. + pub async fn is_cached(&self) -> Result { + let raw = self + .core + .execute_command_async("get_cached_models".into(), None) + .await?; + if raw.trim().is_empty() { + return Ok(false); + } + let cached_ids: Vec = serde_json::from_str(&raw)?; + Ok(cached_ids.iter().any(|id| id == &self.info.id)) + } + + /// Check whether the variant is currently loaded into memory. + pub async fn is_loaded(&self) -> Result { + let loaded = self.model_load_manager.list_loaded().await?; + Ok(loaded.iter().any(|id| id == &self.info.id)) + } + + /// Download the model variant. If `progress` is provided, it receives + /// human-readable progress strings as the download proceeds. + pub async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(&str) + Send + 'static, + { + let params = json!({ "Params": { "Model": self.info.id } }); + match progress { + Some(cb) => { + self.core + .execute_command_streaming_async("download_model".into(), Some(params), cb) + .await?; + } + None => { + self.core + .execute_command_async("download_model".into(), Some(params)) + .await?; + } + } + Ok(()) + } + + /// Return the local file-system path where this variant is stored. + pub async fn path(&self) -> Result { + let params = json!({ "Params": { "Model": self.info.id } }); + self.core + .execute_command_async("get_model_path".into(), Some(params)) + .await + } + + /// Load the variant into memory. + pub async fn load(&self) -> Result { + self.model_load_manager.load(&self.info.id).await + } + + /// Unload the variant from memory. + pub async fn unload(&self) -> Result { + self.model_load_manager.unload(&self.info.id).await + } + + /// Remove the variant from the local cache. + pub async fn remove_from_cache(&self) -> Result { + let params = json!({ "Params": { "Model": self.info.id } }); + self.core + .execute_command_async("remove_cached_model".into(), Some(params)) + .await + } + + /// Create a [`ChatClient`] bound to this variant. + pub fn create_chat_client(&self) -> ChatClient { + ChatClient::new(self.info.id.clone(), Arc::clone(&self.core)) + } + + /// Create an [`AudioClient`] bound to this variant. + pub fn create_audio_client(&self) -> AudioClient { + AudioClient::new(self.info.id.clone(), Arc::clone(&self.core)) + } +} diff --git a/sdk_v2/rust/src/openai/audio_client.rs b/sdk_v2/rust/src/openai/audio_client.rs new file mode 100644 index 00000000..1a26afc7 --- /dev/null +++ b/sdk_v2/rust/src/openai/audio_client.rs @@ -0,0 +1,201 @@ +//! OpenAI-compatible audio transcription client. + +use std::path::Path; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use serde_json::{json, Value}; + +use crate::detail::core_interop::CoreInterop; +use crate::error::{FoundryLocalError, Result}; + +/// OpenAI-compatible audio transcription response. +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub struct AudioTranscriptionResponse { + /// The transcribed text. + pub text: String, + /// The language of the input audio (if detected). + #[serde(skip_serializing_if = "Option::is_none")] + pub language: Option, + /// Duration of the input audio in seconds (if available). + #[serde(skip_serializing_if = "Option::is_none")] + pub duration: Option, + /// Segments of the transcription (if available). + #[serde(skip_serializing_if = "Option::is_none")] + pub segments: Option>, + /// Words with timestamps (if available). + #[serde(skip_serializing_if = "Option::is_none")] + pub words: Option>, +} + +/// Tuning knobs for audio transcription requests. +/// +/// Use the chainable setter methods to configure, e.g.: +/// +/// ```ignore +/// let mut client = model.create_audio_client(); +/// client.language("en").temperature(0.2); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct AudioClientSettings { + language: Option, + temperature: Option, +} + +impl AudioClientSettings { + /// Serialise settings into the JSON fragment expected by the native core. + fn serialize(&self, model_id: &str, file_name: &str) -> Value { + let mut map = serde_json::Map::new(); + + map.insert("Model".into(), json!(model_id)); + map.insert("FileName".into(), json!(file_name)); + + if let Some(ref lang) = self.language { + map.insert("Language".into(), json!(lang)); + } + if let Some(temp) = self.temperature { + map.insert("Temperature".into(), json!(temp)); + } + + Value::Object(map) + } +} + +/// A stream of [`AudioTranscriptionResponse`] chunks. +/// +/// Returned by [`AudioClient::transcribe_streaming`]. +pub struct AudioTranscriptionStream { + rx: tokio::sync::mpsc::UnboundedReceiver, + handle: Option>>, +} + +impl futures_core::Stream for AudioTranscriptionStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.rx.poll_recv(cx) { + Poll::Ready(Some(chunk)) => { + if chunk.is_empty() { + cx.waker().wake_by_ref(); + Poll::Pending + } else { + let parsed = serde_json::from_str::(&chunk) + .map_err(FoundryLocalError::from); + Poll::Ready(Some(parsed)) + } + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl AudioTranscriptionStream { + /// Consume the stream and wait for the background FFI task to finish. + pub async fn close(mut self) -> Result<()> { + if let Some(handle) = self.handle.take() { + handle + .await + .map_err(|e| FoundryLocalError::CommandExecution(format!("task join error: {e}")))? + .map(|_| ()) + } else { + Ok(()) + } + } +} + +/// Client for OpenAI-compatible audio transcription backed by a local model. +pub struct AudioClient { + model_id: String, + core: Arc, + settings: AudioClientSettings, +} + +impl AudioClient { + pub(crate) fn new(model_id: String, core: Arc) -> Self { + Self { + model_id, + core, + settings: AudioClientSettings::default(), + } + } + + /// Set the language hint for transcription. + pub fn language(&mut self, lang: impl Into) -> &mut Self { + self.settings.language = Some(lang.into()); + self + } + + /// Set the sampling temperature. + pub fn temperature(&mut self, v: f64) -> &mut Self { + self.settings.temperature = Some(v); + self + } + + /// Transcribe an audio file. + pub async fn transcribe(&self, audio_file_path: impl AsRef) -> Result { + let path_str = audio_file_path + .as_ref() + .to_str() + .ok_or_else(|| FoundryLocalError::Validation("audio file path is not valid UTF-8".into()))?; + Self::validate_path(path_str)?; + + let request = self.settings.serialize(&self.model_id, path_str); + let params = json!({ + "Params": { + "OpenAICreateRequest": serde_json::to_string(&request)? + } + }); + + let raw = self + .core + .execute_command_async("audio_transcribe".into(), Some(params)) + .await?; + let parsed: AudioTranscriptionResponse = serde_json::from_str(&raw)?; + Ok(parsed) + } + + /// Transcribe an audio file with streaming results, returning an + /// [`AudioTranscriptionStream`]. + pub async fn transcribe_streaming( + &self, + audio_file_path: impl AsRef, + ) -> Result { + let path_str = audio_file_path + .as_ref() + .to_str() + .ok_or_else(|| FoundryLocalError::Validation("audio file path is not valid UTF-8".into()))?; + Self::validate_path(path_str)?; + + let mut request = self.settings.serialize(&self.model_id, path_str); + if let Some(map) = request.as_object_mut() { + map.insert("stream".into(), json!(true)); + } + + let params = json!({ + "Params": { + "OpenAICreateRequest": serde_json::to_string(&request)? + } + }); + + let (rx, handle) = self + .core + .execute_command_streaming_channel("audio_transcribe".into(), Some(params)) + .await?; + + Ok(AudioTranscriptionStream { + rx, + handle: Some(handle), + }) + } + + fn validate_path(path: &str) -> Result<()> { + if path.trim().is_empty() { + return Err(FoundryLocalError::Validation( + "audio_file_path must be a non-empty string".into(), + )); + } + Ok(()) + } +} diff --git a/sdk_v2/rust/src/openai/chat_client.rs b/sdk_v2/rust/src/openai/chat_client.rs new file mode 100644 index 00000000..98d575e7 --- /dev/null +++ b/sdk_v2/rust/src/openai/chat_client.rs @@ -0,0 +1,330 @@ +//! OpenAI-compatible chat completions client. + +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use async_openai::types::chat::{ + ChatCompletionRequestMessage, ChatCompletionTools, CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +}; +use serde_json::{json, Value}; + +use crate::detail::core_interop::CoreInterop; +use crate::error::{FoundryLocalError, Result}; +use crate::types::{ChatResponseFormat, ChatToolChoice}; + +/// Tuning knobs for chat completion requests. +/// +/// Use the chainable setter methods to configure, e.g.: +/// +/// ```ignore +/// let mut client = model.create_chat_client(); +/// client.temperature(0.7).max_tokens(256); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct ChatClientSettings { + frequency_penalty: Option, + max_tokens: Option, + n: Option, + temperature: Option, + presence_penalty: Option, + top_p: Option, + top_k: Option, + random_seed: Option, + response_format: Option, + tool_choice: Option, +} + +impl ChatClientSettings { + /// Serialise settings into the JSON fragment expected by the native core. + fn serialize(&self) -> Value { + let mut map = serde_json::Map::new(); + + if let Some(v) = self.frequency_penalty { + map.insert("frequency_penalty".into(), json!(v)); + } + if let Some(v) = self.max_tokens { + map.insert("max_tokens".into(), json!(v)); + } + if let Some(v) = self.n { + map.insert("n".into(), json!(v)); + } + if let Some(v) = self.presence_penalty { + map.insert("presence_penalty".into(), json!(v)); + } + if let Some(v) = self.temperature { + map.insert("temperature".into(), json!(v)); + } + if let Some(v) = self.top_p { + map.insert("top_p".into(), json!(v)); + } + + if let Some(ref rf) = self.response_format { + let mut rf_map = serde_json::Map::new(); + match rf { + ChatResponseFormat::Text => { + rf_map.insert("type".into(), json!("text")); + } + ChatResponseFormat::JsonObject => { + rf_map.insert("type".into(), json!("json_object")); + } + ChatResponseFormat::JsonSchema(schema) => { + rf_map.insert("type".into(), json!("json_schema")); + rf_map.insert("jsonSchema".into(), json!(schema)); + } + ChatResponseFormat::LarkGrammar(grammar) => { + rf_map.insert("type".into(), json!("lark_grammar")); + rf_map.insert("larkGrammar".into(), json!(grammar)); + } + } + map.insert("response_format".into(), Value::Object(rf_map)); + } + + if let Some(ref tc) = self.tool_choice { + let mut tc_map = serde_json::Map::new(); + match tc { + ChatToolChoice::None => { + tc_map.insert("type".into(), json!("none")); + } + ChatToolChoice::Auto => { + tc_map.insert("type".into(), json!("auto")); + } + ChatToolChoice::Required => { + tc_map.insert("type".into(), json!("required")); + } + ChatToolChoice::Function(name) => { + tc_map.insert("type".into(), json!("function")); + tc_map.insert("name".into(), json!(name)); + } + } + map.insert("tool_choice".into(), Value::Object(tc_map)); + } + + // Foundry-specific metadata for settings that don't map directly to + // the OpenAI spec. + let mut metadata: HashMap = HashMap::new(); + if let Some(k) = self.top_k { + metadata.insert("top_k".into(), k.to_string()); + } + if let Some(s) = self.random_seed { + metadata.insert("random_seed".into(), s.to_string()); + } + if !metadata.is_empty() { + map.insert("metadata".into(), json!(metadata)); + } + + Value::Object(map) + } +} + +/// A stream of [`CreateChatCompletionStreamResponse`] chunks. +/// +/// Returned by [`ChatClient::complete_streaming_chat`]. +pub struct ChatCompletionStream { + rx: tokio::sync::mpsc::UnboundedReceiver, + handle: Option>>, +} + +impl futures_core::Stream for ChatCompletionStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.rx.poll_recv(cx) { + Poll::Ready(Some(chunk)) => { + if chunk.is_empty() { + // Skip empty chunks and poll again. + cx.waker().wake_by_ref(); + Poll::Pending + } else { + let parsed = serde_json::from_str::(&chunk) + .map_err(FoundryLocalError::from); + Poll::Ready(Some(parsed)) + } + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +impl ChatCompletionStream { + /// Consume the stream and wait for the background FFI task to finish. + /// + /// Call this after the stream is exhausted to retrieve any error from + /// the native core response buffer. + pub async fn close(mut self) -> Result<()> { + if let Some(handle) = self.handle.take() { + handle + .await + .map_err(|e| FoundryLocalError::CommandExecution(format!("task join error: {e}")))? + .map(|_| ()) + } else { + Ok(()) + } + } +} + +/// Client for OpenAI-compatible chat completions backed by a local model. +pub struct ChatClient { + model_id: String, + core: Arc, + settings: ChatClientSettings, +} + +impl ChatClient { + pub(crate) fn new(model_id: String, core: Arc) -> Self { + Self { + model_id, + core, + settings: ChatClientSettings::default(), + } + } + + /// Set the frequency penalty. + pub fn frequency_penalty(&mut self, v: f64) -> &mut Self { + self.settings.frequency_penalty = Some(v); + self + } + + /// Set the maximum number of tokens to generate. + pub fn max_tokens(&mut self, v: u32) -> &mut Self { + self.settings.max_tokens = Some(v); + self + } + + /// Set the number of completions to generate. + pub fn n(&mut self, v: u32) -> &mut Self { + self.settings.n = Some(v); + self + } + + /// Set the sampling temperature. + pub fn temperature(&mut self, v: f64) -> &mut Self { + self.settings.temperature = Some(v); + self + } + + /// Set the presence penalty. + pub fn presence_penalty(&mut self, v: f64) -> &mut Self { + self.settings.presence_penalty = Some(v); + self + } + + /// Set the nucleus sampling probability. + pub fn top_p(&mut self, v: f64) -> &mut Self { + self.settings.top_p = Some(v); + self + } + + /// Set the top-k sampling parameter (Foundry extension). + pub fn top_k(&mut self, v: u32) -> &mut Self { + self.settings.top_k = Some(v); + self + } + + /// Set the random seed for reproducible results (Foundry extension). + pub fn random_seed(&mut self, v: u64) -> &mut Self { + self.settings.random_seed = Some(v); + self + } + + /// Set the desired response format. + pub fn response_format(&mut self, v: ChatResponseFormat) -> &mut Self { + self.settings.response_format = Some(v); + self + } + + /// Set the tool choice strategy. + pub fn tool_choice(&mut self, v: ChatToolChoice) -> &mut Self { + self.settings.tool_choice = Some(v); + self + } + + /// Perform a non-streaming chat completion. + pub async fn complete_chat( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + ) -> Result { + if messages.is_empty() { + return Err(FoundryLocalError::Validation( + "messages must be a non-empty array".into(), + )); + } + + let request = self.build_request(messages, tools, false)?; + let params = json!({ + "Params": { + "OpenAICreateRequest": serde_json::to_string(&request)? + } + }); + + let raw = self + .core + .execute_command_async("chat_completions".into(), Some(params)) + .await?; + let parsed: CreateChatCompletionResponse = serde_json::from_str(&raw)?; + Ok(parsed) + } + + /// Perform a streaming chat completion, returning a [`ChatCompletionStream`]. + /// + /// Use the stream with `futures_core::StreamExt::next()` or + /// `tokio_stream::StreamExt::next()`. + pub async fn complete_streaming_chat( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + ) -> Result { + if messages.is_empty() { + return Err(FoundryLocalError::Validation( + "messages must be a non-empty array".into(), + )); + } + + let request = self.build_request(messages, tools, true)?; + let params = json!({ + "Params": { + "OpenAICreateRequest": serde_json::to_string(&request)? + } + }); + + let (rx, handle) = self + .core + .execute_command_streaming_channel("chat_completions".into(), Some(params)) + .await?; + + Ok(ChatCompletionStream { + rx, + handle: Some(handle), + }) + } + + fn build_request( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + stream: bool, + ) -> Result { + let settings_value = self.settings.serialize(); + let mut map = match settings_value { + Value::Object(m) => m, + _ => serde_json::Map::new(), + }; + + map.insert("model".into(), json!(self.model_id)); + map.insert("messages".into(), serde_json::to_value(messages)?); + + if stream { + map.insert("stream".into(), json!(true)); + } + + if let Some(t) = tools { + map.insert("tools".into(), serde_json::to_value(t)?); + } + + Ok(Value::Object(map)) + } +} diff --git a/sdk_v2/rust/src/openai/mod.rs b/sdk_v2/rust/src/openai/mod.rs new file mode 100644 index 00000000..e1d52dda --- /dev/null +++ b/sdk_v2/rust/src/openai/mod.rs @@ -0,0 +1,5 @@ +mod chat_client; +mod audio_client; + +pub use chat_client::{ChatClient, ChatClientSettings, ChatCompletionStream}; +pub use audio_client::{AudioClient, AudioClientSettings, AudioTranscriptionResponse, AudioTranscriptionStream}; diff --git a/sdk_v2/rust/src/types.rs b/sdk_v2/rust/src/types.rs new file mode 100644 index 00000000..495d5aab --- /dev/null +++ b/sdk_v2/rust/src/types.rs @@ -0,0 +1,124 @@ +use serde::{Deserialize, Serialize}; + +/// Hardware device type for model execution. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DeviceType { + Invalid, + CPU, + GPU, + NPU, +} + +impl Default for DeviceType { + fn default() -> Self { + Self::CPU + } +} + +/// Prompt template describing how messages are formatted for the model. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptTemplate { + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub assistant: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, +} + +/// Runtime information for a model (device type and execution provider). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Runtime { + pub device_type: DeviceType, + pub execution_provider: String, +} + +/// A single parameter key-value pair used in model settings. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Parameter { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, +} + +/// Model-level settings containing a list of parameters. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelSettings { + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option>, +} + +/// Full metadata for a model variant as returned by the catalog. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelInfo { + pub id: String, + pub name: String, + pub version: i64, + pub alias: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub display_name: Option, + pub provider_type: String, + pub uri: String, + pub model_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_template: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub publisher: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_settings: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub license: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub license_description: Option, + pub cached: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub task: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub runtime: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_size_mb: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub supports_tool_calling: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_fl_version: Option, + #[serde(default)] + pub created_at_unix: i64, +} + +/// Desired response format for chat completions. +/// +/// Extends the standard OpenAI formats with the Foundry-specific +/// `LarkGrammar` variant. +#[derive(Debug, Clone)] +pub enum ChatResponseFormat { + /// Plain text output (default). + Text, + /// JSON output (unstructured). + JsonObject, + /// JSON output constrained by a schema string. + JsonSchema(String), + /// Output constrained by a Lark grammar (Foundry extension). + LarkGrammar(String), +} + +/// Tool choice configuration for chat completions. +#[derive(Debug, Clone)] +pub enum ChatToolChoice { + /// Model will not call any tool. + None, + /// Model decides whether to call a tool. + Auto, + /// Model must call at least one tool. + Required, + /// Model must call the named function. + Function(String), +} diff --git a/sdk_v2/rust/tests/audio_client_test.rs b/sdk_v2/rust/tests/audio_client_test.rs new file mode 100644 index 00000000..4f3664f1 --- /dev/null +++ b/sdk_v2/rust/tests/audio_client_test.rs @@ -0,0 +1,138 @@ +//! Integration tests for the [`AudioClient`] transcription API (non-streaming +//! and streaming, with optional temperature). +//! +//! Mirrors `audioClient.test.ts` from the JavaScript SDK. + +mod common; + +use foundry_local_sdk::openai::AudioClient; +use tokio_stream::StreamExt; + +mod tests { + use super::*; + + // ── Helpers ────────────────────────────────────────────────────────── + + /// Load the whisper model and return an [`AudioClient`] ready for use. + async fn setup_audio_client() -> AudioClient { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::WHISPER_MODEL_ALIAS) + .await + .expect("get_model(whisper-tiny) failed"); + model.load().await.expect("model.load() failed"); + model.create_audio_client() + } + + fn audio_file() -> String { + common::get_audio_file_path() + .to_string_lossy() + .into_owned() + } + + // ── Non-streaming transcription ────────────────────────────────────── + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_transcribe_audio_without_streaming() { + let client = setup_audio_client().await; + let response = client + .transcribe(&audio_file()) + .await + .expect("transcribe failed"); + + assert!( + response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Transcription should contain expected text, got: {}", response.text + ); + } + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_transcribe_audio_without_streaming_with_temperature() { + let mut client = setup_audio_client().await; + client.language("en").temperature(0.0); + + let response = client + .transcribe(&audio_file()) + .await + .expect("transcribe with temperature failed"); + + assert!( + response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Transcription should contain expected text, got: {}", response.text + ); + } + + // ── Streaming transcription ────────────────────────────────────────── + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_transcribe_audio_with_streaming() { + let client = setup_audio_client().await; + let mut full_text = String::new(); + + let mut stream = client + .transcribe_streaming(&audio_file()) + .await + .expect("transcribe_streaming setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + full_text.push_str(&chunk.text); + } + stream.close().await.expect("stream close failed"); + + assert!( + full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Streamed transcription should contain expected text, got: {full_text}" + ); + } + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_transcribe_audio_with_streaming_with_temperature() { + let mut client = setup_audio_client().await; + client.language("en").temperature(0.0); + + let mut full_text = String::new(); + + let mut stream = client + .transcribe_streaming(&audio_file()) + .await + .expect("transcribe_streaming with temperature setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + full_text.push_str(&chunk.text); + } + stream.close().await.expect("stream close failed"); + + assert!( + full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Streamed transcription should contain expected text, got: {full_text}" + ); + } + + // ── Validation: empty file path ────────────────────────────────────── + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_throw_when_transcribing_with_empty_audio_file_path() { + let client = setup_audio_client().await; + let result = client.transcribe("").await; + assert!(result.is_err(), "Expected error for empty audio file path"); + } + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_throw_when_transcribing_streaming_with_empty_audio_file_path() { + let client = setup_audio_client().await; + let result = client.transcribe_streaming("").await; + assert!( + result.is_err(), + "Expected error for empty audio file path in streaming" + ); + } +} diff --git a/sdk_v2/rust/tests/catalog_test.rs b/sdk_v2/rust/tests/catalog_test.rs new file mode 100644 index 00000000..30fa10e5 --- /dev/null +++ b/sdk_v2/rust/tests/catalog_test.rs @@ -0,0 +1,137 @@ +//! Integration tests for the [`Catalog`] API. +//! +//! Mirrors `catalog.test.ts` from the JavaScript SDK. + +mod common; + +use foundry_local_sdk::Catalog; + +mod tests { + use super::*; + + // ── Helpers ────────────────────────────────────────────────────────── + + fn catalog() -> &'static Catalog { + common::get_test_manager().catalog() + } + + // ── Basic catalogue access ─────────────────────────────────────────── + + /// The catalog should expose a non-empty name after initialisation. + #[test] + #[ignore = "requires native Foundry Local library"] + fn should_initialize_with_catalog_name() { + let cat = catalog(); + let name = cat.name(); + assert!(!name.is_empty(), "Catalog name must not be empty"); + } + + /// `list_models()` should return at least one model and the test model + /// should be present among them. + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_list_models() { + let cat = catalog(); + let models = cat.get_models().await.expect("get_models failed"); + + assert!(!models.is_empty(), "Expected at least one model in the catalog"); + + let found = models.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + found, + "Test model '{}' not found in catalog", + common::TEST_MODEL_ALIAS + ); + } + + /// `get_model()` with a valid alias should return the corresponding model. + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_get_model_by_alias() { + let cat = catalog(); + let model = cat + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); + } + + // ── Validation: empty / invalid alias ──────────────────────────────── + + /// `get_model("")` should return an error containing + /// "Model alias must be a non-empty string". + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_throw_when_getting_model_with_empty_alias() { + let cat = catalog(); + let result = cat.get_model("").await; + assert!(result.is_err(), "Expected error for empty alias"); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Model alias must be a non-empty string"), + "Unexpected error message: {err_msg}" + ); + } + + /// An unknown alias should produce an error mentioning "not found" and + /// listing available models. + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_throw_when_getting_model_with_unknown_alias() { + let cat = catalog(); + let result = cat.get_model("unknown-nonexistent-model-alias").await; + assert!(result.is_err(), "Expected error for unknown alias"); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("not found"), + "Error should mention 'not found': {err_msg}" + ); + assert!( + err_msg.contains("Available models"), + "Error should list available models: {err_msg}" + ); + } + + // ── Cached models ──────────────────────────────────────────────────── + + /// `get_cached_models()` should return at least one model and the test + /// model should be cached. + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_get_cached_models() { + let cat = catalog(); + let cached = cat.get_cached_models().await.expect("get_cached_models failed"); + + assert!(!cached.is_empty(), "Expected at least one cached model"); + + let found = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + found, + "Test model '{}' should be in the cached models list", + common::TEST_MODEL_ALIAS + ); + } + + // ── Model variant validation ───────────────────────────────────────── + + /// `get_model_variant("")` should return a validation error. + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_throw_when_getting_model_variant_with_empty_id() { + let cat = catalog(); + let result = cat.get_model_variant("").await; + assert!(result.is_err(), "Expected error for empty variant ID"); + } + + /// `get_model_variant()` with an unknown ID should return an error. + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_throw_when_getting_model_variant_with_unknown_id() { + let cat = catalog(); + let result = cat.get_model_variant("unknown-nonexistent-variant-id").await; + assert!(result.is_err(), "Expected error for unknown variant ID"); + } +} diff --git a/sdk_v2/rust/tests/chat_client_test.rs b/sdk_v2/rust/tests/chat_client_test.rs new file mode 100644 index 00000000..c4350ca8 --- /dev/null +++ b/sdk_v2/rust/tests/chat_client_test.rs @@ -0,0 +1,350 @@ +//! Integration tests for the [`ChatClient`] (non-streaming, streaming, and +//! tool-calling variants). +//! +//! Mirrors `chatClient.test.ts` from the JavaScript SDK. + +mod common; + +use foundry_local_sdk::{ + ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestToolMessage, + ChatCompletionRequestUserMessage, ChatToolChoice, +}; +use foundry_local_sdk::openai::ChatClient; +use serde_json::json; +use tokio_stream::StreamExt; + +mod tests { + use super::*; + + // ── Helpers ────────────────────────────────────────────────────────── + + /// Load the test model and return a [`ChatClient`] ready for use. + async fn setup_chat_client() -> ChatClient { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let mut client = model.create_chat_client(); + client.max_tokens(500).temperature(0.0); + client + } + + fn user_message(content: &str) -> ChatCompletionRequestMessage { + ChatCompletionRequestUserMessage::from(content).into() + } + + fn system_message(content: &str) -> ChatCompletionRequestMessage { + ChatCompletionRequestSystemMessage::from(content).into() + } + + fn assistant_message(content: &str) -> ChatCompletionRequestMessage { + serde_json::from_value(json!({ "role": "assistant", "content": content })) + .expect("failed to construct assistant message") + } + + // ── Non-streaming completion ───────────────────────────────────────── + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_perform_chat_completion() { + let client = setup_chat_client().await; + let messages = vec![ + system_message("You are a helpful math assistant. Respond with just the answer."), + user_message("What is 7*6?"), + ]; + + let response = client.complete_chat(&messages, None).await.expect("complete_chat failed"); + let content = response.choices.first() + .and_then(|c| c.message.content.as_deref()) + .unwrap_or(""); + + assert!( + content.contains("42"), + "Expected response to contain '42', got: {content}" + ); + } + + // ── Streaming completion ───────────────────────────────────────────── + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_perform_streaming_chat_completion() { + let client = setup_chat_client().await; + let mut messages = vec![ + system_message("You are a helpful math assistant. Respond with just the answer."), + user_message("What is 7*6?"), + ]; + + // First turn — expect "42" + let mut first_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, None) + .await + .expect("streaming chat (first turn) setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + first_result.push_str(content); + } + } + } + stream.close().await.expect("stream close failed"); + + assert!( + first_result.contains("42"), + "First turn should contain '42', got: {first_result}" + ); + + // Follow-up turn — expect "67" + messages.push(assistant_message(&first_result)); + messages.push(user_message("Now add 25 to that result.")); + + let mut second_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, None) + .await + .expect("streaming chat (follow-up) setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + second_result.push_str(content); + } + } + } + stream.close().await.expect("stream close failed"); + + assert!( + second_result.contains("67"), + "Follow-up should contain '67', got: {second_result}" + ); + } + + // ── Validation: empty / invalid messages ───────────────────────────── + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_throw_when_completing_chat_with_empty_messages() { + let client = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_chat(&messages, None).await; + assert!(result.is_err(), "Expected error for empty messages"); + } + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_throw_when_completing_streaming_chat_with_empty_messages() { + let client = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_streaming_chat(&messages, None).await; + assert!(result.is_err(), "Expected error for empty messages in streaming"); + } + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_throw_when_completing_streaming_chat_with_invalid_callback() { + let client = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_streaming_chat(&messages, None).await; + assert!( + result.is_err(), + "Expected error even with empty messages" + ); + } + + // ── Tool calling (non-streaming) ───────────────────────────────────── + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_perform_tool_calling_chat_completion_non_streaming() { + let mut client = setup_chat_client().await; + client.tool_choice(ChatToolChoice::Required); + + let tools = vec![common::get_multiply_tool()]; + let mut messages = vec![ + system_message("You are a math assistant. Use the multiply tool to answer."), + user_message("What is 6 times 7?"), + ]; + + // Step 1 — the model should request the multiply tool. + let response = client + .complete_chat(&messages, Some(&tools)) + .await + .expect("complete_chat with tools failed"); + + let choice = response.choices.first().expect("Expected at least one choice"); + let tool_calls = choice.message.tool_calls.as_ref().expect("Expected tool_calls"); + assert!( + !tool_calls.is_empty(), + "Expected at least one tool call in the response" + ); + + let tool_call = match &tool_calls[0] { + ChatCompletionMessageToolCalls::Function(tc) => tc, + _ => panic!("Expected a function tool call"), + }; + assert_eq!( + tool_call.function.name, + "multiply", + "Expected tool call to 'multiply'" + ); + + // Parse arguments and compute the result. + let args: serde_json::Value = + serde_json::from_str(&tool_call.function.arguments) + .expect("Failed to parse tool call arguments"); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let product = (a * b) as i64; + + // Step 2 — feed the tool result back and get the final answer. + let tool_call_id = &tool_call.id; + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + }] + })).expect("failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: product.to_string().into(), + tool_call_id: tool_call_id.clone(), + } + .into(), + ); + + // Switch to auto so the model can answer freely. + client.tool_choice(ChatToolChoice::Auto); + + let final_response = client + .complete_chat(&messages, Some(&tools)) + .await + .expect("follow-up complete_chat with tools failed"); + let content = final_response.choices.first() + .and_then(|c| c.message.content.as_deref()) + .unwrap_or(""); + + assert!( + content.contains("42"), + "Final answer should contain '42', got: {content}" + ); + } + + // ── Tool calling (streaming) ───────────────────────────────────────── + + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_perform_tool_calling_chat_completion_streaming() { + let mut client = setup_chat_client().await; + client.tool_choice(ChatToolChoice::Required); + + let tools = vec![common::get_multiply_tool()]; + let mut messages = vec![ + system_message("You are a math assistant. Use the multiply tool to answer."), + user_message("What is 6 times 7?"), + ]; + + // Step 1 — collect streaming tool call chunks. + let mut tool_call_name = String::new(); + let mut tool_call_args = String::new(); + let mut tool_call_id = String::new(); + + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await + .expect("streaming tool call setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref tool_calls) = choice.delta.tool_calls { + for call in tool_calls { + if let Some(ref func) = call.function { + if let Some(ref name) = func.name { + tool_call_name.push_str(name); + } + if let Some(ref args) = func.arguments { + tool_call_args.push_str(args); + } + } + if let Some(ref id) = call.id { + tool_call_id = id.clone(); + } + } + } + } + } + stream.close().await.expect("stream close failed"); + + assert_eq!( + tool_call_name, "multiply", + "Expected streamed tool call to 'multiply'" + ); + + // Parse arguments and compute. + let args: serde_json::Value = + serde_json::from_str(&tool_call_args).unwrap_or_else(|_| json!({})); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let product = (a * b) as i64; + + // Step 2 — feed the result back and stream the final answer. + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call_name, + "arguments": tool_call_args + } + }] + })).expect("failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: product.to_string().into(), + tool_call_id: tool_call_id.clone(), + } + .into(), + ); + + client.tool_choice(ChatToolChoice::Auto); + + let mut final_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await + .expect("streaming follow-up setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + final_result.push_str(content); + } + } + } + stream.close().await.expect("stream close failed"); + + assert!( + final_result.contains("42"), + "Streamed final answer should contain '42', got: {final_result}" + ); + } +} diff --git a/sdk_v2/rust/tests/common/mod.rs b/sdk_v2/rust/tests/common/mod.rs new file mode 100644 index 00000000..9a42e53d --- /dev/null +++ b/sdk_v2/rust/tests/common/mod.rs @@ -0,0 +1,127 @@ +//! Shared test utilities and configuration for Foundry Local SDK integration tests. +//! +//! Mirrors `testUtils.ts` from the JavaScript SDK test suite. + +use std::collections::HashMap; +use std::path::PathBuf; + +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager, LogLevel}; + +/// Default model alias used for chat-completion integration tests. +pub const TEST_MODEL_ALIAS: &str = "qwen2.5-0.5b"; + +/// Default model alias used for audio-transcription integration tests. +pub const WHISPER_MODEL_ALIAS: &str = "whisper-tiny"; + +/// Expected transcription text fragment for the shared audio test file. +pub const EXPECTED_TRANSCRIPTION_TEXT: &str = + " And lots of times you need to give people more than one link at a time"; + +// ── Environment helpers ────────────────────────────────────────────────────── + +/// Returns `true` when the tests are running inside a CI environment +/// (Azure DevOps or GitHub Actions). +pub fn is_running_in_ci() -> bool { + let azure_devops = std::env::var("TF_BUILD").unwrap_or_else(|_| "false".into()); + let github_actions = std::env::var("GITHUB_ACTIONS").unwrap_or_else(|_| "false".into()); + azure_devops.eq_ignore_ascii_case("true") || github_actions.eq_ignore_ascii_case("true") +} + +/// Walk upward from `CARGO_MANIFEST_DIR` until a `.git` directory is found. +pub fn get_git_repo_root() -> PathBuf { + let mut current = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + loop { + if current.join(".git").exists() { + return current; + } + if !current.pop() { + panic!( + "Could not locate git repo root starting from {}", + env!("CARGO_MANIFEST_DIR") + ); + } + } +} + +/// Path to the shared test-data directory that lives alongside the repo root. +pub fn get_test_data_shared_path() -> PathBuf { + let repo_root = get_git_repo_root(); + repo_root + .parent() + .expect("repo root has no parent") + .join("test-data-shared") +} + +/// Path to the shared audio test file used by audio-client tests. +pub fn get_audio_file_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("testdata") + .join("Recording.mp3") +} + +// ── Test configuration ─────────────────────────────────────────────────────── + +/// Build a [`FoundryLocalConfig`] suitable for integration tests. +/// +/// * `modelCacheDir` → `/../test-data-shared` +/// * `logsDir` → `/sdk_v2/rust/logs` +/// * `logLevel` → `Warn` +/// * `Bootstrap` → `false` (via additional settings) +pub fn test_config() -> FoundryLocalConfig { + let repo_root = get_git_repo_root(); + let logs_dir = repo_root.join("sdk_v2").join("rust").join("logs"); + + let mut additional = HashMap::new(); + additional.insert("Bootstrap".into(), "false".into()); + + FoundryLocalConfig { + app_name: "FoundryLocalTest".into(), + app_data_dir: None, + model_cache_dir: Some(get_test_data_shared_path().to_string_lossy().into_owned()), + logs_dir: Some(logs_dir.to_string_lossy().into_owned()), + log_level: Some(LogLevel::Warn), + web_service_urls: None, + service_endpoint: None, + library_path: None, + additional_settings: Some(additional), + } +} + +/// Create (or return the cached) [`FoundryLocalManager`] for tests. +/// +/// Panics if creation fails so that test set-up failures are immediately +/// visible. +pub fn get_test_manager() -> &'static FoundryLocalManager { + FoundryLocalManager::create(test_config()).expect("Failed to create FoundryLocalManager") +} + +// ── Tool definitions ───────────────────────────────────────────────────────── + +/// Returns a tool definition for a simple "multiply" function. +/// +/// Used by tool-calling chat-completion tests. +pub fn get_multiply_tool() -> foundry_local_sdk::ChatCompletionTools { + serde_json::from_value(serde_json::json!({ + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two numbers together", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "The first number" + }, + "b": { + "type": "number", + "description": "The second number" + } + }, + "required": ["a", "b"] + } + } + })) + .expect("Failed to parse multiply tool definition") +} diff --git a/sdk_v2/rust/tests/manager_test.rs b/sdk_v2/rust/tests/manager_test.rs new file mode 100644 index 00000000..d05abd1f --- /dev/null +++ b/sdk_v2/rust/tests/manager_test.rs @@ -0,0 +1,33 @@ +//! Integration tests for [`FoundryLocalManager`] initialisation. +//! +//! Mirrors `foundryLocalManager.test.ts` from the JavaScript SDK. + +mod common; + +use foundry_local_sdk::FoundryLocalManager; + +mod tests { + use super::*; + + // ── Initialisation ─────────────────────────────────────────────────── + + /// The manager should initialise successfully with the test configuration. + #[test] + #[ignore = "requires native Foundry Local library"] + fn should_initialize_successfully() { + let config = common::test_config(); + let manager = FoundryLocalManager::create(config); + assert!(manager.is_ok(), "Manager creation failed: {:?}", manager.err()); + } + + /// The catalog obtained from a freshly-created manager should have a + /// non-empty name. + #[test] + #[ignore = "requires native Foundry Local library"] + fn should_return_catalog_with_non_empty_name() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let name = catalog.name(); + assert!(!name.is_empty(), "Catalog name should not be empty"); + } +} diff --git a/sdk_v2/rust/tests/model_load_manager_test.rs b/sdk_v2/rust/tests/model_load_manager_test.rs new file mode 100644 index 00000000..161c7c62 --- /dev/null +++ b/sdk_v2/rust/tests/model_load_manager_test.rs @@ -0,0 +1,149 @@ +//! Integration tests for model loading and unloading through the public API. +//! +//! Mirrors `modelLoadManager.test.ts` from the JavaScript SDK. +//! +//! **Note:** In the JavaScript SDK these tests access the private +//! `coreInterop` property via an `as any` cast. In Rust, `CoreInterop` and +//! `ModelLoadManager::new` are `pub(crate)` and cannot be reached from +//! integration tests. Instead, we exercise model loading and unloading +//! through the public [`Model`] and [`Catalog`] APIs which internally +//! delegate to `ModelLoadManager`. + +mod common; + +mod tests { + use super::*; + + // ── Helpers ────────────────────────────────────────────────────────── + + /// Return the test model from the catalog. + async fn get_test_model() -> foundry_local_sdk::Model { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed") + } + + // ── Core-interop path ──────────────────────────────────────────────── + + /// Loading a model via the core interop (in-process) path should succeed. + /// + /// Timeout note: the JS test uses a 120 s timeout. + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_load_model_using_core_interop() { + let model = get_test_model().await; + + model + .load() + .await + .expect("model.load() failed"); + } + + /// Unloading a previously loaded model via the core interop path should + /// succeed. + /// + /// Timeout note: the JS test uses a 120 s timeout. + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_unload_model_using_core_interop() { + let model = get_test_model().await; + + // Ensure the model is loaded first. + model + .load() + .await + .expect("model.load() failed"); + + model + .unload() + .await + .expect("model.unload() failed"); + } + + /// Listing loaded models via the core interop path should return a + /// collection (possibly empty, but the call itself must succeed). + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_list_loaded_models_using_core_interop() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + + let loaded = catalog + .get_loaded_models() + .await + .expect("catalog.get_loaded_models() failed"); + + // The result should be a valid (possibly empty) list of model IDs. + // (Vec is always valid; just ensure the call succeeded.) + let _ = loaded; + } + + // ── External web-service path ──────────────────────────────────────── + + /// Loading and unloading a model through the external HTTP service should + /// succeed. + /// + /// This test is skipped in CI because it requires a running web service. + /// + /// Timeout note: the JS test uses a 120 s timeout. + #[tokio::test] + #[ignore = "requires native Foundry Local library and running web service"] + async fn should_load_and_unload_model_using_external_service() { + if common::is_running_in_ci() { + eprintln!("Skipping external-service test in CI"); + return; + } + + let manager = common::get_test_manager(); + let model = get_test_model().await; + + // Start the web service so we can test the HTTP path. + let _urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + + // Load via the model API (delegates to ModelLoadManager internally). + model + .load() + .await + .expect("load via external service failed"); + + // Unload + model + .unload() + .await + .expect("unload via external service failed"); + } + + /// Listing loaded models through the external HTTP service should succeed. + /// + /// This test is skipped in CI because it requires a running web service. + #[tokio::test] + #[ignore = "requires native Foundry Local library and running web service"] + async fn should_list_loaded_models_using_external_service() { + if common::is_running_in_ci() { + eprintln!("Skipping external-service test in CI"); + return; + } + + let manager = common::get_test_manager(); + + let _urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + + let catalog = manager.catalog(); + let loaded = catalog + .get_loaded_models() + .await + .expect("get_loaded_models via external service failed"); + + // Vec is always a valid list; just ensure the call succeeded. + let _ = loaded; + } +} diff --git a/sdk_v2/rust/tests/model_test.rs b/sdk_v2/rust/tests/model_test.rs new file mode 100644 index 00000000..5238bc2a --- /dev/null +++ b/sdk_v2/rust/tests/model_test.rs @@ -0,0 +1,70 @@ +//! Integration tests for the [`Model`] lifecycle (cache verification, +//! load / unload). +//! +//! Mirrors `model.test.ts` from the JavaScript SDK. + +mod common; + + +mod tests { + use super::*; + + // ── Cache verification ─────────────────────────────────────────────── + + /// The shared test-data directory should contain pre-cached models for + /// both `qwen2.5-0.5b` and `whisper-tiny`. + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_verify_cached_models_from_test_data_shared() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let cached = catalog.get_cached_models().await.expect("get_cached_models failed"); + + // qwen2.5-0.5b must be cached + let has_qwen = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + has_qwen, + "'{}' should be present in cached models", + common::TEST_MODEL_ALIAS + ); + + // whisper-tiny must be cached + let has_whisper = cached.iter().any(|m| m.alias() == common::WHISPER_MODEL_ALIAS); + assert!( + has_whisper, + "'{}' should be present in cached models", + common::WHISPER_MODEL_ALIAS + ); + } + + // ── Load / unload lifecycle ────────────────────────────────────────── + + /// Loading a model should mark it as loaded; unloading should mark it as + /// not loaded. + /// + /// Timeout note: the JS test uses a 120 s timeout for this test. + #[tokio::test] + #[ignore = "requires native Foundry Local library"] + async fn should_load_and_unload_model() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + // Load + model.load().await.expect("model.load() failed"); + assert!( + model.is_loaded().await.expect("is_loaded check failed"), + "Model should be loaded after load()" + ); + + // Unload + model.unload().await.expect("model.unload() failed"); + assert!( + !model.is_loaded().await.expect("is_loaded check failed"), + "Model should not be loaded after unload()" + ); + } +} From 5a24e71e8741e717b183da3f5909c3742d768f5e Mon Sep 17 00:00:00 2001 From: samuel100 Date: Mon, 9 Mar 2026 17:19:37 +0000 Subject: [PATCH 02/25] fix formatting --- .../tool-calling-foundry-local/src/main.rs | 4 +- sdk_v2/rust/build.rs | 31 ++++++------ sdk_v2/rust/examples/interactive_chat.rs | 10 ++-- sdk_v2/rust/examples/tool_calling.rs | 18 ++----- sdk_v2/rust/src/catalog.rs | 5 +- sdk_v2/rust/src/detail/core_interop.rs | 48 ++++++++++-------- sdk_v2/rust/src/foundry_local_manager.rs | 13 ++--- sdk_v2/rust/src/lib.rs | 49 +++++++----------- sdk_v2/rust/src/openai/audio_client.rs | 19 +++---- sdk_v2/rust/src/openai/mod.rs | 6 ++- sdk_v2/rust/tests/audio_client_test.rs | 10 ++-- sdk_v2/rust/tests/catalog_test.rs | 14 ++++-- sdk_v2/rust/tests/chat_client_test.rs | 50 ++++++++++++------- sdk_v2/rust/tests/manager_test.rs | 6 ++- sdk_v2/rust/tests/model_load_manager_test.rs | 15 ++---- sdk_v2/rust/tests/model_test.rs | 10 ++-- 16 files changed, 155 insertions(+), 153 deletions(-) diff --git a/samples/rust/tool-calling-foundry-local/src/main.rs b/samples/rust/tool-calling-foundry-local/src/main.rs index 919fe33e..43ff0a18 100644 --- a/samples/rust/tool-calling-foundry-local/src/main.rs +++ b/samples/rust/tool-calling-foundry-local/src/main.rs @@ -75,9 +75,7 @@ async fn main() -> Result<(), Box> { // ── 3. Create a chat client with tool_choice = required ────────────── let mut client = model.create_chat_client(); - client - .max_tokens(512) - .tool_choice(ChatToolChoice::Required); + client.max_tokens(512).tool_choice(ChatToolChoice::Required); // Define the multiply_numbers tool. let tools: Vec = serde_json::from_value(json!([{ diff --git a/sdk_v2/rust/build.rs b/sdk_v2/rust/build.rs index 7e4db4ca..d51c5805 100644 --- a/sdk_v2/rust/build.rs +++ b/sdk_v2/rust/build.rs @@ -4,7 +4,8 @@ use std::io::{self, Read}; use std::path::{Path, PathBuf}; const NUGET_FEED: &str = "https://api.nuget.org/v3/index.json"; -const ORT_NIGHTLY_FEED: &str = "https://pkgs.dev.azure.com/aiinfra/PublicPackages/_packaging/ORT-Nightly/nuget/v3/index.json"; +const ORT_NIGHTLY_FEED: &str = + "https://pkgs.dev.azure.com/aiinfra/PublicPackages/_packaging/ORT-Nightly/nuget/v3/index.json"; const CORE_VERSION: &str = "0.9.0.8-rc3"; const ORT_VERSION: &str = "1.24.3"; @@ -116,8 +117,8 @@ fn resolve_base_address(feed_url: &str) -> Result { .read_to_string() .map_err(|e| format!("Failed to read feed index response: {e}"))?; - let index: serde_json::Value = serde_json::from_str(&body) - .map_err(|e| format!("Failed to parse feed index JSON: {e}"))?; + let index: serde_json::Value = + serde_json::from_str(&body).map_err(|e| format!("Failed to parse feed index JSON: {e}"))?; let resources = index["resources"] .as_array() @@ -161,22 +162,22 @@ fn resolve_latest_version(package_name: &str, feed_url: &str) -> Option } /// Download a .nupkg and extract native libraries for the given RID into `out_dir`. -fn download_and_extract( - pkg: &NuGetPackage, - rid: &str, - out_dir: &Path, -) -> Result<(), String> { +fn download_and_extract(pkg: &NuGetPackage, rid: &str, out_dir: &Path) -> Result<(), String> { let base_address = resolve_base_address(pkg.feed_url)?; let lower_name = pkg.name.to_lowercase(); let lower_version = pkg.version.to_lowercase(); - let url = format!( - "{base_address}{lower_name}/{lower_version}/{lower_name}.{lower_version}.nupkg" - ); + let url = + format!("{base_address}{lower_name}/{lower_version}/{lower_name}.{lower_version}.nupkg"); - println!("cargo:warning=Downloading {name} {ver} from {feed}", + println!( + "cargo:warning=Downloading {name} {ver} from {feed}", name = pkg.name, ver = pkg.version, - feed = if pkg.feed_url == NUGET_FEED { "NuGet.org" } else { "ORT-Nightly" }, + feed = if pkg.feed_url == NUGET_FEED { + "NuGet.org" + } else { + "ORT-Nightly" + }, ); let mut response = ureq::get(&url) @@ -286,7 +287,9 @@ fn main() { if let Err(e) = download_and_extract(pkg, rid, &out_dir) { println!("cargo:warning=Error downloading {}: {e}", pkg.name); println!("cargo:warning=Build will continue, but runtime loading may fail."); - println!("cargo:warning=You can manually place native libraries in the output directory."); + println!( + "cargo:warning=You can manually place native libraries in the output directory." + ); } } diff --git a/sdk_v2/rust/examples/interactive_chat.rs b/sdk_v2/rust/examples/interactive_chat.rs index aabdae35..f2c7911f 100644 --- a/sdk_v2/rust/examples/interactive_chat.rs +++ b/sdk_v2/rust/examples/interactive_chat.rs @@ -5,9 +5,9 @@ use std::io::{self, Write}; use foundry_local_sdk::{ - ChatCompletionRequestMessage, ChatCompletionRequestAssistantMessage, - ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, - FoundryLocalConfig, FoundryLocalManager, + ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, FoundryLocalConfig, + FoundryLocalManager, }; use tokio_stream::StreamExt; @@ -44,9 +44,7 @@ async fn main() -> Result<(), Box> { // Download if needed if !model.is_cached().await? { println!("Downloading '{alias}'…"); - model - .download(Some(|p: &str| print!("\r {p}%"))) - .await?; + model.download(Some(|p: &str| print!("\r {p}%"))).await?; println!(); } diff --git a/sdk_v2/rust/examples/tool_calling.rs b/sdk_v2/rust/examples/tool_calling.rs index 5fd8bf51..6acdeb76 100644 --- a/sdk_v2/rust/examples/tool_calling.rs +++ b/sdk_v2/rust/examples/tool_calling.rs @@ -25,14 +25,8 @@ fn multiply(a: f64, b: f64) -> f64 { fn invoke_tool(name: &str, arguments: &Value) -> Result { match name { "multiply" => { - let a = arguments - .get("a") - .and_then(|v| v.as_f64()) - .unwrap_or(0.0); - let b = arguments - .get("b") - .and_then(|v| v.as_f64()) - .unwrap_or(0.0); + let a = arguments.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0); + let b = arguments.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0); let result = multiply(a, b); Ok(result.to_string()) } @@ -68,18 +62,14 @@ async fn main() -> Result<()> { if !model.is_cached().await? { println!("Downloading model '{}'…", model.alias()); - model - .download(Some(|p: &str| println!(" {p}"))) - .await?; + model.download(Some(|p: &str| println!(" {p}"))).await?; } println!("Loading model '{}'…", model.alias()); model.load().await?; // ── 3. Create a chat client with tool_choice = required ────────────── let mut client = model.create_chat_client(); - client - .tool_choice(ChatToolChoice::Required) - .max_tokens(512); + client.tool_choice(ChatToolChoice::Required).max_tokens(512); let tools: Vec = serde_json::from_value(json!([{ "type": "function", diff --git a/sdk_v2/rust/src/catalog.rs b/sdk_v2/rust/src/catalog.rs index 7d253471..2061ba76 100644 --- a/sdk_v2/rust/src/catalog.rs +++ b/sdk_v2/rust/src/catalog.rs @@ -25,7 +25,10 @@ pub struct Catalog { } impl Catalog { - pub(crate) fn new(core: Arc, model_load_manager: Arc) -> Result { + pub(crate) fn new( + core: Arc, + model_load_manager: Arc, + ) -> Result { let name = core .execute_command("get_catalog_name", None) .unwrap_or_else(|_| "default".into()); diff --git a/sdk_v2/rust/src/detail/core_interop.rs b/sdk_v2/rust/src/detail/core_interop.rs index 7223fc63..46b1e261 100644 --- a/sdk_v2/rust/src/detail/core_interop.rs +++ b/sdk_v2/rust/src/detail/core_interop.rs @@ -55,8 +55,12 @@ type ExecuteCommandFn = unsafe extern "C" fn(*const RequestBuffer, *mut Response type CallbackFn = unsafe extern "C" fn(*const u8, i32, *mut std::ffi::c_void); /// Signature for `execute_command_with_callback`. -type ExecuteCommandWithCallbackFn = - unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer, CallbackFn, *mut std::ffi::c_void); +type ExecuteCommandWithCallbackFn = unsafe extern "C" fn( + *const RequestBuffer, + *mut ResponseBuffer, + CallbackFn, + *mut std::ffi::c_void, +); // ── Library name helpers ───────────────────────────────────────────────────── @@ -124,8 +128,12 @@ pub(crate) struct CoreInterop { #[cfg(target_os = "windows")] _dependency_libs: Vec, execute_command: unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer), - execute_command_with_callback: - unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer, CallbackFn, *mut std::ffi::c_void), + execute_command_with_callback: unsafe extern "C" fn( + *const RequestBuffer, + *mut ResponseBuffer, + CallbackFn, + *mut std::ffi::c_void, + ), } impl std::fmt::Debug for CoreInterop { @@ -134,7 +142,6 @@ impl std::fmt::Debug for CoreInterop { } } - impl CoreInterop { /// Load the native core library using the provided configuration to locate /// it on disk. @@ -159,18 +166,16 @@ impl CoreInterop { }; let execute_command: ExecuteCommandFn = unsafe { - let sym: Symbol = - library.get(b"execute_command\0").map_err(|e| { - FoundryLocalError::LibraryLoad(format!( - "Symbol 'execute_command' not found: {e}" - )) - })?; + let sym: Symbol = library.get(b"execute_command\0").map_err(|e| { + FoundryLocalError::LibraryLoad(format!("Symbol 'execute_command' not found: {e}")) + })?; *sym }; let execute_command_with_callback: ExecuteCommandWithCallbackFn = unsafe { - let sym: Symbol = - library.get(b"execute_command_with_callback\0").map_err(|e| { + let sym: Symbol = library + .get(b"execute_command_with_callback\0") + .map_err(|e| { FoundryLocalError::LibraryLoad(format!( "Symbol 'execute_command_with_callback' not found: {e}" )) @@ -282,11 +287,9 @@ impl CoreInterop { params: Option, ) -> Result { let this = Arc::clone(self); - tokio::task::spawn_blocking(move || { - this.execute_command(&command, params.as_ref()) - }) - .await - .map_err(|e| FoundryLocalError::CommandExecution(format!("task join error: {e}")))? + tokio::task::spawn_blocking(move || this.execute_command(&command, params.as_ref())) + .await + .map_err(|e| FoundryLocalError::CommandExecution(format!("task join error: {e}")))? } /// Async version of [`Self::execute_command_streaming`]. @@ -321,7 +324,10 @@ impl CoreInterop { self: &Arc, command: String, params: Option, - ) -> Result<(tokio::sync::mpsc::UnboundedReceiver, tokio::task::JoinHandle>)> { + ) -> Result<( + tokio::sync::mpsc::UnboundedReceiver, + tokio::task::JoinHandle>, + )> { let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); let this = Arc::clone(self); @@ -426,9 +432,7 @@ impl CoreInterop { /// resolve them. #[cfg(target_os = "windows")] fn load_windows_dependencies(core_lib_path: &Path) -> Result> { - let dir = core_lib_path - .parent() - .unwrap_or_else(|| Path::new(".")); + let dir = core_lib_path.parent().unwrap_or_else(|| Path::new(".")); let mut libs = Vec::new(); diff --git a/sdk_v2/rust/src/foundry_local_manager.rs b/sdk_v2/rust/src/foundry_local_manager.rs index a3d38e23..0ec10cac 100644 --- a/sdk_v2/rust/src/foundry_local_manager.rs +++ b/sdk_v2/rust/src/foundry_local_manager.rs @@ -49,15 +49,10 @@ impl FoundryLocalManager { let init_params = json!({ "Params": internal_config.params }); core.execute_command("initialize", Some(&init_params))?; - let service_endpoint = internal_config - .params - .get("WebServiceExternalUrl") - .cloned(); - - let model_load_manager = Arc::new(ModelLoadManager::new( - Arc::clone(&core), - service_endpoint, - )); + let service_endpoint = internal_config.params.get("WebServiceExternalUrl").cloned(); + + let model_load_manager = + Arc::new(ModelLoadManager::new(Arc::clone(&core), service_endpoint)); let catalog = Catalog::new(Arc::clone(&core), Arc::clone(&model_load_manager))?; diff --git a/sdk_v2/rust/src/lib.rs b/sdk_v2/rust/src/lib.rs index a6472175..df7e7cef 100644 --- a/sdk_v2/rust/src/lib.rs +++ b/sdk_v2/rust/src/lib.rs @@ -2,55 +2,42 @@ //! //! Local AI model inference powered by the Foundry Local Core engine. -mod error; -mod types; +mod catalog; mod configuration; +mod error; mod foundry_local_manager; -mod catalog; mod model; mod model_variant; +mod types; pub(crate) mod detail; pub mod openai; -pub use error::FoundryLocalError; -pub use types::*; +pub use catalog::Catalog; pub use configuration::{FoundryLocalConfig, LogLevel}; +pub use detail::ModelLoadManager; +pub use error::FoundryLocalError; pub use foundry_local_manager::FoundryLocalManager; -pub use catalog::Catalog; pub use model::Model; pub use model_variant::ModelVariant; -pub use detail::ModelLoadManager; +pub use types::*; // Re-export OpenAI request types so callers can construct typed messages. pub use async_openai::types::chat::{ - ChatCompletionRequestMessage, - ChatCompletionRequestSystemMessage, - ChatCompletionRequestUserMessage, - ChatCompletionRequestAssistantMessage, - ChatCompletionRequestToolMessage, - ChatCompletionTools, - ChatCompletionToolChoiceOption, - ChatCompletionNamedToolChoice, - FunctionObject, + ChatCompletionNamedToolChoice, ChatCompletionRequestAssistantMessage, + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestToolMessage, ChatCompletionRequestUserMessage, + ChatCompletionToolChoiceOption, ChatCompletionTools, FunctionObject, }; // Re-export OpenAI response types for convenience. +pub use crate::openai::{ + AudioTranscriptionResponse, AudioTranscriptionStream, ChatCompletionStream, +}; pub use async_openai::types::chat::{ - CreateChatCompletionResponse, - CreateChatCompletionStreamResponse, - ChatChoice, - ChatChoiceStream, - ChatCompletionResponseMessage, - ChatCompletionStreamResponseDelta, - ChatCompletionMessageToolCall, - ChatCompletionMessageToolCalls, - ChatCompletionMessageToolCallChunk, - FunctionCall, + ChatChoice, ChatChoiceStream, ChatCompletionMessageToolCall, + ChatCompletionMessageToolCallChunk, ChatCompletionMessageToolCalls, + ChatCompletionResponseMessage, ChatCompletionStreamResponseDelta, CompletionUsage, + CreateChatCompletionResponse, CreateChatCompletionStreamResponse, FinishReason, FunctionCall, FunctionCallStream, - FinishReason, - CompletionUsage, -}; -pub use crate::openai::{ - AudioTranscriptionResponse, ChatCompletionStream, AudioTranscriptionStream, }; diff --git a/sdk_v2/rust/src/openai/audio_client.rs b/sdk_v2/rust/src/openai/audio_client.rs index 1a26afc7..f909d544 100644 --- a/sdk_v2/rust/src/openai/audio_client.rs +++ b/sdk_v2/rust/src/openai/audio_client.rs @@ -134,11 +134,13 @@ impl AudioClient { } /// Transcribe an audio file. - pub async fn transcribe(&self, audio_file_path: impl AsRef) -> Result { - let path_str = audio_file_path - .as_ref() - .to_str() - .ok_or_else(|| FoundryLocalError::Validation("audio file path is not valid UTF-8".into()))?; + pub async fn transcribe( + &self, + audio_file_path: impl AsRef, + ) -> Result { + let path_str = audio_file_path.as_ref().to_str().ok_or_else(|| { + FoundryLocalError::Validation("audio file path is not valid UTF-8".into()) + })?; Self::validate_path(path_str)?; let request = self.settings.serialize(&self.model_id, path_str); @@ -162,10 +164,9 @@ impl AudioClient { &self, audio_file_path: impl AsRef, ) -> Result { - let path_str = audio_file_path - .as_ref() - .to_str() - .ok_or_else(|| FoundryLocalError::Validation("audio file path is not valid UTF-8".into()))?; + let path_str = audio_file_path.as_ref().to_str().ok_or_else(|| { + FoundryLocalError::Validation("audio file path is not valid UTF-8".into()) + })?; Self::validate_path(path_str)?; let mut request = self.settings.serialize(&self.model_id, path_str); diff --git a/sdk_v2/rust/src/openai/mod.rs b/sdk_v2/rust/src/openai/mod.rs index e1d52dda..190c6268 100644 --- a/sdk_v2/rust/src/openai/mod.rs +++ b/sdk_v2/rust/src/openai/mod.rs @@ -1,5 +1,7 @@ -mod chat_client; mod audio_client; +mod chat_client; +pub use audio_client::{ + AudioClient, AudioClientSettings, AudioTranscriptionResponse, AudioTranscriptionStream, +}; pub use chat_client::{ChatClient, ChatClientSettings, ChatCompletionStream}; -pub use audio_client::{AudioClient, AudioClientSettings, AudioTranscriptionResponse, AudioTranscriptionStream}; diff --git a/sdk_v2/rust/tests/audio_client_test.rs b/sdk_v2/rust/tests/audio_client_test.rs index 4f3664f1..4fb7e1d0 100644 --- a/sdk_v2/rust/tests/audio_client_test.rs +++ b/sdk_v2/rust/tests/audio_client_test.rs @@ -26,9 +26,7 @@ mod tests { } fn audio_file() -> String { - common::get_audio_file_path() - .to_string_lossy() - .into_owned() + common::get_audio_file_path().to_string_lossy().into_owned() } // ── Non-streaming transcription ────────────────────────────────────── @@ -44,7 +42,8 @@ mod tests { assert!( response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), - "Transcription should contain expected text, got: {}", response.text + "Transcription should contain expected text, got: {}", + response.text ); } @@ -61,7 +60,8 @@ mod tests { assert!( response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), - "Transcription should contain expected text, got: {}", response.text + "Transcription should contain expected text, got: {}", + response.text ); } diff --git a/sdk_v2/rust/tests/catalog_test.rs b/sdk_v2/rust/tests/catalog_test.rs index 30fa10e5..f6b8d8de 100644 --- a/sdk_v2/rust/tests/catalog_test.rs +++ b/sdk_v2/rust/tests/catalog_test.rs @@ -34,7 +34,10 @@ mod tests { let cat = catalog(); let models = cat.get_models().await.expect("get_models failed"); - assert!(!models.is_empty(), "Expected at least one model in the catalog"); + assert!( + !models.is_empty(), + "Expected at least one model in the catalog" + ); let found = models.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); assert!( @@ -103,7 +106,10 @@ mod tests { #[ignore = "requires native Foundry Local library"] async fn should_get_cached_models() { let cat = catalog(); - let cached = cat.get_cached_models().await.expect("get_cached_models failed"); + let cached = cat + .get_cached_models() + .await + .expect("get_cached_models failed"); assert!(!cached.is_empty(), "Expected at least one cached model"); @@ -131,7 +137,9 @@ mod tests { #[ignore = "requires native Foundry Local library"] async fn should_throw_when_getting_model_variant_with_unknown_id() { let cat = catalog(); - let result = cat.get_model_variant("unknown-nonexistent-variant-id").await; + let result = cat + .get_model_variant("unknown-nonexistent-variant-id") + .await; assert!(result.is_err(), "Expected error for unknown variant ID"); } } diff --git a/sdk_v2/rust/tests/chat_client_test.rs b/sdk_v2/rust/tests/chat_client_test.rs index c4350ca8..5308091f 100644 --- a/sdk_v2/rust/tests/chat_client_test.rs +++ b/sdk_v2/rust/tests/chat_client_test.rs @@ -5,12 +5,12 @@ mod common; +use foundry_local_sdk::openai::ChatClient; use foundry_local_sdk::{ ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, ChatCompletionRequestToolMessage, ChatCompletionRequestUserMessage, ChatToolChoice, }; -use foundry_local_sdk::openai::ChatClient; use serde_json::json; use tokio_stream::StreamExt; @@ -58,8 +58,13 @@ mod tests { user_message("What is 7*6?"), ]; - let response = client.complete_chat(&messages, None).await.expect("complete_chat failed"); - let content = response.choices.first() + let response = client + .complete_chat(&messages, None) + .await + .expect("complete_chat failed"); + let content = response + .choices + .first() .and_then(|c| c.message.content.as_deref()) .unwrap_or(""); @@ -145,7 +150,10 @@ mod tests { let messages: Vec = vec![]; let result = client.complete_streaming_chat(&messages, None).await; - assert!(result.is_err(), "Expected error for empty messages in streaming"); + assert!( + result.is_err(), + "Expected error for empty messages in streaming" + ); } #[tokio::test] @@ -155,10 +163,7 @@ mod tests { let messages: Vec = vec![]; let result = client.complete_streaming_chat(&messages, None).await; - assert!( - result.is_err(), - "Expected error even with empty messages" - ); + assert!(result.is_err(), "Expected error even with empty messages"); } // ── Tool calling (non-streaming) ───────────────────────────────────── @@ -181,8 +186,15 @@ mod tests { .await .expect("complete_chat with tools failed"); - let choice = response.choices.first().expect("Expected at least one choice"); - let tool_calls = choice.message.tool_calls.as_ref().expect("Expected tool_calls"); + let choice = response + .choices + .first() + .expect("Expected at least one choice"); + let tool_calls = choice + .message + .tool_calls + .as_ref() + .expect("Expected tool_calls"); assert!( !tool_calls.is_empty(), "Expected at least one tool call in the response" @@ -193,15 +205,13 @@ mod tests { _ => panic!("Expected a function tool call"), }; assert_eq!( - tool_call.function.name, - "multiply", + tool_call.function.name, "multiply", "Expected tool call to 'multiply'" ); // Parse arguments and compute the result. - let args: serde_json::Value = - serde_json::from_str(&tool_call.function.arguments) - .expect("Failed to parse tool call arguments"); + let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments) + .expect("Failed to parse tool call arguments"); let a = args["a"].as_f64().unwrap_or(0.0); let b = args["b"].as_f64().unwrap_or(0.0); let product = (a * b) as i64; @@ -219,7 +229,8 @@ mod tests { "arguments": tool_call.function.arguments, } }] - })).expect("failed to construct assistant message"); + })) + .expect("failed to construct assistant message"); messages.push(assistant_msg); messages.push( ChatCompletionRequestToolMessage { @@ -236,7 +247,9 @@ mod tests { .complete_chat(&messages, Some(&tools)) .await .expect("follow-up complete_chat with tools failed"); - let content = final_response.choices.first() + let content = final_response + .choices + .first() .and_then(|c| c.message.content.as_deref()) .unwrap_or(""); @@ -315,7 +328,8 @@ mod tests { "arguments": tool_call_args } }] - })).expect("failed to construct assistant message"); + })) + .expect("failed to construct assistant message"); messages.push(assistant_msg); messages.push( ChatCompletionRequestToolMessage { diff --git a/sdk_v2/rust/tests/manager_test.rs b/sdk_v2/rust/tests/manager_test.rs index d05abd1f..ef49399c 100644 --- a/sdk_v2/rust/tests/manager_test.rs +++ b/sdk_v2/rust/tests/manager_test.rs @@ -17,7 +17,11 @@ mod tests { fn should_initialize_successfully() { let config = common::test_config(); let manager = FoundryLocalManager::create(config); - assert!(manager.is_ok(), "Manager creation failed: {:?}", manager.err()); + assert!( + manager.is_ok(), + "Manager creation failed: {:?}", + manager.err() + ); } /// The catalog obtained from a freshly-created manager should have a diff --git a/sdk_v2/rust/tests/model_load_manager_test.rs b/sdk_v2/rust/tests/model_load_manager_test.rs index 161c7c62..dc5dd5fa 100644 --- a/sdk_v2/rust/tests/model_load_manager_test.rs +++ b/sdk_v2/rust/tests/model_load_manager_test.rs @@ -36,10 +36,7 @@ mod tests { async fn should_load_model_using_core_interop() { let model = get_test_model().await; - model - .load() - .await - .expect("model.load() failed"); + model.load().await.expect("model.load() failed"); } /// Unloading a previously loaded model via the core interop path should @@ -52,15 +49,9 @@ mod tests { let model = get_test_model().await; // Ensure the model is loaded first. - model - .load() - .await - .expect("model.load() failed"); + model.load().await.expect("model.load() failed"); - model - .unload() - .await - .expect("model.unload() failed"); + model.unload().await.expect("model.unload() failed"); } /// Listing loaded models via the core interop path should return a diff --git a/sdk_v2/rust/tests/model_test.rs b/sdk_v2/rust/tests/model_test.rs index 5238bc2a..524ab398 100644 --- a/sdk_v2/rust/tests/model_test.rs +++ b/sdk_v2/rust/tests/model_test.rs @@ -5,7 +5,6 @@ mod common; - mod tests { use super::*; @@ -18,7 +17,10 @@ mod tests { async fn should_verify_cached_models_from_test_data_shared() { let manager = common::get_test_manager(); let catalog = manager.catalog(); - let cached = catalog.get_cached_models().await.expect("get_cached_models failed"); + let cached = catalog + .get_cached_models() + .await + .expect("get_cached_models failed"); // qwen2.5-0.5b must be cached let has_qwen = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); @@ -29,7 +31,9 @@ mod tests { ); // whisper-tiny must be cached - let has_whisper = cached.iter().any(|m| m.alias() == common::WHISPER_MODEL_ALIAS); + let has_whisper = cached + .iter() + .any(|m| m.alias() == common::WHISPER_MODEL_ALIAS); assert!( has_whisper, "'{}' should be present in cached models", From 563d1543cdfbf2b2597bc09488b627835568e219 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Mon, 9 Mar 2026 17:24:42 +0000 Subject: [PATCH 03/25] fix clippy error --- sdk_v2/rust/src/types.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sdk_v2/rust/src/types.rs b/sdk_v2/rust/src/types.rs index 495d5aab..66bbba92 100644 --- a/sdk_v2/rust/src/types.rs +++ b/sdk_v2/rust/src/types.rs @@ -1,20 +1,15 @@ use serde::{Deserialize, Serialize}; /// Hardware device type for model execution. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum DeviceType { Invalid, + #[default] CPU, GPU, NPU, } -impl Default for DeviceType { - fn default() -> Self { - Self::CPU - } -} - /// Prompt template describing how messages are formatted for the model. #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] From 2a05d30a462de1d606006092e3fceff9f3cd96d7 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Mon, 9 Mar 2026 17:30:41 +0000 Subject: [PATCH 04/25] fix clippy treating unused items as errors --- sdk_v2/rust/tests/common/mod.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sdk_v2/rust/tests/common/mod.rs b/sdk_v2/rust/tests/common/mod.rs index 9a42e53d..45d66fe1 100644 --- a/sdk_v2/rust/tests/common/mod.rs +++ b/sdk_v2/rust/tests/common/mod.rs @@ -2,6 +2,8 @@ //! //! Mirrors `testUtils.ts` from the JavaScript SDK test suite. +#![allow(dead_code)] + use std::collections::HashMap; use std::path::PathBuf; From c08c2fb3064a94e9116ce23ab670897e95e45548 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Mon, 9 Mar 2026 18:23:06 +0000 Subject: [PATCH 05/25] enable integration tests --- sdk_v2/rust/tests/audio_client_test.rs | 6 ------ sdk_v2/rust/tests/catalog_test.rs | 8 -------- sdk_v2/rust/tests/chat_client_test.rs | 7 ------- sdk_v2/rust/tests/manager_test.rs | 2 -- sdk_v2/rust/tests/model_load_manager_test.rs | 3 --- sdk_v2/rust/tests/model_test.rs | 2 -- 6 files changed, 28 deletions(-) diff --git a/sdk_v2/rust/tests/audio_client_test.rs b/sdk_v2/rust/tests/audio_client_test.rs index 4fb7e1d0..aff45905 100644 --- a/sdk_v2/rust/tests/audio_client_test.rs +++ b/sdk_v2/rust/tests/audio_client_test.rs @@ -32,7 +32,6 @@ mod tests { // ── Non-streaming transcription ────────────────────────────────────── #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_transcribe_audio_without_streaming() { let client = setup_audio_client().await; let response = client @@ -48,7 +47,6 @@ mod tests { } #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_transcribe_audio_without_streaming_with_temperature() { let mut client = setup_audio_client().await; client.language("en").temperature(0.0); @@ -68,7 +66,6 @@ mod tests { // ── Streaming transcription ────────────────────────────────────────── #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_transcribe_audio_with_streaming() { let client = setup_audio_client().await; let mut full_text = String::new(); @@ -91,7 +88,6 @@ mod tests { } #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_transcribe_audio_with_streaming_with_temperature() { let mut client = setup_audio_client().await; client.language("en").temperature(0.0); @@ -118,7 +114,6 @@ mod tests { // ── Validation: empty file path ────────────────────────────────────── #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_throw_when_transcribing_with_empty_audio_file_path() { let client = setup_audio_client().await; let result = client.transcribe("").await; @@ -126,7 +121,6 @@ mod tests { } #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_throw_when_transcribing_streaming_with_empty_audio_file_path() { let client = setup_audio_client().await; let result = client.transcribe_streaming("").await; diff --git a/sdk_v2/rust/tests/catalog_test.rs b/sdk_v2/rust/tests/catalog_test.rs index f6b8d8de..9f6b762e 100644 --- a/sdk_v2/rust/tests/catalog_test.rs +++ b/sdk_v2/rust/tests/catalog_test.rs @@ -19,7 +19,6 @@ mod tests { /// The catalog should expose a non-empty name after initialisation. #[test] - #[ignore = "requires native Foundry Local library"] fn should_initialize_with_catalog_name() { let cat = catalog(); let name = cat.name(); @@ -29,7 +28,6 @@ mod tests { /// `list_models()` should return at least one model and the test model /// should be present among them. #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_list_models() { let cat = catalog(); let models = cat.get_models().await.expect("get_models failed"); @@ -49,7 +47,6 @@ mod tests { /// `get_model()` with a valid alias should return the corresponding model. #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_get_model_by_alias() { let cat = catalog(); let model = cat @@ -65,7 +62,6 @@ mod tests { /// `get_model("")` should return an error containing /// "Model alias must be a non-empty string". #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_throw_when_getting_model_with_empty_alias() { let cat = catalog(); let result = cat.get_model("").await; @@ -81,7 +77,6 @@ mod tests { /// An unknown alias should produce an error mentioning "not found" and /// listing available models. #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_throw_when_getting_model_with_unknown_alias() { let cat = catalog(); let result = cat.get_model("unknown-nonexistent-model-alias").await; @@ -103,7 +98,6 @@ mod tests { /// `get_cached_models()` should return at least one model and the test /// model should be cached. #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_get_cached_models() { let cat = catalog(); let cached = cat @@ -125,7 +119,6 @@ mod tests { /// `get_model_variant("")` should return a validation error. #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_throw_when_getting_model_variant_with_empty_id() { let cat = catalog(); let result = cat.get_model_variant("").await; @@ -134,7 +127,6 @@ mod tests { /// `get_model_variant()` with an unknown ID should return an error. #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_throw_when_getting_model_variant_with_unknown_id() { let cat = catalog(); let result = cat diff --git a/sdk_v2/rust/tests/chat_client_test.rs b/sdk_v2/rust/tests/chat_client_test.rs index 5308091f..81c5a2ea 100644 --- a/sdk_v2/rust/tests/chat_client_test.rs +++ b/sdk_v2/rust/tests/chat_client_test.rs @@ -50,7 +50,6 @@ mod tests { // ── Non-streaming completion ───────────────────────────────────────── #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_perform_chat_completion() { let client = setup_chat_client().await; let messages = vec![ @@ -77,7 +76,6 @@ mod tests { // ── Streaming completion ───────────────────────────────────────────── #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_perform_streaming_chat_completion() { let client = setup_chat_client().await; let mut messages = vec![ @@ -134,7 +132,6 @@ mod tests { // ── Validation: empty / invalid messages ───────────────────────────── #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_throw_when_completing_chat_with_empty_messages() { let client = setup_chat_client().await; let messages: Vec = vec![]; @@ -144,7 +141,6 @@ mod tests { } #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_throw_when_completing_streaming_chat_with_empty_messages() { let client = setup_chat_client().await; let messages: Vec = vec![]; @@ -157,7 +153,6 @@ mod tests { } #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_throw_when_completing_streaming_chat_with_invalid_callback() { let client = setup_chat_client().await; let messages: Vec = vec![]; @@ -169,7 +164,6 @@ mod tests { // ── Tool calling (non-streaming) ───────────────────────────────────── #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_perform_tool_calling_chat_completion_non_streaming() { let mut client = setup_chat_client().await; client.tool_choice(ChatToolChoice::Required); @@ -262,7 +256,6 @@ mod tests { // ── Tool calling (streaming) ───────────────────────────────────────── #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_perform_tool_calling_chat_completion_streaming() { let mut client = setup_chat_client().await; client.tool_choice(ChatToolChoice::Required); diff --git a/sdk_v2/rust/tests/manager_test.rs b/sdk_v2/rust/tests/manager_test.rs index ef49399c..256ce420 100644 --- a/sdk_v2/rust/tests/manager_test.rs +++ b/sdk_v2/rust/tests/manager_test.rs @@ -13,7 +13,6 @@ mod tests { /// The manager should initialise successfully with the test configuration. #[test] - #[ignore = "requires native Foundry Local library"] fn should_initialize_successfully() { let config = common::test_config(); let manager = FoundryLocalManager::create(config); @@ -27,7 +26,6 @@ mod tests { /// The catalog obtained from a freshly-created manager should have a /// non-empty name. #[test] - #[ignore = "requires native Foundry Local library"] fn should_return_catalog_with_non_empty_name() { let manager = common::get_test_manager(); let catalog = manager.catalog(); diff --git a/sdk_v2/rust/tests/model_load_manager_test.rs b/sdk_v2/rust/tests/model_load_manager_test.rs index dc5dd5fa..3b0bf0c4 100644 --- a/sdk_v2/rust/tests/model_load_manager_test.rs +++ b/sdk_v2/rust/tests/model_load_manager_test.rs @@ -32,7 +32,6 @@ mod tests { /// /// Timeout note: the JS test uses a 120 s timeout. #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_load_model_using_core_interop() { let model = get_test_model().await; @@ -44,7 +43,6 @@ mod tests { /// /// Timeout note: the JS test uses a 120 s timeout. #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_unload_model_using_core_interop() { let model = get_test_model().await; @@ -57,7 +55,6 @@ mod tests { /// Listing loaded models via the core interop path should return a /// collection (possibly empty, but the call itself must succeed). #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_list_loaded_models_using_core_interop() { let manager = common::get_test_manager(); let catalog = manager.catalog(); diff --git a/sdk_v2/rust/tests/model_test.rs b/sdk_v2/rust/tests/model_test.rs index 524ab398..8eedc5fb 100644 --- a/sdk_v2/rust/tests/model_test.rs +++ b/sdk_v2/rust/tests/model_test.rs @@ -13,7 +13,6 @@ mod tests { /// The shared test-data directory should contain pre-cached models for /// both `qwen2.5-0.5b` and `whisper-tiny`. #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_verify_cached_models_from_test_data_shared() { let manager = common::get_test_manager(); let catalog = manager.catalog(); @@ -48,7 +47,6 @@ mod tests { /// /// Timeout note: the JS test uses a 120 s timeout for this test. #[tokio::test] - #[ignore = "requires native Foundry Local library"] async fn should_load_and_unload_model() { let manager = common::get_test_manager(); let catalog = manager.catalog(); From 7a9fd6bc014104578877390a9764504f5ad3ddbc Mon Sep 17 00:00:00 2001 From: samuel100 Date: Mon, 9 Mar 2026 18:38:03 +0000 Subject: [PATCH 06/25] added test data step to rust build step --- .github/workflows/build-rust-steps.yml | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/.github/workflows/build-rust-steps.yml b/.github/workflows/build-rust-steps.yml index 4ee92aa6..6b655310 100644 --- a/.github/workflows/build-rust-steps.yml +++ b/.github/workflows/build-rust-steps.yml @@ -46,6 +46,37 @@ jobs: with: workspaces: sdk_v2/rust -> target + - name: Checkout test-data-shared from Azure DevOps + if: ${{ inputs.run-integration-tests }} + shell: pwsh + working-directory: ${{ github.workspace }}/.. + run: | + $pat = "${{ secrets.AZURE_DEVOPS_PAT }}" + $encodedPat = [Convert]::ToBase64String([Text.Encoding]::ASCII.GetBytes(":$pat")) + + # Configure git to use the PAT + git config --global http.https://dev.azure.com.extraheader "AUTHORIZATION: Basic $encodedPat" + + # Clone with LFS to parent directory + git lfs install + git clone --depth 1 https://dev.azure.com/microsoft/windows.ai.toolkit/_git/test-data-shared test-data-shared + + Write-Host "Clone completed successfully to ${{ github.workspace }}/../test-data-shared" + + - name: Checkout specific commit in test-data-shared + if: ${{ inputs.run-integration-tests }} + shell: pwsh + working-directory: ${{ github.workspace }}/../test-data-shared + run: | + Write-Host "Current directory: $(Get-Location)" + git checkout 231f820fe285145b7ea4a449b112c1228ce66a41 + if ($LASTEXITCODE -ne 0) { + Write-Error "Git checkout failed." + exit 1 + } + Write-Host "`nDirectory contents:" + Get-ChildItem -Recurse -Depth 2 | ForEach-Object { Write-Host " $($_.FullName)" } + - name: Check formatting run: cargo fmt --all -- --check From bef92eb7b4a33dab422e8c44239756bdb25fd17f Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 10:09:00 +0000 Subject: [PATCH 07/25] ensure integration tests run and remove redundant workflow --- .github/workflows/foundry-local-sdk-build.yml | 24 ++------------- .github/workflows/rustfmt.yml | 29 ------------------- 2 files changed, 3 insertions(+), 50 deletions(-) delete mode 100644 .github/workflows/rustfmt.yml diff --git a/.github/workflows/foundry-local-sdk-build.yml b/.github/workflows/foundry-local-sdk-build.yml index 7639091b..e38fc251 100644 --- a/.github/workflows/foundry-local-sdk-build.yml +++ b/.github/workflows/foundry-local-sdk-build.yml @@ -62,40 +62,22 @@ jobs: uses: ./.github/workflows/build-rust-steps.yml with: platform: 'windows' + run-integration-tests: true secrets: inherit build-rust-windows-WinML: uses: ./.github/workflows/build-rust-steps.yml with: platform: 'windows' useWinML: true + run-integration-tests: true secrets: inherit build-rust-ubuntu: - uses: ./.github/workflows/build-rust-steps.yml - with: - platform: 'ubuntu' - secrets: inherit - build-rust-macos: - uses: ./.github/workflows/build-rust-steps.yml - with: - platform: 'macos' - secrets: inherit - - integration-test-rust-ubuntu: - if: github.event_name == 'workflow_dispatch' || github.ref == 'refs/heads/main' uses: ./.github/workflows/build-rust-steps.yml with: platform: 'ubuntu' run-integration-tests: true secrets: inherit - integration-test-rust-windows: - if: github.event_name == 'workflow_dispatch' || github.ref == 'refs/heads/main' - uses: ./.github/workflows/build-rust-steps.yml - with: - platform: 'windows' - run-integration-tests: true - secrets: inherit - integration-test-rust-macos: - if: github.event_name == 'workflow_dispatch' || github.ref == 'refs/heads/main' + build-rust-macos: uses: ./.github/workflows/build-rust-steps.yml with: platform: 'macos' diff --git a/.github/workflows/rustfmt.yml b/.github/workflows/rustfmt.yml deleted file mode 100644 index 28e61acc..00000000 --- a/.github/workflows/rustfmt.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: Rust-fmt - -on: - pull_request: - paths: - - 'sdk/rust/**' - - 'samples/rust/**' - push: - paths: - - 'sdk/rust/**' - - 'samples/rust/**' - branches: - - main - workflow_dispatch: - -jobs: - check: - runs-on: ubuntu-22.04 - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Update toolchain - run: rustup update --no-self-update stable && rustup default stable - - name: Check SDK - working-directory: sdk/rust - run: cargo fmt --all -- --check - - name: Check Samples - working-directory: samples/rust - run: cargo fmt --all -- --check From 759f83b0d2e0ec9c482b22a5891796f753c5541f Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 10:16:39 +0000 Subject: [PATCH 08/25] fix: consolidate integration tests into single binary MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cargo compiles each tests/*.rs as a separate binary. The native .NET core has global state — initializing from a second binary causes 'FoundryLocalCore has already been initialized'. Consolidating all test modules into tests/integration.rs ensures the OnceLock singleton initializes once and all 30 tests share it. Also adds --test-threads=1 to CI since tests share model load/unload state and cannot safely run in parallel. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/build-rust-steps.yml | 2 +- sdk_v2/rust/tests/audio_client_test.rs | 132 ---- sdk_v2/rust/tests/catalog_test.rs | 137 ---- sdk_v2/rust/tests/chat_client_test.rs | 357 --------- sdk_v2/rust/tests/integration.rs | 728 +++++++++++++++++++ sdk_v2/rust/tests/manager_test.rs | 35 - sdk_v2/rust/tests/model_load_manager_test.rs | 137 ---- sdk_v2/rust/tests/model_test.rs | 72 -- 8 files changed, 729 insertions(+), 871 deletions(-) delete mode 100644 sdk_v2/rust/tests/audio_client_test.rs delete mode 100644 sdk_v2/rust/tests/catalog_test.rs delete mode 100644 sdk_v2/rust/tests/chat_client_test.rs create mode 100644 sdk_v2/rust/tests/integration.rs delete mode 100644 sdk_v2/rust/tests/manager_test.rs delete mode 100644 sdk_v2/rust/tests/model_load_manager_test.rs delete mode 100644 sdk_v2/rust/tests/model_test.rs diff --git a/.github/workflows/build-rust-steps.yml b/.github/workflows/build-rust-steps.yml index 6b655310..2dc67cdc 100644 --- a/.github/workflows/build-rust-steps.yml +++ b/.github/workflows/build-rust-steps.yml @@ -91,4 +91,4 @@ jobs: - name: Run integration tests if: ${{ inputs.run-integration-tests }} - run: cargo test --test '*' ${{ env.CARGO_FEATURES }} + run: cargo test --test '*' ${{ env.CARGO_FEATURES }} -- --test-threads=1 diff --git a/sdk_v2/rust/tests/audio_client_test.rs b/sdk_v2/rust/tests/audio_client_test.rs deleted file mode 100644 index aff45905..00000000 --- a/sdk_v2/rust/tests/audio_client_test.rs +++ /dev/null @@ -1,132 +0,0 @@ -//! Integration tests for the [`AudioClient`] transcription API (non-streaming -//! and streaming, with optional temperature). -//! -//! Mirrors `audioClient.test.ts` from the JavaScript SDK. - -mod common; - -use foundry_local_sdk::openai::AudioClient; -use tokio_stream::StreamExt; - -mod tests { - use super::*; - - // ── Helpers ────────────────────────────────────────────────────────── - - /// Load the whisper model and return an [`AudioClient`] ready for use. - async fn setup_audio_client() -> AudioClient { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let model = catalog - .get_model(common::WHISPER_MODEL_ALIAS) - .await - .expect("get_model(whisper-tiny) failed"); - model.load().await.expect("model.load() failed"); - model.create_audio_client() - } - - fn audio_file() -> String { - common::get_audio_file_path().to_string_lossy().into_owned() - } - - // ── Non-streaming transcription ────────────────────────────────────── - - #[tokio::test] - async fn should_transcribe_audio_without_streaming() { - let client = setup_audio_client().await; - let response = client - .transcribe(&audio_file()) - .await - .expect("transcribe failed"); - - assert!( - response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), - "Transcription should contain expected text, got: {}", - response.text - ); - } - - #[tokio::test] - async fn should_transcribe_audio_without_streaming_with_temperature() { - let mut client = setup_audio_client().await; - client.language("en").temperature(0.0); - - let response = client - .transcribe(&audio_file()) - .await - .expect("transcribe with temperature failed"); - - assert!( - response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), - "Transcription should contain expected text, got: {}", - response.text - ); - } - - // ── Streaming transcription ────────────────────────────────────────── - - #[tokio::test] - async fn should_transcribe_audio_with_streaming() { - let client = setup_audio_client().await; - let mut full_text = String::new(); - - let mut stream = client - .transcribe_streaming(&audio_file()) - .await - .expect("transcribe_streaming setup failed"); - - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - full_text.push_str(&chunk.text); - } - stream.close().await.expect("stream close failed"); - - assert!( - full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), - "Streamed transcription should contain expected text, got: {full_text}" - ); - } - - #[tokio::test] - async fn should_transcribe_audio_with_streaming_with_temperature() { - let mut client = setup_audio_client().await; - client.language("en").temperature(0.0); - - let mut full_text = String::new(); - - let mut stream = client - .transcribe_streaming(&audio_file()) - .await - .expect("transcribe_streaming with temperature setup failed"); - - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - full_text.push_str(&chunk.text); - } - stream.close().await.expect("stream close failed"); - - assert!( - full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), - "Streamed transcription should contain expected text, got: {full_text}" - ); - } - - // ── Validation: empty file path ────────────────────────────────────── - - #[tokio::test] - async fn should_throw_when_transcribing_with_empty_audio_file_path() { - let client = setup_audio_client().await; - let result = client.transcribe("").await; - assert!(result.is_err(), "Expected error for empty audio file path"); - } - - #[tokio::test] - async fn should_throw_when_transcribing_streaming_with_empty_audio_file_path() { - let client = setup_audio_client().await; - let result = client.transcribe_streaming("").await; - assert!( - result.is_err(), - "Expected error for empty audio file path in streaming" - ); - } -} diff --git a/sdk_v2/rust/tests/catalog_test.rs b/sdk_v2/rust/tests/catalog_test.rs deleted file mode 100644 index 9f6b762e..00000000 --- a/sdk_v2/rust/tests/catalog_test.rs +++ /dev/null @@ -1,137 +0,0 @@ -//! Integration tests for the [`Catalog`] API. -//! -//! Mirrors `catalog.test.ts` from the JavaScript SDK. - -mod common; - -use foundry_local_sdk::Catalog; - -mod tests { - use super::*; - - // ── Helpers ────────────────────────────────────────────────────────── - - fn catalog() -> &'static Catalog { - common::get_test_manager().catalog() - } - - // ── Basic catalogue access ─────────────────────────────────────────── - - /// The catalog should expose a non-empty name after initialisation. - #[test] - fn should_initialize_with_catalog_name() { - let cat = catalog(); - let name = cat.name(); - assert!(!name.is_empty(), "Catalog name must not be empty"); - } - - /// `list_models()` should return at least one model and the test model - /// should be present among them. - #[tokio::test] - async fn should_list_models() { - let cat = catalog(); - let models = cat.get_models().await.expect("get_models failed"); - - assert!( - !models.is_empty(), - "Expected at least one model in the catalog" - ); - - let found = models.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); - assert!( - found, - "Test model '{}' not found in catalog", - common::TEST_MODEL_ALIAS - ); - } - - /// `get_model()` with a valid alias should return the corresponding model. - #[tokio::test] - async fn should_get_model_by_alias() { - let cat = catalog(); - let model = cat - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); - } - - // ── Validation: empty / invalid alias ──────────────────────────────── - - /// `get_model("")` should return an error containing - /// "Model alias must be a non-empty string". - #[tokio::test] - async fn should_throw_when_getting_model_with_empty_alias() { - let cat = catalog(); - let result = cat.get_model("").await; - assert!(result.is_err(), "Expected error for empty alias"); - - let err_msg = result.unwrap_err().to_string(); - assert!( - err_msg.contains("Model alias must be a non-empty string"), - "Unexpected error message: {err_msg}" - ); - } - - /// An unknown alias should produce an error mentioning "not found" and - /// listing available models. - #[tokio::test] - async fn should_throw_when_getting_model_with_unknown_alias() { - let cat = catalog(); - let result = cat.get_model("unknown-nonexistent-model-alias").await; - assert!(result.is_err(), "Expected error for unknown alias"); - - let err_msg = result.unwrap_err().to_string(); - assert!( - err_msg.contains("not found"), - "Error should mention 'not found': {err_msg}" - ); - assert!( - err_msg.contains("Available models"), - "Error should list available models: {err_msg}" - ); - } - - // ── Cached models ──────────────────────────────────────────────────── - - /// `get_cached_models()` should return at least one model and the test - /// model should be cached. - #[tokio::test] - async fn should_get_cached_models() { - let cat = catalog(); - let cached = cat - .get_cached_models() - .await - .expect("get_cached_models failed"); - - assert!(!cached.is_empty(), "Expected at least one cached model"); - - let found = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); - assert!( - found, - "Test model '{}' should be in the cached models list", - common::TEST_MODEL_ALIAS - ); - } - - // ── Model variant validation ───────────────────────────────────────── - - /// `get_model_variant("")` should return a validation error. - #[tokio::test] - async fn should_throw_when_getting_model_variant_with_empty_id() { - let cat = catalog(); - let result = cat.get_model_variant("").await; - assert!(result.is_err(), "Expected error for empty variant ID"); - } - - /// `get_model_variant()` with an unknown ID should return an error. - #[tokio::test] - async fn should_throw_when_getting_model_variant_with_unknown_id() { - let cat = catalog(); - let result = cat - .get_model_variant("unknown-nonexistent-variant-id") - .await; - assert!(result.is_err(), "Expected error for unknown variant ID"); - } -} diff --git a/sdk_v2/rust/tests/chat_client_test.rs b/sdk_v2/rust/tests/chat_client_test.rs deleted file mode 100644 index 81c5a2ea..00000000 --- a/sdk_v2/rust/tests/chat_client_test.rs +++ /dev/null @@ -1,357 +0,0 @@ -//! Integration tests for the [`ChatClient`] (non-streaming, streaming, and -//! tool-calling variants). -//! -//! Mirrors `chatClient.test.ts` from the JavaScript SDK. - -mod common; - -use foundry_local_sdk::openai::ChatClient; -use foundry_local_sdk::{ - ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, - ChatCompletionRequestSystemMessage, ChatCompletionRequestToolMessage, - ChatCompletionRequestUserMessage, ChatToolChoice, -}; -use serde_json::json; -use tokio_stream::StreamExt; - -mod tests { - use super::*; - - // ── Helpers ────────────────────────────────────────────────────────── - - /// Load the test model and return a [`ChatClient`] ready for use. - async fn setup_chat_client() -> ChatClient { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let model = catalog - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - model.load().await.expect("model.load() failed"); - - let mut client = model.create_chat_client(); - client.max_tokens(500).temperature(0.0); - client - } - - fn user_message(content: &str) -> ChatCompletionRequestMessage { - ChatCompletionRequestUserMessage::from(content).into() - } - - fn system_message(content: &str) -> ChatCompletionRequestMessage { - ChatCompletionRequestSystemMessage::from(content).into() - } - - fn assistant_message(content: &str) -> ChatCompletionRequestMessage { - serde_json::from_value(json!({ "role": "assistant", "content": content })) - .expect("failed to construct assistant message") - } - - // ── Non-streaming completion ───────────────────────────────────────── - - #[tokio::test] - async fn should_perform_chat_completion() { - let client = setup_chat_client().await; - let messages = vec![ - system_message("You are a helpful math assistant. Respond with just the answer."), - user_message("What is 7*6?"), - ]; - - let response = client - .complete_chat(&messages, None) - .await - .expect("complete_chat failed"); - let content = response - .choices - .first() - .and_then(|c| c.message.content.as_deref()) - .unwrap_or(""); - - assert!( - content.contains("42"), - "Expected response to contain '42', got: {content}" - ); - } - - // ── Streaming completion ───────────────────────────────────────────── - - #[tokio::test] - async fn should_perform_streaming_chat_completion() { - let client = setup_chat_client().await; - let mut messages = vec![ - system_message("You are a helpful math assistant. Respond with just the answer."), - user_message("What is 7*6?"), - ]; - - // First turn — expect "42" - let mut first_result = String::new(); - let mut stream = client - .complete_streaming_chat(&messages, None) - .await - .expect("streaming chat (first turn) setup failed"); - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - if let Some(choice) = chunk.choices.first() { - if let Some(ref content) = choice.delta.content { - first_result.push_str(content); - } - } - } - stream.close().await.expect("stream close failed"); - - assert!( - first_result.contains("42"), - "First turn should contain '42', got: {first_result}" - ); - - // Follow-up turn — expect "67" - messages.push(assistant_message(&first_result)); - messages.push(user_message("Now add 25 to that result.")); - - let mut second_result = String::new(); - let mut stream = client - .complete_streaming_chat(&messages, None) - .await - .expect("streaming chat (follow-up) setup failed"); - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - if let Some(choice) = chunk.choices.first() { - if let Some(ref content) = choice.delta.content { - second_result.push_str(content); - } - } - } - stream.close().await.expect("stream close failed"); - - assert!( - second_result.contains("67"), - "Follow-up should contain '67', got: {second_result}" - ); - } - - // ── Validation: empty / invalid messages ───────────────────────────── - - #[tokio::test] - async fn should_throw_when_completing_chat_with_empty_messages() { - let client = setup_chat_client().await; - let messages: Vec = vec![]; - - let result = client.complete_chat(&messages, None).await; - assert!(result.is_err(), "Expected error for empty messages"); - } - - #[tokio::test] - async fn should_throw_when_completing_streaming_chat_with_empty_messages() { - let client = setup_chat_client().await; - let messages: Vec = vec![]; - - let result = client.complete_streaming_chat(&messages, None).await; - assert!( - result.is_err(), - "Expected error for empty messages in streaming" - ); - } - - #[tokio::test] - async fn should_throw_when_completing_streaming_chat_with_invalid_callback() { - let client = setup_chat_client().await; - let messages: Vec = vec![]; - - let result = client.complete_streaming_chat(&messages, None).await; - assert!(result.is_err(), "Expected error even with empty messages"); - } - - // ── Tool calling (non-streaming) ───────────────────────────────────── - - #[tokio::test] - async fn should_perform_tool_calling_chat_completion_non_streaming() { - let mut client = setup_chat_client().await; - client.tool_choice(ChatToolChoice::Required); - - let tools = vec![common::get_multiply_tool()]; - let mut messages = vec![ - system_message("You are a math assistant. Use the multiply tool to answer."), - user_message("What is 6 times 7?"), - ]; - - // Step 1 — the model should request the multiply tool. - let response = client - .complete_chat(&messages, Some(&tools)) - .await - .expect("complete_chat with tools failed"); - - let choice = response - .choices - .first() - .expect("Expected at least one choice"); - let tool_calls = choice - .message - .tool_calls - .as_ref() - .expect("Expected tool_calls"); - assert!( - !tool_calls.is_empty(), - "Expected at least one tool call in the response" - ); - - let tool_call = match &tool_calls[0] { - ChatCompletionMessageToolCalls::Function(tc) => tc, - _ => panic!("Expected a function tool call"), - }; - assert_eq!( - tool_call.function.name, "multiply", - "Expected tool call to 'multiply'" - ); - - // Parse arguments and compute the result. - let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments) - .expect("Failed to parse tool call arguments"); - let a = args["a"].as_f64().unwrap_or(0.0); - let b = args["b"].as_f64().unwrap_or(0.0); - let product = (a * b) as i64; - - // Step 2 — feed the tool result back and get the final answer. - let tool_call_id = &tool_call.id; - let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": tool_call_id, - "type": "function", - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - } - }] - })) - .expect("failed to construct assistant message"); - messages.push(assistant_msg); - messages.push( - ChatCompletionRequestToolMessage { - content: product.to_string().into(), - tool_call_id: tool_call_id.clone(), - } - .into(), - ); - - // Switch to auto so the model can answer freely. - client.tool_choice(ChatToolChoice::Auto); - - let final_response = client - .complete_chat(&messages, Some(&tools)) - .await - .expect("follow-up complete_chat with tools failed"); - let content = final_response - .choices - .first() - .and_then(|c| c.message.content.as_deref()) - .unwrap_or(""); - - assert!( - content.contains("42"), - "Final answer should contain '42', got: {content}" - ); - } - - // ── Tool calling (streaming) ───────────────────────────────────────── - - #[tokio::test] - async fn should_perform_tool_calling_chat_completion_streaming() { - let mut client = setup_chat_client().await; - client.tool_choice(ChatToolChoice::Required); - - let tools = vec![common::get_multiply_tool()]; - let mut messages = vec![ - system_message("You are a math assistant. Use the multiply tool to answer."), - user_message("What is 6 times 7?"), - ]; - - // Step 1 — collect streaming tool call chunks. - let mut tool_call_name = String::new(); - let mut tool_call_args = String::new(); - let mut tool_call_id = String::new(); - - let mut stream = client - .complete_streaming_chat(&messages, Some(&tools)) - .await - .expect("streaming tool call setup failed"); - - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - if let Some(choice) = chunk.choices.first() { - if let Some(ref tool_calls) = choice.delta.tool_calls { - for call in tool_calls { - if let Some(ref func) = call.function { - if let Some(ref name) = func.name { - tool_call_name.push_str(name); - } - if let Some(ref args) = func.arguments { - tool_call_args.push_str(args); - } - } - if let Some(ref id) = call.id { - tool_call_id = id.clone(); - } - } - } - } - } - stream.close().await.expect("stream close failed"); - - assert_eq!( - tool_call_name, "multiply", - "Expected streamed tool call to 'multiply'" - ); - - // Parse arguments and compute. - let args: serde_json::Value = - serde_json::from_str(&tool_call_args).unwrap_or_else(|_| json!({})); - let a = args["a"].as_f64().unwrap_or(0.0); - let b = args["b"].as_f64().unwrap_or(0.0); - let product = (a * b) as i64; - - // Step 2 — feed the result back and stream the final answer. - let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ - "role": "assistant", - "tool_calls": [{ - "id": tool_call_id, - "type": "function", - "function": { - "name": tool_call_name, - "arguments": tool_call_args - } - }] - })) - .expect("failed to construct assistant message"); - messages.push(assistant_msg); - messages.push( - ChatCompletionRequestToolMessage { - content: product.to_string().into(), - tool_call_id: tool_call_id.clone(), - } - .into(), - ); - - client.tool_choice(ChatToolChoice::Auto); - - let mut final_result = String::new(); - let mut stream = client - .complete_streaming_chat(&messages, Some(&tools)) - .await - .expect("streaming follow-up setup failed"); - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - if let Some(choice) = chunk.choices.first() { - if let Some(ref content) = choice.delta.content { - final_result.push_str(content); - } - } - } - stream.close().await.expect("stream close failed"); - - assert!( - final_result.contains("42"), - "Streamed final answer should contain '42', got: {final_result}" - ); - } -} diff --git a/sdk_v2/rust/tests/integration.rs b/sdk_v2/rust/tests/integration.rs new file mode 100644 index 00000000..07383df0 --- /dev/null +++ b/sdk_v2/rust/tests/integration.rs @@ -0,0 +1,728 @@ +//! Single integration test binary for the Foundry Local Rust SDK. +//! +//! All test modules are compiled into one binary so the native core is only +//! initialised once (via the `OnceLock` singleton in `FoundryLocalManager`). +//! Running them as separate binaries causes "already initialized" errors +//! because the .NET native runtime retains state across process-level +//! library loads. + +mod common; + +mod manager_tests { + use super::common; + use foundry_local_sdk::FoundryLocalManager; + + #[test] + fn should_initialize_successfully() { + let config = common::test_config(); + let manager = FoundryLocalManager::create(config); + assert!( + manager.is_ok(), + "Manager creation failed: {:?}", + manager.err() + ); + } + + #[test] + fn should_return_catalog_with_non_empty_name() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let name = catalog.name(); + assert!(!name.is_empty(), "Catalog name should not be empty"); + } +} + +mod catalog_tests { + use super::common; + use foundry_local_sdk::Catalog; + + fn catalog() -> &'static Catalog { + common::get_test_manager().catalog() + } + + #[test] + fn should_initialize_with_catalog_name() { + let cat = catalog(); + let name = cat.name(); + assert!(!name.is_empty(), "Catalog name must not be empty"); + } + + #[tokio::test] + async fn should_list_models() { + let cat = catalog(); + let models = cat.get_models().await.expect("get_models failed"); + + assert!( + !models.is_empty(), + "Expected at least one model in the catalog" + ); + + let found = models.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + found, + "Test model '{}' not found in catalog", + common::TEST_MODEL_ALIAS + ); + } + + #[tokio::test] + async fn should_get_model_by_alias() { + let cat = catalog(); + let model = cat + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); + } + + #[tokio::test] + async fn should_throw_when_getting_model_with_empty_alias() { + let cat = catalog(); + let result = cat.get_model("").await; + assert!(result.is_err(), "Expected error for empty alias"); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Model alias must be a non-empty string"), + "Unexpected error message: {err_msg}" + ); + } + + #[tokio::test] + async fn should_throw_when_getting_model_with_unknown_alias() { + let cat = catalog(); + let result = cat.get_model("unknown-nonexistent-model-alias").await; + assert!(result.is_err(), "Expected error for unknown alias"); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("not found"), + "Error should mention 'not found': {err_msg}" + ); + assert!( + err_msg.contains("Available models"), + "Error should list available models: {err_msg}" + ); + } + + #[tokio::test] + async fn should_get_cached_models() { + let cat = catalog(); + let cached = cat + .get_cached_models() + .await + .expect("get_cached_models failed"); + + assert!(!cached.is_empty(), "Expected at least one cached model"); + + let found = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + found, + "Test model '{}' should be in the cached models list", + common::TEST_MODEL_ALIAS + ); + } + + #[tokio::test] + async fn should_throw_when_getting_model_variant_with_empty_id() { + let cat = catalog(); + let result = cat.get_model_variant("").await; + assert!(result.is_err(), "Expected error for empty variant ID"); + } + + #[tokio::test] + async fn should_throw_when_getting_model_variant_with_unknown_id() { + let cat = catalog(); + let result = cat + .get_model_variant("unknown-nonexistent-variant-id") + .await; + assert!(result.is_err(), "Expected error for unknown variant ID"); + } +} + +mod model_tests { + use super::common; + + #[tokio::test] + async fn should_verify_cached_models_from_test_data_shared() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let cached = catalog + .get_cached_models() + .await + .expect("get_cached_models failed"); + + let has_qwen = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + has_qwen, + "'{}' should be present in cached models", + common::TEST_MODEL_ALIAS + ); + + let has_whisper = cached + .iter() + .any(|m| m.alias() == common::WHISPER_MODEL_ALIAS); + assert!( + has_whisper, + "'{}' should be present in cached models", + common::WHISPER_MODEL_ALIAS + ); + } + + #[tokio::test] + async fn should_load_and_unload_model() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + model.load().await.expect("model.load() failed"); + assert!( + model.is_loaded().await.expect("is_loaded check failed"), + "Model should be loaded after load()" + ); + + model.unload().await.expect("model.unload() failed"); + assert!( + !model.is_loaded().await.expect("is_loaded check failed"), + "Model should not be loaded after unload()" + ); + } +} + +mod model_load_manager_tests { + use super::common; + + async fn get_test_model() -> foundry_local_sdk::Model { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed") + } + + #[tokio::test] + async fn should_load_model_using_core_interop() { + let model = get_test_model().await; + model.load().await.expect("model.load() failed"); + } + + #[tokio::test] + async fn should_unload_model_using_core_interop() { + let model = get_test_model().await; + model.load().await.expect("model.load() failed"); + model.unload().await.expect("model.unload() failed"); + } + + #[tokio::test] + async fn should_list_loaded_models_using_core_interop() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + + let loaded = catalog + .get_loaded_models() + .await + .expect("catalog.get_loaded_models() failed"); + + let _ = loaded; + } + + #[tokio::test] + #[ignore = "requires running web service"] + async fn should_load_and_unload_model_using_external_service() { + if common::is_running_in_ci() { + eprintln!("Skipping external-service test in CI"); + return; + } + + let manager = common::get_test_manager(); + let model = get_test_model().await; + + let _urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + + model + .load() + .await + .expect("load via external service failed"); + + model + .unload() + .await + .expect("unload via external service failed"); + } + + #[tokio::test] + #[ignore = "requires running web service"] + async fn should_list_loaded_models_using_external_service() { + if common::is_running_in_ci() { + eprintln!("Skipping external-service test in CI"); + return; + } + + let manager = common::get_test_manager(); + + let _urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + + let catalog = manager.catalog(); + let loaded = catalog + .get_loaded_models() + .await + .expect("get_loaded_models via external service failed"); + + let _ = loaded; + } +} + +mod chat_client_tests { + use super::common; + use foundry_local_sdk::openai::ChatClient; + use foundry_local_sdk::{ + ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestToolMessage, + ChatCompletionRequestUserMessage, ChatToolChoice, + }; + use serde_json::json; + use tokio_stream::StreamExt; + + async fn setup_chat_client() -> ChatClient { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let mut client = model.create_chat_client(); + client.max_tokens(500).temperature(0.0); + client + } + + fn user_message(content: &str) -> ChatCompletionRequestMessage { + ChatCompletionRequestUserMessage::from(content).into() + } + + fn system_message(content: &str) -> ChatCompletionRequestMessage { + ChatCompletionRequestSystemMessage::from(content).into() + } + + fn assistant_message(content: &str) -> ChatCompletionRequestMessage { + serde_json::from_value(json!({ "role": "assistant", "content": content })) + .expect("failed to construct assistant message") + } + + #[tokio::test] + async fn should_perform_chat_completion() { + let client = setup_chat_client().await; + let messages = vec![ + system_message("You are a helpful math assistant. Respond with just the answer."), + user_message("What is 7*6?"), + ]; + + let response = client + .complete_chat(&messages, None) + .await + .expect("complete_chat failed"); + let content = response + .choices + .first() + .and_then(|c| c.message.content.as_deref()) + .unwrap_or(""); + + assert!( + content.contains("42"), + "Expected response to contain '42', got: {content}" + ); + } + + #[tokio::test] + async fn should_perform_streaming_chat_completion() { + let client = setup_chat_client().await; + let mut messages = vec![ + system_message("You are a helpful math assistant. Respond with just the answer."), + user_message("What is 7*6?"), + ]; + + let mut first_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, None) + .await + .expect("streaming chat (first turn) setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + first_result.push_str(content); + } + } + } + stream.close().await.expect("stream close failed"); + + assert!( + first_result.contains("42"), + "First turn should contain '42', got: {first_result}" + ); + + messages.push(assistant_message(&first_result)); + messages.push(user_message("Now add 25 to that result.")); + + let mut second_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, None) + .await + .expect("streaming chat (follow-up) setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + second_result.push_str(content); + } + } + } + stream.close().await.expect("stream close failed"); + + assert!( + second_result.contains("67"), + "Follow-up should contain '67', got: {second_result}" + ); + } + + #[tokio::test] + async fn should_throw_when_completing_chat_with_empty_messages() { + let client = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_chat(&messages, None).await; + assert!(result.is_err(), "Expected error for empty messages"); + } + + #[tokio::test] + async fn should_throw_when_completing_streaming_chat_with_empty_messages() { + let client = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_streaming_chat(&messages, None).await; + assert!( + result.is_err(), + "Expected error for empty messages in streaming" + ); + } + + #[tokio::test] + async fn should_throw_when_completing_streaming_chat_with_invalid_callback() { + let client = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_streaming_chat(&messages, None).await; + assert!(result.is_err(), "Expected error even with empty messages"); + } + + #[tokio::test] + async fn should_perform_tool_calling_chat_completion_non_streaming() { + let mut client = setup_chat_client().await; + client.tool_choice(ChatToolChoice::Required); + + let tools = vec![common::get_multiply_tool()]; + let mut messages = vec![ + system_message("You are a math assistant. Use the multiply tool to answer."), + user_message("What is 6 times 7?"), + ]; + + let response = client + .complete_chat(&messages, Some(&tools)) + .await + .expect("complete_chat with tools failed"); + + let choice = response + .choices + .first() + .expect("Expected at least one choice"); + let tool_calls = choice + .message + .tool_calls + .as_ref() + .expect("Expected tool_calls"); + assert!( + !tool_calls.is_empty(), + "Expected at least one tool call in the response" + ); + + let tool_call = match &tool_calls[0] { + ChatCompletionMessageToolCalls::Function(tc) => tc, + _ => panic!("Expected a function tool call"), + }; + assert_eq!( + tool_call.function.name, "multiply", + "Expected tool call to 'multiply'" + ); + + let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments) + .expect("Failed to parse tool call arguments"); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let product = (a * b) as i64; + + let tool_call_id = &tool_call.id; + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + }] + })) + .expect("failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: product.to_string().into(), + tool_call_id: tool_call_id.clone(), + } + .into(), + ); + + client.tool_choice(ChatToolChoice::Auto); + + let final_response = client + .complete_chat(&messages, Some(&tools)) + .await + .expect("follow-up complete_chat with tools failed"); + let content = final_response + .choices + .first() + .and_then(|c| c.message.content.as_deref()) + .unwrap_or(""); + + assert!( + content.contains("42"), + "Final answer should contain '42', got: {content}" + ); + } + + #[tokio::test] + async fn should_perform_tool_calling_chat_completion_streaming() { + let mut client = setup_chat_client().await; + client.tool_choice(ChatToolChoice::Required); + + let tools = vec![common::get_multiply_tool()]; + let mut messages = vec![ + system_message("You are a math assistant. Use the multiply tool to answer."), + user_message("What is 6 times 7?"), + ]; + + let mut tool_call_name = String::new(); + let mut tool_call_args = String::new(); + let mut tool_call_id = String::new(); + + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await + .expect("streaming tool call setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref tool_calls) = choice.delta.tool_calls { + for call in tool_calls { + if let Some(ref func) = call.function { + if let Some(ref name) = func.name { + tool_call_name.push_str(name); + } + if let Some(ref args) = func.arguments { + tool_call_args.push_str(args); + } + } + if let Some(ref id) = call.id { + tool_call_id = id.clone(); + } + } + } + } + } + stream.close().await.expect("stream close failed"); + + assert_eq!( + tool_call_name, "multiply", + "Expected streamed tool call to 'multiply'" + ); + + let args: serde_json::Value = + serde_json::from_str(&tool_call_args).unwrap_or_else(|_| json!({})); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let product = (a * b) as i64; + + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call_name, + "arguments": tool_call_args + } + }] + })) + .expect("failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: product.to_string().into(), + tool_call_id: tool_call_id.clone(), + } + .into(), + ); + + client.tool_choice(ChatToolChoice::Auto); + + let mut final_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await + .expect("streaming follow-up setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + final_result.push_str(content); + } + } + } + stream.close().await.expect("stream close failed"); + + assert!( + final_result.contains("42"), + "Streamed final answer should contain '42', got: {final_result}" + ); + } +} + +mod audio_client_tests { + use super::common; + use foundry_local_sdk::openai::AudioClient; + use tokio_stream::StreamExt; + + async fn setup_audio_client() -> AudioClient { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::WHISPER_MODEL_ALIAS) + .await + .expect("get_model(whisper-tiny) failed"); + model.load().await.expect("model.load() failed"); + model.create_audio_client() + } + + fn audio_file() -> String { + common::get_audio_file_path().to_string_lossy().into_owned() + } + + #[tokio::test] + async fn should_transcribe_audio_without_streaming() { + let client = setup_audio_client().await; + let response = client + .transcribe(&audio_file()) + .await + .expect("transcribe failed"); + + assert!( + response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Transcription should contain expected text, got: {}", + response.text + ); + } + + #[tokio::test] + async fn should_transcribe_audio_without_streaming_with_temperature() { + let mut client = setup_audio_client().await; + client.language("en").temperature(0.0); + + let response = client + .transcribe(&audio_file()) + .await + .expect("transcribe with temperature failed"); + + assert!( + response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Transcription should contain expected text, got: {}", + response.text + ); + } + + #[tokio::test] + async fn should_transcribe_audio_with_streaming() { + let client = setup_audio_client().await; + let mut full_text = String::new(); + + let mut stream = client + .transcribe_streaming(&audio_file()) + .await + .expect("transcribe_streaming setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + full_text.push_str(&chunk.text); + } + stream.close().await.expect("stream close failed"); + + assert!( + full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Streamed transcription should contain expected text, got: {full_text}" + ); + } + + #[tokio::test] + async fn should_transcribe_audio_with_streaming_with_temperature() { + let mut client = setup_audio_client().await; + client.language("en").temperature(0.0); + + let mut full_text = String::new(); + + let mut stream = client + .transcribe_streaming(&audio_file()) + .await + .expect("transcribe_streaming with temperature setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + full_text.push_str(&chunk.text); + } + stream.close().await.expect("stream close failed"); + + assert!( + full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Streamed transcription should contain expected text, got: {full_text}" + ); + } + + #[tokio::test] + async fn should_throw_when_transcribing_with_empty_audio_file_path() { + let client = setup_audio_client().await; + let result = client.transcribe("").await; + assert!(result.is_err(), "Expected error for empty audio file path"); + } + + #[tokio::test] + async fn should_throw_when_transcribing_streaming_with_empty_audio_file_path() { + let client = setup_audio_client().await; + let result = client.transcribe_streaming("").await; + assert!( + result.is_err(), + "Expected error for empty audio file path in streaming" + ); + } +} diff --git a/sdk_v2/rust/tests/manager_test.rs b/sdk_v2/rust/tests/manager_test.rs deleted file mode 100644 index 256ce420..00000000 --- a/sdk_v2/rust/tests/manager_test.rs +++ /dev/null @@ -1,35 +0,0 @@ -//! Integration tests for [`FoundryLocalManager`] initialisation. -//! -//! Mirrors `foundryLocalManager.test.ts` from the JavaScript SDK. - -mod common; - -use foundry_local_sdk::FoundryLocalManager; - -mod tests { - use super::*; - - // ── Initialisation ─────────────────────────────────────────────────── - - /// The manager should initialise successfully with the test configuration. - #[test] - fn should_initialize_successfully() { - let config = common::test_config(); - let manager = FoundryLocalManager::create(config); - assert!( - manager.is_ok(), - "Manager creation failed: {:?}", - manager.err() - ); - } - - /// The catalog obtained from a freshly-created manager should have a - /// non-empty name. - #[test] - fn should_return_catalog_with_non_empty_name() { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let name = catalog.name(); - assert!(!name.is_empty(), "Catalog name should not be empty"); - } -} diff --git a/sdk_v2/rust/tests/model_load_manager_test.rs b/sdk_v2/rust/tests/model_load_manager_test.rs deleted file mode 100644 index 3b0bf0c4..00000000 --- a/sdk_v2/rust/tests/model_load_manager_test.rs +++ /dev/null @@ -1,137 +0,0 @@ -//! Integration tests for model loading and unloading through the public API. -//! -//! Mirrors `modelLoadManager.test.ts` from the JavaScript SDK. -//! -//! **Note:** In the JavaScript SDK these tests access the private -//! `coreInterop` property via an `as any` cast. In Rust, `CoreInterop` and -//! `ModelLoadManager::new` are `pub(crate)` and cannot be reached from -//! integration tests. Instead, we exercise model loading and unloading -//! through the public [`Model`] and [`Catalog`] APIs which internally -//! delegate to `ModelLoadManager`. - -mod common; - -mod tests { - use super::*; - - // ── Helpers ────────────────────────────────────────────────────────── - - /// Return the test model from the catalog. - async fn get_test_model() -> foundry_local_sdk::Model { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - catalog - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed") - } - - // ── Core-interop path ──────────────────────────────────────────────── - - /// Loading a model via the core interop (in-process) path should succeed. - /// - /// Timeout note: the JS test uses a 120 s timeout. - #[tokio::test] - async fn should_load_model_using_core_interop() { - let model = get_test_model().await; - - model.load().await.expect("model.load() failed"); - } - - /// Unloading a previously loaded model via the core interop path should - /// succeed. - /// - /// Timeout note: the JS test uses a 120 s timeout. - #[tokio::test] - async fn should_unload_model_using_core_interop() { - let model = get_test_model().await; - - // Ensure the model is loaded first. - model.load().await.expect("model.load() failed"); - - model.unload().await.expect("model.unload() failed"); - } - - /// Listing loaded models via the core interop path should return a - /// collection (possibly empty, but the call itself must succeed). - #[tokio::test] - async fn should_list_loaded_models_using_core_interop() { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - - let loaded = catalog - .get_loaded_models() - .await - .expect("catalog.get_loaded_models() failed"); - - // The result should be a valid (possibly empty) list of model IDs. - // (Vec is always valid; just ensure the call succeeded.) - let _ = loaded; - } - - // ── External web-service path ──────────────────────────────────────── - - /// Loading and unloading a model through the external HTTP service should - /// succeed. - /// - /// This test is skipped in CI because it requires a running web service. - /// - /// Timeout note: the JS test uses a 120 s timeout. - #[tokio::test] - #[ignore = "requires native Foundry Local library and running web service"] - async fn should_load_and_unload_model_using_external_service() { - if common::is_running_in_ci() { - eprintln!("Skipping external-service test in CI"); - return; - } - - let manager = common::get_test_manager(); - let model = get_test_model().await; - - // Start the web service so we can test the HTTP path. - let _urls = manager - .start_web_service() - .await - .expect("start_web_service failed"); - - // Load via the model API (delegates to ModelLoadManager internally). - model - .load() - .await - .expect("load via external service failed"); - - // Unload - model - .unload() - .await - .expect("unload via external service failed"); - } - - /// Listing loaded models through the external HTTP service should succeed. - /// - /// This test is skipped in CI because it requires a running web service. - #[tokio::test] - #[ignore = "requires native Foundry Local library and running web service"] - async fn should_list_loaded_models_using_external_service() { - if common::is_running_in_ci() { - eprintln!("Skipping external-service test in CI"); - return; - } - - let manager = common::get_test_manager(); - - let _urls = manager - .start_web_service() - .await - .expect("start_web_service failed"); - - let catalog = manager.catalog(); - let loaded = catalog - .get_loaded_models() - .await - .expect("get_loaded_models via external service failed"); - - // Vec is always a valid list; just ensure the call succeeded. - let _ = loaded; - } -} diff --git a/sdk_v2/rust/tests/model_test.rs b/sdk_v2/rust/tests/model_test.rs deleted file mode 100644 index 8eedc5fb..00000000 --- a/sdk_v2/rust/tests/model_test.rs +++ /dev/null @@ -1,72 +0,0 @@ -//! Integration tests for the [`Model`] lifecycle (cache verification, -//! load / unload). -//! -//! Mirrors `model.test.ts` from the JavaScript SDK. - -mod common; - -mod tests { - use super::*; - - // ── Cache verification ─────────────────────────────────────────────── - - /// The shared test-data directory should contain pre-cached models for - /// both `qwen2.5-0.5b` and `whisper-tiny`. - #[tokio::test] - async fn should_verify_cached_models_from_test_data_shared() { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let cached = catalog - .get_cached_models() - .await - .expect("get_cached_models failed"); - - // qwen2.5-0.5b must be cached - let has_qwen = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); - assert!( - has_qwen, - "'{}' should be present in cached models", - common::TEST_MODEL_ALIAS - ); - - // whisper-tiny must be cached - let has_whisper = cached - .iter() - .any(|m| m.alias() == common::WHISPER_MODEL_ALIAS); - assert!( - has_whisper, - "'{}' should be present in cached models", - common::WHISPER_MODEL_ALIAS - ); - } - - // ── Load / unload lifecycle ────────────────────────────────────────── - - /// Loading a model should mark it as loaded; unloading should mark it as - /// not loaded. - /// - /// Timeout note: the JS test uses a 120 s timeout for this test. - #[tokio::test] - async fn should_load_and_unload_model() { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let model = catalog - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - // Load - model.load().await.expect("model.load() failed"); - assert!( - model.is_loaded().await.expect("is_loaded check failed"), - "Model should be loaded after load()" - ); - - // Unload - model.unload().await.expect("model.unload() failed"); - assert!( - !model.is_loaded().await.expect("is_loaded check failed"), - "Model should not be loaded after unload()" - ); - } -} From e45b40513febef7528ae90e304456f807036bec9 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 10:26:11 +0000 Subject: [PATCH 09/25] fix: deterministic audio tests and correct error assertion - Audio tests without explicit temperature produced non-deterministic transcription output. Set language('en') and temperature(0.0) on all audio test variants for consistent results across platforms. - Fix catalog unknown alias assertion to match actual error message ('Unknown model alias' not 'not found'). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/tests/integration.rs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/sdk_v2/rust/tests/integration.rs b/sdk_v2/rust/tests/integration.rs index 07383df0..7da6e95e 100644 --- a/sdk_v2/rust/tests/integration.rs +++ b/sdk_v2/rust/tests/integration.rs @@ -97,11 +97,11 @@ mod catalog_tests { let err_msg = result.unwrap_err().to_string(); assert!( - err_msg.contains("not found"), - "Error should mention 'not found': {err_msg}" + err_msg.contains("Unknown model alias"), + "Error should mention unknown alias: {err_msg}" ); assert!( - err_msg.contains("Available models"), + err_msg.contains("Available"), "Error should list available models: {err_msg}" ); } @@ -633,7 +633,8 @@ mod audio_client_tests { #[tokio::test] async fn should_transcribe_audio_without_streaming() { - let client = setup_audio_client().await; + let mut client = setup_audio_client().await; + client.language("en").temperature(0.0); let response = client .transcribe(&audio_file()) .await @@ -665,7 +666,8 @@ mod audio_client_tests { #[tokio::test] async fn should_transcribe_audio_with_streaming() { - let client = setup_audio_client().await; + let mut client = setup_audio_client().await; + client.language("en").temperature(0.0); let mut full_text = String::new(); let mut stream = client From f7e5b148b10bbcaf3615a4a068b1458e7c51d497 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 10:33:53 +0000 Subject: [PATCH 10/25] fix: link ole32 on Windows for CoTaskMemFree The free_native_buffer function uses CoTaskMemFree on Windows, which lives in ole32.lib. Without this link directive, the integration test binary fails to link with LNK2019 unresolved external symbol. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/build.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sdk_v2/rust/build.rs b/sdk_v2/rust/build.rs index d51c5805..c12f6712 100644 --- a/sdk_v2/rust/build.rs +++ b/sdk_v2/rust/build.rs @@ -278,6 +278,8 @@ fn main() { println!("cargo:warning=Native libraries already present in OUT_DIR, skipping download."); println!("cargo:rustc-link-search=native={}", out_dir.display()); println!("cargo:rustc-env=FOUNDRY_NATIVE_DIR={}", out_dir.display()); + #[cfg(windows)] + println!("cargo:rustc-link-lib=ole32"); return; } @@ -295,4 +297,8 @@ fn main() { println!("cargo:rustc-link-search=native={}", out_dir.display()); println!("cargo:rustc-env=FOUNDRY_NATIVE_DIR={}", out_dir.display()); + + // CoTaskMemFree (used to free native-allocated buffers) lives in ole32.lib on Windows. + #[cfg(windows)] + println!("cargo:rustc-link-lib=ole32"); } From 5dfb2bb52ddf9faf4e7c35ce72a5ea531d11efbf Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 10:42:43 +0000 Subject: [PATCH 11/25] fix: unload models in integration tests to prevent OGA resource leaks All chat and audio tests now properly unload models after use, matching the try/finally pattern in the JS SDK tests. Setup helpers return the Model alongside the client so tests can call model.unload() at the end. Also fixes model_load_manager should_load_model test to unload after loading. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/tests/integration.rs | 61 +++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/sdk_v2/rust/tests/integration.rs b/sdk_v2/rust/tests/integration.rs index 7da6e95e..6caf5a20 100644 --- a/sdk_v2/rust/tests/integration.rs +++ b/sdk_v2/rust/tests/integration.rs @@ -209,6 +209,7 @@ mod model_load_manager_tests { async fn should_load_model_using_core_interop() { let model = get_test_model().await; model.load().await.expect("model.load() failed"); + model.unload().await.expect("model.unload() failed"); } #[tokio::test] @@ -294,7 +295,7 @@ mod chat_client_tests { use serde_json::json; use tokio_stream::StreamExt; - async fn setup_chat_client() -> ChatClient { + async fn setup_chat_client() -> (ChatClient, foundry_local_sdk::Model) { let manager = common::get_test_manager(); let catalog = manager.catalog(); let model = catalog @@ -305,7 +306,7 @@ mod chat_client_tests { let mut client = model.create_chat_client(); client.max_tokens(500).temperature(0.0); - client + (client, model) } fn user_message(content: &str) -> ChatCompletionRequestMessage { @@ -323,7 +324,7 @@ mod chat_client_tests { #[tokio::test] async fn should_perform_chat_completion() { - let client = setup_chat_client().await; + let (client, model) = setup_chat_client().await; let messages = vec![ system_message("You are a helpful math assistant. Respond with just the answer."), user_message("What is 7*6?"), @@ -343,11 +344,13 @@ mod chat_client_tests { content.contains("42"), "Expected response to contain '42', got: {content}" ); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_perform_streaming_chat_completion() { - let client = setup_chat_client().await; + let (client, model) = setup_chat_client().await; let mut messages = vec![ system_message("You are a helpful math assistant. Respond with just the answer."), user_message("What is 7*6?"), @@ -395,20 +398,24 @@ mod chat_client_tests { second_result.contains("67"), "Follow-up should contain '67', got: {second_result}" ); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_throw_when_completing_chat_with_empty_messages() { - let client = setup_chat_client().await; + let (client, model) = setup_chat_client().await; let messages: Vec = vec![]; let result = client.complete_chat(&messages, None).await; assert!(result.is_err(), "Expected error for empty messages"); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_throw_when_completing_streaming_chat_with_empty_messages() { - let client = setup_chat_client().await; + let (client, model) = setup_chat_client().await; let messages: Vec = vec![]; let result = client.complete_streaming_chat(&messages, None).await; @@ -416,20 +423,24 @@ mod chat_client_tests { result.is_err(), "Expected error for empty messages in streaming" ); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_throw_when_completing_streaming_chat_with_invalid_callback() { - let client = setup_chat_client().await; + let (client, model) = setup_chat_client().await; let messages: Vec = vec![]; let result = client.complete_streaming_chat(&messages, None).await; assert!(result.is_err(), "Expected error even with empty messages"); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_perform_tool_calling_chat_completion_non_streaming() { - let mut client = setup_chat_client().await; + let (mut client, model) = setup_chat_client().await; client.tool_choice(ChatToolChoice::Required); let tools = vec![common::get_multiply_tool()]; @@ -511,11 +522,13 @@ mod chat_client_tests { content.contains("42"), "Final answer should contain '42', got: {content}" ); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_perform_tool_calling_chat_completion_streaming() { - let mut client = setup_chat_client().await; + let (mut client, model) = setup_chat_client().await; client.tool_choice(ChatToolChoice::Required); let tools = vec![common::get_multiply_tool()]; @@ -608,6 +621,8 @@ mod chat_client_tests { final_result.contains("42"), "Streamed final answer should contain '42', got: {final_result}" ); + + model.unload().await.expect("model.unload() failed"); } } @@ -616,7 +631,7 @@ mod audio_client_tests { use foundry_local_sdk::openai::AudioClient; use tokio_stream::StreamExt; - async fn setup_audio_client() -> AudioClient { + async fn setup_audio_client() -> (AudioClient, foundry_local_sdk::Model) { let manager = common::get_test_manager(); let catalog = manager.catalog(); let model = catalog @@ -624,7 +639,7 @@ mod audio_client_tests { .await .expect("get_model(whisper-tiny) failed"); model.load().await.expect("model.load() failed"); - model.create_audio_client() + (model.create_audio_client(), model) } fn audio_file() -> String { @@ -633,7 +648,7 @@ mod audio_client_tests { #[tokio::test] async fn should_transcribe_audio_without_streaming() { - let mut client = setup_audio_client().await; + let (mut client, model) = setup_audio_client().await; client.language("en").temperature(0.0); let response = client .transcribe(&audio_file()) @@ -645,11 +660,13 @@ mod audio_client_tests { "Transcription should contain expected text, got: {}", response.text ); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_transcribe_audio_without_streaming_with_temperature() { - let mut client = setup_audio_client().await; + let (mut client, model) = setup_audio_client().await; client.language("en").temperature(0.0); let response = client @@ -662,11 +679,13 @@ mod audio_client_tests { "Transcription should contain expected text, got: {}", response.text ); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_transcribe_audio_with_streaming() { - let mut client = setup_audio_client().await; + let (mut client, model) = setup_audio_client().await; client.language("en").temperature(0.0); let mut full_text = String::new(); @@ -685,11 +704,13 @@ mod audio_client_tests { full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), "Streamed transcription should contain expected text, got: {full_text}" ); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_transcribe_audio_with_streaming_with_temperature() { - let mut client = setup_audio_client().await; + let (mut client, model) = setup_audio_client().await; client.language("en").temperature(0.0); let mut full_text = String::new(); @@ -709,22 +730,28 @@ mod audio_client_tests { full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), "Streamed transcription should contain expected text, got: {full_text}" ); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_throw_when_transcribing_with_empty_audio_file_path() { - let client = setup_audio_client().await; + let (client, model) = setup_audio_client().await; let result = client.transcribe("").await; assert!(result.is_err(), "Expected error for empty audio file path"); + + model.unload().await.expect("model.unload() failed"); } #[tokio::test] async fn should_throw_when_transcribing_streaming_with_empty_audio_file_path() { - let client = setup_audio_client().await; + let (client, model) = setup_audio_client().await; let result = client.transcribe_streaming("").await; assert!( result.is_err(), "Expected error for empty audio file path in streaming" ); + + model.unload().await.expect("model.unload() failed"); } } From 4dd7729480d4570bd04bd8be953d9d1adfbc9cf3 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 10:57:42 +0000 Subject: [PATCH 12/25] test: add model introspection integration tests Covers previously untested Model public API: alias(), id(), variants(), selected_variant(), is_cached(), path(), select_variant() (success and error paths). Total integration tests: 38. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/tests/integration.rs | 137 +++++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) diff --git a/sdk_v2/rust/tests/integration.rs b/sdk_v2/rust/tests/integration.rs index 6caf5a20..f222a757 100644 --- a/sdk_v2/rust/tests/integration.rs +++ b/sdk_v2/rust/tests/integration.rs @@ -191,6 +191,143 @@ mod model_tests { "Model should not be loaded after unload()" ); } + + // ── Introspection ──────────────────────────────────────────────────── + + #[tokio::test] + async fn should_expose_alias() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); + } + + #[tokio::test] + async fn should_expose_non_empty_id() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + assert!( + !model.id().is_empty(), + "Model id() should be a non-empty string" + ); + } + + #[tokio::test] + async fn should_have_at_least_one_variant() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let variants = model.variants(); + assert!( + !variants.is_empty(), + "Model should have at least one variant" + ); + } + + #[tokio::test] + async fn should_have_selected_variant_matching_id() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let selected = model.selected_variant(); + assert_eq!( + selected.id(), + model.id(), + "selected_variant().id() should match model.id()" + ); + } + + #[tokio::test] + async fn should_report_cached_model_as_cached() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let cached = model.is_cached().await.expect("is_cached() should succeed"); + assert!( + cached, + "Test model '{}' should be cached (from test-data-shared)", + common::TEST_MODEL_ALIAS + ); + } + + #[tokio::test] + async fn should_return_non_empty_path_for_cached_model() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let path = model.path().await.expect("path() should succeed"); + assert!( + !path.is_empty(), + "Cached model should have a non-empty path" + ); + } + + #[tokio::test] + async fn should_select_variant_by_id() { + let manager = common::get_test_manager(); + let mut model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let first_variant_id = model.variants()[0].id().to_string(); + model + .select_variant(&first_variant_id) + .expect("select_variant should succeed"); + assert_eq!( + model.id(), + first_variant_id, + "After select_variant, id() should match the selected variant" + ); + } + + #[tokio::test] + async fn should_fail_to_select_unknown_variant() { + let manager = common::get_test_manager(); + let mut model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let result = model.select_variant("nonexistent-variant-id"); + assert!( + result.is_err(), + "select_variant with unknown ID should fail" + ); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("not found"), + "Error should mention 'not found': {err_msg}" + ); + } } mod model_load_manager_tests { From 6ce7d5e896484bbfc7f84cf4b84a90a7c8275b1f Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 11:13:38 +0000 Subject: [PATCH 13/25] test: add web service integration tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tests the full web server lifecycle: start service → load model → REST call to v1/chat/completions (non-streaming and SSE streaming) → verify response → stop service → unload model. Also tests that urls() returns the correct addresses after start. Total integration tests: 41. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/tests/integration.rs | 161 +++++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/sdk_v2/rust/tests/integration.rs b/sdk_v2/rust/tests/integration.rs index f222a757..a611d36c 100644 --- a/sdk_v2/rust/tests/integration.rs +++ b/sdk_v2/rust/tests/integration.rs @@ -892,3 +892,164 @@ mod audio_client_tests { model.unload().await.expect("model.unload() failed"); } } + +mod web_service_tests { + use super::common; + use serde_json::json; + + /// Start the web service, make a non-streaming POST to v1/chat/completions, + /// verify we get a valid response, then stop the service. + #[tokio::test] + async fn should_complete_chat_via_rest_api() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + let base_url = urls.first().expect("no URL returned").trim_end_matches('/'); + + let client = reqwest::Client::new(); + let resp = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&json!({ + "model": model.id(), + "messages": [ + { "role": "system", "content": "You are a helpful math assistant. Respond with just the answer." }, + { "role": "user", "content": "What is 7*6?" } + ], + "max_tokens": 500, + "temperature": 0.0, + "stream": false + })) + .send() + .await + .expect("HTTP request failed"); + + assert!( + resp.status().is_success(), + "Expected 2xx, got {}", + resp.status() + ); + + let body: serde_json::Value = resp.json().await.expect("failed to parse response JSON"); + let content = body + .pointer("/choices/0/message/content") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + assert!( + content.contains("42"), + "Expected response to contain '42', got: {content}" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); + model.unload().await.expect("model.unload() failed"); + } + + /// Start the web service, make a streaming POST to v1/chat/completions, + /// collect SSE chunks, verify we get a valid streamed response. + #[tokio::test] + async fn should_stream_chat_via_rest_api() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + let base_url = urls.first().expect("no URL returned").trim_end_matches('/'); + + let client = reqwest::Client::new(); + let mut response = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&json!({ + "model": model.id(), + "messages": [ + { "role": "system", "content": "You are a helpful math assistant. Respond with just the answer." }, + { "role": "user", "content": "What is 7*6?" } + ], + "max_tokens": 500, + "temperature": 0.0, + "stream": true + })) + .send() + .await + .expect("HTTP request failed"); + + assert!( + response.status().is_success(), + "Expected 2xx, got {}", + response.status() + ); + + let mut full_text = String::new(); + while let Some(chunk) = response.chunk().await.expect("chunk read failed") { + let text = String::from_utf8_lossy(&chunk); + for line in text.lines() { + let line = line.trim(); + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + break; + } + if let Ok(parsed) = serde_json::from_str::(data) { + if let Some(content) = parsed + .pointer("/choices/0/delta/content") + .and_then(|v| v.as_str()) + { + full_text.push_str(content); + } + } + } + } + } + + assert!( + full_text.contains("42"), + "Expected streamed response to contain '42', got: {full_text}" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); + model.unload().await.expect("model.unload() failed"); + } + + /// urls() should return the listening addresses after start_web_service. + #[tokio::test] + async fn should_expose_urls_after_start() { + let manager = common::get_test_manager(); + + let urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + assert!(!urls.is_empty(), "start_web_service should return URLs"); + + let cached_urls = manager.urls(); + assert_eq!( + urls, cached_urls, + "urls() should match what start_web_service returned" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); + } +} From 589777860d8b5ed54fa65cb994bd537ec264fed9 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 11:15:18 +0000 Subject: [PATCH 14/25] ci: add artifact upload for Rust SDK builds Packages the crate with cargo package and uploads the .crate file, matching the pattern used by CS (.nupkg) and JS (.tgz) builds. Also uploads flcore logs on all runs (including failures). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/build-rust-steps.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/build-rust-steps.yml b/.github/workflows/build-rust-steps.yml index 2dc67cdc..90f5e193 100644 --- a/.github/workflows/build-rust-steps.yml +++ b/.github/workflows/build-rust-steps.yml @@ -92,3 +92,19 @@ jobs: - name: Run integration tests if: ${{ inputs.run-integration-tests }} run: cargo test --test '*' ${{ env.CARGO_FEATURES }} -- --test-threads=1 + + - name: Package crate + run: cargo package ${{ env.CARGO_FEATURES }} --allow-dirty + + - name: Upload SDK artifact + uses: actions/upload-artifact@v4 + with: + name: rust-sdk-${{ inputs.platform }}${{ inputs.useWinML == true && '-winml' || '' }} + path: sdk_v2/rust/target/package/*.crate + + - name: Upload flcore logs + uses: actions/upload-artifact@v4 + if: always() + with: + name: rust-sdk-${{ inputs.platform }}${{ inputs.useWinML == true && '-winml' || '' }}-logs + path: sdk_v2/rust/logs/** From c9b5c6862eedabf1c9576897405ecfcd72c0fee7 Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 11:26:42 +0000 Subject: [PATCH 15/25] test: add verbose output logging to integration tests All tests that produce response content now println! the actual output (chat completions, transcriptions, REST responses, model metadata). Output is visible with --nocapture (added to CI integration test step). Helps diagnose failures by showing actual vs expected values. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/build-rust-steps.yml | 2 +- sdk_v2/rust/tests/integration.rs | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build-rust-steps.yml b/.github/workflows/build-rust-steps.yml index 90f5e193..b675b4b6 100644 --- a/.github/workflows/build-rust-steps.yml +++ b/.github/workflows/build-rust-steps.yml @@ -91,7 +91,7 @@ jobs: - name: Run integration tests if: ${{ inputs.run-integration-tests }} - run: cargo test --test '*' ${{ env.CARGO_FEATURES }} -- --test-threads=1 + run: cargo test --test '*' ${{ env.CARGO_FEATURES }} -- --test-threads=1 --nocapture - name: Package crate run: cargo package ${{ env.CARGO_FEATURES }} --allow-dirty diff --git a/sdk_v2/rust/tests/integration.rs b/sdk_v2/rust/tests/integration.rs index a611d36c..5013d8b8 100644 --- a/sdk_v2/rust/tests/integration.rs +++ b/sdk_v2/rust/tests/integration.rs @@ -215,6 +215,8 @@ mod model_tests { .await .expect("get_model failed"); + println!("Model id: {}", model.id()); + assert!( !model.id().is_empty(), "Model id() should be a non-empty string" @@ -231,6 +233,8 @@ mod model_tests { .expect("get_model failed"); let variants = model.variants(); + println!("Model has {} variant(s)", variants.len()); + assert!( !variants.is_empty(), "Model should have at least one variant" @@ -281,6 +285,8 @@ mod model_tests { .expect("get_model failed"); let path = model.path().await.expect("path() should succeed"); + println!("Model path: {path}"); + assert!( !path.is_empty(), "Cached model should have a non-empty path" @@ -476,6 +482,9 @@ mod chat_client_tests { .first() .and_then(|c| c.message.content.as_deref()) .unwrap_or(""); + println!("Response: {content}"); + + println!("REST response: {content}"); assert!( content.contains("42"), @@ -508,6 +517,8 @@ mod chat_client_tests { } stream.close().await.expect("stream close failed"); + println!("First turn: {first_result}"); + assert!( first_result.contains("42"), "First turn should contain '42', got: {first_result}" @@ -531,6 +542,8 @@ mod chat_client_tests { } stream.close().await.expect("stream close failed"); + println!("Follow-up: {second_result}"); + assert!( second_result.contains("67"), "Follow-up should contain '67', got: {second_result}" @@ -655,6 +668,8 @@ mod chat_client_tests { .and_then(|c| c.message.content.as_deref()) .unwrap_or(""); + println!("Tool call result: {content}"); + assert!( content.contains("42"), "Final answer should contain '42', got: {content}" @@ -754,6 +769,8 @@ mod chat_client_tests { } stream.close().await.expect("stream close failed"); + println!("Streamed tool call result: {final_result}"); + assert!( final_result.contains("42"), "Streamed final answer should contain '42', got: {final_result}" @@ -837,6 +854,8 @@ mod audio_client_tests { } stream.close().await.expect("stream close failed"); + println!("Streamed transcription: {full_text}"); + assert!( full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), "Streamed transcription should contain expected text, got: {full_text}" @@ -863,6 +882,8 @@ mod audio_client_tests { } stream.close().await.expect("stream close failed"); + println!("Streamed transcription: {full_text}"); + assert!( full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), "Streamed transcription should contain expected text, got: {full_text}" @@ -944,6 +965,8 @@ mod web_service_tests { .and_then(|v| v.as_str()) .unwrap_or(""); + println!("REST response: {content}"); + assert!( content.contains("42"), "Expected response to contain '42', got: {content}" @@ -1018,6 +1041,8 @@ mod web_service_tests { } } + println!("REST streamed response: {full_text}"); + assert!( full_text.contains("42"), "Expected streamed response to contain '42', got: {full_text}" @@ -1039,6 +1064,7 @@ mod web_service_tests { .start_web_service() .await .expect("start_web_service failed"); + println!("Web service URLs: {urls:?}"); assert!(!urls.is_empty(), "start_web_service should return URLs"); let cached_urls = manager.urls(); From 93ed1773906df09aa513dd093f7be328d998510d Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 14:48:44 +0000 Subject: [PATCH 16/25] refactor(rust): apply Canonical Rust Best Practices to sdk_v2/rust - Replace `pub use types::*` with explicit re-exports of all 8 types - Add `self::` prefix to all child module re-exports in lib.rs, detail/mod.rs, and openai/mod.rs - Add `#[non_exhaustive]` to FoundryLocalConfig with a new `FoundryLocalConfig::new(app_name)` constructor - Change error variants from positional `(String)` to named `{ reason: String }` fields per the guide's error type conventions - Update all examples, samples, and tests to use the new constructor Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../audio-transcription-example/src/main.rs | 5 +- .../rust/foundry-local-webserver/src/main.rs | 5 +- .../rust/native-chat-completions/src/main.rs | 5 +- .../tool-calling-foundry-local/src/main.rs | 5 +- sdk_v2/rust/README.md | 5 +- sdk_v2/rust/examples/chat_completion.rs | 5 +- sdk_v2/rust/examples/interactive_chat.rs | 5 +- sdk_v2/rust/examples/tool_calling.rs | 5 +- sdk_v2/rust/src/catalog.rs | 24 +++---- sdk_v2/rust/src/configuration.rs | 45 +++++++----- sdk_v2/rust/src/detail/core_interop.rs | 71 ++++++++++--------- sdk_v2/rust/src/detail/mod.rs | 2 +- sdk_v2/rust/src/error.rs | 20 +++--- sdk_v2/rust/src/lib.rs | 19 ++--- sdk_v2/rust/src/model.rs | 10 +-- sdk_v2/rust/src/openai/audio_client.rs | 30 +++++--- sdk_v2/rust/src/openai/chat_client.rs | 16 +++-- sdk_v2/rust/src/openai/mod.rs | 4 +- sdk_v2/rust/tests/common/mod.rs | 17 ++--- 19 files changed, 150 insertions(+), 148 deletions(-) diff --git a/samples/rust/audio-transcription-example/src/main.rs b/samples/rust/audio-transcription-example/src/main.rs index 4a602196..9dc64dc0 100644 --- a/samples/rust/audio-transcription-example/src/main.rs +++ b/samples/rust/audio-transcription-example/src/main.rs @@ -21,10 +21,7 @@ async fn main() -> Result<(), Box> { }); // ── 1. Initialise the manager ──────────────────────────────────────── - let manager = FoundryLocalManager::create(FoundryLocalConfig { - app_name: "foundry_local_samples".into(), - ..Default::default() - })?; + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; // ── 2. Pick the whisper model and ensure it is downloaded ──────────── let model = manager.catalog().get_model(ALIAS).await?; diff --git a/samples/rust/foundry-local-webserver/src/main.rs b/samples/rust/foundry-local-webserver/src/main.rs index 3c9dbb75..a3b5f326 100644 --- a/samples/rust/foundry-local-webserver/src/main.rs +++ b/samples/rust/foundry-local-webserver/src/main.rs @@ -18,10 +18,7 @@ use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; async fn main() -> Result<(), Box> { // ── 1. Initialise the SDK ──────────────────────────────────────────── println!("Initializing Foundry Local SDK..."); - let manager = FoundryLocalManager::create(FoundryLocalConfig { - app_name: "foundry_local_samples".into(), - ..Default::default() - })?; + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; println!("✓ SDK initialized"); // ── 2. Download and load a model ───────────────────────────────────── diff --git a/samples/rust/native-chat-completions/src/main.rs b/samples/rust/native-chat-completions/src/main.rs index 904c3934..61df92d7 100644 --- a/samples/rust/native-chat-completions/src/main.rs +++ b/samples/rust/native-chat-completions/src/main.rs @@ -17,10 +17,7 @@ async fn main() -> Result<(), Box> { println!("=======================\n"); // ── 1. Initialise the manager ──────────────────────────────────────── - let manager = FoundryLocalManager::create(FoundryLocalConfig { - app_name: "foundry_local_samples".into(), - ..Default::default() - })?; + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; // ── 2. Pick a model and ensure it is downloaded ────────────────────── let model = manager.catalog().get_model(ALIAS).await?; diff --git a/samples/rust/tool-calling-foundry-local/src/main.rs b/samples/rust/tool-calling-foundry-local/src/main.rs index 43ff0a18..53ada9c5 100644 --- a/samples/rust/tool-calling-foundry-local/src/main.rs +++ b/samples/rust/tool-calling-foundry-local/src/main.rs @@ -49,10 +49,7 @@ async fn main() -> Result<(), Box> { println!("===============================\n"); // ── 1. Initialise the manager ──────────────────────────────────────── - let manager = FoundryLocalManager::create(FoundryLocalConfig { - app_name: "foundry_local_samples".into(), - ..Default::default() - })?; + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; // ── 2. Load a model ────────────────────────────────────────────────── let model = manager.catalog().get_model(ALIAS).await?; diff --git a/sdk_v2/rust/README.md b/sdk_v2/rust/README.md index bf8fa508..191cce07 100644 --- a/sdk_v2/rust/README.md +++ b/sdk_v2/rust/README.md @@ -47,10 +47,7 @@ use foundry_local_sdk::{ #[tokio::main] async fn main() -> Result<(), Box> { // Initialize the manager — loads native libraries and starts the engine - let manager = FoundryLocalManager::create(FoundryLocalConfig { - app_name: "my_app".into(), - ..Default::default() - })?; + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("my_app"))?; // List available models let models = manager.catalog().get_models().await?; diff --git a/sdk_v2/rust/examples/chat_completion.rs b/sdk_v2/rust/examples/chat_completion.rs index 2d604417..e3ae1884 100644 --- a/sdk_v2/rust/examples/chat_completion.rs +++ b/sdk_v2/rust/examples/chat_completion.rs @@ -15,10 +15,7 @@ type Result = std::result::Result; #[tokio::main] async fn main() -> Result<()> { // ── 1. Initialise the manager ──────────────────────────────────────── - let config = FoundryLocalConfig { - app_name: "foundry_local_samples".into(), - ..Default::default() - }; + let config = FoundryLocalConfig::new("foundry_local_samples"); let manager = FoundryLocalManager::create(config)?; diff --git a/sdk_v2/rust/examples/interactive_chat.rs b/sdk_v2/rust/examples/interactive_chat.rs index f2c7911f..951b9997 100644 --- a/sdk_v2/rust/examples/interactive_chat.rs +++ b/sdk_v2/rust/examples/interactive_chat.rs @@ -14,10 +14,7 @@ use tokio_stream::StreamExt; #[tokio::main] async fn main() -> Result<(), Box> { // ── Initialise ─────────────────────────────────────────────────────── - let manager = FoundryLocalManager::create(FoundryLocalConfig { - app_name: "foundry_local_samples".into(), - ..Default::default() - })?; + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; // Pick the first available model (or change this to a specific alias) let catalog = manager.catalog(); diff --git a/sdk_v2/rust/examples/tool_calling.rs b/sdk_v2/rust/examples/tool_calling.rs index 6acdeb76..8bca619d 100644 --- a/sdk_v2/rust/examples/tool_calling.rs +++ b/sdk_v2/rust/examples/tool_calling.rs @@ -45,10 +45,7 @@ struct ToolCallState { #[tokio::main] async fn main() -> Result<()> { // ── 1. Initialise ──────────────────────────────────────────────────── - let config = FoundryLocalConfig { - app_name: "foundry_local_samples".into(), - ..Default::default() - }; + let config = FoundryLocalConfig::new("foundry_local_samples"); let manager = FoundryLocalManager::create(config)?; diff --git a/sdk_v2/rust/src/catalog.rs b/sdk_v2/rust/src/catalog.rs index 2061ba76..20967fbe 100644 --- a/sdk_v2/rust/src/catalog.rs +++ b/sdk_v2/rust/src/catalog.rs @@ -76,34 +76,34 @@ impl Catalog { /// Look up a model by its alias. pub async fn get_model(&self, alias: &str) -> Result { if alias.trim().is_empty() { - return Err(FoundryLocalError::Validation( - "Model alias must be a non-empty string".into(), - )); + return Err(FoundryLocalError::Validation { + reason: "Model alias must be a non-empty string".into(), + }); } self.update_models().await?; let map = self.models_by_alias.lock().unwrap(); map.get(alias).cloned().ok_or_else(|| { let available: Vec<&String> = map.keys().collect(); - FoundryLocalError::ModelOperation(format!( - "Unknown model alias '{alias}'. Available: {available:?}" - )) + FoundryLocalError::ModelOperation { + reason: format!("Unknown model alias '{alias}'. Available: {available:?}"), + } }) } /// Look up a specific model variant by its unique id. pub async fn get_model_variant(&self, id: &str) -> Result { if id.trim().is_empty() { - return Err(FoundryLocalError::Validation( - "Variant id must be a non-empty string".into(), - )); + return Err(FoundryLocalError::Validation { + reason: "Variant id must be a non-empty string".into(), + }); } self.update_models().await?; let map = self.variants_by_id.lock().unwrap(); map.get(id).cloned().ok_or_else(|| { let available: Vec<&String> = map.keys().collect(); - FoundryLocalError::ModelOperation(format!( - "Unknown variant id '{id}'. Available: {available:?}" - )) + FoundryLocalError::ModelOperation { + reason: format!("Unknown variant id '{id}'. Available: {available:?}"), + } }) } diff --git a/sdk_v2/rust/src/configuration.rs b/sdk_v2/rust/src/configuration.rs index 0eb66e43..8645b3e7 100644 --- a/sdk_v2/rust/src/configuration.rs +++ b/sdk_v2/rust/src/configuration.rs @@ -29,6 +29,7 @@ impl LogLevel { /// User-facing configuration for initializing the Foundry Local SDK. #[derive(Debug, Clone, Default)] +#[non_exhaustive] pub struct FoundryLocalConfig { pub app_name: String, pub app_data_dir: Option, @@ -41,6 +42,26 @@ pub struct FoundryLocalConfig { pub additional_settings: Option>, } +impl FoundryLocalConfig { + /// Create a new configuration with the given application name. + /// + /// All other fields default to `None`. Use the struct update syntax to + /// override specific options: + /// + /// ```ignore + /// let config = FoundryLocalConfig { + /// log_level: Some(LogLevel::Debug), + /// ..FoundryLocalConfig::new("my_app") + /// }; + /// ``` + pub fn new(app_name: impl Into) -> Self { + Self { + app_name: app_name.into(), + ..Self::default() + } + } +} + /// Internal configuration object that converts [`FoundryLocalConfig`] into the /// parameter map expected by the native core library. #[derive(Debug, Clone)] @@ -58,9 +79,9 @@ impl Configuration { pub fn new(config: FoundryLocalConfig) -> Result { let app_name = config.app_name.trim().to_string(); if app_name.is_empty() { - return Err(FoundryLocalError::InvalidConfiguration( - "app_name must be set and non-empty".into(), - )); + return Err(FoundryLocalError::InvalidConfiguration { + reason: "app_name must be set and non-empty".into(), + }); } let mut params = HashMap::new(); @@ -104,15 +125,8 @@ mod tests { #[test] fn valid_config() { let cfg = FoundryLocalConfig { - app_name: "TestApp".into(), - app_data_dir: None, - model_cache_dir: None, - logs_dir: None, log_level: Some(LogLevel::Debug), - web_service_urls: None, - service_endpoint: None, - library_path: None, - additional_settings: None, + ..FoundryLocalConfig::new("TestApp") }; let c = Configuration::new(cfg).unwrap(); assert_eq!(c.params["AppName"], "TestApp"); @@ -123,14 +137,7 @@ mod tests { fn empty_app_name_fails() { let cfg = FoundryLocalConfig { app_name: " ".into(), - app_data_dir: None, - model_cache_dir: None, - logs_dir: None, - log_level: None, - web_service_urls: None, - service_endpoint: None, - library_path: None, - additional_settings: None, + ..FoundryLocalConfig::default() }; assert!(Configuration::new(cfg).is_err()); } diff --git a/sdk_v2/rust/src/detail/core_interop.rs b/sdk_v2/rust/src/detail/core_interop.rs index 46b1e261..08430823 100644 --- a/sdk_v2/rust/src/detail/core_interop.rs +++ b/sdk_v2/rust/src/detail/core_interop.rs @@ -157,28 +157,29 @@ impl CoreInterop { let _dependency_libs = Self::load_windows_dependencies(&lib_path)?; let library = unsafe { - Library::new(&lib_path).map_err(|e| { - FoundryLocalError::LibraryLoad(format!( + Library::new(&lib_path).map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!( "Failed to load native library at {}: {e}", lib_path.display() - )) + ), })? }; let execute_command: ExecuteCommandFn = unsafe { - let sym: Symbol = library.get(b"execute_command\0").map_err(|e| { - FoundryLocalError::LibraryLoad(format!("Symbol 'execute_command' not found: {e}")) - })?; + let sym: Symbol = + library + .get(b"execute_command\0") + .map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!("Symbol 'execute_command' not found: {e}"), + })?; *sym }; let execute_command_with_callback: ExecuteCommandWithCallbackFn = unsafe { let sym: Symbol = library .get(b"execute_command_with_callback\0") - .map_err(|e| { - FoundryLocalError::LibraryLoad(format!( - "Symbol 'execute_command_with_callback' not found: {e}" - )) + .map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!("Symbol 'execute_command_with_callback' not found: {e}"), })?; *sym }; @@ -198,17 +199,18 @@ impl CoreInterop { /// `params` is an optional JSON value that will be serialised and sent as /// the data payload. pub fn execute_command(&self, command: &str, params: Option<&Value>) -> Result { - let cmd = CString::new(command).map_err(|e| { - FoundryLocalError::CommandExecution(format!("Invalid command string: {e}")) + let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid command string: {e}"), })?; let data_json = match params { Some(v) => serde_json::to_string(v)?, None => String::new(), }; - let data_cstr = CString::new(data_json.as_str()).map_err(|e| { - FoundryLocalError::CommandExecution(format!("Invalid data string: {e}")) - })?; + let data_cstr = + CString::new(data_json.as_str()).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid data string: {e}"), + })?; let request = RequestBuffer { command: cmd.as_ptr(), @@ -240,17 +242,18 @@ impl CoreInterop { where F: FnMut(&str), { - let cmd = CString::new(command).map_err(|e| { - FoundryLocalError::CommandExecution(format!("Invalid command string: {e}")) + let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid command string: {e}"), })?; let data_json = match params { Some(v) => serde_json::to_string(v)?, None => String::new(), }; - let data_cstr = CString::new(data_json.as_str()).map_err(|e| { - FoundryLocalError::CommandExecution(format!("Invalid data string: {e}")) - })?; + let data_cstr = + CString::new(data_json.as_str()).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid data string: {e}"), + })?; let request = RequestBuffer { command: cmd.as_ptr(), @@ -289,7 +292,9 @@ impl CoreInterop { let this = Arc::clone(self); tokio::task::spawn_blocking(move || this.execute_command(&command, params.as_ref())) .await - .map_err(|e| FoundryLocalError::CommandExecution(format!("task join error: {e}")))? + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? } /// Async version of [`Self::execute_command_streaming`]. @@ -310,7 +315,9 @@ impl CoreInterop { this.execute_command_streaming(&command, params.as_ref(), callback) }) .await - .map_err(|e| FoundryLocalError::CommandExecution(format!("task join error: {e}")))? + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? } /// Async streaming variant that bridges the FFI callback into a @@ -374,7 +381,7 @@ impl CoreInterop { // Return error or data. if let Some(err) = error_str { - Err(FoundryLocalError::CommandExecution(err)) + Err(FoundryLocalError::CommandExecution { reason: err }) } else { Ok(data_str.unwrap_or_default()) } @@ -421,11 +428,13 @@ impl CoreInterop { } } - Err(FoundryLocalError::LibraryLoad(format!( - "Could not locate native library '{CORE_LIB_NAME}'. \ - Set the FoundryLocalCorePath config option or the FOUNDRY_NATIVE_DIR \ - environment variable." - ))) + Err(FoundryLocalError::LibraryLoad { + reason: format!( + "Could not locate native library '{CORE_LIB_NAME}'. \ + Set the FoundryLocalCorePath config option or the FOUNDRY_NATIVE_DIR \ + environment variable." + ), + }) } /// On Windows, pre-load runtime dependencies so the core library can @@ -448,10 +457,8 @@ impl CoreInterop { let dep_path = dir.join(dep); if dep_path.exists() { let lib = unsafe { - Library::new(&dep_path).map_err(|e| { - FoundryLocalError::LibraryLoad(format!( - "Failed to load dependency {dep}: {e}" - )) + Library::new(&dep_path).map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!("Failed to load dependency {dep}: {e}"), })? }; libs.push(lib); diff --git a/sdk_v2/rust/src/detail/mod.rs b/sdk_v2/rust/src/detail/mod.rs index 3f7fd07c..c7f2fd32 100644 --- a/sdk_v2/rust/src/detail/mod.rs +++ b/sdk_v2/rust/src/detail/mod.rs @@ -1,4 +1,4 @@ pub(crate) mod core_interop; mod model_load_manager; -pub use model_load_manager::ModelLoadManager; +pub use self::model_load_manager::ModelLoadManager; diff --git a/sdk_v2/rust/src/error.rs b/sdk_v2/rust/src/error.rs index a91453e5..226139b1 100644 --- a/sdk_v2/rust/src/error.rs +++ b/sdk_v2/rust/src/error.rs @@ -4,17 +4,17 @@ use thiserror::Error; #[derive(Debug, Error)] pub enum FoundryLocalError { /// The native core library could not be loaded. - #[error("library load error: {0}")] - LibraryLoad(String), + #[error("library load error: {reason}")] + LibraryLoad { reason: String }, /// A command executed against the native core returned an error. - #[error("command execution error: {0}")] - CommandExecution(String), + #[error("command execution error: {reason}")] + CommandExecution { reason: String }, /// The provided configuration is invalid. - #[error("invalid configuration: {0}")] - InvalidConfiguration(String), + #[error("invalid configuration: {reason}")] + InvalidConfiguration { reason: String }, /// A model operation failed (load, unload, download, etc.). - #[error("model operation error: {0}")] - ModelOperation(String), + #[error("model operation error: {reason}")] + ModelOperation { reason: String }, /// An HTTP request to the external service failed. #[error("HTTP request error: {0}")] HttpRequest(#[from] reqwest::Error), @@ -22,8 +22,8 @@ pub enum FoundryLocalError { #[error("serialization error: {0}")] Serialization(#[from] serde_json::Error), /// A validation check on user-supplied input failed. - #[error("validation error: {0}")] - Validation(String), + #[error("validation error: {reason}")] + Validation { reason: String }, /// An I/O error occurred. #[error("I/O error: {0}")] Io(#[from] std::io::Error), diff --git a/sdk_v2/rust/src/lib.rs b/sdk_v2/rust/src/lib.rs index df7e7cef..f3564145 100644 --- a/sdk_v2/rust/src/lib.rs +++ b/sdk_v2/rust/src/lib.rs @@ -13,14 +13,17 @@ mod types; pub(crate) mod detail; pub mod openai; -pub use catalog::Catalog; -pub use configuration::{FoundryLocalConfig, LogLevel}; -pub use detail::ModelLoadManager; -pub use error::FoundryLocalError; -pub use foundry_local_manager::FoundryLocalManager; -pub use model::Model; -pub use model_variant::ModelVariant; -pub use types::*; +pub use self::catalog::Catalog; +pub use self::configuration::{FoundryLocalConfig, LogLevel}; +pub use self::detail::ModelLoadManager; +pub use self::error::FoundryLocalError; +pub use self::foundry_local_manager::FoundryLocalManager; +pub use self::model::Model; +pub use self::model_variant::ModelVariant; +pub use self::types::{ + ChatResponseFormat, ChatToolChoice, DeviceType, ModelInfo, ModelSettings, Parameter, + PromptTemplate, Runtime, +}; // Re-export OpenAI request types so callers can construct typed messages. pub use async_openai::types::chat::{ diff --git a/sdk_v2/rust/src/model.rs b/sdk_v2/rust/src/model.rs index 54422c80..6856ac0d 100644 --- a/sdk_v2/rust/src/model.rs +++ b/sdk_v2/rust/src/model.rs @@ -58,10 +58,12 @@ impl Model { return Ok(()); } let available: Vec = self.variants.iter().map(|v| v.id().to_string()).collect(); - Err(FoundryLocalError::ModelOperation(format!( - "Variant '{id}' not found for model '{}'. Available: {available:?}", - self.alias - ))) + Err(FoundryLocalError::ModelOperation { + reason: format!( + "Variant '{id}' not found for model '{}'. Available: {available:?}", + self.alias + ), + }) } /// Returns a reference to the currently selected variant. diff --git a/sdk_v2/rust/src/openai/audio_client.rs b/sdk_v2/rust/src/openai/audio_client.rs index f909d544..7fc25817 100644 --- a/sdk_v2/rust/src/openai/audio_client.rs +++ b/sdk_v2/rust/src/openai/audio_client.rs @@ -97,7 +97,9 @@ impl AudioTranscriptionStream { if let Some(handle) = self.handle.take() { handle .await - .map_err(|e| FoundryLocalError::CommandExecution(format!("task join error: {e}")))? + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? .map(|_| ()) } else { Ok(()) @@ -138,9 +140,13 @@ impl AudioClient { &self, audio_file_path: impl AsRef, ) -> Result { - let path_str = audio_file_path.as_ref().to_str().ok_or_else(|| { - FoundryLocalError::Validation("audio file path is not valid UTF-8".into()) - })?; + let path_str = + audio_file_path + .as_ref() + .to_str() + .ok_or_else(|| FoundryLocalError::Validation { + reason: "audio file path is not valid UTF-8".into(), + })?; Self::validate_path(path_str)?; let request = self.settings.serialize(&self.model_id, path_str); @@ -164,9 +170,13 @@ impl AudioClient { &self, audio_file_path: impl AsRef, ) -> Result { - let path_str = audio_file_path.as_ref().to_str().ok_or_else(|| { - FoundryLocalError::Validation("audio file path is not valid UTF-8".into()) - })?; + let path_str = + audio_file_path + .as_ref() + .to_str() + .ok_or_else(|| FoundryLocalError::Validation { + reason: "audio file path is not valid UTF-8".into(), + })?; Self::validate_path(path_str)?; let mut request = self.settings.serialize(&self.model_id, path_str); @@ -193,9 +203,9 @@ impl AudioClient { fn validate_path(path: &str) -> Result<()> { if path.trim().is_empty() { - return Err(FoundryLocalError::Validation( - "audio_file_path must be a non-empty string".into(), - )); + return Err(FoundryLocalError::Validation { + reason: "audio_file_path must be a non-empty string".into(), + }); } Ok(()) } diff --git a/sdk_v2/rust/src/openai/chat_client.rs b/sdk_v2/rust/src/openai/chat_client.rs index 98d575e7..26fae831 100644 --- a/sdk_v2/rust/src/openai/chat_client.rs +++ b/sdk_v2/rust/src/openai/chat_client.rs @@ -158,7 +158,9 @@ impl ChatCompletionStream { if let Some(handle) = self.handle.take() { handle .await - .map_err(|e| FoundryLocalError::CommandExecution(format!("task join error: {e}")))? + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? .map(|_| ()) } else { Ok(()) @@ -249,9 +251,9 @@ impl ChatClient { tools: Option<&[ChatCompletionTools]>, ) -> Result { if messages.is_empty() { - return Err(FoundryLocalError::Validation( - "messages must be a non-empty array".into(), - )); + return Err(FoundryLocalError::Validation { + reason: "messages must be a non-empty array".into(), + }); } let request = self.build_request(messages, tools, false)?; @@ -279,9 +281,9 @@ impl ChatClient { tools: Option<&[ChatCompletionTools]>, ) -> Result { if messages.is_empty() { - return Err(FoundryLocalError::Validation( - "messages must be a non-empty array".into(), - )); + return Err(FoundryLocalError::Validation { + reason: "messages must be a non-empty array".into(), + }); } let request = self.build_request(messages, tools, true)?; diff --git a/sdk_v2/rust/src/openai/mod.rs b/sdk_v2/rust/src/openai/mod.rs index 190c6268..7a800c67 100644 --- a/sdk_v2/rust/src/openai/mod.rs +++ b/sdk_v2/rust/src/openai/mod.rs @@ -1,7 +1,7 @@ mod audio_client; mod chat_client; -pub use audio_client::{ +pub use self::audio_client::{ AudioClient, AudioClientSettings, AudioTranscriptionResponse, AudioTranscriptionStream, }; -pub use chat_client::{ChatClient, ChatClientSettings, ChatCompletionStream}; +pub use self::chat_client::{ChatClient, ChatClientSettings, ChatCompletionStream}; diff --git a/sdk_v2/rust/tests/common/mod.rs b/sdk_v2/rust/tests/common/mod.rs index 45d66fe1..dbe3414d 100644 --- a/sdk_v2/rust/tests/common/mod.rs +++ b/sdk_v2/rust/tests/common/mod.rs @@ -77,17 +77,12 @@ pub fn test_config() -> FoundryLocalConfig { let mut additional = HashMap::new(); additional.insert("Bootstrap".into(), "false".into()); - FoundryLocalConfig { - app_name: "FoundryLocalTest".into(), - app_data_dir: None, - model_cache_dir: Some(get_test_data_shared_path().to_string_lossy().into_owned()), - logs_dir: Some(logs_dir.to_string_lossy().into_owned()), - log_level: Some(LogLevel::Warn), - web_service_urls: None, - service_endpoint: None, - library_path: None, - additional_settings: Some(additional), - } + let mut config = FoundryLocalConfig::new("FoundryLocalTest"); + config.model_cache_dir = Some(get_test_data_shared_path().to_string_lossy().into_owned()); + config.logs_dir = Some(logs_dir.to_string_lossy().into_owned()); + config.log_level = Some(LogLevel::Warn); + config.additional_settings = Some(additional); + config } /// Create (or return the cached) [`FoundryLocalManager`] for tests. From 38649d9a7a2e57826afcaa6c09ec596e4229bb6b Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 14:57:23 +0000 Subject: [PATCH 17/25] update ordering and description (style guide) --- sdk_v2/rust/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk_v2/rust/Cargo.toml b/sdk_v2/rust/Cargo.toml index fc035329..cb2cfecf 100644 --- a/sdk_v2/rust/Cargo.toml +++ b/sdk_v2/rust/Cargo.toml @@ -2,9 +2,9 @@ name = "foundry-local-sdk" version = "0.1.0" edition = "2021" -description = "Foundry Local Rust SDK - Local AI model inference" license = "MIT" readme = "README.md" +description = "Local AI model inference powered by the Foundry Local Core engine" [features] default = [] From 553690ff349f7ce6b5f156deb40f5317e9bd221c Mon Sep 17 00:00:00 2001 From: samuel100 Date: Tue, 10 Mar 2026 15:07:46 +0000 Subject: [PATCH 18/25] update to sample README. --- samples/rust/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/samples/rust/README.md b/samples/rust/README.md index 3e824369..c5399b3d 100644 --- a/samples/rust/README.md +++ b/samples/rust/README.md @@ -5,7 +5,6 @@ This directory contains samples demonstrating how to use the Foundry Local Rust ## Prerequisites - Rust 1.70.0 or later -- Foundry Local installed and available on PATH ## Samples From 9183d44d13ca0c3e2c1d7e71990285ae94ccc032 Mon Sep 17 00:00:00 2001 From: samkemp Date: Thu, 12 Mar 2026 12:47:17 +0000 Subject: [PATCH 19/25] Expand Rust SDK README with features and usage examples Add comprehensive documentation matching the quality and structure of the JS and C# SDK READMEs. The expanded README now includes: - Feature overview listing all SDK capabilities - Installation instructions with tokio dependency guidance - Quick Start example with numbered steps - Usage sections for catalog browsing, model lifecycle, chat completions, streaming responses, tool calling, response format options, audio transcription, and embedded web service - Chat client settings reference table - Error handling guide with FoundryLocalError enum variants - Configuration reference table with all FoundryLocalConfig fields - Links to sample applications in samples/rust/ Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/README.md | 402 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 384 insertions(+), 18 deletions(-) diff --git a/sdk_v2/rust/README.md b/sdk_v2/rust/README.md index 191cce07..e751681b 100644 --- a/sdk_v2/rust/README.md +++ b/sdk_v2/rust/README.md @@ -1,8 +1,22 @@ # Foundry Local Rust SDK -Rust bindings for [Foundry Local](https://github.com/microsoft/Foundry-Local) — run AI models locally with a simple API. - -The SDK dynamically loads the Foundry Local Core native engine and exposes a safe Rust interface for model management, chat completions, and audio processing. +The Foundry Local Rust SDK provides an async Rust interface for running AI models locally on your machine. Discover, download, load, and run inference — all without cloud dependencies. + +## Features + +- **Local-first AI** — Run models entirely on your machine with no cloud calls +- **Model catalog** — Browse and discover available models; check what's cached or loaded +- **Automatic model management** — Download, load, unload, and remove models from cache +- **Chat completions** — OpenAI-compatible chat API with both non-streaming and streaming responses +- **Audio transcription** — Transcribe audio files locally with streaming support +- **Tool calling** — Function/tool calling with streaming, multi-turn conversation support +- **Response format control** — Text, JSON, JSON Schema, and Lark grammar constrained output +- **Multi-variant models** — Models can have multiple variants (e.g., different quantizations) with automatic selection of the best cached variant +- **Embedded web service** — Start a local HTTP server for OpenAI-compatible API access +- **WinML support** — Automatic execution provider download on Windows for NPU/GPU acceleration +- **Configurable inference** — Control temperature, max tokens, top-k, top-p, frequency penalty, random seed, and more +- **Async-first** — Every operation is `async`; designed for use with the `tokio` runtime +- **Safe FFI** — Dynamically loads the native Foundry Local Core engine with a safe Rust wrapper ## Prerequisites @@ -22,11 +36,19 @@ Or add to your `Cargo.toml`: foundry-local-sdk = "0.1" ``` -## Feature Flags +You also need an async runtime. Most examples use [tokio](https://crates.io/crates/tokio): + +```toml +[dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-stream = "0.1" # for StreamExt on streaming responses +``` + +### Feature Flags | Feature | Description | |-----------|-------------| -| `winml` | Use the WinML backend (Windows only). Selects different ONNX Runtime and GenAI packages. | +| `winml` | Use the WinML backend (Windows only). Selects different ONNX Runtime and GenAI packages for NPU/GPU acceleration. | | `nightly` | Resolve the latest nightly build of the Core package from the ORT-Nightly feed. | Enable features in `Cargo.toml`: @@ -36,6 +58,8 @@ Enable features in `Cargo.toml`: foundry-local-sdk = { version = "0.1", features = ["winml"] } ``` +> **Note:** The `winml` feature is only relevant on Windows. On macOS and Linux, the standard build is used regardless. No code changes are needed — your application code stays the same. + ## Quick Start ```rust @@ -46,20 +70,14 @@ use foundry_local_sdk::{ #[tokio::main] async fn main() -> Result<(), Box> { - // Initialize the manager — loads native libraries and starts the engine + // 1. Initialize the manager — loads native libraries and starts the engine let manager = FoundryLocalManager::create(FoundryLocalConfig::new("my_app"))?; - // List available models - let models = manager.catalog().get_models().await?; - for model in &models { - println!("{} (id: {})", model.alias(), model.id()); - } - - // Pick a model and ensure it is loaded + // 2. Get a model from the catalog and load it let model = manager.catalog().get_model("phi-3.5-mini").await?; model.load().await?; - // Create a chat client and send a completion request + // 3. Create a chat client and run inference let mut client = model.create_chat_client(); client.temperature(0.7).max_tokens(256); @@ -69,16 +87,346 @@ async fn main() -> Result<(), Box> { ]; let response = client.complete_chat(&messages, None).await?; - if let Some(choice) = response.choices.first() { - if let Some(ref content) = choice.message.content { - println!("{content}"); + println!("{}", response.choices[0].message.content.as_deref().unwrap_or("")); + + // 4. Clean up + model.unload().await?; + + Ok(()) +} +``` + +## Usage + +### Browsing the Model Catalog + +The `Catalog` lets you discover what models are available, which are already cached locally, and which are currently loaded in memory. + +```rust +let catalog = manager.catalog(); + +// List all available models +let models = catalog.get_models().await?; +for model in &models { + println!("{} (id: {})", model.alias(), model.id()); +} + +// Look up a specific model by alias +let model = catalog.get_model("phi-3.5-mini").await?; + +// Look up a specific variant by its unique model ID +let variant = catalog.get_model_variant("phi-3.5-mini-generic-gpu-4").await?; + +// See what's already downloaded +let cached = catalog.get_cached_models().await?; + +// See what's currently loaded in memory +let loaded = catalog.get_loaded_models().await?; +``` + +### Model Lifecycle + +Each `Model` wraps one or more `ModelVariant` entries (different quantizations, hardware targets). The SDK auto-selects the best available variant, preferring cached versions. + +```rust +let mut model = catalog.get_model("phi-3.5-mini").await?; + +// Inspect available variants +println!("Selected: {}", model.selected_variant().id()); +for v in model.variants() { + println!(" {} (cached: {})", v.id(), v.is_cached().await?); +} + +// Switch to a specific variant +model.select_variant("phi-3.5-mini-generic-gpu-4")?; +``` + +Download, load, and unload: + +```rust +// Download with progress reporting +model.download(Some(|progress: &str| { + print!("\r{progress}"); + std::io::Write::flush(&mut std::io::stdout()).ok(); +})).await?; + +// Load into memory +model.load().await?; + +// Unload when done +model.unload().await?; + +// Remove from local cache entirely +model.remove_from_cache().await?; +``` + +### Chat Completions + +The `ChatClient` follows the OpenAI Chat Completion API structure. + +```rust +let mut client = model.create_chat_client(); + +// Configure generation settings (builder pattern) +client + .temperature(0.7) + .max_tokens(256) + .top_p(0.9) + .frequency_penalty(0.5); + +// Non-streaming completion +let response = client.complete_chat( + &[ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("Explain Rust's ownership model.").into(), + ], + None, +).await?; + +println!("{}", response.choices[0].message.content.as_deref().unwrap_or("")); +``` + +### Streaming Responses + +For real-time token-by-token output, use streaming: + +```rust +use tokio_stream::StreamExt; + +let mut stream = client.complete_streaming_chat( + &[ChatCompletionRequestUserMessage::from("Write a short poem about Rust.").into()], + None, +).await?; + +while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(content) = &chunk.choices[0].delta.content { + print!("{content}"); + } +} + +// Always close the stream to finalize the native session +stream.close().await?; +``` + +### Tool Calling + +Define functions the model can call and handle the multi-turn conversation: + +```rust +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestToolMessage, + ChatCompletionTools, ChatToolChoice, FinishReason, +}; +use serde_json::json; + +// Define available tools +let tools: Vec = serde_json::from_value(json!([{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { "type": "string", "description": "City name" } + }, + "required": ["location"] } } +}]))?; + +let mut client = model.create_chat_client(); +client.max_tokens(512).tool_choice(ChatToolChoice::Auto); + +let mut messages: Vec = vec![ + ChatCompletionRequestUserMessage::from("What's the weather in Seattle?").into(), +]; + +// First request — model may call a tool +let response = client.complete_chat(&messages, Some(&tools)).await?; +let choice = &response.choices[0]; + +if choice.finish_reason == Some(FinishReason::ToolCalls) { + if let Some(tool_calls) = &choice.message.tool_calls { + for tc in tool_calls { + // Execute the tool (your application logic) + let result = execute_tool(&tc.function.name, &tc.function.arguments); + + // Add assistant message with tool calls, then the tool result + messages.push(serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ "id": tc.id, "type": "function", + "function": { "name": tc.function.name, + "arguments": tc.function.arguments } }] + }))?); + messages.push(ChatCompletionRequestToolMessage { + content: result.into(), + tool_call_id: tc.id.clone(), + }.into()); + } - Ok(()) + // Continue the conversation with tool results + let final_response = client.complete_chat(&messages, Some(&tools)).await?; + println!("{}", final_response.choices[0].message.content.as_deref().unwrap_or("")); + } +} +``` + +Tool calling also works with streaming via `complete_streaming_chat` — accumulate tool call fragments during streaming and check for `FinishReason::ToolCalls`. + +### Response Format Options + +Control the output format of chat completions: + +```rust +use foundry_local_sdk::ChatResponseFormat; + +let mut client = model.create_chat_client(); + +// Plain text (default) +client.response_format(ChatResponseFormat::Text); + +// Unstructured JSON output +client.response_format(ChatResponseFormat::JsonObject); + +// JSON constrained to a schema +client.response_format(ChatResponseFormat::JsonSchema(r#"{ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "integer" } + }, + "required": ["name", "age"] +}"#.to_string())); + +// Output constrained by a Lark grammar (Foundry extension) +client.response_format(ChatResponseFormat::LarkGrammar(grammar.to_string())); +``` + +### Audio Transcription + +Transcribe audio files locally using the `AudioClient`: + +```rust +let model = manager.catalog().get_model("whisper-tiny").await?; +model.load().await?; + +let mut audio_client = model.create_audio_client(); +audio_client.language("en"); + +// Non-streaming transcription +let result = audio_client.transcribe("recording.wav").await?; +println!("{}", result.text); +``` + +#### Streaming Transcription + +```rust +use tokio_stream::StreamExt; + +let mut stream = audio_client.transcribe_streaming("recording.wav").await?; +while let Some(chunk) = stream.next().await { + print!("{}", chunk?.text); +} +stream.close().await?; +``` + +### Embedded Web Service + +Start a local HTTP server that exposes an OpenAI-compatible REST API: + +```rust +let urls = manager.start_web_service().await?; +println!("Service running at: {:?}", urls); + +// Any OpenAI-compatible client or tool can now connect to the endpoint. +// ... + +manager.stop_web_service().await?; +``` + +### Chat Client Settings + +All settings are configured via chainable builder methods on `ChatClient`: + +| Method | Type | Description | +|--------|------|-------------| +| `temperature(v)` | `f64` | Sampling temperature (0.0–2.0; higher = more random) | +| `max_tokens(v)` | `u32` | Maximum number of tokens to generate | +| `top_p(v)` | `f64` | Nucleus sampling probability (0.0–1.0) | +| `top_k(v)` | `u32` | Top-k sampling parameter (Foundry extension) | +| `frequency_penalty(v)` | `f64` | Frequency penalty | +| `presence_penalty(v)` | `f64` | Presence penalty | +| `n(v)` | `u32` | Number of completions to generate | +| `random_seed(v)` | `u64` | Random seed for reproducible results (Foundry extension) | +| `response_format(v)` | `ChatResponseFormat` | Output format (Text, JsonObject, JsonSchema, LarkGrammar) | +| `tool_choice(v)` | `ChatToolChoice` | Tool selection strategy (None, Auto, Required, Function) | + +## Error Handling + +All fallible operations return `foundry_local_sdk::Result`, which is an alias for `std::result::Result`. + +```rust +use foundry_local_sdk::FoundryLocalError; + +match manager.catalog().get_model("nonexistent").await { + Ok(model) => { /* use model */ } + Err(FoundryLocalError::ModelOperation { reason }) => { + eprintln!("Model error: {reason}"); + } + Err(FoundryLocalError::CommandExecution { reason }) => { + eprintln!("Core engine error: {reason}"); + } + Err(e) => { + eprintln!("Unexpected error: {e}"); + } } ``` +### Error Variants + +| Variant | Description | +|---------|-------------| +| `LibraryLoad { reason }` | The native core library could not be loaded | +| `CommandExecution { reason }` | A command executed against native core returned an error | +| `InvalidConfiguration { reason }` | The provided configuration is invalid | +| `ModelOperation { reason }` | A model operation failed (load, unload, download, etc.) | +| `HttpRequest(reqwest::Error)` | An HTTP request to an external service failed | +| `Serialization(serde_json::Error)` | JSON serialization/deserialization failed | +| `Validation { reason }` | A validation check on user-supplied input failed | +| `Io(std::io::Error)` | An I/O error occurred | + +## Configuration + +The SDK is configured via `FoundryLocalConfig` when creating the manager: + +```rust +use foundry_local_sdk::{FoundryLocalConfig, LogLevel}; + +let config = FoundryLocalConfig { + log_level: Some(LogLevel::Info), + model_cache_dir: Some("/path/to/cache".into()), + web_service_urls: Some("http://127.0.0.1:5000".into()), + ..FoundryLocalConfig::new("my_app") +}; + +let manager = FoundryLocalManager::create(config)?; +``` + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `app_name` | `String` | **(required)** | Your application name | +| `app_data_dir` | `Option` | `~/.{app_name}` | Application data directory | +| `model_cache_dir` | `Option` | `{app_data_dir}/cache/models` | Where models are stored locally | +| `logs_dir` | `Option` | `{app_data_dir}/logs` | Log output directory | +| `log_level` | `Option` | `Warn` | `Trace`, `Debug`, `Info`, `Warn`, `Error`, `Fatal` | +| `web_service_urls` | `Option` | `None` | Bind address for the embedded web service | +| `service_endpoint` | `Option` | `None` | URL of an existing external service to connect to | +| `library_path` | `Option` | Auto-discovered | Path to native Foundry Local Core libraries | +| `additional_settings` | `Option>` | `None` | Extra key-value settings passed to Core | + ## How It Works ### Native Library Download @@ -105,6 +453,24 @@ At runtime, the SDK uses `libloading` to dynamically load the Foundry Local Core | Linux x64 | `linux-x64`| ✅ | | macOS ARM64 | `osx-arm64`| ✅ | +## Running Examples + +Sample applications are available in [`samples/rust/`](../../samples/rust/): + +| Sample | Description | +|--------|-------------| +| `native-chat-completions` | Non-streaming and streaming chat completions | +| `tool-calling-foundry-local` | Function/tool calling with multi-turn conversations | +| `audio-transcription-example` | Audio transcription (non-streaming and streaming) | +| `foundry-local-webserver` | Embedded OpenAI-compatible REST API server | + +Run a sample with: + +```sh +cd samples/rust +cargo run -p native-chat-completions +``` + ## License MIT — see [LICENSE](../../LICENSE) for details. From 06da7f0bea9a96f3382994bd6bb96f4f57ded8c4 Mon Sep 17 00:00:00 2001 From: samkemp Date: Thu, 12 Mar 2026 12:51:06 +0000 Subject: [PATCH 20/25] refactor(rust): split integration tests into separate files Split the monolithic integration.rs into separate test modules while keeping a single test binary (tests/integration/main.rs) to avoid .NET native runtime re-initialization errors: - manager_test.rs: FoundryLocalManager tests - catalog_test.rs: Catalog operation tests - model_test.rs: Model operations (load, unload, cache, introspection) - chat_client_test.rs: ChatClient (completions, streaming, tool calling) - audio_client_test.rs: AudioClient (transcription, streaming) - web_service_test.rs: REST API tests Additional fixes: - Remove duplicate 'invalid callback' test (was identical to empty messages streaming test) - Use temperature(0.5) instead of 0.0 in with_temperature audio tests so they actually validate non-default temperature behavior Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/tests/integration.rs | 1081 ----------------- .../tests/integration/audio_client_test.rs | 131 ++ sdk_v2/rust/tests/integration/catalog_test.rs | 106 ++ .../tests/integration/chat_client_test.rs | 342 ++++++ .../tests/{ => integration}/common/mod.rs | 0 sdk_v2/rust/tests/integration/main.rs | 16 + sdk_v2/rust/tests/integration/manager_test.rs | 21 + sdk_v2/rust/tests/integration/model_test.rs | 285 +++++ .../tests/integration/web_service_test.rs | 163 +++ 9 files changed, 1064 insertions(+), 1081 deletions(-) delete mode 100644 sdk_v2/rust/tests/integration.rs create mode 100644 sdk_v2/rust/tests/integration/audio_client_test.rs create mode 100644 sdk_v2/rust/tests/integration/catalog_test.rs create mode 100644 sdk_v2/rust/tests/integration/chat_client_test.rs rename sdk_v2/rust/tests/{ => integration}/common/mod.rs (100%) create mode 100644 sdk_v2/rust/tests/integration/main.rs create mode 100644 sdk_v2/rust/tests/integration/manager_test.rs create mode 100644 sdk_v2/rust/tests/integration/model_test.rs create mode 100644 sdk_v2/rust/tests/integration/web_service_test.rs diff --git a/sdk_v2/rust/tests/integration.rs b/sdk_v2/rust/tests/integration.rs deleted file mode 100644 index 5013d8b8..00000000 --- a/sdk_v2/rust/tests/integration.rs +++ /dev/null @@ -1,1081 +0,0 @@ -//! Single integration test binary for the Foundry Local Rust SDK. -//! -//! All test modules are compiled into one binary so the native core is only -//! initialised once (via the `OnceLock` singleton in `FoundryLocalManager`). -//! Running them as separate binaries causes "already initialized" errors -//! because the .NET native runtime retains state across process-level -//! library loads. - -mod common; - -mod manager_tests { - use super::common; - use foundry_local_sdk::FoundryLocalManager; - - #[test] - fn should_initialize_successfully() { - let config = common::test_config(); - let manager = FoundryLocalManager::create(config); - assert!( - manager.is_ok(), - "Manager creation failed: {:?}", - manager.err() - ); - } - - #[test] - fn should_return_catalog_with_non_empty_name() { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let name = catalog.name(); - assert!(!name.is_empty(), "Catalog name should not be empty"); - } -} - -mod catalog_tests { - use super::common; - use foundry_local_sdk::Catalog; - - fn catalog() -> &'static Catalog { - common::get_test_manager().catalog() - } - - #[test] - fn should_initialize_with_catalog_name() { - let cat = catalog(); - let name = cat.name(); - assert!(!name.is_empty(), "Catalog name must not be empty"); - } - - #[tokio::test] - async fn should_list_models() { - let cat = catalog(); - let models = cat.get_models().await.expect("get_models failed"); - - assert!( - !models.is_empty(), - "Expected at least one model in the catalog" - ); - - let found = models.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); - assert!( - found, - "Test model '{}' not found in catalog", - common::TEST_MODEL_ALIAS - ); - } - - #[tokio::test] - async fn should_get_model_by_alias() { - let cat = catalog(); - let model = cat - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); - } - - #[tokio::test] - async fn should_throw_when_getting_model_with_empty_alias() { - let cat = catalog(); - let result = cat.get_model("").await; - assert!(result.is_err(), "Expected error for empty alias"); - - let err_msg = result.unwrap_err().to_string(); - assert!( - err_msg.contains("Model alias must be a non-empty string"), - "Unexpected error message: {err_msg}" - ); - } - - #[tokio::test] - async fn should_throw_when_getting_model_with_unknown_alias() { - let cat = catalog(); - let result = cat.get_model("unknown-nonexistent-model-alias").await; - assert!(result.is_err(), "Expected error for unknown alias"); - - let err_msg = result.unwrap_err().to_string(); - assert!( - err_msg.contains("Unknown model alias"), - "Error should mention unknown alias: {err_msg}" - ); - assert!( - err_msg.contains("Available"), - "Error should list available models: {err_msg}" - ); - } - - #[tokio::test] - async fn should_get_cached_models() { - let cat = catalog(); - let cached = cat - .get_cached_models() - .await - .expect("get_cached_models failed"); - - assert!(!cached.is_empty(), "Expected at least one cached model"); - - let found = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); - assert!( - found, - "Test model '{}' should be in the cached models list", - common::TEST_MODEL_ALIAS - ); - } - - #[tokio::test] - async fn should_throw_when_getting_model_variant_with_empty_id() { - let cat = catalog(); - let result = cat.get_model_variant("").await; - assert!(result.is_err(), "Expected error for empty variant ID"); - } - - #[tokio::test] - async fn should_throw_when_getting_model_variant_with_unknown_id() { - let cat = catalog(); - let result = cat - .get_model_variant("unknown-nonexistent-variant-id") - .await; - assert!(result.is_err(), "Expected error for unknown variant ID"); - } -} - -mod model_tests { - use super::common; - - #[tokio::test] - async fn should_verify_cached_models_from_test_data_shared() { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let cached = catalog - .get_cached_models() - .await - .expect("get_cached_models failed"); - - let has_qwen = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); - assert!( - has_qwen, - "'{}' should be present in cached models", - common::TEST_MODEL_ALIAS - ); - - let has_whisper = cached - .iter() - .any(|m| m.alias() == common::WHISPER_MODEL_ALIAS); - assert!( - has_whisper, - "'{}' should be present in cached models", - common::WHISPER_MODEL_ALIAS - ); - } - - #[tokio::test] - async fn should_load_and_unload_model() { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let model = catalog - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - model.load().await.expect("model.load() failed"); - assert!( - model.is_loaded().await.expect("is_loaded check failed"), - "Model should be loaded after load()" - ); - - model.unload().await.expect("model.unload() failed"); - assert!( - !model.is_loaded().await.expect("is_loaded check failed"), - "Model should not be loaded after unload()" - ); - } - - // ── Introspection ──────────────────────────────────────────────────── - - #[tokio::test] - async fn should_expose_alias() { - let manager = common::get_test_manager(); - let model = manager - .catalog() - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); - } - - #[tokio::test] - async fn should_expose_non_empty_id() { - let manager = common::get_test_manager(); - let model = manager - .catalog() - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - println!("Model id: {}", model.id()); - - assert!( - !model.id().is_empty(), - "Model id() should be a non-empty string" - ); - } - - #[tokio::test] - async fn should_have_at_least_one_variant() { - let manager = common::get_test_manager(); - let model = manager - .catalog() - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - let variants = model.variants(); - println!("Model has {} variant(s)", variants.len()); - - assert!( - !variants.is_empty(), - "Model should have at least one variant" - ); - } - - #[tokio::test] - async fn should_have_selected_variant_matching_id() { - let manager = common::get_test_manager(); - let model = manager - .catalog() - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - let selected = model.selected_variant(); - assert_eq!( - selected.id(), - model.id(), - "selected_variant().id() should match model.id()" - ); - } - - #[tokio::test] - async fn should_report_cached_model_as_cached() { - let manager = common::get_test_manager(); - let model = manager - .catalog() - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - let cached = model.is_cached().await.expect("is_cached() should succeed"); - assert!( - cached, - "Test model '{}' should be cached (from test-data-shared)", - common::TEST_MODEL_ALIAS - ); - } - - #[tokio::test] - async fn should_return_non_empty_path_for_cached_model() { - let manager = common::get_test_manager(); - let model = manager - .catalog() - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - let path = model.path().await.expect("path() should succeed"); - println!("Model path: {path}"); - - assert!( - !path.is_empty(), - "Cached model should have a non-empty path" - ); - } - - #[tokio::test] - async fn should_select_variant_by_id() { - let manager = common::get_test_manager(); - let mut model = manager - .catalog() - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - let first_variant_id = model.variants()[0].id().to_string(); - model - .select_variant(&first_variant_id) - .expect("select_variant should succeed"); - assert_eq!( - model.id(), - first_variant_id, - "After select_variant, id() should match the selected variant" - ); - } - - #[tokio::test] - async fn should_fail_to_select_unknown_variant() { - let manager = common::get_test_manager(); - let mut model = manager - .catalog() - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - - let result = model.select_variant("nonexistent-variant-id"); - assert!( - result.is_err(), - "select_variant with unknown ID should fail" - ); - - let err_msg = result.unwrap_err().to_string(); - assert!( - err_msg.contains("not found"), - "Error should mention 'not found': {err_msg}" - ); - } -} - -mod model_load_manager_tests { - use super::common; - - async fn get_test_model() -> foundry_local_sdk::Model { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - catalog - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed") - } - - #[tokio::test] - async fn should_load_model_using_core_interop() { - let model = get_test_model().await; - model.load().await.expect("model.load() failed"); - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_unload_model_using_core_interop() { - let model = get_test_model().await; - model.load().await.expect("model.load() failed"); - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_list_loaded_models_using_core_interop() { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - - let loaded = catalog - .get_loaded_models() - .await - .expect("catalog.get_loaded_models() failed"); - - let _ = loaded; - } - - #[tokio::test] - #[ignore = "requires running web service"] - async fn should_load_and_unload_model_using_external_service() { - if common::is_running_in_ci() { - eprintln!("Skipping external-service test in CI"); - return; - } - - let manager = common::get_test_manager(); - let model = get_test_model().await; - - let _urls = manager - .start_web_service() - .await - .expect("start_web_service failed"); - - model - .load() - .await - .expect("load via external service failed"); - - model - .unload() - .await - .expect("unload via external service failed"); - } - - #[tokio::test] - #[ignore = "requires running web service"] - async fn should_list_loaded_models_using_external_service() { - if common::is_running_in_ci() { - eprintln!("Skipping external-service test in CI"); - return; - } - - let manager = common::get_test_manager(); - - let _urls = manager - .start_web_service() - .await - .expect("start_web_service failed"); - - let catalog = manager.catalog(); - let loaded = catalog - .get_loaded_models() - .await - .expect("get_loaded_models via external service failed"); - - let _ = loaded; - } -} - -mod chat_client_tests { - use super::common; - use foundry_local_sdk::openai::ChatClient; - use foundry_local_sdk::{ - ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, - ChatCompletionRequestSystemMessage, ChatCompletionRequestToolMessage, - ChatCompletionRequestUserMessage, ChatToolChoice, - }; - use serde_json::json; - use tokio_stream::StreamExt; - - async fn setup_chat_client() -> (ChatClient, foundry_local_sdk::Model) { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let model = catalog - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - model.load().await.expect("model.load() failed"); - - let mut client = model.create_chat_client(); - client.max_tokens(500).temperature(0.0); - (client, model) - } - - fn user_message(content: &str) -> ChatCompletionRequestMessage { - ChatCompletionRequestUserMessage::from(content).into() - } - - fn system_message(content: &str) -> ChatCompletionRequestMessage { - ChatCompletionRequestSystemMessage::from(content).into() - } - - fn assistant_message(content: &str) -> ChatCompletionRequestMessage { - serde_json::from_value(json!({ "role": "assistant", "content": content })) - .expect("failed to construct assistant message") - } - - #[tokio::test] - async fn should_perform_chat_completion() { - let (client, model) = setup_chat_client().await; - let messages = vec![ - system_message("You are a helpful math assistant. Respond with just the answer."), - user_message("What is 7*6?"), - ]; - - let response = client - .complete_chat(&messages, None) - .await - .expect("complete_chat failed"); - let content = response - .choices - .first() - .and_then(|c| c.message.content.as_deref()) - .unwrap_or(""); - println!("Response: {content}"); - - println!("REST response: {content}"); - - assert!( - content.contains("42"), - "Expected response to contain '42', got: {content}" - ); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_perform_streaming_chat_completion() { - let (client, model) = setup_chat_client().await; - let mut messages = vec![ - system_message("You are a helpful math assistant. Respond with just the answer."), - user_message("What is 7*6?"), - ]; - - let mut first_result = String::new(); - let mut stream = client - .complete_streaming_chat(&messages, None) - .await - .expect("streaming chat (first turn) setup failed"); - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - if let Some(choice) = chunk.choices.first() { - if let Some(ref content) = choice.delta.content { - first_result.push_str(content); - } - } - } - stream.close().await.expect("stream close failed"); - - println!("First turn: {first_result}"); - - assert!( - first_result.contains("42"), - "First turn should contain '42', got: {first_result}" - ); - - messages.push(assistant_message(&first_result)); - messages.push(user_message("Now add 25 to that result.")); - - let mut second_result = String::new(); - let mut stream = client - .complete_streaming_chat(&messages, None) - .await - .expect("streaming chat (follow-up) setup failed"); - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - if let Some(choice) = chunk.choices.first() { - if let Some(ref content) = choice.delta.content { - second_result.push_str(content); - } - } - } - stream.close().await.expect("stream close failed"); - - println!("Follow-up: {second_result}"); - - assert!( - second_result.contains("67"), - "Follow-up should contain '67', got: {second_result}" - ); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_throw_when_completing_chat_with_empty_messages() { - let (client, model) = setup_chat_client().await; - let messages: Vec = vec![]; - - let result = client.complete_chat(&messages, None).await; - assert!(result.is_err(), "Expected error for empty messages"); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_throw_when_completing_streaming_chat_with_empty_messages() { - let (client, model) = setup_chat_client().await; - let messages: Vec = vec![]; - - let result = client.complete_streaming_chat(&messages, None).await; - assert!( - result.is_err(), - "Expected error for empty messages in streaming" - ); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_throw_when_completing_streaming_chat_with_invalid_callback() { - let (client, model) = setup_chat_client().await; - let messages: Vec = vec![]; - - let result = client.complete_streaming_chat(&messages, None).await; - assert!(result.is_err(), "Expected error even with empty messages"); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_perform_tool_calling_chat_completion_non_streaming() { - let (mut client, model) = setup_chat_client().await; - client.tool_choice(ChatToolChoice::Required); - - let tools = vec![common::get_multiply_tool()]; - let mut messages = vec![ - system_message("You are a math assistant. Use the multiply tool to answer."), - user_message("What is 6 times 7?"), - ]; - - let response = client - .complete_chat(&messages, Some(&tools)) - .await - .expect("complete_chat with tools failed"); - - let choice = response - .choices - .first() - .expect("Expected at least one choice"); - let tool_calls = choice - .message - .tool_calls - .as_ref() - .expect("Expected tool_calls"); - assert!( - !tool_calls.is_empty(), - "Expected at least one tool call in the response" - ); - - let tool_call = match &tool_calls[0] { - ChatCompletionMessageToolCalls::Function(tc) => tc, - _ => panic!("Expected a function tool call"), - }; - assert_eq!( - tool_call.function.name, "multiply", - "Expected tool call to 'multiply'" - ); - - let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments) - .expect("Failed to parse tool call arguments"); - let a = args["a"].as_f64().unwrap_or(0.0); - let b = args["b"].as_f64().unwrap_or(0.0); - let product = (a * b) as i64; - - let tool_call_id = &tool_call.id; - let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ - "role": "assistant", - "content": null, - "tool_calls": [{ - "id": tool_call_id, - "type": "function", - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - } - }] - })) - .expect("failed to construct assistant message"); - messages.push(assistant_msg); - messages.push( - ChatCompletionRequestToolMessage { - content: product.to_string().into(), - tool_call_id: tool_call_id.clone(), - } - .into(), - ); - - client.tool_choice(ChatToolChoice::Auto); - - let final_response = client - .complete_chat(&messages, Some(&tools)) - .await - .expect("follow-up complete_chat with tools failed"); - let content = final_response - .choices - .first() - .and_then(|c| c.message.content.as_deref()) - .unwrap_or(""); - - println!("Tool call result: {content}"); - - assert!( - content.contains("42"), - "Final answer should contain '42', got: {content}" - ); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_perform_tool_calling_chat_completion_streaming() { - let (mut client, model) = setup_chat_client().await; - client.tool_choice(ChatToolChoice::Required); - - let tools = vec![common::get_multiply_tool()]; - let mut messages = vec![ - system_message("You are a math assistant. Use the multiply tool to answer."), - user_message("What is 6 times 7?"), - ]; - - let mut tool_call_name = String::new(); - let mut tool_call_args = String::new(); - let mut tool_call_id = String::new(); - - let mut stream = client - .complete_streaming_chat(&messages, Some(&tools)) - .await - .expect("streaming tool call setup failed"); - - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - if let Some(choice) = chunk.choices.first() { - if let Some(ref tool_calls) = choice.delta.tool_calls { - for call in tool_calls { - if let Some(ref func) = call.function { - if let Some(ref name) = func.name { - tool_call_name.push_str(name); - } - if let Some(ref args) = func.arguments { - tool_call_args.push_str(args); - } - } - if let Some(ref id) = call.id { - tool_call_id = id.clone(); - } - } - } - } - } - stream.close().await.expect("stream close failed"); - - assert_eq!( - tool_call_name, "multiply", - "Expected streamed tool call to 'multiply'" - ); - - let args: serde_json::Value = - serde_json::from_str(&tool_call_args).unwrap_or_else(|_| json!({})); - let a = args["a"].as_f64().unwrap_or(0.0); - let b = args["b"].as_f64().unwrap_or(0.0); - let product = (a * b) as i64; - - let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ - "role": "assistant", - "tool_calls": [{ - "id": tool_call_id, - "type": "function", - "function": { - "name": tool_call_name, - "arguments": tool_call_args - } - }] - })) - .expect("failed to construct assistant message"); - messages.push(assistant_msg); - messages.push( - ChatCompletionRequestToolMessage { - content: product.to_string().into(), - tool_call_id: tool_call_id.clone(), - } - .into(), - ); - - client.tool_choice(ChatToolChoice::Auto); - - let mut final_result = String::new(); - let mut stream = client - .complete_streaming_chat(&messages, Some(&tools)) - .await - .expect("streaming follow-up setup failed"); - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - if let Some(choice) = chunk.choices.first() { - if let Some(ref content) = choice.delta.content { - final_result.push_str(content); - } - } - } - stream.close().await.expect("stream close failed"); - - println!("Streamed tool call result: {final_result}"); - - assert!( - final_result.contains("42"), - "Streamed final answer should contain '42', got: {final_result}" - ); - - model.unload().await.expect("model.unload() failed"); - } -} - -mod audio_client_tests { - use super::common; - use foundry_local_sdk::openai::AudioClient; - use tokio_stream::StreamExt; - - async fn setup_audio_client() -> (AudioClient, foundry_local_sdk::Model) { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let model = catalog - .get_model(common::WHISPER_MODEL_ALIAS) - .await - .expect("get_model(whisper-tiny) failed"); - model.load().await.expect("model.load() failed"); - (model.create_audio_client(), model) - } - - fn audio_file() -> String { - common::get_audio_file_path().to_string_lossy().into_owned() - } - - #[tokio::test] - async fn should_transcribe_audio_without_streaming() { - let (mut client, model) = setup_audio_client().await; - client.language("en").temperature(0.0); - let response = client - .transcribe(&audio_file()) - .await - .expect("transcribe failed"); - - assert!( - response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), - "Transcription should contain expected text, got: {}", - response.text - ); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_transcribe_audio_without_streaming_with_temperature() { - let (mut client, model) = setup_audio_client().await; - client.language("en").temperature(0.0); - - let response = client - .transcribe(&audio_file()) - .await - .expect("transcribe with temperature failed"); - - assert!( - response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), - "Transcription should contain expected text, got: {}", - response.text - ); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_transcribe_audio_with_streaming() { - let (mut client, model) = setup_audio_client().await; - client.language("en").temperature(0.0); - let mut full_text = String::new(); - - let mut stream = client - .transcribe_streaming(&audio_file()) - .await - .expect("transcribe_streaming setup failed"); - - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - full_text.push_str(&chunk.text); - } - stream.close().await.expect("stream close failed"); - - println!("Streamed transcription: {full_text}"); - - assert!( - full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), - "Streamed transcription should contain expected text, got: {full_text}" - ); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_transcribe_audio_with_streaming_with_temperature() { - let (mut client, model) = setup_audio_client().await; - client.language("en").temperature(0.0); - - let mut full_text = String::new(); - - let mut stream = client - .transcribe_streaming(&audio_file()) - .await - .expect("transcribe_streaming with temperature setup failed"); - - while let Some(chunk) = stream.next().await { - let chunk = chunk.expect("stream chunk error"); - full_text.push_str(&chunk.text); - } - stream.close().await.expect("stream close failed"); - - println!("Streamed transcription: {full_text}"); - - assert!( - full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), - "Streamed transcription should contain expected text, got: {full_text}" - ); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_throw_when_transcribing_with_empty_audio_file_path() { - let (client, model) = setup_audio_client().await; - let result = client.transcribe("").await; - assert!(result.is_err(), "Expected error for empty audio file path"); - - model.unload().await.expect("model.unload() failed"); - } - - #[tokio::test] - async fn should_throw_when_transcribing_streaming_with_empty_audio_file_path() { - let (client, model) = setup_audio_client().await; - let result = client.transcribe_streaming("").await; - assert!( - result.is_err(), - "Expected error for empty audio file path in streaming" - ); - - model.unload().await.expect("model.unload() failed"); - } -} - -mod web_service_tests { - use super::common; - use serde_json::json; - - /// Start the web service, make a non-streaming POST to v1/chat/completions, - /// verify we get a valid response, then stop the service. - #[tokio::test] - async fn should_complete_chat_via_rest_api() { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let model = catalog - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - model.load().await.expect("model.load() failed"); - - let urls = manager - .start_web_service() - .await - .expect("start_web_service failed"); - let base_url = urls.first().expect("no URL returned").trim_end_matches('/'); - - let client = reqwest::Client::new(); - let resp = client - .post(format!("{base_url}/v1/chat/completions")) - .json(&json!({ - "model": model.id(), - "messages": [ - { "role": "system", "content": "You are a helpful math assistant. Respond with just the answer." }, - { "role": "user", "content": "What is 7*6?" } - ], - "max_tokens": 500, - "temperature": 0.0, - "stream": false - })) - .send() - .await - .expect("HTTP request failed"); - - assert!( - resp.status().is_success(), - "Expected 2xx, got {}", - resp.status() - ); - - let body: serde_json::Value = resp.json().await.expect("failed to parse response JSON"); - let content = body - .pointer("/choices/0/message/content") - .and_then(|v| v.as_str()) - .unwrap_or(""); - - println!("REST response: {content}"); - - assert!( - content.contains("42"), - "Expected response to contain '42', got: {content}" - ); - - manager - .stop_web_service() - .await - .expect("stop_web_service failed"); - model.unload().await.expect("model.unload() failed"); - } - - /// Start the web service, make a streaming POST to v1/chat/completions, - /// collect SSE chunks, verify we get a valid streamed response. - #[tokio::test] - async fn should_stream_chat_via_rest_api() { - let manager = common::get_test_manager(); - let catalog = manager.catalog(); - let model = catalog - .get_model(common::TEST_MODEL_ALIAS) - .await - .expect("get_model failed"); - model.load().await.expect("model.load() failed"); - - let urls = manager - .start_web_service() - .await - .expect("start_web_service failed"); - let base_url = urls.first().expect("no URL returned").trim_end_matches('/'); - - let client = reqwest::Client::new(); - let mut response = client - .post(format!("{base_url}/v1/chat/completions")) - .json(&json!({ - "model": model.id(), - "messages": [ - { "role": "system", "content": "You are a helpful math assistant. Respond with just the answer." }, - { "role": "user", "content": "What is 7*6?" } - ], - "max_tokens": 500, - "temperature": 0.0, - "stream": true - })) - .send() - .await - .expect("HTTP request failed"); - - assert!( - response.status().is_success(), - "Expected 2xx, got {}", - response.status() - ); - - let mut full_text = String::new(); - while let Some(chunk) = response.chunk().await.expect("chunk read failed") { - let text = String::from_utf8_lossy(&chunk); - for line in text.lines() { - let line = line.trim(); - if let Some(data) = line.strip_prefix("data: ") { - if data == "[DONE]" { - break; - } - if let Ok(parsed) = serde_json::from_str::(data) { - if let Some(content) = parsed - .pointer("/choices/0/delta/content") - .and_then(|v| v.as_str()) - { - full_text.push_str(content); - } - } - } - } - } - - println!("REST streamed response: {full_text}"); - - assert!( - full_text.contains("42"), - "Expected streamed response to contain '42', got: {full_text}" - ); - - manager - .stop_web_service() - .await - .expect("stop_web_service failed"); - model.unload().await.expect("model.unload() failed"); - } - - /// urls() should return the listening addresses after start_web_service. - #[tokio::test] - async fn should_expose_urls_after_start() { - let manager = common::get_test_manager(); - - let urls = manager - .start_web_service() - .await - .expect("start_web_service failed"); - println!("Web service URLs: {urls:?}"); - assert!(!urls.is_empty(), "start_web_service should return URLs"); - - let cached_urls = manager.urls(); - assert_eq!( - urls, cached_urls, - "urls() should match what start_web_service returned" - ); - - manager - .stop_web_service() - .await - .expect("stop_web_service failed"); - } -} diff --git a/sdk_v2/rust/tests/integration/audio_client_test.rs b/sdk_v2/rust/tests/integration/audio_client_test.rs new file mode 100644 index 00000000..6cc1d0cd --- /dev/null +++ b/sdk_v2/rust/tests/integration/audio_client_test.rs @@ -0,0 +1,131 @@ +use super::common; +use foundry_local_sdk::openai::AudioClient; +use tokio_stream::StreamExt; + +async fn setup_audio_client() -> (AudioClient, foundry_local_sdk::Model) { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::WHISPER_MODEL_ALIAS) + .await + .expect("get_model(whisper-tiny) failed"); + model.load().await.expect("model.load() failed"); + (model.create_audio_client(), model) +} + +fn audio_file() -> String { + common::get_audio_file_path().to_string_lossy().into_owned() +} + +#[tokio::test] +async fn should_transcribe_audio_without_streaming() { + let (mut client, model) = setup_audio_client().await; + client.language("en").temperature(0.0); + let response = client + .transcribe(&audio_file()) + .await + .expect("transcribe failed"); + + assert!( + response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Transcription should contain expected text, got: {}", + response.text + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_transcribe_audio_without_streaming_with_temperature() { + let (mut client, model) = setup_audio_client().await; + client.language("en").temperature(0.5); + + let response = client + .transcribe(&audio_file()) + .await + .expect("transcribe with temperature failed"); + + assert!( + response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Transcription should contain expected text, got: {}", + response.text + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_transcribe_audio_with_streaming() { + let (mut client, model) = setup_audio_client().await; + client.language("en").temperature(0.0); + let mut full_text = String::new(); + + let mut stream = client + .transcribe_streaming(&audio_file()) + .await + .expect("transcribe_streaming setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + full_text.push_str(&chunk.text); + } + stream.close().await.expect("stream close failed"); + + println!("Streamed transcription: {full_text}"); + + assert!( + full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Streamed transcription should contain expected text, got: {full_text}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_transcribe_audio_with_streaming_with_temperature() { + let (mut client, model) = setup_audio_client().await; + client.language("en").temperature(0.5); + + let mut full_text = String::new(); + + let mut stream = client + .transcribe_streaming(&audio_file()) + .await + .expect("transcribe_streaming with temperature setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + full_text.push_str(&chunk.text); + } + stream.close().await.expect("stream close failed"); + + println!("Streamed transcription: {full_text}"); + + assert!( + full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Streamed transcription should contain expected text, got: {full_text}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_transcribing_with_empty_audio_file_path() { + let (client, model) = setup_audio_client().await; + let result = client.transcribe("").await; + assert!(result.is_err(), "Expected error for empty audio file path"); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_transcribing_streaming_with_empty_audio_file_path() { + let (client, model) = setup_audio_client().await; + let result = client.transcribe_streaming("").await; + assert!( + result.is_err(), + "Expected error for empty audio file path in streaming" + ); + + model.unload().await.expect("model.unload() failed"); +} diff --git a/sdk_v2/rust/tests/integration/catalog_test.rs b/sdk_v2/rust/tests/integration/catalog_test.rs new file mode 100644 index 00000000..d418c7a7 --- /dev/null +++ b/sdk_v2/rust/tests/integration/catalog_test.rs @@ -0,0 +1,106 @@ +use super::common; +use foundry_local_sdk::Catalog; + +fn catalog() -> &'static Catalog { + common::get_test_manager().catalog() +} + +#[test] +fn should_initialize_with_catalog_name() { + let cat = catalog(); + let name = cat.name(); + assert!(!name.is_empty(), "Catalog name must not be empty"); +} + +#[tokio::test] +async fn should_list_models() { + let cat = catalog(); + let models = cat.get_models().await.expect("get_models failed"); + + assert!( + !models.is_empty(), + "Expected at least one model in the catalog" + ); + + let found = models.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + found, + "Test model '{}' not found in catalog", + common::TEST_MODEL_ALIAS + ); +} + +#[tokio::test] +async fn should_get_model_by_alias() { + let cat = catalog(); + let model = cat + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); +} + +#[tokio::test] +async fn should_throw_when_getting_model_with_empty_alias() { + let cat = catalog(); + let result = cat.get_model("").await; + assert!(result.is_err(), "Expected error for empty alias"); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Model alias must be a non-empty string"), + "Unexpected error message: {err_msg}" + ); +} + +#[tokio::test] +async fn should_throw_when_getting_model_with_unknown_alias() { + let cat = catalog(); + let result = cat.get_model("unknown-nonexistent-model-alias").await; + assert!(result.is_err(), "Expected error for unknown alias"); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Unknown model alias"), + "Error should mention unknown alias: {err_msg}" + ); + assert!( + err_msg.contains("Available"), + "Error should list available models: {err_msg}" + ); +} + +#[tokio::test] +async fn should_get_cached_models() { + let cat = catalog(); + let cached = cat + .get_cached_models() + .await + .expect("get_cached_models failed"); + + assert!(!cached.is_empty(), "Expected at least one cached model"); + + let found = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + found, + "Test model '{}' should be in the cached models list", + common::TEST_MODEL_ALIAS + ); +} + +#[tokio::test] +async fn should_throw_when_getting_model_variant_with_empty_id() { + let cat = catalog(); + let result = cat.get_model_variant("").await; + assert!(result.is_err(), "Expected error for empty variant ID"); +} + +#[tokio::test] +async fn should_throw_when_getting_model_variant_with_unknown_id() { + let cat = catalog(); + let result = cat + .get_model_variant("unknown-nonexistent-variant-id") + .await; + assert!(result.is_err(), "Expected error for unknown variant ID"); +} diff --git a/sdk_v2/rust/tests/integration/chat_client_test.rs b/sdk_v2/rust/tests/integration/chat_client_test.rs new file mode 100644 index 00000000..90f53709 --- /dev/null +++ b/sdk_v2/rust/tests/integration/chat_client_test.rs @@ -0,0 +1,342 @@ +use super::common; +use foundry_local_sdk::openai::ChatClient; +use foundry_local_sdk::{ + ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestToolMessage, + ChatCompletionRequestUserMessage, ChatToolChoice, +}; +use serde_json::json; +use tokio_stream::StreamExt; + +async fn setup_chat_client() -> (ChatClient, foundry_local_sdk::Model) { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let mut client = model.create_chat_client(); + client.max_tokens(500).temperature(0.0); + (client, model) +} + +fn user_message(content: &str) -> ChatCompletionRequestMessage { + ChatCompletionRequestUserMessage::from(content).into() +} + +fn system_message(content: &str) -> ChatCompletionRequestMessage { + ChatCompletionRequestSystemMessage::from(content).into() +} + +fn assistant_message(content: &str) -> ChatCompletionRequestMessage { + serde_json::from_value(json!({ "role": "assistant", "content": content })) + .expect("failed to construct assistant message") +} + +#[tokio::test] +async fn should_perform_chat_completion() { + let (client, model) = setup_chat_client().await; + let messages = vec![ + system_message("You are a helpful math assistant. Respond with just the answer."), + user_message("What is 7*6?"), + ]; + + let response = client + .complete_chat(&messages, None) + .await + .expect("complete_chat failed"); + let content = response + .choices + .first() + .and_then(|c| c.message.content.as_deref()) + .unwrap_or(""); + println!("Response: {content}"); + + println!("REST response: {content}"); + + assert!( + content.contains("42"), + "Expected response to contain '42', got: {content}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_perform_streaming_chat_completion() { + let (client, model) = setup_chat_client().await; + let mut messages = vec![ + system_message("You are a helpful math assistant. Respond with just the answer."), + user_message("What is 7*6?"), + ]; + + let mut first_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, None) + .await + .expect("streaming chat (first turn) setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + first_result.push_str(content); + } + } + } + stream.close().await.expect("stream close failed"); + + println!("First turn: {first_result}"); + + assert!( + first_result.contains("42"), + "First turn should contain '42', got: {first_result}" + ); + + messages.push(assistant_message(&first_result)); + messages.push(user_message("Now add 25 to that result.")); + + let mut second_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, None) + .await + .expect("streaming chat (follow-up) setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + second_result.push_str(content); + } + } + } + stream.close().await.expect("stream close failed"); + + println!("Follow-up: {second_result}"); + + assert!( + second_result.contains("67"), + "Follow-up should contain '67', got: {second_result}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_completing_chat_with_empty_messages() { + let (client, model) = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_chat(&messages, None).await; + assert!(result.is_err(), "Expected error for empty messages"); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_completing_streaming_chat_with_empty_messages() { + let (client, model) = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_streaming_chat(&messages, None).await; + assert!( + result.is_err(), + "Expected error for empty messages in streaming" + ); + + model.unload().await.expect("model.unload() failed"); +} + +// Note: The "invalid callback" test was removed because it was an exact +// duplicate of should_throw_when_completing_streaming_chat_with_empty_messages. + +#[tokio::test] +async fn should_perform_tool_calling_chat_completion_non_streaming() { + let (mut client, model) = setup_chat_client().await; + client.tool_choice(ChatToolChoice::Required); + + let tools = vec![common::get_multiply_tool()]; + let mut messages = vec![ + system_message("You are a math assistant. Use the multiply tool to answer."), + user_message("What is 6 times 7?"), + ]; + + let response = client + .complete_chat(&messages, Some(&tools)) + .await + .expect("complete_chat with tools failed"); + + let choice = response + .choices + .first() + .expect("Expected at least one choice"); + let tool_calls = choice + .message + .tool_calls + .as_ref() + .expect("Expected tool_calls"); + assert!( + !tool_calls.is_empty(), + "Expected at least one tool call in the response" + ); + + let tool_call = match &tool_calls[0] { + ChatCompletionMessageToolCalls::Function(tc) => tc, + _ => panic!("Expected a function tool call"), + }; + assert_eq!( + tool_call.function.name, "multiply", + "Expected tool call to 'multiply'" + ); + + let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments) + .expect("Failed to parse tool call arguments"); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let product = (a * b) as i64; + + let tool_call_id = &tool_call.id; + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + }] + })) + .expect("failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: product.to_string().into(), + tool_call_id: tool_call_id.clone(), + } + .into(), + ); + + client.tool_choice(ChatToolChoice::Auto); + + let final_response = client + .complete_chat(&messages, Some(&tools)) + .await + .expect("follow-up complete_chat with tools failed"); + let content = final_response + .choices + .first() + .and_then(|c| c.message.content.as_deref()) + .unwrap_or(""); + + println!("Tool call result: {content}"); + + assert!( + content.contains("42"), + "Final answer should contain '42', got: {content}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_perform_tool_calling_chat_completion_streaming() { + let (mut client, model) = setup_chat_client().await; + client.tool_choice(ChatToolChoice::Required); + + let tools = vec![common::get_multiply_tool()]; + let mut messages = vec![ + system_message("You are a math assistant. Use the multiply tool to answer."), + user_message("What is 6 times 7?"), + ]; + + let mut tool_call_name = String::new(); + let mut tool_call_args = String::new(); + let mut tool_call_id = String::new(); + + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await + .expect("streaming tool call setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref tool_calls) = choice.delta.tool_calls { + for call in tool_calls { + if let Some(ref func) = call.function { + if let Some(ref name) = func.name { + tool_call_name.push_str(name); + } + if let Some(ref args) = func.arguments { + tool_call_args.push_str(args); + } + } + if let Some(ref id) = call.id { + tool_call_id = id.clone(); + } + } + } + } + } + stream.close().await.expect("stream close failed"); + + assert_eq!( + tool_call_name, "multiply", + "Expected streamed tool call to 'multiply'" + ); + + let args: serde_json::Value = + serde_json::from_str(&tool_call_args).unwrap_or_else(|_| json!({})); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let product = (a * b) as i64; + + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call_name, + "arguments": tool_call_args + } + }] + })) + .expect("failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: product.to_string().into(), + tool_call_id: tool_call_id.clone(), + } + .into(), + ); + + client.tool_choice(ChatToolChoice::Auto); + + let mut final_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await + .expect("streaming follow-up setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + final_result.push_str(content); + } + } + } + stream.close().await.expect("stream close failed"); + + println!("Streamed tool call result: {final_result}"); + + assert!( + final_result.contains("42"), + "Streamed final answer should contain '42', got: {final_result}" + ); + + model.unload().await.expect("model.unload() failed"); +} diff --git a/sdk_v2/rust/tests/common/mod.rs b/sdk_v2/rust/tests/integration/common/mod.rs similarity index 100% rename from sdk_v2/rust/tests/common/mod.rs rename to sdk_v2/rust/tests/integration/common/mod.rs diff --git a/sdk_v2/rust/tests/integration/main.rs b/sdk_v2/rust/tests/integration/main.rs new file mode 100644 index 00000000..a18a29a6 --- /dev/null +++ b/sdk_v2/rust/tests/integration/main.rs @@ -0,0 +1,16 @@ +//! Single integration test binary for the Foundry Local Rust SDK. +//! +//! All test modules are compiled into one binary so the native core is only +//! initialised once (via the `OnceLock` singleton in `FoundryLocalManager`). +//! Running them as separate binaries causes "already initialized" errors +//! because the .NET native runtime retains state across process-level +//! library loads. + +mod common; + +mod manager_test; +mod catalog_test; +mod model_test; +mod chat_client_test; +mod audio_client_test; +mod web_service_test; diff --git a/sdk_v2/rust/tests/integration/manager_test.rs b/sdk_v2/rust/tests/integration/manager_test.rs new file mode 100644 index 00000000..aa3e0614 --- /dev/null +++ b/sdk_v2/rust/tests/integration/manager_test.rs @@ -0,0 +1,21 @@ +use super::common; +use foundry_local_sdk::FoundryLocalManager; + +#[test] +fn should_initialize_successfully() { + let config = common::test_config(); + let manager = FoundryLocalManager::create(config); + assert!( + manager.is_ok(), + "Manager creation failed: {:?}", + manager.err() + ); +} + +#[test] +fn should_return_catalog_with_non_empty_name() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let name = catalog.name(); + assert!(!name.is_empty(), "Catalog name should not be empty"); +} diff --git a/sdk_v2/rust/tests/integration/model_test.rs b/sdk_v2/rust/tests/integration/model_test.rs new file mode 100644 index 00000000..8730d5bd --- /dev/null +++ b/sdk_v2/rust/tests/integration/model_test.rs @@ -0,0 +1,285 @@ +use super::common; + +// ── Cached model verification ──────────────────────────────────────────────── + +#[tokio::test] +async fn should_verify_cached_models_from_test_data_shared() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let cached = catalog + .get_cached_models() + .await + .expect("get_cached_models failed"); + + let has_qwen = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + has_qwen, + "'{}' should be present in cached models", + common::TEST_MODEL_ALIAS + ); + + let has_whisper = cached + .iter() + .any(|m| m.alias() == common::WHISPER_MODEL_ALIAS); + assert!( + has_whisper, + "'{}' should be present in cached models", + common::WHISPER_MODEL_ALIAS + ); +} + +// ── Load / Unload ──────────────────────────────────────────────────────────── + +#[tokio::test] +async fn should_load_and_unload_model() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + model.load().await.expect("model.load() failed"); + assert!( + model.is_loaded().await.expect("is_loaded check failed"), + "Model should be loaded after load()" + ); + + model.unload().await.expect("model.unload() failed"); + assert!( + !model.is_loaded().await.expect("is_loaded check failed"), + "Model should not be loaded after unload()" + ); +} + +// ── Introspection ──────────────────────────────────────────────────────────── + +#[tokio::test] +async fn should_expose_alias() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); +} + +#[tokio::test] +async fn should_expose_non_empty_id() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + println!("Model id: {}", model.id()); + + assert!( + !model.id().is_empty(), + "Model id() should be a non-empty string" + ); +} + +#[tokio::test] +async fn should_have_at_least_one_variant() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let variants = model.variants(); + println!("Model has {} variant(s)", variants.len()); + + assert!( + !variants.is_empty(), + "Model should have at least one variant" + ); +} + +#[tokio::test] +async fn should_have_selected_variant_matching_id() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let selected = model.selected_variant(); + assert_eq!( + selected.id(), + model.id(), + "selected_variant().id() should match model.id()" + ); +} + +#[tokio::test] +async fn should_report_cached_model_as_cached() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let cached = model.is_cached().await.expect("is_cached() should succeed"); + assert!( + cached, + "Test model '{}' should be cached (from test-data-shared)", + common::TEST_MODEL_ALIAS + ); +} + +#[tokio::test] +async fn should_return_non_empty_path_for_cached_model() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let path = model.path().await.expect("path() should succeed"); + println!("Model path: {path}"); + + assert!( + !path.is_empty(), + "Cached model should have a non-empty path" + ); +} + +#[tokio::test] +async fn should_select_variant_by_id() { + let manager = common::get_test_manager(); + let mut model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let first_variant_id = model.variants()[0].id().to_string(); + model + .select_variant(&first_variant_id) + .expect("select_variant should succeed"); + assert_eq!( + model.id(), + first_variant_id, + "After select_variant, id() should match the selected variant" + ); +} + +#[tokio::test] +async fn should_fail_to_select_unknown_variant() { + let manager = common::get_test_manager(); + let mut model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let result = model.select_variant("nonexistent-variant-id"); + assert!( + result.is_err(), + "select_variant with unknown ID should fail" + ); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("not found"), + "Error should mention 'not found': {err_msg}" + ); +} + +// ── Load manager (core interop) ────────────────────────────────────────────── + +async fn get_test_model() -> foundry_local_sdk::Model { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed") +} + +#[tokio::test] +async fn should_load_model_using_core_interop() { + let model = get_test_model().await; + model.load().await.expect("model.load() failed"); + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_unload_model_using_core_interop() { + let model = get_test_model().await; + model.load().await.expect("model.load() failed"); + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_list_loaded_models_using_core_interop() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + + let loaded = catalog + .get_loaded_models() + .await + .expect("catalog.get_loaded_models() failed"); + + let _ = loaded; +} + +#[tokio::test] +#[ignore = "requires running web service"] +async fn should_load_and_unload_model_using_external_service() { + if common::is_running_in_ci() { + eprintln!("Skipping external-service test in CI"); + return; + } + + let manager = common::get_test_manager(); + let model = get_test_model().await; + + let _urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + + model + .load() + .await + .expect("load via external service failed"); + + model + .unload() + .await + .expect("unload via external service failed"); +} + +#[tokio::test] +#[ignore = "requires running web service"] +async fn should_list_loaded_models_using_external_service() { + if common::is_running_in_ci() { + eprintln!("Skipping external-service test in CI"); + return; + } + + let manager = common::get_test_manager(); + + let _urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + + let catalog = manager.catalog(); + let loaded = catalog + .get_loaded_models() + .await + .expect("get_loaded_models via external service failed"); + + let _ = loaded; +} diff --git a/sdk_v2/rust/tests/integration/web_service_test.rs b/sdk_v2/rust/tests/integration/web_service_test.rs new file mode 100644 index 00000000..cd9ccfce --- /dev/null +++ b/sdk_v2/rust/tests/integration/web_service_test.rs @@ -0,0 +1,163 @@ +use super::common; +use serde_json::json; + +/// Start the web service, make a non-streaming POST to v1/chat/completions, +/// verify we get a valid response, then stop the service. +#[tokio::test] +async fn should_complete_chat_via_rest_api() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + let base_url = urls.first().expect("no URL returned").trim_end_matches('/'); + + let client = reqwest::Client::new(); + let resp = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&json!({ + "model": model.id(), + "messages": [ + { "role": "system", "content": "You are a helpful math assistant. Respond with just the answer." }, + { "role": "user", "content": "What is 7*6?" } + ], + "max_tokens": 500, + "temperature": 0.0, + "stream": false + })) + .send() + .await + .expect("HTTP request failed"); + + assert!( + resp.status().is_success(), + "Expected 2xx, got {}", + resp.status() + ); + + let body: serde_json::Value = resp.json().await.expect("failed to parse response JSON"); + let content = body + .pointer("/choices/0/message/content") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + println!("REST response: {content}"); + + assert!( + content.contains("42"), + "Expected response to contain '42', got: {content}" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); + model.unload().await.expect("model.unload() failed"); +} + +/// Start the web service, make a streaming POST to v1/chat/completions, +/// collect SSE chunks, verify we get a valid streamed response. +#[tokio::test] +async fn should_stream_chat_via_rest_api() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + let base_url = urls.first().expect("no URL returned").trim_end_matches('/'); + + let client = reqwest::Client::new(); + let mut response = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&json!({ + "model": model.id(), + "messages": [ + { "role": "system", "content": "You are a helpful math assistant. Respond with just the answer." }, + { "role": "user", "content": "What is 7*6?" } + ], + "max_tokens": 500, + "temperature": 0.0, + "stream": true + })) + .send() + .await + .expect("HTTP request failed"); + + assert!( + response.status().is_success(), + "Expected 2xx, got {}", + response.status() + ); + + let mut full_text = String::new(); + while let Some(chunk) = response.chunk().await.expect("chunk read failed") { + let text = String::from_utf8_lossy(&chunk); + for line in text.lines() { + let line = line.trim(); + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + break; + } + if let Ok(parsed) = serde_json::from_str::(data) { + if let Some(content) = parsed + .pointer("/choices/0/delta/content") + .and_then(|v| v.as_str()) + { + full_text.push_str(content); + } + } + } + } + } + + println!("REST streamed response: {full_text}"); + + assert!( + full_text.contains("42"), + "Expected streamed response to contain '42', got: {full_text}" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); + model.unload().await.expect("model.unload() failed"); +} + +/// urls() should return the listening addresses after start_web_service. +#[tokio::test] +async fn should_expose_urls_after_start() { + let manager = common::get_test_manager(); + + let urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + println!("Web service URLs: {urls:?}"); + assert!(!urls.is_empty(), "start_web_service should return URLs"); + + let cached_urls = manager.urls(); + assert_eq!( + urls, cached_urls, + "urls() should match what start_web_service returned" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); +} From 93ae47a07b3427929c2545883749b6f9b5e75f26 Mon Sep 17 00:00:00 2001 From: samkemp Date: Thu, 12 Mar 2026 14:48:09 +0000 Subject: [PATCH 21/25] Address PR #500 review feedback (37 threads) and fix WinAppSDK bootstrapping Resolve all 37 review comments from nenad1002, prathikr, and copilot-bot: Safety & correctness: - Fix Windows buffer deallocation: CoTaskMemFree -> LocalFree (matches FreeHGlobal) - Wrap FFI streaming callback in catch_unwind to prevent UB from panics - Fix singleton TOCTOU race with Once + OnceLock (stable Rust compatible) - Replace all .unwrap() on mutex locks with FoundryLocalError::Internal API design: - Wrap Model/ModelVariant in Arc for efficient sharing (no full clones) - Change ResponseBuffer length to u32, move by value (consume-once) - Return PathBuf for file system paths, remove unused struct fields - Public APIs return Result instead of swallowing errors - Define only platform extension, build native lib filename dynamically - Reduce resolve_library_path to 3 strategies (config, build.rs, exe-sibling) CI/workflow: - Weave Rust jobs into platform-specific blocks (match cs/js format) - Remove ubuntu (known issues), enable tests on all platforms - Fix cargo test command (--tests, --include-ignored), fix boolean default - Add comments for clippy and --allow-dirty Tests: - Split monolithic integration.rs into 6 per-feature test modules - Remove duplicate invalid-callback test - Use non-default temperature (0.5) in with_temperature tests Samples: - Fix {progress:.1}% -> {progress}% in all 4 samples (string truncation bug) Consistency & performance: - Remove stream:true from audio transcription (match JS/C# SDKs) - Isolate unsafe code with SAFETY comments throughout core_interop - Reuse reqwest::Client in ModelLoadManager (connection pooling) - Fix libs_already_present to check specific core library file Bug fix (not in review): - Auto-detect WinAppSDK Bootstrap DLL and set Bootstrap=true in config params before native initialize (matches JS SDK behavior, required for WinML/OpenVINO execution providers) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .github/workflows/build-rust-steps.yml | 8 +- .github/workflows/foundry-local-sdk-build.yml | 35 ++-- .../audio-transcription-example/src/main.rs | 2 +- .../rust/foundry-local-webserver/src/main.rs | 2 +- .../rust/native-chat-completions/src/main.rs | 2 +- .../tool-calling-foundry-local/src/main.rs | 2 +- sdk_v2/rust/build.rs | 24 +-- sdk_v2/rust/src/catalog.rs | 58 ++++-- sdk_v2/rust/src/detail/core_interop.rs | 173 ++++++++++++------ sdk_v2/rust/src/detail/model_load_manager.rs | 18 +- sdk_v2/rust/src/error.rs | 3 + sdk_v2/rust/src/foundry_local_manager.rs | 95 +++++----- sdk_v2/rust/src/model.rs | 9 +- sdk_v2/rust/src/model_variant.rs | 11 +- sdk_v2/rust/src/openai/audio_client.rs | 5 +- .../tests/integration/audio_client_test.rs | 3 +- .../tests/integration/chat_client_test.rs | 3 +- sdk_v2/rust/tests/integration/model_test.rs | 17 +- .../tests/integration/web_service_test.rs | 2 +- 19 files changed, 280 insertions(+), 192 deletions(-) diff --git a/.github/workflows/build-rust-steps.yml b/.github/workflows/build-rust-steps.yml index b675b4b6..ef4a349f 100644 --- a/.github/workflows/build-rust-steps.yml +++ b/.github/workflows/build-rust-steps.yml @@ -14,7 +14,7 @@ on: run-integration-tests: required: false type: boolean - default: false + default: true permissions: contents: read @@ -28,7 +28,7 @@ jobs: working-directory: sdk_v2/rust env: - CARGO_FEATURES: ${{ inputs.useWinML && '--features winml' || '' }} + CARGO_FEATURES: ${{ inputs.useWinML && '--features winml' || false }} steps: - name: Checkout repository @@ -80,6 +80,7 @@ jobs: - name: Check formatting run: cargo fmt --all -- --check + # Run Clippy - Rust's official linter for catching common mistakes, enforcing idioms, and improving code quality - name: Run clippy run: cargo clippy --all-targets ${{ env.CARGO_FEATURES }} -- -D warnings @@ -91,8 +92,9 @@ jobs: - name: Run integration tests if: ${{ inputs.run-integration-tests }} - run: cargo test --test '*' ${{ env.CARGO_FEATURES }} -- --test-threads=1 --nocapture + run: cargo test --tests ${{ env.CARGO_FEATURES }} -- --include-ignored --test-threads=1 --nocapture + # --allow-dirty allows publishing with uncommitted changes, needed because the build process modifies generated files - name: Package crate run: cargo package ${{ env.CARGO_FEATURES }} --allow-dirty diff --git a/.github/workflows/foundry-local-sdk-build.yml b/.github/workflows/foundry-local-sdk-build.yml index e38fc251..604cb2d3 100644 --- a/.github/workflows/foundry-local-sdk-build.yml +++ b/.github/workflows/foundry-local-sdk-build.yml @@ -29,6 +29,12 @@ jobs: version: '0.9.0.${{ github.run_number }}' platform: 'windows' secrets: inherit + build-rust-windows: + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'windows' + run-integration-tests: true + secrets: inherit build-cs-windows-WinML: uses: ./.github/workflows/build-cs-steps.yml @@ -44,7 +50,14 @@ jobs: platform: 'windows' useWinML: true secrets: inherit - + build-rust-windows-WinML: + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'windows' + useWinML: true + run-integration-tests: true + secrets: inherit + build-cs-macos: uses: ./.github/workflows/build-cs-steps.yml with: @@ -57,26 +70,6 @@ jobs: version: '0.9.0.${{ github.run_number }}' platform: 'macos' secrets: inherit - - build-rust-windows: - uses: ./.github/workflows/build-rust-steps.yml - with: - platform: 'windows' - run-integration-tests: true - secrets: inherit - build-rust-windows-WinML: - uses: ./.github/workflows/build-rust-steps.yml - with: - platform: 'windows' - useWinML: true - run-integration-tests: true - secrets: inherit - build-rust-ubuntu: - uses: ./.github/workflows/build-rust-steps.yml - with: - platform: 'ubuntu' - run-integration-tests: true - secrets: inherit build-rust-macos: uses: ./.github/workflows/build-rust-steps.yml with: diff --git a/samples/rust/audio-transcription-example/src/main.rs b/samples/rust/audio-transcription-example/src/main.rs index 9dc64dc0..bd2141c1 100644 --- a/samples/rust/audio-transcription-example/src/main.rs +++ b/samples/rust/audio-transcription-example/src/main.rs @@ -31,7 +31,7 @@ async fn main() -> Result<(), Box> { println!("Downloading model..."); model .download(Some(|progress: &str| { - print!("\r {progress:.1}%"); + print!("\r {progress}%"); io::stdout().flush().ok(); })) .await?; diff --git a/samples/rust/foundry-local-webserver/src/main.rs b/samples/rust/foundry-local-webserver/src/main.rs index a3b5f326..e5ed3ae8 100644 --- a/samples/rust/foundry-local-webserver/src/main.rs +++ b/samples/rust/foundry-local-webserver/src/main.rs @@ -29,7 +29,7 @@ async fn main() -> Result<(), Box> { print!("Downloading model {model_alias}..."); model .download(Some(move |progress: &str| { - print!("\rDownloading model... {progress:.1}%"); + print!("\rDownloading model... {progress}%"); io::stdout().flush().ok(); })) .await?; diff --git a/samples/rust/native-chat-completions/src/main.rs b/samples/rust/native-chat-completions/src/main.rs index 61df92d7..68dec925 100644 --- a/samples/rust/native-chat-completions/src/main.rs +++ b/samples/rust/native-chat-completions/src/main.rs @@ -27,7 +27,7 @@ async fn main() -> Result<(), Box> { println!("Downloading model..."); model .download(Some(|progress: &str| { - print!("\r {progress:.1}%"); + print!("\r {progress}%"); io::stdout().flush().ok(); })) .await?; diff --git a/samples/rust/tool-calling-foundry-local/src/main.rs b/samples/rust/tool-calling-foundry-local/src/main.rs index 53ada9c5..5a70a2e4 100644 --- a/samples/rust/tool-calling-foundry-local/src/main.rs +++ b/samples/rust/tool-calling-foundry-local/src/main.rs @@ -59,7 +59,7 @@ async fn main() -> Result<(), Box> { println!("Downloading model..."); model .download(Some(|progress: &str| { - print!("\r {progress:.1}%"); + print!("\r {progress}%"); io::stdout().flush().ok(); })) .await?; diff --git a/sdk_v2/rust/build.rs b/sdk_v2/rust/build.rs index c12f6712..e2365bc3 100644 --- a/sdk_v2/rust/build.rs +++ b/sdk_v2/rust/build.rs @@ -241,19 +241,15 @@ fn download_and_extract(pkg: &NuGetPackage, rid: &str, out_dir: &Path) -> Result Ok(()) } -/// Check whether we already have at least one native library in `out_dir`. +/// Check whether the core native library is already present in `out_dir`. fn libs_already_present(out_dir: &Path) -> bool { - let ext = native_lib_extension(); - if let Ok(entries) = fs::read_dir(out_dir) { - for entry in entries.flatten() { - if let Some(name) = entry.file_name().to_str() { - if name.ends_with(&format!(".{ext}")) { - return true; - } - } - } - } - false + let core_lib = match env::consts::OS { + "windows" => "foundry_local_core.dll", + "linux" => "libfoundry_local_core.so", + "macos" => "libfoundry_local_core.dylib", + _ => return false, + }; + out_dir.join(core_lib).exists() } fn main() { @@ -298,7 +294,7 @@ fn main() { println!("cargo:rustc-link-search=native={}", out_dir.display()); println!("cargo:rustc-env=FOUNDRY_NATIVE_DIR={}", out_dir.display()); - // CoTaskMemFree (used to free native-allocated buffers) lives in ole32.lib on Windows. + // LocalFree (used to free native-allocated buffers) lives in kernel32.lib on Windows. #[cfg(windows)] - println!("cargo:rustc-link-lib=ole32"); + println!("cargo:rustc-link-lib=kernel32"); } diff --git a/sdk_v2/rust/src/catalog.rs b/sdk_v2/rust/src/catalog.rs index 20967fbe..e1338fde 100644 --- a/sdk_v2/rust/src/catalog.rs +++ b/sdk_v2/rust/src/catalog.rs @@ -19,8 +19,8 @@ pub struct Catalog { core: Arc, model_load_manager: Arc, name: String, - models_by_alias: Mutex>, - variants_by_id: Mutex>, + models_by_alias: Mutex>>, + variants_by_id: Mutex>>, last_refresh: Mutex>, } @@ -55,7 +55,9 @@ impl Catalog { /// Refresh the catalog from the native core if the cache has expired. pub async fn update_models(&self) -> Result<()> { { - let last = self.last_refresh.lock().unwrap(); + let last = self.last_refresh.lock().map_err(|_| FoundryLocalError::Internal { + reason: "last_refresh mutex poisoned".into(), + })?; if let Some(ts) = *last { if ts.elapsed() < CACHE_TTL { return Ok(()); @@ -67,21 +69,25 @@ impl Catalog { } /// Return all known models keyed by alias. - pub async fn get_models(&self) -> Result> { + pub async fn get_models(&self) -> Result>> { self.update_models().await?; - let map = self.models_by_alias.lock().unwrap(); + let map = self.models_by_alias.lock().map_err(|_| FoundryLocalError::Internal { + reason: "models_by_alias mutex poisoned".into(), + })?; Ok(map.values().cloned().collect()) } /// Look up a model by its alias. - pub async fn get_model(&self, alias: &str) -> Result { + pub async fn get_model(&self, alias: &str) -> Result> { if alias.trim().is_empty() { return Err(FoundryLocalError::Validation { reason: "Model alias must be a non-empty string".into(), }); } self.update_models().await?; - let map = self.models_by_alias.lock().unwrap(); + let map = self.models_by_alias.lock().map_err(|_| FoundryLocalError::Internal { + reason: "models_by_alias mutex poisoned".into(), + })?; map.get(alias).cloned().ok_or_else(|| { let available: Vec<&String> = map.keys().collect(); FoundryLocalError::ModelOperation { @@ -91,14 +97,16 @@ impl Catalog { } /// Look up a specific model variant by its unique id. - pub async fn get_model_variant(&self, id: &str) -> Result { + pub async fn get_model_variant(&self, id: &str) -> Result> { if id.trim().is_empty() { return Err(FoundryLocalError::Validation { reason: "Variant id must be a non-empty string".into(), }); } self.update_models().await?; - let map = self.variants_by_id.lock().unwrap(); + let map = self.variants_by_id.lock().map_err(|_| FoundryLocalError::Internal { + reason: "variants_by_id mutex poisoned".into(), + })?; map.get(id).cloned().ok_or_else(|| { let available: Vec<&String> = map.keys().collect(); FoundryLocalError::ModelOperation { @@ -111,7 +119,7 @@ impl Catalog { /// /// The native core returns a list of variant IDs. This method resolves /// them against the internal cache, matching the JS SDK behaviour. - pub async fn get_cached_models(&self) -> Result> { + pub async fn get_cached_models(&self) -> Result>> { self.update_models().await?; let raw = self .core @@ -121,7 +129,9 @@ impl Catalog { return Ok(Vec::new()); } let cached_ids: Vec = serde_json::from_str(&raw)?; - let id_map = self.variants_by_id.lock().unwrap(); + let id_map = self.variants_by_id.lock().map_err(|_| FoundryLocalError::Internal { + reason: "variants_by_id mutex poisoned".into(), + })?; Ok(cached_ids .iter() .filter_map(|id| id_map.get(id).cloned()) @@ -155,8 +165,8 @@ impl Catalog { serde_json::from_str(raw)? }; - let mut alias_map: HashMap = HashMap::new(); - let mut id_map: HashMap = HashMap::new(); + let mut alias_map_build: HashMap = HashMap::new(); + let mut id_map: HashMap> = HashMap::new(); for info in infos { let variant = ModelVariant::new( @@ -164,23 +174,33 @@ impl Catalog { Arc::clone(&self.core), Arc::clone(&self.model_load_manager), ); - id_map.insert(info.id.clone(), variant.clone()); + id_map.insert(info.id.clone(), Arc::new(variant.clone())); - alias_map + alias_map_build .entry(info.alias.clone()) .or_insert_with(|| { Model::new( info.alias.clone(), Arc::clone(&self.core), - Arc::clone(&self.model_load_manager), ) }) .add_variant(variant); } - *self.models_by_alias.lock().unwrap() = alias_map; - *self.variants_by_id.lock().unwrap() = id_map; - *self.last_refresh.lock().unwrap() = Some(Instant::now()); + let alias_map: HashMap> = alias_map_build + .into_iter() + .map(|(k, v)| (k, Arc::new(v))) + .collect(); + + *self.models_by_alias.lock().map_err(|_| FoundryLocalError::Internal { + reason: "models_by_alias mutex poisoned".into(), + })? = alias_map; + *self.variants_by_id.lock().map_err(|_| FoundryLocalError::Internal { + reason: "variants_by_id mutex poisoned".into(), + })? = id_map; + *self.last_refresh.lock().map_err(|_| FoundryLocalError::Internal { + reason: "last_refresh mutex poisoned".into(), + })? = Some(Instant::now()); Ok(()) } diff --git a/sdk_v2/rust/src/detail/core_interop.rs b/sdk_v2/rust/src/detail/core_interop.rs index 08430823..965c1d78 100644 --- a/sdk_v2/rust/src/detail/core_interop.rs +++ b/sdk_v2/rust/src/detail/core_interop.rs @@ -32,9 +32,9 @@ struct RequestBuffer { #[repr(C)] struct ResponseBuffer { data: *mut u8, - data_length: i32, + data_length: u32, error: *mut u8, - error_length: i32, + error_length: u32, } impl ResponseBuffer { @@ -65,11 +65,11 @@ type ExecuteCommandWithCallbackFn = unsafe extern "C" fn( // ── Library name helpers ───────────────────────────────────────────────────── #[cfg(target_os = "windows")] -const CORE_LIB_NAME: &str = "Microsoft.AI.Foundry.Local.Core.dll"; +const LIB_EXTENSION: &str = "dll"; #[cfg(target_os = "macos")] -const CORE_LIB_NAME: &str = "Microsoft.AI.Foundry.Local.Core.dylib"; +const LIB_EXTENSION: &str = "dylib"; #[cfg(target_os = "linux")] -const CORE_LIB_NAME: &str = "Microsoft.AI.Foundry.Local.Core.so"; +const LIB_EXTENSION: &str = "so"; // ── Native buffer deallocation ──────────────────────────────────────────────── @@ -77,7 +77,14 @@ const CORE_LIB_NAME: &str = "Microsoft.AI.Foundry.Local.Core.so"; /// /// The .NET native core allocates response buffers with /// `Marshal.AllocHGlobal` which maps to `malloc` on Unix and -/// `CoTaskMemAlloc` on Windows. +/// `LocalAlloc` (process heap) on Windows. The corresponding +/// free functions are `free` and `LocalFree` respectively. +/// +/// # Safety +/// +/// `ptr` must be null or a valid pointer previously allocated by the native +/// core library via the corresponding platform allocator. Calling this with +/// any other pointer is undefined behaviour. unsafe fn free_native_buffer(ptr: *mut u8) { if ptr.is_null() { return; @@ -87,14 +94,17 @@ unsafe fn free_native_buffer(ptr: *mut u8) { extern "C" { fn free(ptr: *mut std::ffi::c_void); } + // SAFETY: `ptr` was allocated by the native core via `malloc` on Unix. free(ptr as *mut std::ffi::c_void); } #[cfg(windows)] { extern "system" { - fn CoTaskMemFree(pv: *mut std::ffi::c_void); + fn LocalFree(hMem: *mut std::ffi::c_void) -> *mut std::ffi::c_void; } - CoTaskMemFree(ptr as *mut std::ffi::c_void); + // SAFETY: `ptr` was allocated by the native core via `LocalAlloc` + // (Marshal.AllocHGlobal) on Windows. + LocalFree(ptr as *mut std::ffi::c_void); } } @@ -102,6 +112,18 @@ unsafe fn free_native_buffer(ptr: *mut u8) { /// C-ABI trampoline that forwards chunks from the native library into a Rust /// closure stored behind `user_data`. +/// +/// Wrapped in [`std::panic::catch_unwind`] so that a panic inside the Rust +/// closure cannot unwind across the FFI boundary (which is undefined behaviour). +/// +/// # Safety +/// +/// * `data` must be a valid pointer to `length` bytes of UTF-8 (or at least +/// valid memory) allocated by the native core, valid for the duration of +/// this call. +/// * `user_data` must point to a live `Box>` that was +/// created by [`CoreInterop::execute_command_streaming`] and has not been +/// dropped. unsafe extern "C" fn streaming_trampoline( data: *const u8, length: i32, @@ -110,11 +132,19 @@ unsafe extern "C" fn streaming_trampoline( if data.is_null() || length <= 0 { return; } - let closure = &mut *(user_data as *mut Box); - let slice = std::slice::from_raw_parts(data, length as usize); - if let Ok(chunk) = std::str::from_utf8(slice) { - closure(chunk); - } + // catch_unwind prevents UB if the closure panics across the FFI boundary. + let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + // SAFETY: `user_data` is a pointer to `Box` kept alive + // by the caller of `execute_command_with_callback` for the duration of + // the native call. + let closure = &mut *(user_data as *mut Box); + // SAFETY: `data` is valid for `length` bytes as guaranteed by the native + // core's callback contract. + let slice = std::slice::from_raw_parts(data, length as usize); + if let Ok(chunk) = std::str::from_utf8(slice) { + closure(chunk); + } + })); } // ── CoreInterop ────────────────────────────────────────────────────────────── @@ -148,14 +178,28 @@ impl CoreInterop { /// /// Discovery order: /// 1. `FoundryLocalCorePath` key in `config.params`. - /// 2. `FOUNDRY_NATIVE_DIR` environment variable. - /// 3. Sibling directory of the current executable. - pub fn new(config: &Configuration) -> Result { + /// 2. Sibling directory of the current executable. + pub fn new(config: &mut Configuration) -> Result { let lib_path = Self::resolve_library_path(config)?; + // Auto-detect WinAppSDK Bootstrap DLL next to the core library. + // If present, tell the native core to run the bootstrapper during + // initialisation — this is required for WinML execution providers. + #[cfg(target_os = "windows")] + if !config.params.contains_key("Bootstrap") { + if let Some(dir) = lib_path.parent() { + if dir.join("Microsoft.WindowsAppRuntime.Bootstrap.dll").exists() { + config.params.insert("Bootstrap".into(), "true".into()); + } + } + } + #[cfg(target_os = "windows")] let _dependency_libs = Self::load_windows_dependencies(&lib_path)?; + // SAFETY: `lib_path` has been verified to exist on disk. Loading a + // shared library is inherently unsafe (it executes foreign code), but + // the path is resolved from trusted configuration sources. let library = unsafe { Library::new(&lib_path).map_err(|e| FoundryLocalError::LibraryLoad { reason: format!( @@ -165,6 +209,8 @@ impl CoreInterop { })? }; + // SAFETY: We trust the loaded library to export these symbols with the + // correct C-ABI signatures as defined by the Foundry Local native core. let execute_command: ExecuteCommandFn = unsafe { let sym: Symbol = library @@ -175,6 +221,7 @@ impl CoreInterop { *sym }; + // SAFETY: Same as above — symbol must match `ExecuteCommandWithCallbackFn`. let execute_command_with_callback: ExecuteCommandWithCallbackFn = unsafe { let sym: Symbol = library .get(b"execute_command_with_callback\0") @@ -221,11 +268,14 @@ impl CoreInterop { let mut response = ResponseBuffer::new(); + // SAFETY: `request` fields point into `cmd` and `data_cstr` which are + // alive for the duration of this call. The native function writes into + // `response` using its documented C ABI. unsafe { (self.execute_command)(&request, &mut response); } - Self::process_response(&response) + Self::process_response(response) } /// Execute a command that streams results back via `callback`. @@ -268,6 +318,11 @@ impl CoreInterop { let mut boxed: Box = Box::new(|chunk: &str| callback(chunk)); let user_data = &mut boxed as *mut Box as *mut std::ffi::c_void; + // SAFETY: `request` fields point into `cmd` and `data_cstr` which are + // alive for the duration of this call. `user_data` points to `boxed` + // which lives on this stack frame and outlives the native call. + // `streaming_trampoline` will only cast `user_data` back to the same + // `Box` type. unsafe { (self.execute_command_with_callback)( &request, @@ -277,7 +332,7 @@ impl CoreInterop { ); } - Self::process_response(&response) + Self::process_response(response) } /// Async version of [`Self::execute_command`]. @@ -348,32 +403,36 @@ impl CoreInterop { Ok((rx, handle)) } + /// Read a native response buffer field as a Rust `String`. + /// + /// # Safety + /// + /// `ptr` must be null **or** a valid pointer to at least `len` bytes of + /// memory allocated by the native core. The memory must remain valid for + /// the duration of this call. + unsafe fn read_native_buffer(ptr: *mut u8, len: u32) -> Option { + if ptr.is_null() || len == 0 { + return None; + } + // SAFETY: caller guarantees `ptr` is valid for `len` bytes. + let slice = std::slice::from_raw_parts(ptr, len as usize); + Some(String::from_utf8_lossy(slice).into_owned()) + } + /// Read the response buffer, free the native memory, and return the data /// string or raise an error. - fn process_response(response: &ResponseBuffer) -> Result { - // Extract strings from the native pointers before freeing them. - let error_str = if !response.error.is_null() && response.error_length > 0 { - Some(unsafe { - let slice = - std::slice::from_raw_parts(response.error, response.error_length as usize); - String::from_utf8_lossy(slice).into_owned() - }) - } else { - None - }; - - let data_str = if !response.data.is_null() && response.data_length > 0 { - Some(unsafe { - let slice = - std::slice::from_raw_parts(response.data, response.data_length as usize); - String::from_utf8_lossy(slice).into_owned() - }) - } else { - None - }; - - // Free the heap-allocated response buffers (matches JS koffi.free() - // and C# Marshal.FreeHGlobal() behaviour). + /// + /// Takes the buffer by value so it can only be consumed once. + fn process_response(response: ResponseBuffer) -> Result { + // SAFETY: response fields are either null or valid native-allocated + // pointers filled by the preceding FFI call. + let error_str = unsafe { Self::read_native_buffer(response.error, response.error_length) }; + let data_str = unsafe { Self::read_native_buffer(response.data, response.data_length) }; + + // SAFETY: Free the heap-allocated response buffers (matches JS + // koffi.free() and C# Marshal.FreeHGlobal() behaviour). Each pointer + // is either null (handled inside free_native_buffer) or was allocated + // by the native core's platform allocator. unsafe { free_native_buffer(response.data); free_native_buffer(response.error); @@ -389,9 +448,11 @@ impl CoreInterop { /// Resolve the path to the native core shared library. fn resolve_library_path(config: &Configuration) -> Result { + let lib_name = format!("Microsoft.AI.Foundry.Local.Core.{LIB_EXTENSION}"); + // 1. Explicit path from configuration. if let Some(dir) = config.params.get("FoundryLocalCorePath") { - let p = Path::new(dir).join(CORE_LIB_NAME); + let p = Path::new(dir).join(&lib_name); if p.exists() { return Ok(p); } @@ -402,26 +463,19 @@ impl CoreInterop { } } - // 2. Compile-time environment variable set by build.rs. + // 2. Compile-time path set by build.rs (points at the OUT_DIR where + // native NuGet packages are extracted during `cargo build`). if let Some(dir) = option_env!("FOUNDRY_NATIVE_DIR") { - let p = Path::new(dir).join(CORE_LIB_NAME); - if p.exists() { - return Ok(p); - } - } - - // 3. Runtime environment variable (user override). - if let Ok(dir) = std::env::var("FOUNDRY_NATIVE_DIR") { - let p = Path::new(&dir).join(CORE_LIB_NAME); + let p = Path::new(dir).join(&lib_name); if p.exists() { return Ok(p); } } - // 4. Next to the running executable. + // 3. Next to the running executable (default search path). if let Ok(exe) = std::env::current_exe() { if let Some(dir) = exe.parent() { - let p = dir.join(CORE_LIB_NAME); + let p = dir.join(&lib_name); if p.exists() { return Ok(p); } @@ -430,9 +484,8 @@ impl CoreInterop { Err(FoundryLocalError::LibraryLoad { reason: format!( - "Could not locate native library '{CORE_LIB_NAME}'. \ - Set the FoundryLocalCorePath config option or the FOUNDRY_NATIVE_DIR \ - environment variable." + "Could not locate native library '{lib_name}'. \ + Set the FoundryLocalCorePath config option." ), }) } @@ -448,6 +501,8 @@ impl CoreInterop { // Load WinML bootstrap if present. let bootstrap = dir.join("Microsoft.WindowsAppRuntime.Bootstrap.dll"); if bootstrap.exists() { + // SAFETY: Pre-loading a known dependency DLL from the same trusted + // directory as the core library. if let Ok(lib) = unsafe { Library::new(&bootstrap) } { libs.push(lib); } @@ -456,6 +511,8 @@ impl CoreInterop { for dep in &["onnxruntime.dll", "onnxruntime-genai.dll"] { let dep_path = dir.join(dep); if dep_path.exists() { + // SAFETY: Pre-loading a known dependency DLL from the same + // trusted directory as the core library. let lib = unsafe { Library::new(&dep_path).map_err(|e| FoundryLocalError::LibraryLoad { reason: format!("Failed to load dependency {dep}: {e}"), diff --git a/sdk_v2/rust/src/detail/model_load_manager.rs b/sdk_v2/rust/src/detail/model_load_manager.rs index 639ec691..f6f05fd1 100644 --- a/sdk_v2/rust/src/detail/model_load_manager.rs +++ b/sdk_v2/rust/src/detail/model_load_manager.rs @@ -16,6 +16,7 @@ use crate::error::Result; pub struct ModelLoadManager { core: Arc, external_service_url: Option, + client: reqwest::Client, } impl ModelLoadManager { @@ -23,24 +24,27 @@ impl ModelLoadManager { Self { core, external_service_url, + client: reqwest::Client::new(), } } /// Load a model by its identifier. - pub async fn load(&self, model_id: &str) -> Result { + pub async fn load(&self, model_id: &str) -> Result<()> { if let Some(base_url) = &self.external_service_url { - return Self::http_get(&format!("{base_url}/models/load/{model_id}")).await; + self.http_get(&format!("{base_url}/models/load/{model_id}")).await?; + return Ok(()); } let params = json!({ "Params": { "Model": model_id } }); self.core .execute_command_async("load_model".into(), Some(params)) - .await + .await?; + Ok(()) } /// Unload a previously loaded model. pub async fn unload(&self, model_id: &str) -> Result { if let Some(base_url) = &self.external_service_url { - return Self::http_get(&format!("{base_url}/models/unload/{model_id}")).await; + return self.http_get(&format!("{base_url}/models/unload/{model_id}")).await; } let params = json!({ "Params": { "Model": model_id } }); self.core @@ -51,7 +55,7 @@ impl ModelLoadManager { /// Return the list of currently loaded model identifiers. pub async fn list_loaded(&self) -> Result> { let raw = if let Some(base_url) = &self.external_service_url { - Self::http_get(&format!("{base_url}/models/loaded")).await? + self.http_get(&format!("{base_url}/models/loaded")).await? } else { self.core .execute_command_async("list_loaded_models".into(), None) @@ -66,8 +70,8 @@ impl ModelLoadManager { Ok(ids) } - async fn http_get(url: &str) -> Result { - let body = reqwest::get(url).await?.text().await?; + async fn http_get(&self, url: &str) -> Result { + let body = self.client.get(url).send().await?.text().await?; Ok(body) } } diff --git a/sdk_v2/rust/src/error.rs b/sdk_v2/rust/src/error.rs index 226139b1..c99dbfbf 100644 --- a/sdk_v2/rust/src/error.rs +++ b/sdk_v2/rust/src/error.rs @@ -27,6 +27,9 @@ pub enum FoundryLocalError { /// An I/O error occurred. #[error("I/O error: {0}")] Io(#[from] std::io::Error), + /// An internal SDK error (e.g. poisoned lock). + #[error("internal error: {reason}")] + Internal { reason: String }, } /// Convenience alias used throughout the SDK. diff --git a/sdk_v2/rust/src/foundry_local_manager.rs b/sdk_v2/rust/src/foundry_local_manager.rs index 0ec10cac..a87733b5 100644 --- a/sdk_v2/rust/src/foundry_local_manager.rs +++ b/sdk_v2/rust/src/foundry_local_manager.rs @@ -4,7 +4,7 @@ //! library, provides access to the model [`Catalog`], and can start / stop //! the local web service. -use std::sync::{Arc, OnceLock}; +use std::sync::{Arc, Mutex, OnceLock, Once}; use serde_json::json; @@ -12,21 +12,20 @@ use crate::catalog::Catalog; use crate::configuration::{Configuration, FoundryLocalConfig}; use crate::detail::core_interop::CoreInterop; use crate::detail::ModelLoadManager; -use crate::error::Result; +use crate::error::{FoundryLocalError, Result}; /// Global singleton holder. -static INSTANCE: OnceLock = OnceLock::new(); +static INSTANCE: OnceLock> = OnceLock::new(); +static INIT_ONCE: Once = Once::new(); /// Primary entry point for interacting with Foundry Local. /// /// Created once via [`FoundryLocalManager::create`]; subsequent calls return /// the existing instance. pub struct FoundryLocalManager { - _config: Configuration, core: Arc, - _model_load_manager: Arc, catalog: Catalog, - urls: std::sync::Mutex>, + urls: Mutex>, } impl FoundryLocalManager { @@ -37,37 +36,42 @@ impl FoundryLocalManager { /// calls return a reference to the same instance (the provided config is /// ignored after the first call). pub fn create(config: FoundryLocalConfig) -> Result<&'static Self> { - // If already initialised, return the existing instance. - if let Some(mgr) = INSTANCE.get() { - return Ok(mgr); - } - - let internal_config = Configuration::new(config)?; - let core = Arc::new(CoreInterop::new(&internal_config)?); - - // Send the configuration map to the native core. - let init_params = json!({ "Params": internal_config.params }); - core.execute_command("initialize", Some(&init_params))?; - - let service_endpoint = internal_config.params.get("WebServiceExternalUrl").cloned(); - - let model_load_manager = - Arc::new(ModelLoadManager::new(Arc::clone(&core), service_endpoint)); - - let catalog = Catalog::new(Arc::clone(&core), Arc::clone(&model_load_manager))?; - - let manager = Self { - _config: internal_config, - core, - _model_load_manager: model_load_manager, - catalog, - urls: std::sync::Mutex::new(Vec::new()), - }; - - // Attempt to store; if another thread raced us, return whichever won. - match INSTANCE.set(manager) { - Ok(()) => Ok(INSTANCE.get().unwrap()), - Err(_) => Ok(INSTANCE.get().unwrap()), + // Use `Once` + `OnceLock` to ensure initialisation runs at most once, + // eliminating the TOCTOU race between `get()` and `set()`. + INIT_ONCE.call_once(|| { + let result = (|| -> Result { + let mut internal_config = Configuration::new(config)?; + let core = Arc::new(CoreInterop::new(&mut internal_config)?); + + // Send the configuration map to the native core. + let init_params = json!({ "Params": internal_config.params }); + core.execute_command("initialize", Some(&init_params))?; + + let service_endpoint = internal_config.params.get("WebServiceExternalUrl").cloned(); + + let model_load_manager = + Arc::new(ModelLoadManager::new(Arc::clone(&core), service_endpoint)); + + let catalog = Catalog::new(Arc::clone(&core), Arc::clone(&model_load_manager))?; + + Ok(FoundryLocalManager { + core, + catalog, + urls: Mutex::new(Vec::new()), + }) + })(); + + let _ = INSTANCE.set(result.map_err(|e| e.to_string())); + }); + + match INSTANCE.get() { + Some(Ok(manager)) => Ok(manager), + Some(Err(msg)) => Err(FoundryLocalError::CommandExecution { + reason: format!("SDK initialization failed: {msg}"), + }), + None => Err(FoundryLocalError::CommandExecution { + reason: "SDK initialization not completed".into(), + }), } } @@ -79,8 +83,11 @@ impl FoundryLocalManager { /// URLs that the local web service is listening on. /// /// Empty until [`Self::start_web_service`] has been called. - pub fn urls(&self) -> Vec { - self.urls.lock().unwrap().clone() + pub fn urls(&self) -> Result> { + let lock = self.urls.lock().map_err(|_| FoundryLocalError::Internal { + reason: "Failed to acquire urls lock".into(), + })?; + Ok(lock.clone()) } /// Start the local web service and return the listening URLs. @@ -92,9 +99,11 @@ impl FoundryLocalManager { let parsed: Vec = if raw.trim().is_empty() { Vec::new() } else { - serde_json::from_str(&raw).unwrap_or_else(|_| vec![raw]) + serde_json::from_str(&raw)? }; - *self.urls.lock().unwrap() = parsed.clone(); + *self.urls.lock().map_err(|_| FoundryLocalError::Internal { + reason: "Failed to acquire urls lock".into(), + })? = parsed.clone(); Ok(parsed) } @@ -103,7 +112,9 @@ impl FoundryLocalManager { self.core .execute_command_async("stop_service".into(), None) .await?; - self.urls.lock().unwrap().clear(); + self.urls.lock().map_err(|_| FoundryLocalError::Internal { + reason: "Failed to acquire urls lock".into(), + })?.clear(); Ok(()) } } diff --git a/sdk_v2/rust/src/model.rs b/sdk_v2/rust/src/model.rs index 6856ac0d..4313c5c9 100644 --- a/sdk_v2/rust/src/model.rs +++ b/sdk_v2/rust/src/model.rs @@ -1,10 +1,10 @@ //! High-level model abstraction that wraps one or more [`ModelVariant`]s //! sharing the same alias. +use std::path::PathBuf; use std::sync::Arc; use crate::detail::core_interop::CoreInterop; -use crate::detail::ModelLoadManager; use crate::error::{FoundryLocalError, Result}; use crate::model_variant::ModelVariant; use crate::openai::AudioClient; @@ -18,7 +18,6 @@ use crate::openai::ChatClient; pub struct Model { alias: String, core: Arc, - _model_load_manager: Arc, variants: Vec, selected_index: usize, } @@ -27,12 +26,10 @@ impl Model { pub(crate) fn new( alias: String, core: Arc, - model_load_manager: Arc, ) -> Self { Self { alias, core, - _model_load_manager: model_load_manager, variants: Vec::new(), selected_index: 0, } @@ -106,12 +103,12 @@ impl Model { } /// Return the local file-system path of the selected variant. - pub async fn path(&self) -> Result { + pub async fn path(&self) -> Result { self.selected_variant().path().await } /// Load the selected variant into memory. - pub async fn load(&self) -> Result { + pub async fn load(&self) -> Result<()> { self.selected_variant().load().await } diff --git a/sdk_v2/rust/src/model_variant.rs b/sdk_v2/rust/src/model_variant.rs index b3f73b28..55bc2e5f 100644 --- a/sdk_v2/rust/src/model_variant.rs +++ b/sdk_v2/rust/src/model_variant.rs @@ -1,5 +1,6 @@ //! A single model variant backed by [`ModelInfo`]. +use std::path::PathBuf; use std::sync::Arc; use serde_json::json; @@ -91,15 +92,17 @@ impl ModelVariant { } /// Return the local file-system path where this variant is stored. - pub async fn path(&self) -> Result { + pub async fn path(&self) -> Result { let params = json!({ "Params": { "Model": self.info.id } }); - self.core + let path_str = self + .core .execute_command_async("get_model_path".into(), Some(params)) - .await + .await?; + Ok(PathBuf::from(path_str)) } /// Load the variant into memory. - pub async fn load(&self) -> Result { + pub async fn load(&self) -> Result<()> { self.model_load_manager.load(&self.info.id).await } diff --git a/sdk_v2/rust/src/openai/audio_client.rs b/sdk_v2/rust/src/openai/audio_client.rs index 7fc25817..6d8ab8e2 100644 --- a/sdk_v2/rust/src/openai/audio_client.rs +++ b/sdk_v2/rust/src/openai/audio_client.rs @@ -179,10 +179,7 @@ impl AudioClient { })?; Self::validate_path(path_str)?; - let mut request = self.settings.serialize(&self.model_id, path_str); - if let Some(map) = request.as_object_mut() { - map.insert("stream".into(), json!(true)); - } + let request = self.settings.serialize(&self.model_id, path_str); let params = json!({ "Params": { diff --git a/sdk_v2/rust/tests/integration/audio_client_test.rs b/sdk_v2/rust/tests/integration/audio_client_test.rs index 6cc1d0cd..da28b7cd 100644 --- a/sdk_v2/rust/tests/integration/audio_client_test.rs +++ b/sdk_v2/rust/tests/integration/audio_client_test.rs @@ -1,8 +1,9 @@ +use std::sync::Arc; use super::common; use foundry_local_sdk::openai::AudioClient; use tokio_stream::StreamExt; -async fn setup_audio_client() -> (AudioClient, foundry_local_sdk::Model) { +async fn setup_audio_client() -> (AudioClient, Arc) { let manager = common::get_test_manager(); let catalog = manager.catalog(); let model = catalog diff --git a/sdk_v2/rust/tests/integration/chat_client_test.rs b/sdk_v2/rust/tests/integration/chat_client_test.rs index 90f53709..69f70129 100644 --- a/sdk_v2/rust/tests/integration/chat_client_test.rs +++ b/sdk_v2/rust/tests/integration/chat_client_test.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use super::common; use foundry_local_sdk::openai::ChatClient; use foundry_local_sdk::{ @@ -8,7 +9,7 @@ use foundry_local_sdk::{ use serde_json::json; use tokio_stream::StreamExt; -async fn setup_chat_client() -> (ChatClient, foundry_local_sdk::Model) { +async fn setup_chat_client() -> (ChatClient, Arc) { let manager = common::get_test_manager(); let catalog = manager.catalog(); let model = catalog diff --git a/sdk_v2/rust/tests/integration/model_test.rs b/sdk_v2/rust/tests/integration/model_test.rs index 8730d5bd..cd11f881 100644 --- a/sdk_v2/rust/tests/integration/model_test.rs +++ b/sdk_v2/rust/tests/integration/model_test.rs @@ -1,3 +1,4 @@ +use std::sync::Arc; use super::common; // ── Cached model verification ──────────────────────────────────────────────── @@ -145,10 +146,10 @@ async fn should_return_non_empty_path_for_cached_model() { .expect("get_model failed"); let path = model.path().await.expect("path() should succeed"); - println!("Model path: {path}"); + println!("Model path: {}", path.display()); assert!( - !path.is_empty(), + !path.as_os_str().is_empty(), "Cached model should have a non-empty path" ); } @@ -156,11 +157,12 @@ async fn should_return_non_empty_path_for_cached_model() { #[tokio::test] async fn should_select_variant_by_id() { let manager = common::get_test_manager(); - let mut model = manager + let mut model = (*manager .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed")) + .clone(); let first_variant_id = model.variants()[0].id().to_string(); model @@ -176,11 +178,12 @@ async fn should_select_variant_by_id() { #[tokio::test] async fn should_fail_to_select_unknown_variant() { let manager = common::get_test_manager(); - let mut model = manager + let mut model = (*manager .catalog() .get_model(common::TEST_MODEL_ALIAS) .await - .expect("get_model failed"); + .expect("get_model failed")) + .clone(); let result = model.select_variant("nonexistent-variant-id"); assert!( @@ -197,7 +200,7 @@ async fn should_fail_to_select_unknown_variant() { // ── Load manager (core interop) ────────────────────────────────────────────── -async fn get_test_model() -> foundry_local_sdk::Model { +async fn get_test_model() -> Arc { let manager = common::get_test_manager(); let catalog = manager.catalog(); catalog diff --git a/sdk_v2/rust/tests/integration/web_service_test.rs b/sdk_v2/rust/tests/integration/web_service_test.rs index cd9ccfce..41f04e49 100644 --- a/sdk_v2/rust/tests/integration/web_service_test.rs +++ b/sdk_v2/rust/tests/integration/web_service_test.rs @@ -150,7 +150,7 @@ async fn should_expose_urls_after_start() { println!("Web service URLs: {urls:?}"); assert!(!urls.is_empty(), "start_web_service should return URLs"); - let cached_urls = manager.urls(); + let cached_urls = manager.urls().expect("urls() should succeed"); assert_eq!( urls, cached_urls, "urls() should match what start_web_service returned" From b3e485764ab1ee952c91f480da217985384fa6f3 Mon Sep 17 00:00:00 2001 From: samkemp Date: Thu, 12 Mar 2026 16:42:17 +0000 Subject: [PATCH 22/25] fix: propagate FFI errors from streaming JoinHandle instead of swallowing them When the mpsc channel closes in poll_next, the Stream implementations now check the JoinHandle for errors from the native core response buffer. Previously, these errors were silently dropped by tokio if the consumer only iterated the stream without calling close(). The fix polls the JoinHandle when Ready(None) is received from the channel: - Ok(Ok(_)): stream ends normally - Ok(Err(e)): surfaces the FFI error as the final stream item - Err(e): surfaces the tokio join error as the final stream item Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/src/openai/audio_client.rs | 27 ++++++++++++++++++++++++- sdk_v2/rust/src/openai/chat_client.rs | 28 +++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/sdk_v2/rust/src/openai/audio_client.rs b/sdk_v2/rust/src/openai/audio_client.rs index 6d8ab8e2..74f973b1 100644 --- a/sdk_v2/rust/src/openai/audio_client.rs +++ b/sdk_v2/rust/src/openai/audio_client.rs @@ -5,6 +5,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use futures_core::Future; use serde_json::{json, Value}; use crate::detail::core_interop::CoreInterop; @@ -85,7 +86,31 @@ impl futures_core::Stream for AudioTranscriptionStream { Poll::Ready(Some(parsed)) } } - Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(None) => { + // Channel closed — check the JoinHandle for FFI errors that + // would otherwise be swallowed by tokio. + if let Some(handle) = self.handle.as_mut() { + match Pin::new(handle).poll(cx) { + Poll::Ready(Ok(Ok(_))) => { + self.handle.take(); + Poll::Ready(None) + } + Poll::Ready(Ok(Err(e))) => { + self.handle.take(); + Poll::Ready(Some(Err(e))) + } + Poll::Ready(Err(e)) => { + self.handle.take(); + Poll::Ready(Some(Err(FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + }))) + } + Poll::Pending => Poll::Pending, + } + } else { + Poll::Ready(None) + } + } Poll::Pending => Poll::Pending, } } diff --git a/sdk_v2/rust/src/openai/chat_client.rs b/sdk_v2/rust/src/openai/chat_client.rs index 26fae831..f7ac0d48 100644 --- a/sdk_v2/rust/src/openai/chat_client.rs +++ b/sdk_v2/rust/src/openai/chat_client.rs @@ -5,6 +5,8 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use futures_core::Future; + use async_openai::types::chat::{ ChatCompletionRequestMessage, ChatCompletionTools, CreateChatCompletionResponse, CreateChatCompletionStreamResponse, @@ -143,7 +145,31 @@ impl futures_core::Stream for ChatCompletionStream { Poll::Ready(Some(parsed)) } } - Poll::Ready(None) => Poll::Ready(None), + Poll::Ready(None) => { + // Channel closed — check the JoinHandle for FFI errors that + // would otherwise be swallowed by tokio. + if let Some(handle) = self.handle.as_mut() { + match Pin::new(handle).poll(cx) { + Poll::Ready(Ok(Ok(_))) => { + self.handle.take(); + Poll::Ready(None) + } + Poll::Ready(Ok(Err(e))) => { + self.handle.take(); + Poll::Ready(Some(Err(e))) + } + Poll::Ready(Err(e)) => { + self.handle.take(); + Poll::Ready(Some(Err(FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + }))) + } + Poll::Pending => Poll::Pending, + } + } else { + Poll::Ready(None) + } + } Poll::Pending => Poll::Pending, } } From 92dd72364b9952c2df78129c7cd9af38508fa8a8 Mon Sep 17 00:00:00 2001 From: samkemp Date: Thu, 12 Mar 2026 16:58:05 +0000 Subject: [PATCH 23/25] refactor: send FFI errors through channel instead of JoinHandle Move error propagation into execute_command_streaming_channel itself: the channel now carries Result items, and the blocking task sends any error from the native core response buffer as the final channel item. This eliminates the JoinHandle> from the return type, making it impossible for consumers to silently swallow errors. Simplifies ChatCompletionStream and AudioTranscriptionStream: - Remove handle field and close() method - poll_next now just forwards Ok/Err items from the channel - Errors surface naturally as stream items Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- sdk_v2/rust/README.md | 5 +- sdk_v2/rust/examples/chat_completion.rs | 3 +- sdk_v2/rust/examples/interactive_chat.rs | 1 - sdk_v2/rust/examples/tool_calling.rs | 7 +-- sdk_v2/rust/src/detail/core_interop.rs | 39 +++++++----- sdk_v2/rust/src/openai/audio_client.rs | 56 ++--------------- sdk_v2/rust/src/openai/chat_client.rs | 60 ++----------------- .../tests/integration/audio_client_test.rs | 2 - .../tests/integration/chat_client_test.rs | 8 --- 9 files changed, 41 insertions(+), 140 deletions(-) diff --git a/sdk_v2/rust/README.md b/sdk_v2/rust/README.md index e751681b..cf4ddddf 100644 --- a/sdk_v2/rust/README.md +++ b/sdk_v2/rust/README.md @@ -205,8 +205,8 @@ while let Some(chunk) = stream.next().await { } } -// Always close the stream to finalize the native session -stream.close().await?; +// Errors from the native core are delivered as stream items — +// no separate close() call needed. ``` ### Tool Calling @@ -330,7 +330,6 @@ let mut stream = audio_client.transcribe_streaming("recording.wav").await?; while let Some(chunk) = stream.next().await { print!("{}", chunk?.text); } -stream.close().await?; ``` ### Embedded Web Service diff --git a/sdk_v2/rust/examples/chat_completion.rs b/sdk_v2/rust/examples/chat_completion.rs index e3ae1884..fd9d2b04 100644 --- a/sdk_v2/rust/examples/chat_completion.rs +++ b/sdk_v2/rust/examples/chat_completion.rs @@ -84,10 +84,9 @@ async fn main() -> Result<()> { } } } - stream.close().await?; println!(); - // ── 6. Unload the model ────────────────────────────────────────────── + // ── 6. Unload the model────────────────────────────────────────────── println!("\nUnloading model…"); model.unload().await?; println!("Done."); diff --git a/sdk_v2/rust/examples/interactive_chat.rs b/sdk_v2/rust/examples/interactive_chat.rs index 951b9997..bb699cd4 100644 --- a/sdk_v2/rust/examples/interactive_chat.rs +++ b/sdk_v2/rust/examples/interactive_chat.rs @@ -92,7 +92,6 @@ async fn main() -> Result<(), Box> { } } } - stream.close().await?; println!("\n"); // Add assistant reply to history for multi-turn conversation diff --git a/sdk_v2/rust/examples/tool_calling.rs b/sdk_v2/rust/examples/tool_calling.rs index 8bca619d..e807bf49 100644 --- a/sdk_v2/rust/examples/tool_calling.rs +++ b/sdk_v2/rust/examples/tool_calling.rs @@ -133,9 +133,7 @@ async fn main() -> Result<()> { } } } - stream.close().await?; - - // ── 5. Execute the tool(s) ─────────────────────────────────────────── + // ── 5. Execute the tool(s)─────────────────────────────────────────── for tc in &state.tool_calls { let func = &tc["function"]; let name = func["name"].as_str().unwrap_or_default(); @@ -180,10 +178,9 @@ async fn main() -> Result<()> { } } } - stream.close().await?; println!(); - // ── 7. Clean up ────────────────────────────────────────────────────── + // ── 7. Clean up────────────────────────────────────────────────────── println!("\nUnloading model…"); model.unload().await?; println!("Done."); diff --git a/sdk_v2/rust/src/detail/core_interop.rs b/sdk_v2/rust/src/detail/core_interop.rs index 965c1d78..1d53bab4 100644 --- a/sdk_v2/rust/src/detail/core_interop.rs +++ b/sdk_v2/rust/src/detail/core_interop.rs @@ -378,29 +378,38 @@ impl CoreInterop { /// Async streaming variant that bridges the FFI callback into a /// [`tokio::sync::mpsc`] channel. /// - /// Returns a `Receiver` that yields each chunk as it arrives. - /// The FFI call runs on a dedicated blocking thread; the receiver can - /// be wrapped with [`tokio_stream::wrappers::ReceiverStream`] to get a - /// `Stream`. + /// Returns a `Receiver>` that yields each chunk as it + /// arrives. After all chunks have been delivered the final result from + /// the native core response buffer is sent through the same channel — + /// if the native core reported an error it will appear as an `Err` item. + /// The receiver can be wrapped with + /// [`tokio_stream::wrappers::ReceiverStream`] to get a `Stream`. pub async fn execute_command_streaming_channel( self: &Arc, command: String, params: Option, - ) -> Result<( - tokio::sync::mpsc::UnboundedReceiver, - tokio::task::JoinHandle>, - )> { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + ) -> Result>> { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); let this = Arc::clone(self); - let handle = tokio::task::spawn_blocking(move || { - this.execute_command_streaming(&command, params.as_ref(), move |chunk: &str| { - // Ignore send errors — the receiver was dropped. - let _ = tx.send(chunk.to_owned()); - }) + tokio::task::spawn_blocking(move || { + let tx_chunk = tx.clone(); + let result = this.execute_command_streaming( + &command, + params.as_ref(), + move |chunk: &str| { + let _ = tx_chunk.send(Ok(chunk.to_owned())); + }, + ); + + // Surface any error from the native core response buffer through + // the channel so it cannot be silently swallowed. + if let Err(e) = result { + let _ = tx.send(Err(e)); + } }); - Ok((rx, handle)) + Ok(rx) } /// Read a native response buffer field as a Rust `String`. diff --git a/sdk_v2/rust/src/openai/audio_client.rs b/sdk_v2/rust/src/openai/audio_client.rs index 74f973b1..b9577f81 100644 --- a/sdk_v2/rust/src/openai/audio_client.rs +++ b/sdk_v2/rust/src/openai/audio_client.rs @@ -5,7 +5,6 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use futures_core::Future; use serde_json::{json, Value}; use crate::detail::core_interop::CoreInterop; @@ -67,8 +66,7 @@ impl AudioClientSettings { /// /// Returned by [`AudioClient::transcribe_streaming`]. pub struct AudioTranscriptionStream { - rx: tokio::sync::mpsc::UnboundedReceiver, - handle: Option>>, + rx: tokio::sync::mpsc::UnboundedReceiver>, } impl futures_core::Stream for AudioTranscriptionStream { @@ -76,7 +74,7 @@ impl futures_core::Stream for AudioTranscriptionStream { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.rx.poll_recv(cx) { - Poll::Ready(Some(chunk)) => { + Poll::Ready(Some(Ok(chunk))) => { if chunk.is_empty() { cx.waker().wake_by_ref(); Poll::Pending @@ -86,52 +84,13 @@ impl futures_core::Stream for AudioTranscriptionStream { Poll::Ready(Some(parsed)) } } - Poll::Ready(None) => { - // Channel closed — check the JoinHandle for FFI errors that - // would otherwise be swallowed by tokio. - if let Some(handle) = self.handle.as_mut() { - match Pin::new(handle).poll(cx) { - Poll::Ready(Ok(Ok(_))) => { - self.handle.take(); - Poll::Ready(None) - } - Poll::Ready(Ok(Err(e))) => { - self.handle.take(); - Poll::Ready(Some(Err(e))) - } - Poll::Ready(Err(e)) => { - self.handle.take(); - Poll::Ready(Some(Err(FoundryLocalError::CommandExecution { - reason: format!("task join error: {e}"), - }))) - } - Poll::Pending => Poll::Pending, - } - } else { - Poll::Ready(None) - } - } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } } -impl AudioTranscriptionStream { - /// Consume the stream and wait for the background FFI task to finish. - pub async fn close(mut self) -> Result<()> { - if let Some(handle) = self.handle.take() { - handle - .await - .map_err(|e| FoundryLocalError::CommandExecution { - reason: format!("task join error: {e}"), - })? - .map(|_| ()) - } else { - Ok(()) - } - } -} - /// Client for OpenAI-compatible audio transcription backed by a local model. pub struct AudioClient { model_id: String, @@ -212,15 +171,12 @@ impl AudioClient { } }); - let (rx, handle) = self + let rx = self .core .execute_command_streaming_channel("audio_transcribe".into(), Some(params)) .await?; - Ok(AudioTranscriptionStream { - rx, - handle: Some(handle), - }) + Ok(AudioTranscriptionStream { rx }) } fn validate_path(path: &str) -> Result<()> { diff --git a/sdk_v2/rust/src/openai/chat_client.rs b/sdk_v2/rust/src/openai/chat_client.rs index f7ac0d48..974671c7 100644 --- a/sdk_v2/rust/src/openai/chat_client.rs +++ b/sdk_v2/rust/src/openai/chat_client.rs @@ -5,8 +5,6 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use futures_core::Future; - use async_openai::types::chat::{ ChatCompletionRequestMessage, ChatCompletionTools, CreateChatCompletionResponse, CreateChatCompletionStreamResponse, @@ -125,8 +123,7 @@ impl ChatClientSettings { /// /// Returned by [`ChatClient::complete_streaming_chat`]. pub struct ChatCompletionStream { - rx: tokio::sync::mpsc::UnboundedReceiver, - handle: Option>>, + rx: tokio::sync::mpsc::UnboundedReceiver>, } impl futures_core::Stream for ChatCompletionStream { @@ -134,7 +131,7 @@ impl futures_core::Stream for ChatCompletionStream { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match self.rx.poll_recv(cx) { - Poll::Ready(Some(chunk)) => { + Poll::Ready(Some(Ok(chunk))) => { if chunk.is_empty() { // Skip empty chunks and poll again. cx.waker().wake_by_ref(); @@ -145,55 +142,13 @@ impl futures_core::Stream for ChatCompletionStream { Poll::Ready(Some(parsed)) } } - Poll::Ready(None) => { - // Channel closed — check the JoinHandle for FFI errors that - // would otherwise be swallowed by tokio. - if let Some(handle) = self.handle.as_mut() { - match Pin::new(handle).poll(cx) { - Poll::Ready(Ok(Ok(_))) => { - self.handle.take(); - Poll::Ready(None) - } - Poll::Ready(Ok(Err(e))) => { - self.handle.take(); - Poll::Ready(Some(Err(e))) - } - Poll::Ready(Err(e)) => { - self.handle.take(); - Poll::Ready(Some(Err(FoundryLocalError::CommandExecution { - reason: format!("task join error: {e}"), - }))) - } - Poll::Pending => Poll::Pending, - } - } else { - Poll::Ready(None) - } - } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), Poll::Pending => Poll::Pending, } } } -impl ChatCompletionStream { - /// Consume the stream and wait for the background FFI task to finish. - /// - /// Call this after the stream is exhausted to retrieve any error from - /// the native core response buffer. - pub async fn close(mut self) -> Result<()> { - if let Some(handle) = self.handle.take() { - handle - .await - .map_err(|e| FoundryLocalError::CommandExecution { - reason: format!("task join error: {e}"), - })? - .map(|_| ()) - } else { - Ok(()) - } - } -} - /// Client for OpenAI-compatible chat completions backed by a local model. pub struct ChatClient { model_id: String, @@ -319,15 +274,12 @@ impl ChatClient { } }); - let (rx, handle) = self + let rx = self .core .execute_command_streaming_channel("chat_completions".into(), Some(params)) .await?; - Ok(ChatCompletionStream { - rx, - handle: Some(handle), - }) + Ok(ChatCompletionStream { rx }) } fn build_request( diff --git a/sdk_v2/rust/tests/integration/audio_client_test.rs b/sdk_v2/rust/tests/integration/audio_client_test.rs index da28b7cd..1e895609 100644 --- a/sdk_v2/rust/tests/integration/audio_client_test.rs +++ b/sdk_v2/rust/tests/integration/audio_client_test.rs @@ -70,7 +70,6 @@ async fn should_transcribe_audio_with_streaming() { let chunk = chunk.expect("stream chunk error"); full_text.push_str(&chunk.text); } - stream.close().await.expect("stream close failed"); println!("Streamed transcription: {full_text}"); @@ -98,7 +97,6 @@ async fn should_transcribe_audio_with_streaming_with_temperature() { let chunk = chunk.expect("stream chunk error"); full_text.push_str(&chunk.text); } - stream.close().await.expect("stream close failed"); println!("Streamed transcription: {full_text}"); diff --git a/sdk_v2/rust/tests/integration/chat_client_test.rs b/sdk_v2/rust/tests/integration/chat_client_test.rs index 69f70129..9b3e55d8 100644 --- a/sdk_v2/rust/tests/integration/chat_client_test.rs +++ b/sdk_v2/rust/tests/integration/chat_client_test.rs @@ -86,8 +86,6 @@ async fn should_perform_streaming_chat_completion() { } } } - stream.close().await.expect("stream close failed"); - println!("First turn: {first_result}"); assert!( @@ -111,8 +109,6 @@ async fn should_perform_streaming_chat_completion() { } } } - stream.close().await.expect("stream close failed"); - println!("Follow-up: {second_result}"); assert!( @@ -281,8 +277,6 @@ async fn should_perform_tool_calling_chat_completion_streaming() { } } } - stream.close().await.expect("stream close failed"); - assert_eq!( tool_call_name, "multiply", "Expected streamed tool call to 'multiply'" @@ -330,8 +324,6 @@ async fn should_perform_tool_calling_chat_completion_streaming() { } } } - stream.close().await.expect("stream close failed"); - println!("Streamed tool call result: {final_result}"); assert!( From 2e06a0d630af7995825b768720d4b788267242f4 Mon Sep 17 00:00:00 2001 From: samkemp Date: Thu, 12 Mar 2026 17:12:43 +0000 Subject: [PATCH 24/25] fix: remove close() calls from Rust samples MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Follow-up to the channel restructure — samples also referenced the now-removed close() method on ChatCompletionStream and AudioTranscriptionStream. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- samples/rust/audio-transcription-example/src/main.rs | 3 +-- samples/rust/native-chat-completions/Cargo.toml | 2 +- samples/rust/native-chat-completions/src/main.rs | 3 +-- samples/rust/tool-calling-foundry-local/src/main.rs | 6 ++---- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/samples/rust/audio-transcription-example/src/main.rs b/samples/rust/audio-transcription-example/src/main.rs index bd2141c1..6f9b3e9e 100644 --- a/samples/rust/audio-transcription-example/src/main.rs +++ b/samples/rust/audio-transcription-example/src/main.rs @@ -59,10 +59,9 @@ async fn main() -> Result<(), Box> { print!("{}", chunk.text); io::stdout().flush().ok(); } - stream.close().await?; println!("\n"); - // ── 6. Unload the model ────────────────────────────────────────────── + // ── 6. Unload the model────────────────────────────────────────────── println!("Unloading model..."); model.unload().await?; println!("Done."); diff --git a/samples/rust/native-chat-completions/Cargo.toml b/samples/rust/native-chat-completions/Cargo.toml index 183b99dd..bec8e734 100644 --- a/samples/rust/native-chat-completions/Cargo.toml +++ b/samples/rust/native-chat-completions/Cargo.toml @@ -5,6 +5,6 @@ edition = "2021" description = "Native SDK chat completions (non-streaming and streaming) using the Foundry Local Rust SDK" [dependencies] -foundry-local-sdk = { path = "../../../sdk_v2/rust" } +foundry-local-sdk = { path = "../../../sdk_v2/rust", features = ["winml"] } tokio = { version = "1", features = ["rt-multi-thread", "macros"] } tokio-stream = "0.1" diff --git a/samples/rust/native-chat-completions/src/main.rs b/samples/rust/native-chat-completions/src/main.rs index 68dec925..4b311d8b 100644 --- a/samples/rust/native-chat-completions/src/main.rs +++ b/samples/rust/native-chat-completions/src/main.rs @@ -77,10 +77,9 @@ async fn main() -> Result<(), Box> { } } } - stream.close().await?; println!("\n"); - // ── 6. Unload the model ────────────────────────────────────────────── + // ── 6. Unload the model────────────────────────────────────────────── println!("Unloading model..."); model.unload().await?; println!("Done."); diff --git a/samples/rust/tool-calling-foundry-local/src/main.rs b/samples/rust/tool-calling-foundry-local/src/main.rs index 5a70a2e4..21be3e6e 100644 --- a/samples/rust/tool-calling-foundry-local/src/main.rs +++ b/samples/rust/tool-calling-foundry-local/src/main.rs @@ -154,10 +154,9 @@ async fn main() -> Result<(), Box> { } } } - stream.close().await?; println!(); - // ── 5. Execute the tool(s) and append results ──────────────────────── + // ── 5. Execute the tool(s)and append results ──────────────────────── for tc in &state.tool_calls { let func = &tc["function"]; let name = func["name"].as_str().unwrap_or_default(); @@ -209,10 +208,9 @@ async fn main() -> Result<(), Box> { } } } - stream.close().await?; println!("\n"); - // ── 7. Clean up ────────────────────────────────────────────────────── + // ── 7. Clean up────────────────────────────────────────────────────── println!("Unloading model..."); model.unload().await?; println!("Done."); From b0f38787ee79722f1e6ab584d301b1eb8f02ba84 Mon Sep 17 00:00:00 2001 From: samkemp Date: Thu, 12 Mar 2026 17:27:47 +0000 Subject: [PATCH 25/25] remove winml in sample --- samples/rust/native-chat-completions/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/samples/rust/native-chat-completions/Cargo.toml b/samples/rust/native-chat-completions/Cargo.toml index bec8e734..183b99dd 100644 --- a/samples/rust/native-chat-completions/Cargo.toml +++ b/samples/rust/native-chat-completions/Cargo.toml @@ -5,6 +5,6 @@ edition = "2021" description = "Native SDK chat completions (non-streaming and streaming) using the Foundry Local Rust SDK" [dependencies] -foundry-local-sdk = { path = "../../../sdk_v2/rust", features = ["winml"] } +foundry-local-sdk = { path = "../../../sdk_v2/rust" } tokio = { version = "1", features = ["rt-multi-thread", "macros"] } tokio-stream = "0.1"