diff --git a/.github/workflows/build-rust-steps.yml b/.github/workflows/build-rust-steps.yml new file mode 100644 index 00000000..ef4a349f --- /dev/null +++ b/.github/workflows/build-rust-steps.yml @@ -0,0 +1,112 @@ +name: Build Rust SDK + +on: + workflow_call: + inputs: + platform: + required: false + type: string + default: 'ubuntu' # or 'windows' or 'macos' + useWinML: + required: false + type: boolean + default: false + run-integration-tests: + required: false + type: boolean + default: true + +permissions: + contents: read + +jobs: + build: + runs-on: ${{ inputs.platform }}-latest + + defaults: + run: + working-directory: sdk_v2/rust + + env: + CARGO_FEATURES: ${{ inputs.useWinML && '--features winml' || false }} + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + clean: true + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + components: clippy, rustfmt + + - name: Cache cargo dependencies + uses: Swatinem/rust-cache@v2 + with: + workspaces: sdk_v2/rust -> target + + - name: Checkout test-data-shared from Azure DevOps + if: ${{ inputs.run-integration-tests }} + shell: pwsh + working-directory: ${{ github.workspace }}/.. + run: | + $pat = "${{ secrets.AZURE_DEVOPS_PAT }}" + $encodedPat = [Convert]::ToBase64String([Text.Encoding]::ASCII.GetBytes(":$pat")) + + # Configure git to use the PAT + git config --global http.https://dev.azure.com.extraheader "AUTHORIZATION: Basic $encodedPat" + + # Clone with LFS to parent directory + git lfs install + git clone --depth 1 https://dev.azure.com/microsoft/windows.ai.toolkit/_git/test-data-shared test-data-shared + + Write-Host "Clone completed successfully to ${{ github.workspace }}/../test-data-shared" + + - name: Checkout specific commit in test-data-shared + if: ${{ inputs.run-integration-tests }} + shell: pwsh + working-directory: ${{ github.workspace }}/../test-data-shared + run: | + Write-Host "Current directory: $(Get-Location)" + git checkout 231f820fe285145b7ea4a449b112c1228ce66a41 + if ($LASTEXITCODE -ne 0) { + Write-Error "Git checkout failed." + exit 1 + } + Write-Host "`nDirectory contents:" + Get-ChildItem -Recurse -Depth 2 | ForEach-Object { Write-Host " $($_.FullName)" } + + - name: Check formatting + run: cargo fmt --all -- --check + + # Run Clippy - Rust's official linter for catching common mistakes, enforcing idioms, and improving code quality + - name: Run clippy + run: cargo clippy --all-targets ${{ env.CARGO_FEATURES }} -- -D warnings + + - name: Build + run: cargo build ${{ env.CARGO_FEATURES }} + + - name: Run unit tests + run: cargo test --lib ${{ env.CARGO_FEATURES }} + + - name: Run integration tests + if: ${{ inputs.run-integration-tests }} + run: cargo test --tests ${{ env.CARGO_FEATURES }} -- --include-ignored --test-threads=1 --nocapture + + # --allow-dirty allows publishing with uncommitted changes, needed because the build process modifies generated files + - name: Package crate + run: cargo package ${{ env.CARGO_FEATURES }} --allow-dirty + + - name: Upload SDK artifact + uses: actions/upload-artifact@v4 + with: + name: rust-sdk-${{ inputs.platform }}${{ inputs.useWinML == true && '-winml' || '' }} + path: sdk_v2/rust/target/package/*.crate + + - name: Upload flcore logs + uses: actions/upload-artifact@v4 + if: always() + with: + name: rust-sdk-${{ inputs.platform }}${{ inputs.useWinML == true && '-winml' || '' }}-logs + path: sdk_v2/rust/logs/** diff --git a/.github/workflows/foundry-local-sdk-build.yml b/.github/workflows/foundry-local-sdk-build.yml index 1190ac90..604cb2d3 100644 --- a/.github/workflows/foundry-local-sdk-build.yml +++ b/.github/workflows/foundry-local-sdk-build.yml @@ -29,6 +29,12 @@ jobs: version: '0.9.0.${{ github.run_number }}' platform: 'windows' secrets: inherit + build-rust-windows: + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'windows' + run-integration-tests: true + secrets: inherit build-cs-windows-WinML: uses: ./.github/workflows/build-cs-steps.yml @@ -44,7 +50,14 @@ jobs: platform: 'windows' useWinML: true secrets: inherit - + build-rust-windows-WinML: + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'windows' + useWinML: true + run-integration-tests: true + secrets: inherit + build-cs-macos: uses: ./.github/workflows/build-cs-steps.yml with: @@ -56,4 +69,10 @@ jobs: with: version: '0.9.0.${{ github.run_number }}' platform: 'macos' + secrets: inherit + build-rust-macos: + uses: ./.github/workflows/build-rust-steps.yml + with: + platform: 'macos' + run-integration-tests: true secrets: inherit \ No newline at end of file diff --git a/.github/workflows/rustfmt.yml b/.github/workflows/rustfmt.yml deleted file mode 100644 index 28e61acc..00000000 --- a/.github/workflows/rustfmt.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: Rust-fmt - -on: - pull_request: - paths: - - 'sdk/rust/**' - - 'samples/rust/**' - push: - paths: - - 'sdk/rust/**' - - 'samples/rust/**' - branches: - - main - workflow_dispatch: - -jobs: - check: - runs-on: ubuntu-22.04 - steps: - - name: Checkout - uses: actions/checkout@v4 - - name: Update toolchain - run: rustup update --no-self-update stable && rustup default stable - - name: Check SDK - working-directory: sdk/rust - run: cargo fmt --all -- --check - - name: Check Samples - working-directory: samples/rust - run: cargo fmt --all -- --check diff --git a/samples/rust/Cargo.toml b/samples/rust/Cargo.toml index 97a5f824..bdc9ee44 100644 --- a/samples/rust/Cargo.toml +++ b/samples/rust/Cargo.toml @@ -1,5 +1,8 @@ [workspace] members = [ - "hello-foundry-local" + "foundry-local-webserver", + "tool-calling-foundry-local", + "native-chat-completions", + "audio-transcription-example", ] resolver = "2" diff --git a/samples/rust/README.md b/samples/rust/README.md index 9abb172d..c5399b3d 100644 --- a/samples/rust/README.md +++ b/samples/rust/README.md @@ -5,14 +5,21 @@ This directory contains samples demonstrating how to use the Foundry Local Rust ## Prerequisites - Rust 1.70.0 or later -- Foundry Local installed and available on PATH ## Samples -### [Hello Foundry Local](./hello-foundry-local) +### [Foundry Local Web Server](./foundry-local-webserver) -A simple example that demonstrates how to: -- Start the Foundry Local service -- Download and load a model -- Send a prompt to the model using the OpenAI-compatible API -- Display the response from the model \ No newline at end of file +Demonstrates how to start a local OpenAI-compatible web server using the SDK, then call it with a standard HTTP client. + +### [Native Chat Completions](./native-chat-completions) + +Shows both non-streaming and streaming chat completions using the SDK's native chat client. + +### [Tool Calling with Foundry Local](./tool-calling-foundry-local) + +Demonstrates tool calling with streaming responses, multi-turn conversation, and local tool execution. + +### [Audio Transcription](./audio-transcription-example) + +Demonstrates audio transcription (non-streaming and streaming) using the `whisper` model. \ No newline at end of file diff --git a/samples/rust/audio-transcription-example/Cargo.toml b/samples/rust/audio-transcription-example/Cargo.toml new file mode 100644 index 00000000..2fa535b3 --- /dev/null +++ b/samples/rust/audio-transcription-example/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "audio-transcription-example" +version = "0.1.0" +edition = "2021" +description = "Audio transcription example using the Foundry Local Rust SDK" + +[dependencies] +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-stream = "0.1" diff --git a/samples/rust/audio-transcription-example/README.md b/samples/rust/audio-transcription-example/README.md new file mode 100644 index 00000000..240bd7df --- /dev/null +++ b/samples/rust/audio-transcription-example/README.md @@ -0,0 +1,25 @@ +# Sample: Audio Transcription + +This example demonstrates audio transcription (non-streaming and streaming) using the Foundry Local Rust SDK. It uses the `whisper` model to transcribe a WAV audio file. + +The `foundry-local-sdk` dependency is referenced via a local path. No crates.io publish is required: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +``` + +Run the application with a path to a WAV file: + +```bash +cargo run -- path/to/audio.wav +``` + +## Using WinML (Windows only) + +To use the WinML backend, enable the `winml` feature in `Cargo.toml`: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust", features = ["winml"] } +``` + +No code changes are needed — same API, different backend. diff --git a/samples/rust/audio-transcription-example/src/main.rs b/samples/rust/audio-transcription-example/src/main.rs new file mode 100644 index 00000000..6f9b3e9e --- /dev/null +++ b/samples/rust/audio-transcription-example/src/main.rs @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use std::env; +use std::io::{self, Write}; + +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; +use tokio_stream::StreamExt; + +const ALIAS: &str = "whisper-tiny"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("Audio Transcription Example"); + println!("===========================\n"); + + // Accept an audio file path as a CLI argument. + let audio_path = env::args().nth(1).unwrap_or_else(|| { + eprintln!("Usage: cargo run -- "); + std::process::exit(1); + }); + + // ── 1. Initialise the manager ──────────────────────────────────────── + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; + + // ── 2. Pick the whisper model and ensure it is downloaded ──────────── + let model = manager.catalog().get_model(ALIAS).await?; + println!("Model: {} (id: {})", model.alias(), model.id()); + + if !model.is_cached().await? { + println!("Downloading model..."); + model + .download(Some(|progress: &str| { + print!("\r {progress}%"); + io::stdout().flush().ok(); + })) + .await?; + println!(); + } + + println!("Loading model..."); + model.load().await?; + println!("✓ Model loaded\n"); + + // ── 3. Create an audio client ──────────────────────────────────────── + let audio_client = model.create_audio_client(); + + // ── 4. Non-streaming transcription ─────────────────────────────────── + println!("--- Non-streaming transcription ---"); + let result = audio_client.transcribe(&audio_path).await?; + println!("Transcription: {}", result.text); + + // ── 5. Streaming transcription ─────────────────────────────────────── + println!("--- Streaming transcription ---"); + print!("Transcription: "); + let mut stream = audio_client.transcribe_streaming(&audio_path).await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + print!("{}", chunk.text); + io::stdout().flush().ok(); + } + println!("\n"); + + // ── 6. Unload the model────────────────────────────────────────────── + println!("Unloading model..."); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/samples/rust/foundry-local-webserver/Cargo.toml b/samples/rust/foundry-local-webserver/Cargo.toml new file mode 100644 index 00000000..4a20c1f8 --- /dev/null +++ b/samples/rust/foundry-local-webserver/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "foundry-local-webserver" +version = "0.1.0" +edition = "2021" +description = "Example of using the Foundry Local SDK with a local OpenAI-compatible web server" + +[dependencies] +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +serde_json = "1" +reqwest = { version = "0.12", features = ["json"] } diff --git a/samples/rust/foundry-local-webserver/README.md b/samples/rust/foundry-local-webserver/README.md new file mode 100644 index 00000000..f034f2a0 --- /dev/null +++ b/samples/rust/foundry-local-webserver/README.md @@ -0,0 +1,25 @@ +# Sample: Foundry Local Web Server + +This example demonstrates how to start a local OpenAI-compatible web server using the Foundry Local SDK, then call it with a standard HTTP client. This is useful when you want to use the OpenAI REST API directly or integrate with tools that expect an OpenAI-compatible endpoint. + +The `foundry-local-sdk` dependency is referenced via a local path. No crates.io publish is required: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +``` + +Run the application: + +```bash +cargo run +``` + +## Using WinML (Windows only) + +To use the WinML backend, enable the `winml` feature in `Cargo.toml`: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust", features = ["winml"] } +``` + +No code changes are needed — same API, different backend. diff --git a/samples/rust/foundry-local-webserver/src/main.rs b/samples/rust/foundry-local-webserver/src/main.rs new file mode 100644 index 00000000..e5ed3ae8 --- /dev/null +++ b/samples/rust/foundry-local-webserver/src/main.rs @@ -0,0 +1,102 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Foundry Local Web Server example. +//! +//! Demonstrates how to start a local OpenAI-compatible web server using the +//! Foundry Local SDK, then call it with a standard HTTP client. This is useful +//! when you want to use the OpenAI REST API directly or integrate with tools +//! that expect an OpenAI-compatible endpoint. + +use std::io::{self, Write}; + +use serde_json::json; + +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // ── 1. Initialise the SDK ──────────────────────────────────────────── + println!("Initializing Foundry Local SDK..."); + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; + println!("✓ SDK initialized"); + + // ── 2. Download and load a model ───────────────────────────────────── + let model_alias = "qwen2.5-0.5b"; + let model = manager.catalog().get_model(model_alias).await?; + + if !model.is_cached().await? { + print!("Downloading model {model_alias}..."); + model + .download(Some(move |progress: &str| { + print!("\rDownloading model... {progress}%"); + io::stdout().flush().ok(); + })) + .await?; + println!(); + } + + print!("Loading model {model_alias}..."); + model.load().await?; + println!("done."); + + // ── 3. Start the web service ───────────────────────────────────────── + print!("Starting web service..."); + let urls = manager.start_web_service().await?; + println!("done."); + + let endpoint = urls + .first() + .expect("Web service did not return an endpoint"); + println!("Web service listening on: {endpoint}"); + + // ── 4. Use the OpenAI-compatible REST API with streaming ──────────── + // Any HTTP client (or OpenAI SDK) can now talk to this endpoint. + let client = reqwest::Client::new(); + let base_url = endpoint.trim_end_matches('/'); + + let mut response = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&json!({ + "model": model.id(), + "messages": [ + { "role": "user", "content": "Why is the sky blue?" } + ], + "stream": true + })) + .send() + .await?; + + print!("[ASSISTANT]: "); + while let Some(chunk) = response.chunk().await? { + let text = String::from_utf8_lossy(&chunk); + for line in text.lines() { + let line = line.trim(); + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + break; + } + if let Ok(parsed) = serde_json::from_str::(data) { + if let Some(content) = parsed + .pointer("/choices/0/delta/content") + .and_then(|v| v.as_str()) + { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + } + } + println!(); + + // ── 5. Clean up ────────────────────────────────────────────────────── + println!("\nStopping web service..."); + manager.stop_web_service().await?; + + println!("Unloading model..."); + model.unload().await?; + + println!("✓ Done."); + Ok(()) +} diff --git a/samples/rust/hello-foundry-local/Cargo.toml b/samples/rust/hello-foundry-local/Cargo.toml deleted file mode 100644 index 2d0d58f1..00000000 --- a/samples/rust/hello-foundry-local/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "hello-foundry-local" -version = "0.1.0" -edition = "2021" -description = "A simple example of using the Foundry Local Rust SDK" - -[dependencies] -foundry-local = { path = "../../../sdk/rust" } -tokio = { version = "1", features = ["full"] } -anyhow = "1.0" -reqwest = { version = "0.11", features = ["json"] } -serde_json = "1.0" -env_logger = "0.10" \ No newline at end of file diff --git a/samples/rust/hello-foundry-local/README.md b/samples/rust/hello-foundry-local/README.md deleted file mode 100644 index 49d27d48..00000000 --- a/samples/rust/hello-foundry-local/README.md +++ /dev/null @@ -1,30 +0,0 @@ -# Hello Foundry Local (Rust) - -A simple example that demonstrates using the Foundry Local Rust SDK to interact with AI models locally. - -## Prerequisites - -- Rust 1.70.0 or later -- Foundry Local installed and available on PATH - -## Running the Sample - -1. Make sure Foundry Local is installed -2. Run the sample: - -```bash -cargo run -``` - -## What This Sample Does - -1. Creates a FoundryLocalManager instance -2. Starts the Foundry Local service if it's not already running -3. Downloads and loads the phi-3-mini-4k model -4. Sends a prompt to the model using the OpenAI-compatible API -5. Displays the response from the model - -## Code Structure - -- `src/main.rs` - The main application code -- `Cargo.toml` - Project configuration and dependencies \ No newline at end of file diff --git a/samples/rust/hello-foundry-local/src/main.rs b/samples/rust/hello-foundry-local/src/main.rs deleted file mode 100644 index 7320a2db..00000000 --- a/samples/rust/hello-foundry-local/src/main.rs +++ /dev/null @@ -1,85 +0,0 @@ -use anyhow::Result; -use foundry_local::FoundryLocalManager; - -#[tokio::main] -async fn main() -> Result<()> { - // Set up logging - env_logger::init_from_env( - env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "info"), - ); - - println!("Hello Foundry Local!"); - println!("==================="); - - // For this example, we will use the "phi-3-mini-4k" model which is 2.181 GB in size. - let model_to_use: &str = "phi-3-mini-4k"; - - // Create a FoundryLocalManager instance using the builder pattern - println!("\nInitializing Foundry Local manager..."); - let mut manager = FoundryLocalManager::builder() - // Alternatively to the checks below, you can specify the model to use directly during bootstrapping - // .alias_or_model_id(model_to_use) - .bootstrap(true) // Start the service if not running - .build() - .await?; - - // List all the models in the catalog - println!("\nAvailable models in catalog:"); - let models = manager.list_catalog_models().await?; - let model_in_catalog = models.iter().any(|m| m.alias == model_to_use); - for model in models { - println!("- {model}"); - } - // Check if the model is in the catalog - if !model_in_catalog { - println!("Model '{model_to_use}' not found in catalog. Exiting."); - return Ok(()); - } - - // List available models in the local cache - println!("\nAvailable models in local cache:"); - let models = manager.list_cached_models().await?; - let model_in_cache = models.iter().any(|m| m.alias == model_to_use); - for model in models { - println!("- {model}"); - } - - // Check if the model is already cached and download if not - if !model_in_cache { - println!("Model '{model_to_use}' not found in local cache. Downloading..."); - // Download the model if not in cache - // NOTE if you've bootstrapped with `alias_or_model_id`, you can use that directly and skip this check - manager.download_model(model_to_use, None, false).await?; - println!("Model '{model_to_use}' downloaded successfully."); - } - - // Get the model information - let model_info = manager.get_model_info(model_to_use, true).await?; - println!("\nUsing model: {model_info}"); - - // Build the prompt - let prompt = "What is the golden ratio?"; - println!("\nPrompt: {prompt}"); - - // Use the OpenAI compatible API to interact with the model - let client = reqwest::Client::new(); - let response = client - .post(format!("{}/chat/completions", manager.endpoint()?)) - .json(&serde_json::json!({ - "model": model_info.id, - "messages": [{"role": "user", "content": prompt}], - })) - .send() - .await?; - - // Parse and display the response - let result = response.json::().await?; - if let Some(content) = result["choices"][0]["message"]["content"].as_str() { - println!("\nResponse:\n{content}"); - } else { - println!("\nError: Failed to extract response content from API result"); - println!("Full API response: {result}"); - } - - Ok(()) -} diff --git a/samples/rust/native-chat-completions/Cargo.toml b/samples/rust/native-chat-completions/Cargo.toml new file mode 100644 index 00000000..183b99dd --- /dev/null +++ b/samples/rust/native-chat-completions/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "native-chat-completions" +version = "0.1.0" +edition = "2021" +description = "Native SDK chat completions (non-streaming and streaming) using the Foundry Local Rust SDK" + +[dependencies] +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-stream = "0.1" diff --git a/samples/rust/native-chat-completions/README.md b/samples/rust/native-chat-completions/README.md new file mode 100644 index 00000000..7645f642 --- /dev/null +++ b/samples/rust/native-chat-completions/README.md @@ -0,0 +1,25 @@ +# Sample: Native Chat Completions + +This example demonstrates both non-streaming and streaming chat completions using the Foundry Local Rust SDK's native chat client — no external HTTP libraries needed. + +The `foundry-local-sdk` dependency is referenced via a local path. No crates.io publish is required: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +``` + +Run the application: + +```bash +cargo run +``` + +## Using WinML (Windows only) + +To use the WinML backend, enable the `winml` feature in `Cargo.toml`: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust", features = ["winml"] } +``` + +No code changes are needed — same API, different backend. diff --git a/samples/rust/native-chat-completions/src/main.rs b/samples/rust/native-chat-completions/src/main.rs new file mode 100644 index 00000000..4b311d8b --- /dev/null +++ b/samples/rust/native-chat-completions/src/main.rs @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use std::io::{self, Write}; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, FoundryLocalConfig, FoundryLocalManager, +}; +use tokio_stream::StreamExt; + +const ALIAS: &str = "qwen2.5-0.5b"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("Native Chat Completions"); + println!("=======================\n"); + + // ── 1. Initialise the manager ──────────────────────────────────────── + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; + + // ── 2. Pick a model and ensure it is downloaded ────────────────────── + let model = manager.catalog().get_model(ALIAS).await?; + println!("Model: {} (id: {})", model.alias(), model.id()); + + if !model.is_cached().await? { + println!("Downloading model..."); + model + .download(Some(|progress: &str| { + print!("\r {progress}%"); + io::stdout().flush().ok(); + })) + .await?; + println!(); + } + + println!("Loading model..."); + model.load().await?; + println!("✓ Model loaded\n"); + + // ── 3. Create a chat client ────────────────────────────────────────── + let mut client = model.create_chat_client(); + client.temperature(0.7).max_tokens(256); + + // ── 4. Non-streaming chat completion ───────────────────────────────── + let messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("What is Rust's ownership model?").into(), + ]; + + println!("--- Non-streaming completion ---"); + let response = client.complete_chat(&messages, None).await?; + if let Some(choice) = response.choices.first() { + if let Some(ref content) = choice.message.content { + println!("Assistant: {content}"); + } + } + + // ── 5. Streaming chat completion ───────────────────────────────────── + let stream_messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("Explain the borrow checker in two sentences.") + .into(), + ]; + + println!("\n--- Streaming completion ---"); + print!("Assistant: "); + let mut stream = client + .complete_streaming_chat(&stream_messages, None) + .await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + println!("\n"); + + // ── 6. Unload the model────────────────────────────────────────────── + println!("Unloading model..."); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/samples/rust/tool-calling-foundry-local/Cargo.toml b/samples/rust/tool-calling-foundry-local/Cargo.toml new file mode 100644 index 00000000..73e59316 --- /dev/null +++ b/samples/rust/tool-calling-foundry-local/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "tool-calling-foundry-local" +version = "0.1.0" +edition = "2021" +description = "Tool calling example using the Foundry Local Rust SDK" + +[dependencies] +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-stream = "0.1" +serde_json = "1" diff --git a/samples/rust/tool-calling-foundry-local/README.md b/samples/rust/tool-calling-foundry-local/README.md new file mode 100644 index 00000000..70534c82 --- /dev/null +++ b/samples/rust/tool-calling-foundry-local/README.md @@ -0,0 +1,25 @@ +# Sample: Tool Calling with Foundry Local + +This is a simple example of how to use the Foundry Local Rust SDK to run a model locally and perform tool calling with it. The example demonstrates how to set up the SDK, initialize a model, and perform a generated tool call. + +The `foundry-local-sdk` dependency is referenced via a local path. No crates.io publish is required: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust" } +``` + +Run the application: + +```bash +cargo run +``` + +## Using WinML (Windows only) + +To use the WinML backend, enable the `winml` feature in `Cargo.toml`: + +```toml +foundry-local-sdk = { path = "../../../sdk_v2/rust", features = ["winml"] } +``` + +No code changes are needed — same API, different backend. diff --git a/samples/rust/tool-calling-foundry-local/src/main.rs b/samples/rust/tool-calling-foundry-local/src/main.rs new file mode 100644 index 00000000..21be3e6e --- /dev/null +++ b/samples/rust/tool-calling-foundry-local/src/main.rs @@ -0,0 +1,219 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use std::io::{self, Write}; + +use serde_json::{json, Value}; +use tokio_stream::StreamExt; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestToolMessage, ChatCompletionRequestUserMessage, ChatCompletionTools, + ChatToolChoice, FinishReason, FoundryLocalConfig, FoundryLocalManager, +}; + +// By using an alias, the most suitable model variant will be downloaded +// to your end-user's device. +const ALIAS: &str = "qwen2.5-0.5b"; + +/// A simple tool that multiplies two numbers. +fn multiply_numbers(first: f64, second: f64) -> f64 { + first * second +} + +/// Dispatch a tool call by name and parsed arguments. +fn invoke_tool(name: &str, args: &Value) -> String { + match name { + "multiply_numbers" => { + let first = args.get("first").and_then(|v| v.as_f64()).unwrap_or(0.0); + let second = args.get("second").and_then(|v| v.as_f64()).unwrap_or(0.0); + let result = multiply_numbers(first, second); + result.to_string() + } + _ => format!("Unknown tool: {name}"), + } +} + +/// Accumulated state from a streaming response that contains tool calls. +#[derive(Default)] +struct ToolCallState { + tool_calls: Vec, + current_tool_id: String, + current_tool_name: String, + current_tool_args: String, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("Tool Calling with Foundry Local"); + println!("===============================\n"); + + // ── 1. Initialise the manager ──────────────────────────────────────── + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; + + // ── 2. Load a model ────────────────────────────────────────────────── + let model = manager.catalog().get_model(ALIAS).await?; + println!("Model: {} (id: {})", model.alias(), model.id()); + + if !model.is_cached().await? { + println!("Downloading model..."); + model + .download(Some(|progress: &str| { + print!("\r {progress}%"); + io::stdout().flush().ok(); + })) + .await?; + println!(); + } + + println!("Loading model..."); + model.load().await?; + println!("✓ Model loaded\n"); + + // ── 3. Create a chat client with tool_choice = required ────────────── + let mut client = model.create_chat_client(); + client.max_tokens(512).tool_choice(ChatToolChoice::Required); + + // Define the multiply_numbers tool. + let tools: Vec = serde_json::from_value(json!([{ + "type": "function", + "function": { + "name": "multiply_numbers", + "description": "A tool for multiplying two numbers.", + "parameters": { + "type": "object", + "properties": { + "first": { + "type": "integer", + "description": "The first number in the operation" + }, + "second": { + "type": "integer", + "description": "The second number in the operation" + } + }, + "required": ["first", "second"] + } + } + }]))?; + + // Prepare the initial conversation. + let mut messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from( + "You are a helpful AI assistant. If necessary, you can use any provided tools to answer the question.", + ) + .into(), + ChatCompletionRequestUserMessage::from("What is the answer to 7 multiplied by 6?").into(), + ]; + + // ── 4. First streaming call – expect tool_calls ────────────────────── + println!("Chat completion response:"); + + let mut state = ToolCallState::default(); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await?; + + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + // Accumulate streamed content (if any). + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + + // Accumulate tool call fragments. + if let Some(ref tool_calls) = choice.delta.tool_calls { + for tc in tool_calls { + if let Some(ref id) = tc.id { + state.current_tool_id = id.clone(); + } + if let Some(ref func) = tc.function { + if let Some(ref name) = func.name { + state.current_tool_name = name.clone(); + } + if let Some(ref args) = func.arguments { + state.current_tool_args.push_str(args); + } + } + } + } + + // When the model signals finish_reason = ToolCalls, finalise. + if choice.finish_reason == Some(FinishReason::ToolCalls) { + let tc = json!({ + "id": state.current_tool_id.clone(), + "type": "function", + "function": { + "name": state.current_tool_name.clone(), + "arguments": state.current_tool_args.clone(), + } + }); + state.tool_calls.push(tc); + } + } + } + println!(); + + // ── 5. Execute the tool(s)and append results ──────────────────────── + for tc in &state.tool_calls { + let func = &tc["function"]; + let name = func["name"].as_str().unwrap_or_default(); + let args_str = func["arguments"].as_str().unwrap_or("{}"); + let args: Value = serde_json::from_str(args_str).unwrap_or(json!({})); + + println!("\nInvoking tool: {name} with arguments {args}"); + let result = invoke_tool(name, &args); + println!("Tool response: {result}"); + + // Append the assistant's tool_calls message and the tool result. + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [tc], + }))?; + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: result.into(), + tool_call_id: tc["id"].as_str().unwrap_or_default().to_string(), + } + .into(), + ); + } + + // ── 6. Continue the conversation with auto tool_choice ─────────────── + println!("\nTool calls completed. Prompting model to continue conversation...\n"); + + messages.push( + ChatCompletionRequestSystemMessage::from( + "Respond only with the answer generated by the tool.", + ) + .into(), + ); + + client.tool_choice(ChatToolChoice::Auto); + + print!("Chat completion response: "); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + println!("\n"); + + // ── 7. Clean up────────────────────────────────────────────────────── + println!("Unloading model..."); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/sdk_v2/rust/.clippy.toml b/sdk_v2/rust/.clippy.toml new file mode 100644 index 00000000..1d42f2f1 --- /dev/null +++ b/sdk_v2/rust/.clippy.toml @@ -0,0 +1,2 @@ +# Clippy configuration for Foundry Local Rust SDK +msrv = "1.70" diff --git a/sdk_v2/rust/.rustfmt.toml b/sdk_v2/rust/.rustfmt.toml new file mode 100644 index 00000000..dce363ed --- /dev/null +++ b/sdk_v2/rust/.rustfmt.toml @@ -0,0 +1,3 @@ +edition = "2021" +max_width = 100 +use_field_init_shorthand = true diff --git a/sdk_v2/rust/Cargo.toml b/sdk_v2/rust/Cargo.toml new file mode 100644 index 00000000..cb2cfecf --- /dev/null +++ b/sdk_v2/rust/Cargo.toml @@ -0,0 +1,44 @@ +[package] +name = "foundry-local-sdk" +version = "0.1.0" +edition = "2021" +license = "MIT" +readme = "README.md" +description = "Local AI model inference powered by the Foundry Local Core engine" + +[features] +default = [] +winml = [] +nightly = [] + +[dependencies] +libloading = "0.8" +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "2" +tokio = { version = "1", features = ["rt-multi-thread", "macros", "sync"] } +tokio-stream = "0.1" +futures-core = "0.3" +reqwest = { version = "0.12", features = ["json"] } +async-openai = { version = "0.33", default-features = false, features = ["chat-completion-types"] } + +[build-dependencies] +ureq = "3" +zip = "2" +serde_json = "1" +serde = { version = "1", features = ["derive"] } + +[[example]] +name = "chat_completion" +path = "examples/chat_completion.rs" + +[[example]] +name = "tool_calling" +path = "examples/tool_calling.rs" + +[[example]] +name = "interactive_chat" +path = "examples/interactive_chat.rs" + +[lints.clippy] +all = { level = "warn", priority = -1 } diff --git a/sdk_v2/rust/GENERATE-DOCS.md b/sdk_v2/rust/GENERATE-DOCS.md new file mode 100644 index 00000000..00ba2e0f --- /dev/null +++ b/sdk_v2/rust/GENERATE-DOCS.md @@ -0,0 +1,41 @@ +# Generating API Reference Docs + +The Rust SDK uses `cargo doc` to generate API documentation from `///` doc comments in the source code. + +## Viewing Docs Locally + +To generate and open the API docs in your browser: + +```bash +cd sdk_v2/rust +cargo doc --no-deps --open +``` + +This generates HTML documentation at `target/doc/foundry_local_sdk/index.html`. + +## Public API Surface + +The SDK re-exports all public types from the crate root. Key modules: + +| Module / Type | Description | +|---|---| +| `FoundryLocalManager` | Singleton entry point — SDK initialisation, web service lifecycle | +| `FoundryLocalConfig` | Configuration (app name, log level, service endpoint) | +| `Catalog` | Model discovery and lookup | +| `Model` | Grouped model (alias → best variant) | +| `ModelVariant` | Single variant — download, load, unload | +| `ChatClient` | OpenAI-compatible chat completions (sync + streaming) | +| `AudioClient` | OpenAI-compatible audio transcription (sync + streaming) | +| `CreateChatCompletionResponse` | Typed chat completion response (from `async-openai`) | +| `CreateChatCompletionStreamResponse` | Typed streaming chat chunk (from `async-openai`) | +| `AudioTranscriptionResponse` | Typed audio transcription response | +| `FoundryLocalError` | Error enum with variants for all failure modes | + +## Notes + +- Unlike the C# and JS SDKs which commit generated markdown docs, Rust's convention is to generate HTML docs on demand with `cargo doc`. +- Once the crate is published to crates.io, docs will be automatically hosted at [docs.rs](https://docs.rs). +- Use `--document-private-items` to include internal/private API in the generated docs: + ```bash + cargo doc --no-deps --document-private-items --open + ``` diff --git a/sdk_v2/rust/README.md b/sdk_v2/rust/README.md new file mode 100644 index 00000000..cf4ddddf --- /dev/null +++ b/sdk_v2/rust/README.md @@ -0,0 +1,475 @@ +# Foundry Local Rust SDK + +The Foundry Local Rust SDK provides an async Rust interface for running AI models locally on your machine. Discover, download, load, and run inference — all without cloud dependencies. + +## Features + +- **Local-first AI** — Run models entirely on your machine with no cloud calls +- **Model catalog** — Browse and discover available models; check what's cached or loaded +- **Automatic model management** — Download, load, unload, and remove models from cache +- **Chat completions** — OpenAI-compatible chat API with both non-streaming and streaming responses +- **Audio transcription** — Transcribe audio files locally with streaming support +- **Tool calling** — Function/tool calling with streaming, multi-turn conversation support +- **Response format control** — Text, JSON, JSON Schema, and Lark grammar constrained output +- **Multi-variant models** — Models can have multiple variants (e.g., different quantizations) with automatic selection of the best cached variant +- **Embedded web service** — Start a local HTTP server for OpenAI-compatible API access +- **WinML support** — Automatic execution provider download on Windows for NPU/GPU acceleration +- **Configurable inference** — Control temperature, max tokens, top-k, top-p, frequency penalty, random seed, and more +- **Async-first** — Every operation is `async`; designed for use with the `tokio` runtime +- **Safe FFI** — Dynamically loads the native Foundry Local Core engine with a safe Rust wrapper + +## Prerequisites + +- **Rust** 1.70+ (stable toolchain) +- An internet connection during first build (to download native libraries) + +## Installation + +```sh +cargo add foundry-local-sdk +``` + +Or add to your `Cargo.toml`: + +```toml +[dependencies] +foundry-local-sdk = "0.1" +``` + +You also need an async runtime. Most examples use [tokio](https://crates.io/crates/tokio): + +```toml +[dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } +tokio-stream = "0.1" # for StreamExt on streaming responses +``` + +### Feature Flags + +| Feature | Description | +|-----------|-------------| +| `winml` | Use the WinML backend (Windows only). Selects different ONNX Runtime and GenAI packages for NPU/GPU acceleration. | +| `nightly` | Resolve the latest nightly build of the Core package from the ORT-Nightly feed. | + +Enable features in `Cargo.toml`: + +```toml +[dependencies] +foundry-local-sdk = { version = "0.1", features = ["winml"] } +``` + +> **Note:** The `winml` feature is only relevant on Windows. On macOS and Linux, the standard build is used regardless. No code changes are needed — your application code stays the same. + +## Quick Start + +```rust +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, FoundryLocalConfig, FoundryLocalManager, +}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // 1. Initialize the manager — loads native libraries and starts the engine + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("my_app"))?; + + // 2. Get a model from the catalog and load it + let model = manager.catalog().get_model("phi-3.5-mini").await?; + model.load().await?; + + // 3. Create a chat client and run inference + let mut client = model.create_chat_client(); + client.temperature(0.7).max_tokens(256); + + let messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("What is the capital of France?").into(), + ]; + + let response = client.complete_chat(&messages, None).await?; + println!("{}", response.choices[0].message.content.as_deref().unwrap_or("")); + + // 4. Clean up + model.unload().await?; + + Ok(()) +} +``` + +## Usage + +### Browsing the Model Catalog + +The `Catalog` lets you discover what models are available, which are already cached locally, and which are currently loaded in memory. + +```rust +let catalog = manager.catalog(); + +// List all available models +let models = catalog.get_models().await?; +for model in &models { + println!("{} (id: {})", model.alias(), model.id()); +} + +// Look up a specific model by alias +let model = catalog.get_model("phi-3.5-mini").await?; + +// Look up a specific variant by its unique model ID +let variant = catalog.get_model_variant("phi-3.5-mini-generic-gpu-4").await?; + +// See what's already downloaded +let cached = catalog.get_cached_models().await?; + +// See what's currently loaded in memory +let loaded = catalog.get_loaded_models().await?; +``` + +### Model Lifecycle + +Each `Model` wraps one or more `ModelVariant` entries (different quantizations, hardware targets). The SDK auto-selects the best available variant, preferring cached versions. + +```rust +let mut model = catalog.get_model("phi-3.5-mini").await?; + +// Inspect available variants +println!("Selected: {}", model.selected_variant().id()); +for v in model.variants() { + println!(" {} (cached: {})", v.id(), v.is_cached().await?); +} + +// Switch to a specific variant +model.select_variant("phi-3.5-mini-generic-gpu-4")?; +``` + +Download, load, and unload: + +```rust +// Download with progress reporting +model.download(Some(|progress: &str| { + print!("\r{progress}"); + std::io::Write::flush(&mut std::io::stdout()).ok(); +})).await?; + +// Load into memory +model.load().await?; + +// Unload when done +model.unload().await?; + +// Remove from local cache entirely +model.remove_from_cache().await?; +``` + +### Chat Completions + +The `ChatClient` follows the OpenAI Chat Completion API structure. + +```rust +let mut client = model.create_chat_client(); + +// Configure generation settings (builder pattern) +client + .temperature(0.7) + .max_tokens(256) + .top_p(0.9) + .frequency_penalty(0.5); + +// Non-streaming completion +let response = client.complete_chat( + &[ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("Explain Rust's ownership model.").into(), + ], + None, +).await?; + +println!("{}", response.choices[0].message.content.as_deref().unwrap_or("")); +``` + +### Streaming Responses + +For real-time token-by-token output, use streaming: + +```rust +use tokio_stream::StreamExt; + +let mut stream = client.complete_streaming_chat( + &[ChatCompletionRequestUserMessage::from("Write a short poem about Rust.").into()], + None, +).await?; + +while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(content) = &chunk.choices[0].delta.content { + print!("{content}"); + } +} + +// Errors from the native core are delivered as stream items — +// no separate close() call needed. +``` + +### Tool Calling + +Define functions the model can call and handle the multi-turn conversation: + +```rust +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestToolMessage, + ChatCompletionTools, ChatToolChoice, FinishReason, +}; +use serde_json::json; + +// Define available tools +let tools: Vec = serde_json::from_value(json!([{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": { "type": "string", "description": "City name" } + }, + "required": ["location"] + } + } +}]))?; + +let mut client = model.create_chat_client(); +client.max_tokens(512).tool_choice(ChatToolChoice::Auto); + +let mut messages: Vec = vec![ + ChatCompletionRequestUserMessage::from("What's the weather in Seattle?").into(), +]; + +// First request — model may call a tool +let response = client.complete_chat(&messages, Some(&tools)).await?; +let choice = &response.choices[0]; + +if choice.finish_reason == Some(FinishReason::ToolCalls) { + if let Some(tool_calls) = &choice.message.tool_calls { + for tc in tool_calls { + // Execute the tool (your application logic) + let result = execute_tool(&tc.function.name, &tc.function.arguments); + + // Add assistant message with tool calls, then the tool result + messages.push(serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ "id": tc.id, "type": "function", + "function": { "name": tc.function.name, + "arguments": tc.function.arguments } }] + }))?); + messages.push(ChatCompletionRequestToolMessage { + content: result.into(), + tool_call_id: tc.id.clone(), + }.into()); + } + + // Continue the conversation with tool results + let final_response = client.complete_chat(&messages, Some(&tools)).await?; + println!("{}", final_response.choices[0].message.content.as_deref().unwrap_or("")); + } +} +``` + +Tool calling also works with streaming via `complete_streaming_chat` — accumulate tool call fragments during streaming and check for `FinishReason::ToolCalls`. + +### Response Format Options + +Control the output format of chat completions: + +```rust +use foundry_local_sdk::ChatResponseFormat; + +let mut client = model.create_chat_client(); + +// Plain text (default) +client.response_format(ChatResponseFormat::Text); + +// Unstructured JSON output +client.response_format(ChatResponseFormat::JsonObject); + +// JSON constrained to a schema +client.response_format(ChatResponseFormat::JsonSchema(r#"{ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "integer" } + }, + "required": ["name", "age"] +}"#.to_string())); + +// Output constrained by a Lark grammar (Foundry extension) +client.response_format(ChatResponseFormat::LarkGrammar(grammar.to_string())); +``` + +### Audio Transcription + +Transcribe audio files locally using the `AudioClient`: + +```rust +let model = manager.catalog().get_model("whisper-tiny").await?; +model.load().await?; + +let mut audio_client = model.create_audio_client(); +audio_client.language("en"); + +// Non-streaming transcription +let result = audio_client.transcribe("recording.wav").await?; +println!("{}", result.text); +``` + +#### Streaming Transcription + +```rust +use tokio_stream::StreamExt; + +let mut stream = audio_client.transcribe_streaming("recording.wav").await?; +while let Some(chunk) = stream.next().await { + print!("{}", chunk?.text); +} +``` + +### Embedded Web Service + +Start a local HTTP server that exposes an OpenAI-compatible REST API: + +```rust +let urls = manager.start_web_service().await?; +println!("Service running at: {:?}", urls); + +// Any OpenAI-compatible client or tool can now connect to the endpoint. +// ... + +manager.stop_web_service().await?; +``` + +### Chat Client Settings + +All settings are configured via chainable builder methods on `ChatClient`: + +| Method | Type | Description | +|--------|------|-------------| +| `temperature(v)` | `f64` | Sampling temperature (0.0–2.0; higher = more random) | +| `max_tokens(v)` | `u32` | Maximum number of tokens to generate | +| `top_p(v)` | `f64` | Nucleus sampling probability (0.0–1.0) | +| `top_k(v)` | `u32` | Top-k sampling parameter (Foundry extension) | +| `frequency_penalty(v)` | `f64` | Frequency penalty | +| `presence_penalty(v)` | `f64` | Presence penalty | +| `n(v)` | `u32` | Number of completions to generate | +| `random_seed(v)` | `u64` | Random seed for reproducible results (Foundry extension) | +| `response_format(v)` | `ChatResponseFormat` | Output format (Text, JsonObject, JsonSchema, LarkGrammar) | +| `tool_choice(v)` | `ChatToolChoice` | Tool selection strategy (None, Auto, Required, Function) | + +## Error Handling + +All fallible operations return `foundry_local_sdk::Result`, which is an alias for `std::result::Result`. + +```rust +use foundry_local_sdk::FoundryLocalError; + +match manager.catalog().get_model("nonexistent").await { + Ok(model) => { /* use model */ } + Err(FoundryLocalError::ModelOperation { reason }) => { + eprintln!("Model error: {reason}"); + } + Err(FoundryLocalError::CommandExecution { reason }) => { + eprintln!("Core engine error: {reason}"); + } + Err(e) => { + eprintln!("Unexpected error: {e}"); + } +} +``` + +### Error Variants + +| Variant | Description | +|---------|-------------| +| `LibraryLoad { reason }` | The native core library could not be loaded | +| `CommandExecution { reason }` | A command executed against native core returned an error | +| `InvalidConfiguration { reason }` | The provided configuration is invalid | +| `ModelOperation { reason }` | A model operation failed (load, unload, download, etc.) | +| `HttpRequest(reqwest::Error)` | An HTTP request to an external service failed | +| `Serialization(serde_json::Error)` | JSON serialization/deserialization failed | +| `Validation { reason }` | A validation check on user-supplied input failed | +| `Io(std::io::Error)` | An I/O error occurred | + +## Configuration + +The SDK is configured via `FoundryLocalConfig` when creating the manager: + +```rust +use foundry_local_sdk::{FoundryLocalConfig, LogLevel}; + +let config = FoundryLocalConfig { + log_level: Some(LogLevel::Info), + model_cache_dir: Some("/path/to/cache".into()), + web_service_urls: Some("http://127.0.0.1:5000".into()), + ..FoundryLocalConfig::new("my_app") +}; + +let manager = FoundryLocalManager::create(config)?; +``` + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `app_name` | `String` | **(required)** | Your application name | +| `app_data_dir` | `Option` | `~/.{app_name}` | Application data directory | +| `model_cache_dir` | `Option` | `{app_data_dir}/cache/models` | Where models are stored locally | +| `logs_dir` | `Option` | `{app_data_dir}/logs` | Log output directory | +| `log_level` | `Option` | `Warn` | `Trace`, `Debug`, `Info`, `Warn`, `Error`, `Fatal` | +| `web_service_urls` | `Option` | `None` | Bind address for the embedded web service | +| `service_endpoint` | `Option` | `None` | URL of an existing external service to connect to | +| `library_path` | `Option` | Auto-discovered | Path to native Foundry Local Core libraries | +| `additional_settings` | `Option>` | `None` | Extra key-value settings passed to Core | + +## How It Works + +### Native Library Download + +The `build.rs` build script automatically downloads the required native libraries at compile time: + +1. Queries NuGet/ORT-Nightly feeds for package metadata +2. Downloads `.nupkg` packages (zip archives) +3. Extracts platform-specific native libraries (`.dll`, `.so`, or `.dylib`) +4. Places them in Cargo's `OUT_DIR` for runtime discovery + +Downloaded libraries are cached — subsequent builds skip the download step. + +### Runtime Loading + +At runtime, the SDK uses `libloading` to dynamically load the Foundry Local Core library and resolve function pointers. No static linking or system-wide installation is required. + +## Platform Support + +| Platform | RID | Status | +|-----------------|------------|--------| +| Windows x64 | `win-x64` | ✅ | +| Windows ARM64 | `win-arm64`| ✅ | +| Linux x64 | `linux-x64`| ✅ | +| macOS ARM64 | `osx-arm64`| ✅ | + +## Running Examples + +Sample applications are available in [`samples/rust/`](../../samples/rust/): + +| Sample | Description | +|--------|-------------| +| `native-chat-completions` | Non-streaming and streaming chat completions | +| `tool-calling-foundry-local` | Function/tool calling with multi-turn conversations | +| `audio-transcription-example` | Audio transcription (non-streaming and streaming) | +| `foundry-local-webserver` | Embedded OpenAI-compatible REST API server | + +Run a sample with: + +```sh +cd samples/rust +cargo run -p native-chat-completions +``` + +## License + +MIT — see [LICENSE](../../LICENSE) for details. diff --git a/sdk_v2/rust/build.rs b/sdk_v2/rust/build.rs new file mode 100644 index 00000000..e2365bc3 --- /dev/null +++ b/sdk_v2/rust/build.rs @@ -0,0 +1,300 @@ +use std::env; +use std::fs; +use std::io::{self, Read}; +use std::path::{Path, PathBuf}; + +const NUGET_FEED: &str = "https://api.nuget.org/v3/index.json"; +const ORT_NIGHTLY_FEED: &str = + "https://pkgs.dev.azure.com/aiinfra/PublicPackages/_packaging/ORT-Nightly/nuget/v3/index.json"; + +const CORE_VERSION: &str = "0.9.0.8-rc3"; +const ORT_VERSION: &str = "1.24.3"; +const GENAI_VERSION: &str = "0.12.2"; + +const WINML_ORT_VERSION: &str = "1.23.2.3"; + +struct NuGetPackage { + name: &'static str, + version: String, + feed_url: &'static str, +} + +fn get_rid() -> Option<&'static str> { + let os = env::consts::OS; + let arch = env::consts::ARCH; + match (os, arch) { + ("windows", "x86_64") => Some("win-x64"), + ("windows", "aarch64") => Some("win-arm64"), + ("linux", "x86_64") => Some("linux-x64"), + ("macos", "aarch64") => Some("osx-arm64"), + _ => None, + } +} + +fn native_lib_extension() -> &'static str { + match env::consts::OS { + "windows" => "dll", + "linux" => "so", + "macos" => "dylib", + _ => "so", + } +} + +fn get_packages(rid: &str) -> Vec { + let winml = env::var("CARGO_FEATURE_WINML").is_ok(); + let nightly = env::var("CARGO_FEATURE_NIGHTLY").is_ok(); + let is_linux = rid.starts_with("linux"); + + let core_version = if nightly { + resolve_latest_version("Microsoft.AI.Foundry.Local.Core", ORT_NIGHTLY_FEED) + .unwrap_or_else(|| CORE_VERSION.to_string()) + } else { + CORE_VERSION.to_string() + }; + + let mut packages = Vec::new(); + + if winml { + let winml_core_version = if nightly { + resolve_latest_version("Microsoft.AI.Foundry.Local.Core.WinML", ORT_NIGHTLY_FEED) + .unwrap_or_else(|| CORE_VERSION.to_string()) + } else { + CORE_VERSION.to_string() + }; + + packages.push(NuGetPackage { + name: "Microsoft.AI.Foundry.Local.Core.WinML", + version: winml_core_version, + feed_url: ORT_NIGHTLY_FEED, + }); + packages.push(NuGetPackage { + name: "Microsoft.ML.OnnxRuntime.Foundry", + version: WINML_ORT_VERSION.to_string(), + feed_url: NUGET_FEED, + }); + packages.push(NuGetPackage { + name: "Microsoft.ML.OnnxRuntimeGenAI.WinML", + version: GENAI_VERSION.to_string(), + feed_url: NUGET_FEED, + }); + } else { + packages.push(NuGetPackage { + name: "Microsoft.AI.Foundry.Local.Core", + version: core_version, + feed_url: ORT_NIGHTLY_FEED, + }); + + if is_linux { + packages.push(NuGetPackage { + name: "Microsoft.ML.OnnxRuntime.Gpu.Linux", + version: ORT_VERSION.to_string(), + feed_url: NUGET_FEED, + }); + } else { + packages.push(NuGetPackage { + name: "Microsoft.ML.OnnxRuntime.Foundry", + version: ORT_VERSION.to_string(), + feed_url: NUGET_FEED, + }); + } + + packages.push(NuGetPackage { + name: "Microsoft.ML.OnnxRuntimeGenAI.Foundry", + version: GENAI_VERSION.to_string(), + feed_url: NUGET_FEED, + }); + } + + packages +} + +/// Resolve the PackageBaseAddress from a NuGet v3 service index. +fn resolve_base_address(feed_url: &str) -> Result { + let body: String = ureq::get(feed_url) + .call() + .map_err(|e| format!("Failed to fetch NuGet feed index at {feed_url}: {e}"))? + .body_mut() + .read_to_string() + .map_err(|e| format!("Failed to read feed index response: {e}"))?; + + let index: serde_json::Value = + serde_json::from_str(&body).map_err(|e| format!("Failed to parse feed index JSON: {e}"))?; + + let resources = index["resources"] + .as_array() + .ok_or("Feed index missing 'resources' array")?; + + for resource in resources { + let rtype = resource["@type"].as_str().unwrap_or(""); + if rtype == "PackageBaseAddress/3.0.0" { + if let Some(id) = resource["@id"].as_str() { + let base = if id.ends_with('/') { + id.to_string() + } else { + format!("{id}/") + }; + return Ok(base); + } + } + } + + Err(format!( + "Could not find PackageBaseAddress/3.0.0 in feed {feed_url}" + )) +} + +/// Resolve the latest version of a package from a NuGet feed. +fn resolve_latest_version(package_name: &str, feed_url: &str) -> Option { + let base_address = resolve_base_address(feed_url).ok()?; + let lower_name = package_name.to_lowercase(); + let index_url = format!("{base_address}{lower_name}/index.json"); + + let body: String = ureq::get(&index_url) + .call() + .ok()? + .body_mut() + .read_to_string() + .ok()?; + + let index: serde_json::Value = serde_json::from_str(&body).ok()?; + let versions = index["versions"].as_array()?; + versions.last()?.as_str().map(|s| s.to_string()) +} + +/// Download a .nupkg and extract native libraries for the given RID into `out_dir`. +fn download_and_extract(pkg: &NuGetPackage, rid: &str, out_dir: &Path) -> Result<(), String> { + let base_address = resolve_base_address(pkg.feed_url)?; + let lower_name = pkg.name.to_lowercase(); + let lower_version = pkg.version.to_lowercase(); + let url = + format!("{base_address}{lower_name}/{lower_version}/{lower_name}.{lower_version}.nupkg"); + + println!( + "cargo:warning=Downloading {name} {ver} from {feed}", + name = pkg.name, + ver = pkg.version, + feed = if pkg.feed_url == NUGET_FEED { + "NuGet.org" + } else { + "ORT-Nightly" + }, + ); + + let mut response = ureq::get(&url) + .call() + .map_err(|e| format!("Failed to download {}: {e}", pkg.name))?; + + let mut bytes = Vec::new(); + response + .body_mut() + .as_reader() + .read_to_end(&mut bytes) + .map_err(|e| format!("Failed to read response body for {}: {e}", pkg.name))?; + + let ext = native_lib_extension(); + let prefix = format!("runtimes/{rid}/native/"); + + let cursor = io::Cursor::new(&bytes); + let mut archive = zip::ZipArchive::new(cursor) + .map_err(|e| format!("Failed to open nupkg as zip for {}: {e}", pkg.name))?; + + let mut extracted = 0usize; + for i in 0..archive.len() { + let mut file = archive + .by_index(i) + .map_err(|e| format!("Failed to read zip entry: {e}"))?; + + let name = file.name().to_string(); + if !name.starts_with(&prefix) { + continue; + } + if !name.ends_with(&format!(".{ext}")) { + continue; + } + + let file_name = Path::new(&name) + .file_name() + .map(|n| n.to_string_lossy().to_string()) + .unwrap_or_default(); + + if file_name.is_empty() { + continue; + } + + let dest = out_dir.join(&file_name); + let mut out_file = fs::File::create(&dest) + .map_err(|e| format!("Failed to create {}: {e}", dest.display()))?; + io::copy(&mut file, &mut out_file) + .map_err(|e| format!("Failed to write {}: {e}", dest.display()))?; + + println!("cargo:warning= Extracted {file_name}"); + extracted += 1; + } + + if extracted == 0 { + println!( + "cargo:warning= No native libraries found for RID '{rid}' in {} {}", + pkg.name, pkg.version + ); + } + + Ok(()) +} + +/// Check whether the core native library is already present in `out_dir`. +fn libs_already_present(out_dir: &Path) -> bool { + let core_lib = match env::consts::OS { + "windows" => "foundry_local_core.dll", + "linux" => "libfoundry_local_core.so", + "macos" => "libfoundry_local_core.dylib", + _ => return false, + }; + out_dir.join(core_lib).exists() +} + +fn main() { + println!("cargo:rerun-if-changed=build.rs"); + + let out_dir = PathBuf::from(env::var("OUT_DIR").expect("OUT_DIR not set")); + + let rid = match get_rid() { + Some(r) => r, + None => { + println!( + "cargo:warning=Unsupported platform: {} {}. Native libraries will not be downloaded.", + env::consts::OS, + env::consts::ARCH, + ); + return; + } + }; + + // Skip download if libraries already exist + if libs_already_present(&out_dir) { + println!("cargo:warning=Native libraries already present in OUT_DIR, skipping download."); + println!("cargo:rustc-link-search=native={}", out_dir.display()); + println!("cargo:rustc-env=FOUNDRY_NATIVE_DIR={}", out_dir.display()); + #[cfg(windows)] + println!("cargo:rustc-link-lib=ole32"); + return; + } + + let packages = get_packages(rid); + + for pkg in &packages { + if let Err(e) = download_and_extract(pkg, rid, &out_dir) { + println!("cargo:warning=Error downloading {}: {e}", pkg.name); + println!("cargo:warning=Build will continue, but runtime loading may fail."); + println!( + "cargo:warning=You can manually place native libraries in the output directory." + ); + } + } + + println!("cargo:rustc-link-search=native={}", out_dir.display()); + println!("cargo:rustc-env=FOUNDRY_NATIVE_DIR={}", out_dir.display()); + + // LocalFree (used to free native-allocated buffers) lives in kernel32.lib on Windows. + #[cfg(windows)] + println!("cargo:rustc-link-lib=kernel32"); +} diff --git a/sdk_v2/rust/examples/chat_completion.rs b/sdk_v2/rust/examples/chat_completion.rs new file mode 100644 index 00000000..fd9d2b04 --- /dev/null +++ b/sdk_v2/rust/examples/chat_completion.rs @@ -0,0 +1,95 @@ +//! Basic chat completion example demonstrating synchronous and streaming +//! usage of the Foundry Local SDK. + +use std::io::{self, Write}; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, FoundryLocalConfig, FoundryLocalError, FoundryLocalManager, +}; +use tokio_stream::StreamExt; + +/// Convenience alias matching the SDK's internal Result type. +type Result = std::result::Result; + +#[tokio::main] +async fn main() -> Result<()> { + // ── 1. Initialise the manager ──────────────────────────────────────── + let config = FoundryLocalConfig::new("foundry_local_samples"); + + let manager = FoundryLocalManager::create(config)?; + + // ── 2. List available models ───────────────────────────────────────── + let models = manager.catalog().get_models().await?; + println!("Available models:"); + for model in &models { + println!(" • {} (id: {})", model.alias(), model.id()); + } + + // ── 3. Pick a model and ensure it is loaded ────────────────────────── + let model_alias = models + .first() + .map(|m| m.alias().to_string()) + .expect("No models available in the catalog"); + + let model = manager.catalog().get_model(&model_alias).await?; + + if !model.is_cached().await? { + println!("Downloading model '{}'…", model.alias()); + model + .download(Some(|progress: &str| { + println!(" {progress}"); + })) + .await?; + } + + println!("Loading model '{}'…", model.alias()); + model.load().await?; + + // ── 4. Synchronous chat completion ─────────────────────────────────── + let mut client = model.create_chat_client(); + client.temperature(0.7).max_tokens(256); + + let messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("What is Rust's ownership model?").into(), + ]; + + println!("\n--- Synchronous completion ---"); + let response = client.complete_chat(&messages, None).await?; + if let Some(choice) = response.choices.first() { + if let Some(ref content) = choice.message.content { + println!("Assistant: {content}"); + } + } + + // ── 5. Streaming chat completion ───────────────────────────────────── + println!("\n--- Streaming completion ---"); + let stream_messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful assistant.").into(), + ChatCompletionRequestUserMessage::from("Explain the borrow checker in two sentences.") + .into(), + ]; + + print!("Assistant: "); + let mut stream = client + .complete_streaming_chat(&stream_messages, None) + .await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + println!(); + + // ── 6. Unload the model────────────────────────────────────────────── + println!("\nUnloading model…"); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/sdk_v2/rust/examples/interactive_chat.rs b/sdk_v2/rust/examples/interactive_chat.rs new file mode 100644 index 00000000..bb699cd4 --- /dev/null +++ b/sdk_v2/rust/examples/interactive_chat.rs @@ -0,0 +1,113 @@ +//! Interactive chat example — a simple terminal chatbot powered by Foundry Local. +//! +//! Run with: `cargo run --example interactive_chat` + +use std::io::{self, Write}; + +use foundry_local_sdk::{ + ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage, FoundryLocalConfig, + FoundryLocalManager, +}; +use tokio_stream::StreamExt; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // ── Initialise ─────────────────────────────────────────────────────── + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; + + // Pick the first available model (or change this to a specific alias) + let catalog = manager.catalog(); + let models = catalog.get_models().await?; + + println!("Available models:"); + for (i, m) in models.iter().enumerate() { + println!(" [{i}] {}", m.alias()); + } + + print!("\nSelect a model number (default 0): "); + io::stdout().flush()?; + let mut choice = String::new(); + io::stdin().read_line(&mut choice)?; + let idx: usize = choice.trim().parse().unwrap_or(0); + + let alias = models + .get(idx) + .map(|m| m.alias().to_string()) + .unwrap_or_else(|| models[0].alias().to_string()); + + let model = catalog.get_model(&alias).await?; + + // Download if needed + if !model.is_cached().await? { + println!("Downloading '{alias}'…"); + model.download(Some(|p: &str| print!("\r {p}%"))).await?; + println!(); + } + + println!("Loading '{alias}'…"); + model.load().await?; + println!("Ready! Type your messages below. Press Ctrl-D (or type 'quit') to exit.\n"); + + // ── Chat loop ──────────────────────────────────────────────────────── + let mut client = model.create_chat_client(); + client.temperature(0.7).max_tokens(512); + + let mut messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from("You are a helpful, concise assistant.").into(), + ]; + + loop { + print!("You: "); + io::stdout().flush()?; + + let mut input = String::new(); + if io::stdin().read_line(&mut input)? == 0 { + break; // EOF (Ctrl-D) + } + + let input = input.trim(); + if input.is_empty() { + continue; + } + if input.eq_ignore_ascii_case("quit") || input.eq_ignore_ascii_case("exit") { + break; + } + + messages.push(ChatCompletionRequestUserMessage::from(input).into()); + + // Stream the response token by token + print!("Assistant: "); + io::stdout().flush()?; + + let mut full_response = String::new(); + let mut stream = client.complete_streaming_chat(&messages, None).await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + full_response.push_str(content); + } + } + } + println!("\n"); + + // Add assistant reply to history for multi-turn conversation + messages.push( + ChatCompletionRequestAssistantMessage { + content: Some(full_response.into()), + ..Default::default() + } + .into(), + ); + } + + // ── Cleanup ────────────────────────────────────────────────────────── + println!("\nUnloading model…"); + model.unload().await?; + println!("Goodbye!"); + + Ok(()) +} diff --git a/sdk_v2/rust/examples/tool_calling.rs b/sdk_v2/rust/examples/tool_calling.rs new file mode 100644 index 00000000..e807bf49 --- /dev/null +++ b/sdk_v2/rust/examples/tool_calling.rs @@ -0,0 +1,189 @@ +//! Tool-calling example demonstrating how to define tools, handle +//! `tool_calls` in streaming responses, execute the tool locally, +//! and feed results back for a multi-turn conversation. + +use std::io::{self, Write}; + +use serde_json::{json, Value}; +use tokio_stream::StreamExt; + +use foundry_local_sdk::{ + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestToolMessage, ChatCompletionRequestUserMessage, ChatCompletionTools, + ChatToolChoice, FinishReason, FoundryLocalConfig, FoundryLocalError, FoundryLocalManager, +}; + +/// Convenience alias matching the SDK's internal Result type. +type Result = std::result::Result; + +/// A trivial tool that multiplies two numbers. +fn multiply(a: f64, b: f64) -> f64 { + a * b +} + +/// Dispatch a tool call by name and arguments. +fn invoke_tool(name: &str, arguments: &Value) -> Result { + match name { + "multiply" => { + let a = arguments.get("a").and_then(|v| v.as_f64()).unwrap_or(0.0); + let b = arguments.get("b").and_then(|v| v.as_f64()).unwrap_or(0.0); + let result = multiply(a, b); + Ok(result.to_string()) + } + _ => Ok(format!("Unknown tool: {name}")), + } +} + +#[derive(Default)] +struct ToolCallState { + tool_calls: Vec, + tool_call_args: String, + current_tool_name: String, + current_tool_id: String, +} + +#[tokio::main] +async fn main() -> Result<()> { + // ── 1. Initialise ──────────────────────────────────────────────────── + let config = FoundryLocalConfig::new("foundry_local_samples"); + + let manager = FoundryLocalManager::create(config)?; + + // ── 2. Load a model ────────────────────────────────────────────────── + let models = manager.catalog().get_models().await?; + let model = models + .iter() + .find(|m| m.selected_variant().info().supports_tool_calling == Some(true)) + .or_else(|| models.first()) + .expect("No models available"); + + if !model.is_cached().await? { + println!("Downloading model '{}'…", model.alias()); + model.download(Some(|p: &str| println!(" {p}"))).await?; + } + println!("Loading model '{}'…", model.alias()); + model.load().await?; + + // ── 3. Create a chat client with tool_choice = required ────────────── + let mut client = model.create_chat_client(); + client.tool_choice(ChatToolChoice::Required).max_tokens(512); + + let tools: Vec = serde_json::from_value(json!([{ + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two numbers together.", + "parameters": { + "type": "object", + "properties": { + "a": { "type": "number", "description": "First operand" }, + "b": { "type": "number", "description": "Second operand" } + }, + "required": ["a", "b"] + } + } + }])) + .expect("Failed to parse tool definitions"); + + let mut messages: Vec = vec![ + ChatCompletionRequestSystemMessage::from( + "You are a helpful calculator assistant. Use the multiply tool when asked to multiply.", + ) + .into(), + ChatCompletionRequestUserMessage::from("What is 6 times 7?").into(), + ]; + + // ── 4. First streaming call – expect tool_calls ────────────────────── + println!("Sending initial request…"); + + let mut state = ToolCallState::default(); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await?; + + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref tool_calls) = choice.delta.tool_calls { + for tc in tool_calls { + if let Some(ref func) = tc.function { + if let Some(ref name) = func.name { + state.current_tool_name = name.clone(); + } + if let Some(ref args) = func.arguments { + state.tool_call_args.push_str(args); + } + } + if let Some(ref id) = tc.id { + state.current_tool_id = id.clone(); + } + } + } + + if choice.finish_reason == Some(FinishReason::ToolCalls) { + let tc = json!({ + "id": state.current_tool_id.clone(), + "type": "function", + "function": { + "name": state.current_tool_name.clone(), + "arguments": state.tool_call_args.clone(), + } + }); + state.tool_calls.push(tc); + } + } + } + // ── 5. Execute the tool(s)─────────────────────────────────────────── + for tc in &state.tool_calls { + let func = &tc["function"]; + let name = func["name"].as_str().unwrap_or_default(); + let args_str = func["arguments"].as_str().unwrap_or("{}"); + let args: Value = serde_json::from_str(args_str).unwrap_or(json!({})); + + println!("Tool call: {name}({args})"); + let result = invoke_tool(name, &args)?; + println!("Tool result: {result}"); + + // Append the assistant's tool_calls message and the tool result. + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [tc], + })) + .expect("Failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: result.into(), + tool_call_id: tc["id"].as_str().unwrap_or_default().to_string(), + } + .into(), + ); + } + + // ── 6. Continue the conversation with auto tool_choice ─────────────── + client.tool_choice(ChatToolChoice::Auto); + + println!("\nContinuing conversation…"); + print!("Assistant: "); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await?; + while let Some(chunk) = stream.next().await { + let chunk = chunk?; + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + print!("{content}"); + io::stdout().flush().ok(); + } + } + } + println!(); + + // ── 7. Clean up────────────────────────────────────────────────────── + println!("\nUnloading model…"); + model.unload().await?; + println!("Done."); + + Ok(()) +} diff --git a/sdk_v2/rust/src/catalog.rs b/sdk_v2/rust/src/catalog.rs new file mode 100644 index 00000000..e1338fde --- /dev/null +++ b/sdk_v2/rust/src/catalog.rs @@ -0,0 +1,207 @@ +//! Model catalog – discovers, caches, and looks up available models. + +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use crate::detail::core_interop::CoreInterop; +use crate::detail::ModelLoadManager; +use crate::error::{FoundryLocalError, Result}; +use crate::model::Model; +use crate::model_variant::ModelVariant; +use crate::types::ModelInfo; + +/// How long the catalog cache remains valid before a refresh. +const CACHE_TTL: Duration = Duration::from_secs(6 * 60 * 60); // 6 hours + +/// The model catalog provides discovery and lookup for all available models. +pub struct Catalog { + core: Arc, + model_load_manager: Arc, + name: String, + models_by_alias: Mutex>>, + variants_by_id: Mutex>>, + last_refresh: Mutex>, +} + +impl Catalog { + pub(crate) fn new( + core: Arc, + model_load_manager: Arc, + ) -> Result { + let name = core + .execute_command("get_catalog_name", None) + .unwrap_or_else(|_| "default".into()); + + let catalog = Self { + core, + model_load_manager, + name, + models_by_alias: Mutex::new(HashMap::new()), + variants_by_id: Mutex::new(HashMap::new()), + last_refresh: Mutex::new(None), + }; + + // Perform initial synchronous refresh during construction. + catalog.force_refresh_sync()?; + Ok(catalog) + } + + /// Catalog name as reported by the native core. + pub fn name(&self) -> &str { + &self.name + } + + /// Refresh the catalog from the native core if the cache has expired. + pub async fn update_models(&self) -> Result<()> { + { + let last = self.last_refresh.lock().map_err(|_| FoundryLocalError::Internal { + reason: "last_refresh mutex poisoned".into(), + })?; + if let Some(ts) = *last { + if ts.elapsed() < CACHE_TTL { + return Ok(()); + } + } + } + + self.force_refresh().await + } + + /// Return all known models keyed by alias. + pub async fn get_models(&self) -> Result>> { + self.update_models().await?; + let map = self.models_by_alias.lock().map_err(|_| FoundryLocalError::Internal { + reason: "models_by_alias mutex poisoned".into(), + })?; + Ok(map.values().cloned().collect()) + } + + /// Look up a model by its alias. + pub async fn get_model(&self, alias: &str) -> Result> { + if alias.trim().is_empty() { + return Err(FoundryLocalError::Validation { + reason: "Model alias must be a non-empty string".into(), + }); + } + self.update_models().await?; + let map = self.models_by_alias.lock().map_err(|_| FoundryLocalError::Internal { + reason: "models_by_alias mutex poisoned".into(), + })?; + map.get(alias).cloned().ok_or_else(|| { + let available: Vec<&String> = map.keys().collect(); + FoundryLocalError::ModelOperation { + reason: format!("Unknown model alias '{alias}'. Available: {available:?}"), + } + }) + } + + /// Look up a specific model variant by its unique id. + pub async fn get_model_variant(&self, id: &str) -> Result> { + if id.trim().is_empty() { + return Err(FoundryLocalError::Validation { + reason: "Variant id must be a non-empty string".into(), + }); + } + self.update_models().await?; + let map = self.variants_by_id.lock().map_err(|_| FoundryLocalError::Internal { + reason: "variants_by_id mutex poisoned".into(), + })?; + map.get(id).cloned().ok_or_else(|| { + let available: Vec<&String> = map.keys().collect(); + FoundryLocalError::ModelOperation { + reason: format!("Unknown variant id '{id}'. Available: {available:?}"), + } + }) + } + + /// Return only the model variants that are currently cached on disk. + /// + /// The native core returns a list of variant IDs. This method resolves + /// them against the internal cache, matching the JS SDK behaviour. + pub async fn get_cached_models(&self) -> Result>> { + self.update_models().await?; + let raw = self + .core + .execute_command_async("get_cached_models".into(), None) + .await?; + if raw.trim().is_empty() { + return Ok(Vec::new()); + } + let cached_ids: Vec = serde_json::from_str(&raw)?; + let id_map = self.variants_by_id.lock().map_err(|_| FoundryLocalError::Internal { + reason: "variants_by_id mutex poisoned".into(), + })?; + Ok(cached_ids + .iter() + .filter_map(|id| id_map.get(id).cloned()) + .collect()) + } + + /// Return identifiers of models that are currently loaded into memory. + pub async fn get_loaded_models(&self) -> Result> { + self.model_load_manager.list_loaded().await + } + + async fn force_refresh(&self) -> Result<()> { + let raw = self + .core + .execute_command_async("get_model_list".into(), None) + .await?; + self.apply_model_list(&raw) + } + + /// Synchronous refresh used only during construction (before a tokio + /// runtime may be available). + fn force_refresh_sync(&self) -> Result<()> { + let raw = self.core.execute_command("get_model_list", None)?; + self.apply_model_list(&raw) + } + + fn apply_model_list(&self, raw: &str) -> Result<()> { + let infos: Vec = if raw.trim().is_empty() { + Vec::new() + } else { + serde_json::from_str(raw)? + }; + + let mut alias_map_build: HashMap = HashMap::new(); + let mut id_map: HashMap> = HashMap::new(); + + for info in infos { + let variant = ModelVariant::new( + info.clone(), + Arc::clone(&self.core), + Arc::clone(&self.model_load_manager), + ); + id_map.insert(info.id.clone(), Arc::new(variant.clone())); + + alias_map_build + .entry(info.alias.clone()) + .or_insert_with(|| { + Model::new( + info.alias.clone(), + Arc::clone(&self.core), + ) + }) + .add_variant(variant); + } + + let alias_map: HashMap> = alias_map_build + .into_iter() + .map(|(k, v)| (k, Arc::new(v))) + .collect(); + + *self.models_by_alias.lock().map_err(|_| FoundryLocalError::Internal { + reason: "models_by_alias mutex poisoned".into(), + })? = alias_map; + *self.variants_by_id.lock().map_err(|_| FoundryLocalError::Internal { + reason: "variants_by_id mutex poisoned".into(), + })? = id_map; + *self.last_refresh.lock().map_err(|_| FoundryLocalError::Internal { + reason: "last_refresh mutex poisoned".into(), + })? = Some(Instant::now()); + + Ok(()) + } +} diff --git a/sdk_v2/rust/src/configuration.rs b/sdk_v2/rust/src/configuration.rs new file mode 100644 index 00000000..8645b3e7 --- /dev/null +++ b/sdk_v2/rust/src/configuration.rs @@ -0,0 +1,144 @@ +use std::collections::HashMap; + +use crate::error::{FoundryLocalError, Result}; + +/// Log level for the Foundry Local service. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum LogLevel { + Trace, + Debug, + Info, + Warn, + Error, + Fatal, +} + +impl LogLevel { + /// Returns the string value expected by the native core library. + fn as_core_str(&self) -> &'static str { + match self { + Self::Trace => "Verbose", + Self::Debug => "Debug", + Self::Info => "Information", + Self::Warn => "Warning", + Self::Error => "Error", + Self::Fatal => "Fatal", + } + } +} + +/// User-facing configuration for initializing the Foundry Local SDK. +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct FoundryLocalConfig { + pub app_name: String, + pub app_data_dir: Option, + pub model_cache_dir: Option, + pub logs_dir: Option, + pub log_level: Option, + pub web_service_urls: Option, + pub service_endpoint: Option, + pub library_path: Option, + pub additional_settings: Option>, +} + +impl FoundryLocalConfig { + /// Create a new configuration with the given application name. + /// + /// All other fields default to `None`. Use the struct update syntax to + /// override specific options: + /// + /// ```ignore + /// let config = FoundryLocalConfig { + /// log_level: Some(LogLevel::Debug), + /// ..FoundryLocalConfig::new("my_app") + /// }; + /// ``` + pub fn new(app_name: impl Into) -> Self { + Self { + app_name: app_name.into(), + ..Self::default() + } + } +} + +/// Internal configuration object that converts [`FoundryLocalConfig`] into the +/// parameter map expected by the native core library. +#[derive(Debug, Clone)] +pub(crate) struct Configuration { + pub params: HashMap, +} + +impl Configuration { + /// Build a [`Configuration`] from the user-facing [`FoundryLocalConfig`]. + /// + /// # Errors + /// + /// Returns [`FoundryLocalError::InvalidConfiguration`] when `app_name` is + /// empty or blank. + pub fn new(config: FoundryLocalConfig) -> Result { + let app_name = config.app_name.trim().to_string(); + if app_name.is_empty() { + return Err(FoundryLocalError::InvalidConfiguration { + reason: "app_name must be set and non-empty".into(), + }); + } + + let mut params = HashMap::new(); + params.insert("AppName".into(), app_name); + + if let Some(v) = config.app_data_dir { + params.insert("AppDataDir".into(), v); + } + if let Some(v) = config.model_cache_dir { + params.insert("ModelCacheDir".into(), v); + } + if let Some(v) = config.logs_dir { + params.insert("LogsDir".into(), v); + } + if let Some(level) = config.log_level { + params.insert("LogLevel".into(), level.as_core_str().into()); + } + if let Some(v) = config.web_service_urls { + params.insert("WebServiceUrls".into(), v); + } + if let Some(v) = config.service_endpoint { + params.insert("WebServiceExternalUrl".into(), v); + } + if let Some(v) = config.library_path { + params.insert("FoundryLocalCorePath".into(), v); + } + if let Some(extra) = config.additional_settings { + for (k, v) in extra { + params.insert(k, v); + } + } + + Ok(Self { params }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn valid_config() { + let cfg = FoundryLocalConfig { + log_level: Some(LogLevel::Debug), + ..FoundryLocalConfig::new("TestApp") + }; + let c = Configuration::new(cfg).unwrap(); + assert_eq!(c.params["AppName"], "TestApp"); + assert_eq!(c.params["LogLevel"], "Debug"); + } + + #[test] + fn empty_app_name_fails() { + let cfg = FoundryLocalConfig { + app_name: " ".into(), + ..FoundryLocalConfig::default() + }; + assert!(Configuration::new(cfg).is_err()); + } +} diff --git a/sdk_v2/rust/src/detail/core_interop.rs b/sdk_v2/rust/src/detail/core_interop.rs new file mode 100644 index 00000000..1d53bab4 --- /dev/null +++ b/sdk_v2/rust/src/detail/core_interop.rs @@ -0,0 +1,536 @@ +//! FFI bridge to the `Microsoft.AI.Foundry.Local.Core` native library. +//! +//! Dynamically loads the shared library at runtime via [`libloading`] and +//! exposes two operations: +//! +//! * [`CoreInterop::execute_command`] – synchronous request/response. +//! * [`CoreInterop::execute_command_streaming`] – request with a streaming +//! callback that receives incremental chunks. + +use std::ffi::CString; +use std::path::{Path, PathBuf}; +use std::sync::Arc; + +use libloading::{Library, Symbol}; +use serde_json::Value; + +use crate::configuration::Configuration; +use crate::error::{FoundryLocalError, Result}; + +// ── FFI types ──────────────────────────────────────────────────────────────── + +/// Request buffer passed to the native library. +#[repr(C)] +struct RequestBuffer { + command: *const i8, + command_length: i32, + data: *const i8, + data_length: i32, +} + +/// Response buffer filled by the native library. +#[repr(C)] +struct ResponseBuffer { + data: *mut u8, + data_length: u32, + error: *mut u8, + error_length: u32, +} + +impl ResponseBuffer { + fn new() -> Self { + Self { + data: std::ptr::null_mut(), + data_length: 0, + error: std::ptr::null_mut(), + error_length: 0, + } + } +} + +/// Signature for `execute_command`. +type ExecuteCommandFn = unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer); + +/// Signature for the streaming callback invoked by the native library. +type CallbackFn = unsafe extern "C" fn(*const u8, i32, *mut std::ffi::c_void); + +/// Signature for `execute_command_with_callback`. +type ExecuteCommandWithCallbackFn = unsafe extern "C" fn( + *const RequestBuffer, + *mut ResponseBuffer, + CallbackFn, + *mut std::ffi::c_void, +); + +// ── Library name helpers ───────────────────────────────────────────────────── + +#[cfg(target_os = "windows")] +const LIB_EXTENSION: &str = "dll"; +#[cfg(target_os = "macos")] +const LIB_EXTENSION: &str = "dylib"; +#[cfg(target_os = "linux")] +const LIB_EXTENSION: &str = "so"; + +// ── Native buffer deallocation ──────────────────────────────────────────────── + +/// Free a buffer allocated by the native core library. +/// +/// The .NET native core allocates response buffers with +/// `Marshal.AllocHGlobal` which maps to `malloc` on Unix and +/// `LocalAlloc` (process heap) on Windows. The corresponding +/// free functions are `free` and `LocalFree` respectively. +/// +/// # Safety +/// +/// `ptr` must be null or a valid pointer previously allocated by the native +/// core library via the corresponding platform allocator. Calling this with +/// any other pointer is undefined behaviour. +unsafe fn free_native_buffer(ptr: *mut u8) { + if ptr.is_null() { + return; + } + #[cfg(unix)] + { + extern "C" { + fn free(ptr: *mut std::ffi::c_void); + } + // SAFETY: `ptr` was allocated by the native core via `malloc` on Unix. + free(ptr as *mut std::ffi::c_void); + } + #[cfg(windows)] + { + extern "system" { + fn LocalFree(hMem: *mut std::ffi::c_void) -> *mut std::ffi::c_void; + } + // SAFETY: `ptr` was allocated by the native core via `LocalAlloc` + // (Marshal.AllocHGlobal) on Windows. + LocalFree(ptr as *mut std::ffi::c_void); + } +} + +// ── Trampoline for streaming callback ──────────────────────────────────────── + +/// C-ABI trampoline that forwards chunks from the native library into a Rust +/// closure stored behind `user_data`. +/// +/// Wrapped in [`std::panic::catch_unwind`] so that a panic inside the Rust +/// closure cannot unwind across the FFI boundary (which is undefined behaviour). +/// +/// # Safety +/// +/// * `data` must be a valid pointer to `length` bytes of UTF-8 (or at least +/// valid memory) allocated by the native core, valid for the duration of +/// this call. +/// * `user_data` must point to a live `Box>` that was +/// created by [`CoreInterop::execute_command_streaming`] and has not been +/// dropped. +unsafe extern "C" fn streaming_trampoline( + data: *const u8, + length: i32, + user_data: *mut std::ffi::c_void, +) { + if data.is_null() || length <= 0 { + return; + } + // catch_unwind prevents UB if the closure panics across the FFI boundary. + let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + // SAFETY: `user_data` is a pointer to `Box` kept alive + // by the caller of `execute_command_with_callback` for the duration of + // the native call. + let closure = &mut *(user_data as *mut Box); + // SAFETY: `data` is valid for `length` bytes as guaranteed by the native + // core's callback contract. + let slice = std::slice::from_raw_parts(data, length as usize); + if let Ok(chunk) = std::str::from_utf8(slice) { + closure(chunk); + } + })); +} + +// ── CoreInterop ────────────────────────────────────────────────────────────── + +/// Handle to the loaded native core library. +/// +/// This type is `Send + Sync` because the underlying native library is +/// expected to be thread-safe for distinct request/response pairs. +pub(crate) struct CoreInterop { + _library: Library, + #[cfg(target_os = "windows")] + _dependency_libs: Vec, + execute_command: unsafe extern "C" fn(*const RequestBuffer, *mut ResponseBuffer), + execute_command_with_callback: unsafe extern "C" fn( + *const RequestBuffer, + *mut ResponseBuffer, + CallbackFn, + *mut std::ffi::c_void, + ), +} + +impl std::fmt::Debug for CoreInterop { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CoreInterop").finish_non_exhaustive() + } +} + +impl CoreInterop { + /// Load the native core library using the provided configuration to locate + /// it on disk. + /// + /// Discovery order: + /// 1. `FoundryLocalCorePath` key in `config.params`. + /// 2. Sibling directory of the current executable. + pub fn new(config: &mut Configuration) -> Result { + let lib_path = Self::resolve_library_path(config)?; + + // Auto-detect WinAppSDK Bootstrap DLL next to the core library. + // If present, tell the native core to run the bootstrapper during + // initialisation — this is required for WinML execution providers. + #[cfg(target_os = "windows")] + if !config.params.contains_key("Bootstrap") { + if let Some(dir) = lib_path.parent() { + if dir.join("Microsoft.WindowsAppRuntime.Bootstrap.dll").exists() { + config.params.insert("Bootstrap".into(), "true".into()); + } + } + } + + #[cfg(target_os = "windows")] + let _dependency_libs = Self::load_windows_dependencies(&lib_path)?; + + // SAFETY: `lib_path` has been verified to exist on disk. Loading a + // shared library is inherently unsafe (it executes foreign code), but + // the path is resolved from trusted configuration sources. + let library = unsafe { + Library::new(&lib_path).map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!( + "Failed to load native library at {}: {e}", + lib_path.display() + ), + })? + }; + + // SAFETY: We trust the loaded library to export these symbols with the + // correct C-ABI signatures as defined by the Foundry Local native core. + let execute_command: ExecuteCommandFn = unsafe { + let sym: Symbol = + library + .get(b"execute_command\0") + .map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!("Symbol 'execute_command' not found: {e}"), + })?; + *sym + }; + + // SAFETY: Same as above — symbol must match `ExecuteCommandWithCallbackFn`. + let execute_command_with_callback: ExecuteCommandWithCallbackFn = unsafe { + let sym: Symbol = library + .get(b"execute_command_with_callback\0") + .map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!("Symbol 'execute_command_with_callback' not found: {e}"), + })?; + *sym + }; + + Ok(Self { + _library: library, + #[cfg(target_os = "windows")] + _dependency_libs, + execute_command, + execute_command_with_callback, + }) + } + + /// Execute a synchronous command against the native core. + /// + /// `command` is the operation name (e.g. `"initialize"`, `"load_model"`). + /// `params` is an optional JSON value that will be serialised and sent as + /// the data payload. + pub fn execute_command(&self, command: &str, params: Option<&Value>) -> Result { + let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid command string: {e}"), + })?; + + let data_json = match params { + Some(v) => serde_json::to_string(v)?, + None => String::new(), + }; + let data_cstr = + CString::new(data_json.as_str()).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid data string: {e}"), + })?; + + let request = RequestBuffer { + command: cmd.as_ptr(), + command_length: cmd.as_bytes().len() as i32, + data: data_cstr.as_ptr(), + data_length: data_cstr.as_bytes().len() as i32, + }; + + let mut response = ResponseBuffer::new(); + + // SAFETY: `request` fields point into `cmd` and `data_cstr` which are + // alive for the duration of this call. The native function writes into + // `response` using its documented C ABI. + unsafe { + (self.execute_command)(&request, &mut response); + } + + Self::process_response(response) + } + + /// Execute a command that streams results back via `callback`. + /// + /// Each chunk delivered by the native library is decoded as UTF-8 and + /// forwarded to `callback`. After the native call returns, any error in + /// the response buffer is raised. + pub fn execute_command_streaming( + &self, + command: &str, + params: Option<&Value>, + mut callback: F, + ) -> Result + where + F: FnMut(&str), + { + let cmd = CString::new(command).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid command string: {e}"), + })?; + + let data_json = match params { + Some(v) => serde_json::to_string(v)?, + None => String::new(), + }; + let data_cstr = + CString::new(data_json.as_str()).map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("Invalid data string: {e}"), + })?; + + let request = RequestBuffer { + command: cmd.as_ptr(), + command_length: cmd.as_bytes().len() as i32, + data: data_cstr.as_ptr(), + data_length: data_cstr.as_bytes().len() as i32, + }; + + let mut response = ResponseBuffer::new(); + + // Box the closure so we can pass a stable pointer through FFI. + let mut boxed: Box = Box::new(|chunk: &str| callback(chunk)); + let user_data = &mut boxed as *mut Box as *mut std::ffi::c_void; + + // SAFETY: `request` fields point into `cmd` and `data_cstr` which are + // alive for the duration of this call. `user_data` points to `boxed` + // which lives on this stack frame and outlives the native call. + // `streaming_trampoline` will only cast `user_data` back to the same + // `Box` type. + unsafe { + (self.execute_command_with_callback)( + &request, + &mut response, + streaming_trampoline, + user_data, + ); + } + + Self::process_response(response) + } + + /// Async version of [`Self::execute_command`]. + /// + /// Runs the blocking FFI call on a dedicated thread via + /// [`tokio::task::spawn_blocking`] so the async runtime is never blocked. + pub async fn execute_command_async( + self: &Arc, + command: String, + params: Option, + ) -> Result { + let this = Arc::clone(self); + tokio::task::spawn_blocking(move || this.execute_command(&command, params.as_ref())) + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? + } + + /// Async version of [`Self::execute_command_streaming`]. + /// + /// The `callback` is invoked on the blocking thread – it must be + /// [`Send`] + `'static`. + pub async fn execute_command_streaming_async( + self: &Arc, + command: String, + params: Option, + callback: F, + ) -> Result + where + F: FnMut(&str) + Send + 'static, + { + let this = Arc::clone(self); + tokio::task::spawn_blocking(move || { + this.execute_command_streaming(&command, params.as_ref(), callback) + }) + .await + .map_err(|e| FoundryLocalError::CommandExecution { + reason: format!("task join error: {e}"), + })? + } + + /// Async streaming variant that bridges the FFI callback into a + /// [`tokio::sync::mpsc`] channel. + /// + /// Returns a `Receiver>` that yields each chunk as it + /// arrives. After all chunks have been delivered the final result from + /// the native core response buffer is sent through the same channel — + /// if the native core reported an error it will appear as an `Err` item. + /// The receiver can be wrapped with + /// [`tokio_stream::wrappers::ReceiverStream`] to get a `Stream`. + pub async fn execute_command_streaming_channel( + self: &Arc, + command: String, + params: Option, + ) -> Result>> { + let (tx, rx) = tokio::sync::mpsc::unbounded_channel::>(); + let this = Arc::clone(self); + + tokio::task::spawn_blocking(move || { + let tx_chunk = tx.clone(); + let result = this.execute_command_streaming( + &command, + params.as_ref(), + move |chunk: &str| { + let _ = tx_chunk.send(Ok(chunk.to_owned())); + }, + ); + + // Surface any error from the native core response buffer through + // the channel so it cannot be silently swallowed. + if let Err(e) = result { + let _ = tx.send(Err(e)); + } + }); + + Ok(rx) + } + + /// Read a native response buffer field as a Rust `String`. + /// + /// # Safety + /// + /// `ptr` must be null **or** a valid pointer to at least `len` bytes of + /// memory allocated by the native core. The memory must remain valid for + /// the duration of this call. + unsafe fn read_native_buffer(ptr: *mut u8, len: u32) -> Option { + if ptr.is_null() || len == 0 { + return None; + } + // SAFETY: caller guarantees `ptr` is valid for `len` bytes. + let slice = std::slice::from_raw_parts(ptr, len as usize); + Some(String::from_utf8_lossy(slice).into_owned()) + } + + /// Read the response buffer, free the native memory, and return the data + /// string or raise an error. + /// + /// Takes the buffer by value so it can only be consumed once. + fn process_response(response: ResponseBuffer) -> Result { + // SAFETY: response fields are either null or valid native-allocated + // pointers filled by the preceding FFI call. + let error_str = unsafe { Self::read_native_buffer(response.error, response.error_length) }; + let data_str = unsafe { Self::read_native_buffer(response.data, response.data_length) }; + + // SAFETY: Free the heap-allocated response buffers (matches JS + // koffi.free() and C# Marshal.FreeHGlobal() behaviour). Each pointer + // is either null (handled inside free_native_buffer) or was allocated + // by the native core's platform allocator. + unsafe { + free_native_buffer(response.data); + free_native_buffer(response.error); + } + + // Return error or data. + if let Some(err) = error_str { + Err(FoundryLocalError::CommandExecution { reason: err }) + } else { + Ok(data_str.unwrap_or_default()) + } + } + + /// Resolve the path to the native core shared library. + fn resolve_library_path(config: &Configuration) -> Result { + let lib_name = format!("Microsoft.AI.Foundry.Local.Core.{LIB_EXTENSION}"); + + // 1. Explicit path from configuration. + if let Some(dir) = config.params.get("FoundryLocalCorePath") { + let p = Path::new(dir).join(&lib_name); + if p.exists() { + return Ok(p); + } + // If the config value is the full path to the library itself. + let p = Path::new(dir); + if p.exists() && p.is_file() { + return Ok(p.to_path_buf()); + } + } + + // 2. Compile-time path set by build.rs (points at the OUT_DIR where + // native NuGet packages are extracted during `cargo build`). + if let Some(dir) = option_env!("FOUNDRY_NATIVE_DIR") { + let p = Path::new(dir).join(&lib_name); + if p.exists() { + return Ok(p); + } + } + + // 3. Next to the running executable (default search path). + if let Ok(exe) = std::env::current_exe() { + if let Some(dir) = exe.parent() { + let p = dir.join(&lib_name); + if p.exists() { + return Ok(p); + } + } + } + + Err(FoundryLocalError::LibraryLoad { + reason: format!( + "Could not locate native library '{lib_name}'. \ + Set the FoundryLocalCorePath config option." + ), + }) + } + + /// On Windows, pre-load runtime dependencies so the core library can + /// resolve them. + #[cfg(target_os = "windows")] + fn load_windows_dependencies(core_lib_path: &Path) -> Result> { + let dir = core_lib_path.parent().unwrap_or_else(|| Path::new(".")); + + let mut libs = Vec::new(); + + // Load WinML bootstrap if present. + let bootstrap = dir.join("Microsoft.WindowsAppRuntime.Bootstrap.dll"); + if bootstrap.exists() { + // SAFETY: Pre-loading a known dependency DLL from the same trusted + // directory as the core library. + if let Ok(lib) = unsafe { Library::new(&bootstrap) } { + libs.push(lib); + } + } + + for dep in &["onnxruntime.dll", "onnxruntime-genai.dll"] { + let dep_path = dir.join(dep); + if dep_path.exists() { + // SAFETY: Pre-loading a known dependency DLL from the same + // trusted directory as the core library. + let lib = unsafe { + Library::new(&dep_path).map_err(|e| FoundryLocalError::LibraryLoad { + reason: format!("Failed to load dependency {dep}: {e}"), + })? + }; + libs.push(lib); + } + } + + Ok(libs) + } +} diff --git a/sdk_v2/rust/src/detail/mod.rs b/sdk_v2/rust/src/detail/mod.rs new file mode 100644 index 00000000..c7f2fd32 --- /dev/null +++ b/sdk_v2/rust/src/detail/mod.rs @@ -0,0 +1,4 @@ +pub(crate) mod core_interop; +mod model_load_manager; + +pub use self::model_load_manager::ModelLoadManager; diff --git a/sdk_v2/rust/src/detail/model_load_manager.rs b/sdk_v2/rust/src/detail/model_load_manager.rs new file mode 100644 index 00000000..f6f05fd1 --- /dev/null +++ b/sdk_v2/rust/src/detail/model_load_manager.rs @@ -0,0 +1,77 @@ +//! Manages model loading and unloading. +//! +//! When an external service URL is configured the manager delegates to HTTP +//! endpoints (`models/load/{id}`, `models/unload/{id}`, `models/loaded`). +//! Otherwise it falls through to the native core library via [`CoreInterop`]. + +use std::sync::Arc; + +use serde_json::json; + +use crate::detail::core_interop::CoreInterop; +use crate::error::Result; + +/// Manages the lifecycle of loaded models. +#[derive(Debug)] +pub struct ModelLoadManager { + core: Arc, + external_service_url: Option, + client: reqwest::Client, +} + +impl ModelLoadManager { + pub(crate) fn new(core: Arc, external_service_url: Option) -> Self { + Self { + core, + external_service_url, + client: reqwest::Client::new(), + } + } + + /// Load a model by its identifier. + pub async fn load(&self, model_id: &str) -> Result<()> { + if let Some(base_url) = &self.external_service_url { + self.http_get(&format!("{base_url}/models/load/{model_id}")).await?; + return Ok(()); + } + let params = json!({ "Params": { "Model": model_id } }); + self.core + .execute_command_async("load_model".into(), Some(params)) + .await?; + Ok(()) + } + + /// Unload a previously loaded model. + pub async fn unload(&self, model_id: &str) -> Result { + if let Some(base_url) = &self.external_service_url { + return self.http_get(&format!("{base_url}/models/unload/{model_id}")).await; + } + let params = json!({ "Params": { "Model": model_id } }); + self.core + .execute_command_async("unload_model".into(), Some(params)) + .await + } + + /// Return the list of currently loaded model identifiers. + pub async fn list_loaded(&self) -> Result> { + let raw = if let Some(base_url) = &self.external_service_url { + self.http_get(&format!("{base_url}/models/loaded")).await? + } else { + self.core + .execute_command_async("list_loaded_models".into(), None) + .await? + }; + + if raw.trim().is_empty() { + return Ok(Vec::new()); + } + + let ids: Vec = serde_json::from_str(&raw)?; + Ok(ids) + } + + async fn http_get(&self, url: &str) -> Result { + let body = self.client.get(url).send().await?.text().await?; + Ok(body) + } +} diff --git a/sdk_v2/rust/src/error.rs b/sdk_v2/rust/src/error.rs new file mode 100644 index 00000000..c99dbfbf --- /dev/null +++ b/sdk_v2/rust/src/error.rs @@ -0,0 +1,36 @@ +use thiserror::Error; + +/// Errors that can occur when using the Foundry Local SDK. +#[derive(Debug, Error)] +pub enum FoundryLocalError { + /// The native core library could not be loaded. + #[error("library load error: {reason}")] + LibraryLoad { reason: String }, + /// A command executed against the native core returned an error. + #[error("command execution error: {reason}")] + CommandExecution { reason: String }, + /// The provided configuration is invalid. + #[error("invalid configuration: {reason}")] + InvalidConfiguration { reason: String }, + /// A model operation failed (load, unload, download, etc.). + #[error("model operation error: {reason}")] + ModelOperation { reason: String }, + /// An HTTP request to the external service failed. + #[error("HTTP request error: {0}")] + HttpRequest(#[from] reqwest::Error), + /// Serialization or deserialization of JSON data failed. + #[error("serialization error: {0}")] + Serialization(#[from] serde_json::Error), + /// A validation check on user-supplied input failed. + #[error("validation error: {reason}")] + Validation { reason: String }, + /// An I/O error occurred. + #[error("I/O error: {0}")] + Io(#[from] std::io::Error), + /// An internal SDK error (e.g. poisoned lock). + #[error("internal error: {reason}")] + Internal { reason: String }, +} + +/// Convenience alias used throughout the SDK. +pub type Result = std::result::Result; diff --git a/sdk_v2/rust/src/foundry_local_manager.rs b/sdk_v2/rust/src/foundry_local_manager.rs new file mode 100644 index 00000000..a87733b5 --- /dev/null +++ b/sdk_v2/rust/src/foundry_local_manager.rs @@ -0,0 +1,120 @@ +//! Top-level entry point for the Foundry Local SDK. +//! +//! [`FoundryLocalManager`] is a singleton that initialises the native core +//! library, provides access to the model [`Catalog`], and can start / stop +//! the local web service. + +use std::sync::{Arc, Mutex, OnceLock, Once}; + +use serde_json::json; + +use crate::catalog::Catalog; +use crate::configuration::{Configuration, FoundryLocalConfig}; +use crate::detail::core_interop::CoreInterop; +use crate::detail::ModelLoadManager; +use crate::error::{FoundryLocalError, Result}; + +/// Global singleton holder. +static INSTANCE: OnceLock> = OnceLock::new(); +static INIT_ONCE: Once = Once::new(); + +/// Primary entry point for interacting with Foundry Local. +/// +/// Created once via [`FoundryLocalManager::create`]; subsequent calls return +/// the existing instance. +pub struct FoundryLocalManager { + core: Arc, + catalog: Catalog, + urls: Mutex>, +} + +impl FoundryLocalManager { + /// Initialise the SDK. + /// + /// The first call creates the singleton, loads the native library, runs + /// the `initialize` command, and builds the model catalog. Subsequent + /// calls return a reference to the same instance (the provided config is + /// ignored after the first call). + pub fn create(config: FoundryLocalConfig) -> Result<&'static Self> { + // Use `Once` + `OnceLock` to ensure initialisation runs at most once, + // eliminating the TOCTOU race between `get()` and `set()`. + INIT_ONCE.call_once(|| { + let result = (|| -> Result { + let mut internal_config = Configuration::new(config)?; + let core = Arc::new(CoreInterop::new(&mut internal_config)?); + + // Send the configuration map to the native core. + let init_params = json!({ "Params": internal_config.params }); + core.execute_command("initialize", Some(&init_params))?; + + let service_endpoint = internal_config.params.get("WebServiceExternalUrl").cloned(); + + let model_load_manager = + Arc::new(ModelLoadManager::new(Arc::clone(&core), service_endpoint)); + + let catalog = Catalog::new(Arc::clone(&core), Arc::clone(&model_load_manager))?; + + Ok(FoundryLocalManager { + core, + catalog, + urls: Mutex::new(Vec::new()), + }) + })(); + + let _ = INSTANCE.set(result.map_err(|e| e.to_string())); + }); + + match INSTANCE.get() { + Some(Ok(manager)) => Ok(manager), + Some(Err(msg)) => Err(FoundryLocalError::CommandExecution { + reason: format!("SDK initialization failed: {msg}"), + }), + None => Err(FoundryLocalError::CommandExecution { + reason: "SDK initialization not completed".into(), + }), + } + } + + /// Access the model catalog. + pub fn catalog(&self) -> &Catalog { + &self.catalog + } + + /// URLs that the local web service is listening on. + /// + /// Empty until [`Self::start_web_service`] has been called. + pub fn urls(&self) -> Result> { + let lock = self.urls.lock().map_err(|_| FoundryLocalError::Internal { + reason: "Failed to acquire urls lock".into(), + })?; + Ok(lock.clone()) + } + + /// Start the local web service and return the listening URLs. + pub async fn start_web_service(&self) -> Result> { + let raw = self + .core + .execute_command_async("start_service".into(), None) + .await?; + let parsed: Vec = if raw.trim().is_empty() { + Vec::new() + } else { + serde_json::from_str(&raw)? + }; + *self.urls.lock().map_err(|_| FoundryLocalError::Internal { + reason: "Failed to acquire urls lock".into(), + })? = parsed.clone(); + Ok(parsed) + } + + /// Stop the local web service. + pub async fn stop_web_service(&self) -> Result<()> { + self.core + .execute_command_async("stop_service".into(), None) + .await?; + self.urls.lock().map_err(|_| FoundryLocalError::Internal { + reason: "Failed to acquire urls lock".into(), + })?.clear(); + Ok(()) + } +} diff --git a/sdk_v2/rust/src/lib.rs b/sdk_v2/rust/src/lib.rs new file mode 100644 index 00000000..f3564145 --- /dev/null +++ b/sdk_v2/rust/src/lib.rs @@ -0,0 +1,46 @@ +//! Foundry Local Rust SDK +//! +//! Local AI model inference powered by the Foundry Local Core engine. + +mod catalog; +mod configuration; +mod error; +mod foundry_local_manager; +mod model; +mod model_variant; +mod types; + +pub(crate) mod detail; +pub mod openai; + +pub use self::catalog::Catalog; +pub use self::configuration::{FoundryLocalConfig, LogLevel}; +pub use self::detail::ModelLoadManager; +pub use self::error::FoundryLocalError; +pub use self::foundry_local_manager::FoundryLocalManager; +pub use self::model::Model; +pub use self::model_variant::ModelVariant; +pub use self::types::{ + ChatResponseFormat, ChatToolChoice, DeviceType, ModelInfo, ModelSettings, Parameter, + PromptTemplate, Runtime, +}; + +// Re-export OpenAI request types so callers can construct typed messages. +pub use async_openai::types::chat::{ + ChatCompletionNamedToolChoice, ChatCompletionRequestAssistantMessage, + ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, + ChatCompletionRequestToolMessage, ChatCompletionRequestUserMessage, + ChatCompletionToolChoiceOption, ChatCompletionTools, FunctionObject, +}; + +// Re-export OpenAI response types for convenience. +pub use crate::openai::{ + AudioTranscriptionResponse, AudioTranscriptionStream, ChatCompletionStream, +}; +pub use async_openai::types::chat::{ + ChatChoice, ChatChoiceStream, ChatCompletionMessageToolCall, + ChatCompletionMessageToolCallChunk, ChatCompletionMessageToolCalls, + ChatCompletionResponseMessage, ChatCompletionStreamResponseDelta, CompletionUsage, + CreateChatCompletionResponse, CreateChatCompletionStreamResponse, FinishReason, FunctionCall, + FunctionCallStream, +}; diff --git a/sdk_v2/rust/src/model.rs b/sdk_v2/rust/src/model.rs new file mode 100644 index 00000000..4313c5c9 --- /dev/null +++ b/sdk_v2/rust/src/model.rs @@ -0,0 +1,134 @@ +//! High-level model abstraction that wraps one or more [`ModelVariant`]s +//! sharing the same alias. + +use std::path::PathBuf; +use std::sync::Arc; + +use crate::detail::core_interop::CoreInterop; +use crate::error::{FoundryLocalError, Result}; +use crate::model_variant::ModelVariant; +use crate::openai::AudioClient; +use crate::openai::ChatClient; + +/// A model groups one or more [`ModelVariant`]s that share the same alias. +/// +/// By default the variant that is already cached locally is selected. You +/// can override the selection with [`Model::select_variant`]. +#[derive(Debug, Clone)] +pub struct Model { + alias: String, + core: Arc, + variants: Vec, + selected_index: usize, +} + +impl Model { + pub(crate) fn new( + alias: String, + core: Arc, + ) -> Self { + Self { + alias, + core, + variants: Vec::new(), + selected_index: 0, + } + } + + /// Add a variant. If the new variant is cached and the current selection + /// is not, the new variant becomes the selected one. + pub(crate) fn add_variant(&mut self, variant: ModelVariant) { + self.variants.push(variant); + let new_idx = self.variants.len() - 1; + + // Prefer a cached variant over a non-cached one. + if self.variants[new_idx].info().cached && !self.variants[self.selected_index].info().cached + { + self.selected_index = new_idx; + } + } + + /// Select a variant by its unique id. + pub fn select_variant(&mut self, id: &str) -> Result<()> { + if let Some(pos) = self.variants.iter().position(|v| v.id() == id) { + self.selected_index = pos; + return Ok(()); + } + let available: Vec = self.variants.iter().map(|v| v.id().to_string()).collect(); + Err(FoundryLocalError::ModelOperation { + reason: format!( + "Variant '{id}' not found for model '{}'. Available: {available:?}", + self.alias + ), + }) + } + + /// Returns a reference to the currently selected variant. + pub fn selected_variant(&self) -> &ModelVariant { + &self.variants[self.selected_index] + } + + /// Returns all variants that belong to this model. + pub fn variants(&self) -> &[ModelVariant] { + &self.variants + } + + /// Alias shared by all variants in this model. + pub fn alias(&self) -> &str { + &self.alias + } + + /// Unique identifier of the selected variant. + pub fn id(&self) -> &str { + self.selected_variant().id() + } + + /// Whether the selected variant is cached on disk. + pub async fn is_cached(&self) -> Result { + self.selected_variant().is_cached().await + } + + /// Whether the selected variant is loaded into memory. + pub async fn is_loaded(&self) -> Result { + self.selected_variant().is_loaded().await + } + + /// Download the selected variant. If `progress` is provided, it receives + /// human-readable progress strings as they arrive from the native core. + pub async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(&str) + Send + 'static, + { + self.selected_variant().download(progress).await + } + + /// Return the local file-system path of the selected variant. + pub async fn path(&self) -> Result { + self.selected_variant().path().await + } + + /// Load the selected variant into memory. + pub async fn load(&self) -> Result<()> { + self.selected_variant().load().await + } + + /// Unload the selected variant from memory. + pub async fn unload(&self) -> Result { + self.selected_variant().unload().await + } + + /// Remove the selected variant from the local cache. + pub async fn remove_from_cache(&self) -> Result { + self.selected_variant().remove_from_cache().await + } + + /// Create a [`ChatClient`] bound to the selected variant. + pub fn create_chat_client(&self) -> ChatClient { + ChatClient::new(self.id().to_string(), Arc::clone(&self.core)) + } + + /// Create an [`AudioClient`] bound to the selected variant. + pub fn create_audio_client(&self) -> AudioClient { + AudioClient::new(self.id().to_string(), Arc::clone(&self.core)) + } +} diff --git a/sdk_v2/rust/src/model_variant.rs b/sdk_v2/rust/src/model_variant.rs new file mode 100644 index 00000000..55bc2e5f --- /dev/null +++ b/sdk_v2/rust/src/model_variant.rs @@ -0,0 +1,131 @@ +//! A single model variant backed by [`ModelInfo`]. + +use std::path::PathBuf; +use std::sync::Arc; + +use serde_json::json; + +use crate::detail::core_interop::CoreInterop; +use crate::detail::ModelLoadManager; +use crate::error::Result; +use crate::openai::AudioClient; +use crate::openai::ChatClient; +use crate::types::ModelInfo; + +/// Represents one specific variant of a model (a particular id within an alias +/// group). +#[derive(Debug, Clone)] +pub struct ModelVariant { + info: ModelInfo, + core: Arc, + model_load_manager: Arc, +} + +impl ModelVariant { + pub(crate) fn new( + info: ModelInfo, + core: Arc, + model_load_manager: Arc, + ) -> Self { + Self { + info, + core, + model_load_manager, + } + } + + /// The full [`ModelInfo`] metadata for this variant. + pub fn info(&self) -> &ModelInfo { + &self.info + } + + /// Unique identifier. + pub fn id(&self) -> &str { + &self.info.id + } + + /// Alias shared with sibling variants. + pub fn alias(&self) -> &str { + &self.info.alias + } + + /// Check whether the variant is cached locally by querying the native + /// core. + pub async fn is_cached(&self) -> Result { + let raw = self + .core + .execute_command_async("get_cached_models".into(), None) + .await?; + if raw.trim().is_empty() { + return Ok(false); + } + let cached_ids: Vec = serde_json::from_str(&raw)?; + Ok(cached_ids.iter().any(|id| id == &self.info.id)) + } + + /// Check whether the variant is currently loaded into memory. + pub async fn is_loaded(&self) -> Result { + let loaded = self.model_load_manager.list_loaded().await?; + Ok(loaded.iter().any(|id| id == &self.info.id)) + } + + /// Download the model variant. If `progress` is provided, it receives + /// human-readable progress strings as the download proceeds. + pub async fn download(&self, progress: Option) -> Result<()> + where + F: FnMut(&str) + Send + 'static, + { + let params = json!({ "Params": { "Model": self.info.id } }); + match progress { + Some(cb) => { + self.core + .execute_command_streaming_async("download_model".into(), Some(params), cb) + .await?; + } + None => { + self.core + .execute_command_async("download_model".into(), Some(params)) + .await?; + } + } + Ok(()) + } + + /// Return the local file-system path where this variant is stored. + pub async fn path(&self) -> Result { + let params = json!({ "Params": { "Model": self.info.id } }); + let path_str = self + .core + .execute_command_async("get_model_path".into(), Some(params)) + .await?; + Ok(PathBuf::from(path_str)) + } + + /// Load the variant into memory. + pub async fn load(&self) -> Result<()> { + self.model_load_manager.load(&self.info.id).await + } + + /// Unload the variant from memory. + pub async fn unload(&self) -> Result { + self.model_load_manager.unload(&self.info.id).await + } + + /// Remove the variant from the local cache. + pub async fn remove_from_cache(&self) -> Result { + let params = json!({ "Params": { "Model": self.info.id } }); + self.core + .execute_command_async("remove_cached_model".into(), Some(params)) + .await + } + + /// Create a [`ChatClient`] bound to this variant. + pub fn create_chat_client(&self) -> ChatClient { + ChatClient::new(self.info.id.clone(), Arc::clone(&self.core)) + } + + /// Create an [`AudioClient`] bound to this variant. + pub fn create_audio_client(&self) -> AudioClient { + AudioClient::new(self.info.id.clone(), Arc::clone(&self.core)) + } +} diff --git a/sdk_v2/rust/src/openai/audio_client.rs b/sdk_v2/rust/src/openai/audio_client.rs new file mode 100644 index 00000000..b9577f81 --- /dev/null +++ b/sdk_v2/rust/src/openai/audio_client.rs @@ -0,0 +1,190 @@ +//! OpenAI-compatible audio transcription client. + +use std::path::Path; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use serde_json::{json, Value}; + +use crate::detail::core_interop::CoreInterop; +use crate::error::{FoundryLocalError, Result}; + +/// OpenAI-compatible audio transcription response. +#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)] +pub struct AudioTranscriptionResponse { + /// The transcribed text. + pub text: String, + /// The language of the input audio (if detected). + #[serde(skip_serializing_if = "Option::is_none")] + pub language: Option, + /// Duration of the input audio in seconds (if available). + #[serde(skip_serializing_if = "Option::is_none")] + pub duration: Option, + /// Segments of the transcription (if available). + #[serde(skip_serializing_if = "Option::is_none")] + pub segments: Option>, + /// Words with timestamps (if available). + #[serde(skip_serializing_if = "Option::is_none")] + pub words: Option>, +} + +/// Tuning knobs for audio transcription requests. +/// +/// Use the chainable setter methods to configure, e.g.: +/// +/// ```ignore +/// let mut client = model.create_audio_client(); +/// client.language("en").temperature(0.2); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct AudioClientSettings { + language: Option, + temperature: Option, +} + +impl AudioClientSettings { + /// Serialise settings into the JSON fragment expected by the native core. + fn serialize(&self, model_id: &str, file_name: &str) -> Value { + let mut map = serde_json::Map::new(); + + map.insert("Model".into(), json!(model_id)); + map.insert("FileName".into(), json!(file_name)); + + if let Some(ref lang) = self.language { + map.insert("Language".into(), json!(lang)); + } + if let Some(temp) = self.temperature { + map.insert("Temperature".into(), json!(temp)); + } + + Value::Object(map) + } +} + +/// A stream of [`AudioTranscriptionResponse`] chunks. +/// +/// Returned by [`AudioClient::transcribe_streaming`]. +pub struct AudioTranscriptionStream { + rx: tokio::sync::mpsc::UnboundedReceiver>, +} + +impl futures_core::Stream for AudioTranscriptionStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.rx.poll_recv(cx) { + Poll::Ready(Some(Ok(chunk))) => { + if chunk.is_empty() { + cx.waker().wake_by_ref(); + Poll::Pending + } else { + let parsed = serde_json::from_str::(&chunk) + .map_err(FoundryLocalError::from); + Poll::Ready(Some(parsed)) + } + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +/// Client for OpenAI-compatible audio transcription backed by a local model. +pub struct AudioClient { + model_id: String, + core: Arc, + settings: AudioClientSettings, +} + +impl AudioClient { + pub(crate) fn new(model_id: String, core: Arc) -> Self { + Self { + model_id, + core, + settings: AudioClientSettings::default(), + } + } + + /// Set the language hint for transcription. + pub fn language(&mut self, lang: impl Into) -> &mut Self { + self.settings.language = Some(lang.into()); + self + } + + /// Set the sampling temperature. + pub fn temperature(&mut self, v: f64) -> &mut Self { + self.settings.temperature = Some(v); + self + } + + /// Transcribe an audio file. + pub async fn transcribe( + &self, + audio_file_path: impl AsRef, + ) -> Result { + let path_str = + audio_file_path + .as_ref() + .to_str() + .ok_or_else(|| FoundryLocalError::Validation { + reason: "audio file path is not valid UTF-8".into(), + })?; + Self::validate_path(path_str)?; + + let request = self.settings.serialize(&self.model_id, path_str); + let params = json!({ + "Params": { + "OpenAICreateRequest": serde_json::to_string(&request)? + } + }); + + let raw = self + .core + .execute_command_async("audio_transcribe".into(), Some(params)) + .await?; + let parsed: AudioTranscriptionResponse = serde_json::from_str(&raw)?; + Ok(parsed) + } + + /// Transcribe an audio file with streaming results, returning an + /// [`AudioTranscriptionStream`]. + pub async fn transcribe_streaming( + &self, + audio_file_path: impl AsRef, + ) -> Result { + let path_str = + audio_file_path + .as_ref() + .to_str() + .ok_or_else(|| FoundryLocalError::Validation { + reason: "audio file path is not valid UTF-8".into(), + })?; + Self::validate_path(path_str)?; + + let request = self.settings.serialize(&self.model_id, path_str); + + let params = json!({ + "Params": { + "OpenAICreateRequest": serde_json::to_string(&request)? + } + }); + + let rx = self + .core + .execute_command_streaming_channel("audio_transcribe".into(), Some(params)) + .await?; + + Ok(AudioTranscriptionStream { rx }) + } + + fn validate_path(path: &str) -> Result<()> { + if path.trim().is_empty() { + return Err(FoundryLocalError::Validation { + reason: "audio_file_path must be a non-empty string".into(), + }); + } + Ok(()) + } +} diff --git a/sdk_v2/rust/src/openai/chat_client.rs b/sdk_v2/rust/src/openai/chat_client.rs new file mode 100644 index 00000000..974671c7 --- /dev/null +++ b/sdk_v2/rust/src/openai/chat_client.rs @@ -0,0 +1,310 @@ +//! OpenAI-compatible chat completions client. + +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use async_openai::types::chat::{ + ChatCompletionRequestMessage, ChatCompletionTools, CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +}; +use serde_json::{json, Value}; + +use crate::detail::core_interop::CoreInterop; +use crate::error::{FoundryLocalError, Result}; +use crate::types::{ChatResponseFormat, ChatToolChoice}; + +/// Tuning knobs for chat completion requests. +/// +/// Use the chainable setter methods to configure, e.g.: +/// +/// ```ignore +/// let mut client = model.create_chat_client(); +/// client.temperature(0.7).max_tokens(256); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct ChatClientSettings { + frequency_penalty: Option, + max_tokens: Option, + n: Option, + temperature: Option, + presence_penalty: Option, + top_p: Option, + top_k: Option, + random_seed: Option, + response_format: Option, + tool_choice: Option, +} + +impl ChatClientSettings { + /// Serialise settings into the JSON fragment expected by the native core. + fn serialize(&self) -> Value { + let mut map = serde_json::Map::new(); + + if let Some(v) = self.frequency_penalty { + map.insert("frequency_penalty".into(), json!(v)); + } + if let Some(v) = self.max_tokens { + map.insert("max_tokens".into(), json!(v)); + } + if let Some(v) = self.n { + map.insert("n".into(), json!(v)); + } + if let Some(v) = self.presence_penalty { + map.insert("presence_penalty".into(), json!(v)); + } + if let Some(v) = self.temperature { + map.insert("temperature".into(), json!(v)); + } + if let Some(v) = self.top_p { + map.insert("top_p".into(), json!(v)); + } + + if let Some(ref rf) = self.response_format { + let mut rf_map = serde_json::Map::new(); + match rf { + ChatResponseFormat::Text => { + rf_map.insert("type".into(), json!("text")); + } + ChatResponseFormat::JsonObject => { + rf_map.insert("type".into(), json!("json_object")); + } + ChatResponseFormat::JsonSchema(schema) => { + rf_map.insert("type".into(), json!("json_schema")); + rf_map.insert("jsonSchema".into(), json!(schema)); + } + ChatResponseFormat::LarkGrammar(grammar) => { + rf_map.insert("type".into(), json!("lark_grammar")); + rf_map.insert("larkGrammar".into(), json!(grammar)); + } + } + map.insert("response_format".into(), Value::Object(rf_map)); + } + + if let Some(ref tc) = self.tool_choice { + let mut tc_map = serde_json::Map::new(); + match tc { + ChatToolChoice::None => { + tc_map.insert("type".into(), json!("none")); + } + ChatToolChoice::Auto => { + tc_map.insert("type".into(), json!("auto")); + } + ChatToolChoice::Required => { + tc_map.insert("type".into(), json!("required")); + } + ChatToolChoice::Function(name) => { + tc_map.insert("type".into(), json!("function")); + tc_map.insert("name".into(), json!(name)); + } + } + map.insert("tool_choice".into(), Value::Object(tc_map)); + } + + // Foundry-specific metadata for settings that don't map directly to + // the OpenAI spec. + let mut metadata: HashMap = HashMap::new(); + if let Some(k) = self.top_k { + metadata.insert("top_k".into(), k.to_string()); + } + if let Some(s) = self.random_seed { + metadata.insert("random_seed".into(), s.to_string()); + } + if !metadata.is_empty() { + map.insert("metadata".into(), json!(metadata)); + } + + Value::Object(map) + } +} + +/// A stream of [`CreateChatCompletionStreamResponse`] chunks. +/// +/// Returned by [`ChatClient::complete_streaming_chat`]. +pub struct ChatCompletionStream { + rx: tokio::sync::mpsc::UnboundedReceiver>, +} + +impl futures_core::Stream for ChatCompletionStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.rx.poll_recv(cx) { + Poll::Ready(Some(Ok(chunk))) => { + if chunk.is_empty() { + // Skip empty chunks and poll again. + cx.waker().wake_by_ref(); + Poll::Pending + } else { + let parsed = serde_json::from_str::(&chunk) + .map_err(FoundryLocalError::from); + Poll::Ready(Some(parsed)) + } + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } +} + +/// Client for OpenAI-compatible chat completions backed by a local model. +pub struct ChatClient { + model_id: String, + core: Arc, + settings: ChatClientSettings, +} + +impl ChatClient { + pub(crate) fn new(model_id: String, core: Arc) -> Self { + Self { + model_id, + core, + settings: ChatClientSettings::default(), + } + } + + /// Set the frequency penalty. + pub fn frequency_penalty(&mut self, v: f64) -> &mut Self { + self.settings.frequency_penalty = Some(v); + self + } + + /// Set the maximum number of tokens to generate. + pub fn max_tokens(&mut self, v: u32) -> &mut Self { + self.settings.max_tokens = Some(v); + self + } + + /// Set the number of completions to generate. + pub fn n(&mut self, v: u32) -> &mut Self { + self.settings.n = Some(v); + self + } + + /// Set the sampling temperature. + pub fn temperature(&mut self, v: f64) -> &mut Self { + self.settings.temperature = Some(v); + self + } + + /// Set the presence penalty. + pub fn presence_penalty(&mut self, v: f64) -> &mut Self { + self.settings.presence_penalty = Some(v); + self + } + + /// Set the nucleus sampling probability. + pub fn top_p(&mut self, v: f64) -> &mut Self { + self.settings.top_p = Some(v); + self + } + + /// Set the top-k sampling parameter (Foundry extension). + pub fn top_k(&mut self, v: u32) -> &mut Self { + self.settings.top_k = Some(v); + self + } + + /// Set the random seed for reproducible results (Foundry extension). + pub fn random_seed(&mut self, v: u64) -> &mut Self { + self.settings.random_seed = Some(v); + self + } + + /// Set the desired response format. + pub fn response_format(&mut self, v: ChatResponseFormat) -> &mut Self { + self.settings.response_format = Some(v); + self + } + + /// Set the tool choice strategy. + pub fn tool_choice(&mut self, v: ChatToolChoice) -> &mut Self { + self.settings.tool_choice = Some(v); + self + } + + /// Perform a non-streaming chat completion. + pub async fn complete_chat( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + ) -> Result { + if messages.is_empty() { + return Err(FoundryLocalError::Validation { + reason: "messages must be a non-empty array".into(), + }); + } + + let request = self.build_request(messages, tools, false)?; + let params = json!({ + "Params": { + "OpenAICreateRequest": serde_json::to_string(&request)? + } + }); + + let raw = self + .core + .execute_command_async("chat_completions".into(), Some(params)) + .await?; + let parsed: CreateChatCompletionResponse = serde_json::from_str(&raw)?; + Ok(parsed) + } + + /// Perform a streaming chat completion, returning a [`ChatCompletionStream`]. + /// + /// Use the stream with `futures_core::StreamExt::next()` or + /// `tokio_stream::StreamExt::next()`. + pub async fn complete_streaming_chat( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + ) -> Result { + if messages.is_empty() { + return Err(FoundryLocalError::Validation { + reason: "messages must be a non-empty array".into(), + }); + } + + let request = self.build_request(messages, tools, true)?; + let params = json!({ + "Params": { + "OpenAICreateRequest": serde_json::to_string(&request)? + } + }); + + let rx = self + .core + .execute_command_streaming_channel("chat_completions".into(), Some(params)) + .await?; + + Ok(ChatCompletionStream { rx }) + } + + fn build_request( + &self, + messages: &[ChatCompletionRequestMessage], + tools: Option<&[ChatCompletionTools]>, + stream: bool, + ) -> Result { + let settings_value = self.settings.serialize(); + let mut map = match settings_value { + Value::Object(m) => m, + _ => serde_json::Map::new(), + }; + + map.insert("model".into(), json!(self.model_id)); + map.insert("messages".into(), serde_json::to_value(messages)?); + + if stream { + map.insert("stream".into(), json!(true)); + } + + if let Some(t) = tools { + map.insert("tools".into(), serde_json::to_value(t)?); + } + + Ok(Value::Object(map)) + } +} diff --git a/sdk_v2/rust/src/openai/mod.rs b/sdk_v2/rust/src/openai/mod.rs new file mode 100644 index 00000000..7a800c67 --- /dev/null +++ b/sdk_v2/rust/src/openai/mod.rs @@ -0,0 +1,7 @@ +mod audio_client; +mod chat_client; + +pub use self::audio_client::{ + AudioClient, AudioClientSettings, AudioTranscriptionResponse, AudioTranscriptionStream, +}; +pub use self::chat_client::{ChatClient, ChatClientSettings, ChatCompletionStream}; diff --git a/sdk_v2/rust/src/types.rs b/sdk_v2/rust/src/types.rs new file mode 100644 index 00000000..66bbba92 --- /dev/null +++ b/sdk_v2/rust/src/types.rs @@ -0,0 +1,119 @@ +use serde::{Deserialize, Serialize}; + +/// Hardware device type for model execution. +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub enum DeviceType { + Invalid, + #[default] + CPU, + GPU, + NPU, +} + +/// Prompt template describing how messages are formatted for the model. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptTemplate { + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub user: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub assistant: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, +} + +/// Runtime information for a model (device type and execution provider). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Runtime { + pub device_type: DeviceType, + pub execution_provider: String, +} + +/// A single parameter key-value pair used in model settings. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Parameter { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub value: Option, +} + +/// Model-level settings containing a list of parameters. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelSettings { + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option>, +} + +/// Full metadata for a model variant as returned by the catalog. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelInfo { + pub id: String, + pub name: String, + pub version: i64, + pub alias: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub display_name: Option, + pub provider_type: String, + pub uri: String, + pub model_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_template: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub publisher: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model_settings: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub license: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub license_description: Option, + pub cached: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub task: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub runtime: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_size_mb: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub supports_tool_calling: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_fl_version: Option, + #[serde(default)] + pub created_at_unix: i64, +} + +/// Desired response format for chat completions. +/// +/// Extends the standard OpenAI formats with the Foundry-specific +/// `LarkGrammar` variant. +#[derive(Debug, Clone)] +pub enum ChatResponseFormat { + /// Plain text output (default). + Text, + /// JSON output (unstructured). + JsonObject, + /// JSON output constrained by a schema string. + JsonSchema(String), + /// Output constrained by a Lark grammar (Foundry extension). + LarkGrammar(String), +} + +/// Tool choice configuration for chat completions. +#[derive(Debug, Clone)] +pub enum ChatToolChoice { + /// Model will not call any tool. + None, + /// Model decides whether to call a tool. + Auto, + /// Model must call at least one tool. + Required, + /// Model must call the named function. + Function(String), +} diff --git a/sdk_v2/rust/tests/integration/audio_client_test.rs b/sdk_v2/rust/tests/integration/audio_client_test.rs new file mode 100644 index 00000000..1e895609 --- /dev/null +++ b/sdk_v2/rust/tests/integration/audio_client_test.rs @@ -0,0 +1,130 @@ +use std::sync::Arc; +use super::common; +use foundry_local_sdk::openai::AudioClient; +use tokio_stream::StreamExt; + +async fn setup_audio_client() -> (AudioClient, Arc) { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::WHISPER_MODEL_ALIAS) + .await + .expect("get_model(whisper-tiny) failed"); + model.load().await.expect("model.load() failed"); + (model.create_audio_client(), model) +} + +fn audio_file() -> String { + common::get_audio_file_path().to_string_lossy().into_owned() +} + +#[tokio::test] +async fn should_transcribe_audio_without_streaming() { + let (mut client, model) = setup_audio_client().await; + client.language("en").temperature(0.0); + let response = client + .transcribe(&audio_file()) + .await + .expect("transcribe failed"); + + assert!( + response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Transcription should contain expected text, got: {}", + response.text + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_transcribe_audio_without_streaming_with_temperature() { + let (mut client, model) = setup_audio_client().await; + client.language("en").temperature(0.5); + + let response = client + .transcribe(&audio_file()) + .await + .expect("transcribe with temperature failed"); + + assert!( + response.text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Transcription should contain expected text, got: {}", + response.text + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_transcribe_audio_with_streaming() { + let (mut client, model) = setup_audio_client().await; + client.language("en").temperature(0.0); + let mut full_text = String::new(); + + let mut stream = client + .transcribe_streaming(&audio_file()) + .await + .expect("transcribe_streaming setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + full_text.push_str(&chunk.text); + } + + println!("Streamed transcription: {full_text}"); + + assert!( + full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Streamed transcription should contain expected text, got: {full_text}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_transcribe_audio_with_streaming_with_temperature() { + let (mut client, model) = setup_audio_client().await; + client.language("en").temperature(0.5); + + let mut full_text = String::new(); + + let mut stream = client + .transcribe_streaming(&audio_file()) + .await + .expect("transcribe_streaming with temperature setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + full_text.push_str(&chunk.text); + } + + println!("Streamed transcription: {full_text}"); + + assert!( + full_text.contains(common::EXPECTED_TRANSCRIPTION_TEXT), + "Streamed transcription should contain expected text, got: {full_text}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_transcribing_with_empty_audio_file_path() { + let (client, model) = setup_audio_client().await; + let result = client.transcribe("").await; + assert!(result.is_err(), "Expected error for empty audio file path"); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_transcribing_streaming_with_empty_audio_file_path() { + let (client, model) = setup_audio_client().await; + let result = client.transcribe_streaming("").await; + assert!( + result.is_err(), + "Expected error for empty audio file path in streaming" + ); + + model.unload().await.expect("model.unload() failed"); +} diff --git a/sdk_v2/rust/tests/integration/catalog_test.rs b/sdk_v2/rust/tests/integration/catalog_test.rs new file mode 100644 index 00000000..d418c7a7 --- /dev/null +++ b/sdk_v2/rust/tests/integration/catalog_test.rs @@ -0,0 +1,106 @@ +use super::common; +use foundry_local_sdk::Catalog; + +fn catalog() -> &'static Catalog { + common::get_test_manager().catalog() +} + +#[test] +fn should_initialize_with_catalog_name() { + let cat = catalog(); + let name = cat.name(); + assert!(!name.is_empty(), "Catalog name must not be empty"); +} + +#[tokio::test] +async fn should_list_models() { + let cat = catalog(); + let models = cat.get_models().await.expect("get_models failed"); + + assert!( + !models.is_empty(), + "Expected at least one model in the catalog" + ); + + let found = models.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + found, + "Test model '{}' not found in catalog", + common::TEST_MODEL_ALIAS + ); +} + +#[tokio::test] +async fn should_get_model_by_alias() { + let cat = catalog(); + let model = cat + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); +} + +#[tokio::test] +async fn should_throw_when_getting_model_with_empty_alias() { + let cat = catalog(); + let result = cat.get_model("").await; + assert!(result.is_err(), "Expected error for empty alias"); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Model alias must be a non-empty string"), + "Unexpected error message: {err_msg}" + ); +} + +#[tokio::test] +async fn should_throw_when_getting_model_with_unknown_alias() { + let cat = catalog(); + let result = cat.get_model("unknown-nonexistent-model-alias").await; + assert!(result.is_err(), "Expected error for unknown alias"); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("Unknown model alias"), + "Error should mention unknown alias: {err_msg}" + ); + assert!( + err_msg.contains("Available"), + "Error should list available models: {err_msg}" + ); +} + +#[tokio::test] +async fn should_get_cached_models() { + let cat = catalog(); + let cached = cat + .get_cached_models() + .await + .expect("get_cached_models failed"); + + assert!(!cached.is_empty(), "Expected at least one cached model"); + + let found = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + found, + "Test model '{}' should be in the cached models list", + common::TEST_MODEL_ALIAS + ); +} + +#[tokio::test] +async fn should_throw_when_getting_model_variant_with_empty_id() { + let cat = catalog(); + let result = cat.get_model_variant("").await; + assert!(result.is_err(), "Expected error for empty variant ID"); +} + +#[tokio::test] +async fn should_throw_when_getting_model_variant_with_unknown_id() { + let cat = catalog(); + let result = cat + .get_model_variant("unknown-nonexistent-variant-id") + .await; + assert!(result.is_err(), "Expected error for unknown variant ID"); +} diff --git a/sdk_v2/rust/tests/integration/chat_client_test.rs b/sdk_v2/rust/tests/integration/chat_client_test.rs new file mode 100644 index 00000000..9b3e55d8 --- /dev/null +++ b/sdk_v2/rust/tests/integration/chat_client_test.rs @@ -0,0 +1,335 @@ +use std::sync::Arc; +use super::common; +use foundry_local_sdk::openai::ChatClient; +use foundry_local_sdk::{ + ChatCompletionMessageToolCalls, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestToolMessage, + ChatCompletionRequestUserMessage, ChatToolChoice, +}; +use serde_json::json; +use tokio_stream::StreamExt; + +async fn setup_chat_client() -> (ChatClient, Arc) { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let mut client = model.create_chat_client(); + client.max_tokens(500).temperature(0.0); + (client, model) +} + +fn user_message(content: &str) -> ChatCompletionRequestMessage { + ChatCompletionRequestUserMessage::from(content).into() +} + +fn system_message(content: &str) -> ChatCompletionRequestMessage { + ChatCompletionRequestSystemMessage::from(content).into() +} + +fn assistant_message(content: &str) -> ChatCompletionRequestMessage { + serde_json::from_value(json!({ "role": "assistant", "content": content })) + .expect("failed to construct assistant message") +} + +#[tokio::test] +async fn should_perform_chat_completion() { + let (client, model) = setup_chat_client().await; + let messages = vec![ + system_message("You are a helpful math assistant. Respond with just the answer."), + user_message("What is 7*6?"), + ]; + + let response = client + .complete_chat(&messages, None) + .await + .expect("complete_chat failed"); + let content = response + .choices + .first() + .and_then(|c| c.message.content.as_deref()) + .unwrap_or(""); + println!("Response: {content}"); + + println!("REST response: {content}"); + + assert!( + content.contains("42"), + "Expected response to contain '42', got: {content}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_perform_streaming_chat_completion() { + let (client, model) = setup_chat_client().await; + let mut messages = vec![ + system_message("You are a helpful math assistant. Respond with just the answer."), + user_message("What is 7*6?"), + ]; + + let mut first_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, None) + .await + .expect("streaming chat (first turn) setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + first_result.push_str(content); + } + } + } + println!("First turn: {first_result}"); + + assert!( + first_result.contains("42"), + "First turn should contain '42', got: {first_result}" + ); + + messages.push(assistant_message(&first_result)); + messages.push(user_message("Now add 25 to that result.")); + + let mut second_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, None) + .await + .expect("streaming chat (follow-up) setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + second_result.push_str(content); + } + } + } + println!("Follow-up: {second_result}"); + + assert!( + second_result.contains("67"), + "Follow-up should contain '67', got: {second_result}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_completing_chat_with_empty_messages() { + let (client, model) = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_chat(&messages, None).await; + assert!(result.is_err(), "Expected error for empty messages"); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_throw_when_completing_streaming_chat_with_empty_messages() { + let (client, model) = setup_chat_client().await; + let messages: Vec = vec![]; + + let result = client.complete_streaming_chat(&messages, None).await; + assert!( + result.is_err(), + "Expected error for empty messages in streaming" + ); + + model.unload().await.expect("model.unload() failed"); +} + +// Note: The "invalid callback" test was removed because it was an exact +// duplicate of should_throw_when_completing_streaming_chat_with_empty_messages. + +#[tokio::test] +async fn should_perform_tool_calling_chat_completion_non_streaming() { + let (mut client, model) = setup_chat_client().await; + client.tool_choice(ChatToolChoice::Required); + + let tools = vec![common::get_multiply_tool()]; + let mut messages = vec![ + system_message("You are a math assistant. Use the multiply tool to answer."), + user_message("What is 6 times 7?"), + ]; + + let response = client + .complete_chat(&messages, Some(&tools)) + .await + .expect("complete_chat with tools failed"); + + let choice = response + .choices + .first() + .expect("Expected at least one choice"); + let tool_calls = choice + .message + .tool_calls + .as_ref() + .expect("Expected tool_calls"); + assert!( + !tool_calls.is_empty(), + "Expected at least one tool call in the response" + ); + + let tool_call = match &tool_calls[0] { + ChatCompletionMessageToolCalls::Function(tc) => tc, + _ => panic!("Expected a function tool call"), + }; + assert_eq!( + tool_call.function.name, "multiply", + "Expected tool call to 'multiply'" + ); + + let args: serde_json::Value = serde_json::from_str(&tool_call.function.arguments) + .expect("Failed to parse tool call arguments"); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let product = (a * b) as i64; + + let tool_call_id = &tool_call.id; + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "content": null, + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + }] + })) + .expect("failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: product.to_string().into(), + tool_call_id: tool_call_id.clone(), + } + .into(), + ); + + client.tool_choice(ChatToolChoice::Auto); + + let final_response = client + .complete_chat(&messages, Some(&tools)) + .await + .expect("follow-up complete_chat with tools failed"); + let content = final_response + .choices + .first() + .and_then(|c| c.message.content.as_deref()) + .unwrap_or(""); + + println!("Tool call result: {content}"); + + assert!( + content.contains("42"), + "Final answer should contain '42', got: {content}" + ); + + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_perform_tool_calling_chat_completion_streaming() { + let (mut client, model) = setup_chat_client().await; + client.tool_choice(ChatToolChoice::Required); + + let tools = vec![common::get_multiply_tool()]; + let mut messages = vec![ + system_message("You are a math assistant. Use the multiply tool to answer."), + user_message("What is 6 times 7?"), + ]; + + let mut tool_call_name = String::new(); + let mut tool_call_args = String::new(); + let mut tool_call_id = String::new(); + + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await + .expect("streaming tool call setup failed"); + + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref tool_calls) = choice.delta.tool_calls { + for call in tool_calls { + if let Some(ref func) = call.function { + if let Some(ref name) = func.name { + tool_call_name.push_str(name); + } + if let Some(ref args) = func.arguments { + tool_call_args.push_str(args); + } + } + if let Some(ref id) = call.id { + tool_call_id = id.clone(); + } + } + } + } + } + assert_eq!( + tool_call_name, "multiply", + "Expected streamed tool call to 'multiply'" + ); + + let args: serde_json::Value = + serde_json::from_str(&tool_call_args).unwrap_or_else(|_| json!({})); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let product = (a * b) as i64; + + let assistant_msg: ChatCompletionRequestMessage = serde_json::from_value(json!({ + "role": "assistant", + "tool_calls": [{ + "id": tool_call_id, + "type": "function", + "function": { + "name": tool_call_name, + "arguments": tool_call_args + } + }] + })) + .expect("failed to construct assistant message"); + messages.push(assistant_msg); + messages.push( + ChatCompletionRequestToolMessage { + content: product.to_string().into(), + tool_call_id: tool_call_id.clone(), + } + .into(), + ); + + client.tool_choice(ChatToolChoice::Auto); + + let mut final_result = String::new(); + let mut stream = client + .complete_streaming_chat(&messages, Some(&tools)) + .await + .expect("streaming follow-up setup failed"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.expect("stream chunk error"); + if let Some(choice) = chunk.choices.first() { + if let Some(ref content) = choice.delta.content { + final_result.push_str(content); + } + } + } + println!("Streamed tool call result: {final_result}"); + + assert!( + final_result.contains("42"), + "Streamed final answer should contain '42', got: {final_result}" + ); + + model.unload().await.expect("model.unload() failed"); +} diff --git a/sdk_v2/rust/tests/integration/common/mod.rs b/sdk_v2/rust/tests/integration/common/mod.rs new file mode 100644 index 00000000..dbe3414d --- /dev/null +++ b/sdk_v2/rust/tests/integration/common/mod.rs @@ -0,0 +1,124 @@ +//! Shared test utilities and configuration for Foundry Local SDK integration tests. +//! +//! Mirrors `testUtils.ts` from the JavaScript SDK test suite. + +#![allow(dead_code)] + +use std::collections::HashMap; +use std::path::PathBuf; + +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager, LogLevel}; + +/// Default model alias used for chat-completion integration tests. +pub const TEST_MODEL_ALIAS: &str = "qwen2.5-0.5b"; + +/// Default model alias used for audio-transcription integration tests. +pub const WHISPER_MODEL_ALIAS: &str = "whisper-tiny"; + +/// Expected transcription text fragment for the shared audio test file. +pub const EXPECTED_TRANSCRIPTION_TEXT: &str = + " And lots of times you need to give people more than one link at a time"; + +// ── Environment helpers ────────────────────────────────────────────────────── + +/// Returns `true` when the tests are running inside a CI environment +/// (Azure DevOps or GitHub Actions). +pub fn is_running_in_ci() -> bool { + let azure_devops = std::env::var("TF_BUILD").unwrap_or_else(|_| "false".into()); + let github_actions = std::env::var("GITHUB_ACTIONS").unwrap_or_else(|_| "false".into()); + azure_devops.eq_ignore_ascii_case("true") || github_actions.eq_ignore_ascii_case("true") +} + +/// Walk upward from `CARGO_MANIFEST_DIR` until a `.git` directory is found. +pub fn get_git_repo_root() -> PathBuf { + let mut current = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + loop { + if current.join(".git").exists() { + return current; + } + if !current.pop() { + panic!( + "Could not locate git repo root starting from {}", + env!("CARGO_MANIFEST_DIR") + ); + } + } +} + +/// Path to the shared test-data directory that lives alongside the repo root. +pub fn get_test_data_shared_path() -> PathBuf { + let repo_root = get_git_repo_root(); + repo_root + .parent() + .expect("repo root has no parent") + .join("test-data-shared") +} + +/// Path to the shared audio test file used by audio-client tests. +pub fn get_audio_file_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("..") + .join("testdata") + .join("Recording.mp3") +} + +// ── Test configuration ─────────────────────────────────────────────────────── + +/// Build a [`FoundryLocalConfig`] suitable for integration tests. +/// +/// * `modelCacheDir` → `/../test-data-shared` +/// * `logsDir` → `/sdk_v2/rust/logs` +/// * `logLevel` → `Warn` +/// * `Bootstrap` → `false` (via additional settings) +pub fn test_config() -> FoundryLocalConfig { + let repo_root = get_git_repo_root(); + let logs_dir = repo_root.join("sdk_v2").join("rust").join("logs"); + + let mut additional = HashMap::new(); + additional.insert("Bootstrap".into(), "false".into()); + + let mut config = FoundryLocalConfig::new("FoundryLocalTest"); + config.model_cache_dir = Some(get_test_data_shared_path().to_string_lossy().into_owned()); + config.logs_dir = Some(logs_dir.to_string_lossy().into_owned()); + config.log_level = Some(LogLevel::Warn); + config.additional_settings = Some(additional); + config +} + +/// Create (or return the cached) [`FoundryLocalManager`] for tests. +/// +/// Panics if creation fails so that test set-up failures are immediately +/// visible. +pub fn get_test_manager() -> &'static FoundryLocalManager { + FoundryLocalManager::create(test_config()).expect("Failed to create FoundryLocalManager") +} + +// ── Tool definitions ───────────────────────────────────────────────────────── + +/// Returns a tool definition for a simple "multiply" function. +/// +/// Used by tool-calling chat-completion tests. +pub fn get_multiply_tool() -> foundry_local_sdk::ChatCompletionTools { + serde_json::from_value(serde_json::json!({ + "type": "function", + "function": { + "name": "multiply", + "description": "Multiply two numbers together", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "number", + "description": "The first number" + }, + "b": { + "type": "number", + "description": "The second number" + } + }, + "required": ["a", "b"] + } + } + })) + .expect("Failed to parse multiply tool definition") +} diff --git a/sdk_v2/rust/tests/integration/main.rs b/sdk_v2/rust/tests/integration/main.rs new file mode 100644 index 00000000..a18a29a6 --- /dev/null +++ b/sdk_v2/rust/tests/integration/main.rs @@ -0,0 +1,16 @@ +//! Single integration test binary for the Foundry Local Rust SDK. +//! +//! All test modules are compiled into one binary so the native core is only +//! initialised once (via the `OnceLock` singleton in `FoundryLocalManager`). +//! Running them as separate binaries causes "already initialized" errors +//! because the .NET native runtime retains state across process-level +//! library loads. + +mod common; + +mod manager_test; +mod catalog_test; +mod model_test; +mod chat_client_test; +mod audio_client_test; +mod web_service_test; diff --git a/sdk_v2/rust/tests/integration/manager_test.rs b/sdk_v2/rust/tests/integration/manager_test.rs new file mode 100644 index 00000000..aa3e0614 --- /dev/null +++ b/sdk_v2/rust/tests/integration/manager_test.rs @@ -0,0 +1,21 @@ +use super::common; +use foundry_local_sdk::FoundryLocalManager; + +#[test] +fn should_initialize_successfully() { + let config = common::test_config(); + let manager = FoundryLocalManager::create(config); + assert!( + manager.is_ok(), + "Manager creation failed: {:?}", + manager.err() + ); +} + +#[test] +fn should_return_catalog_with_non_empty_name() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let name = catalog.name(); + assert!(!name.is_empty(), "Catalog name should not be empty"); +} diff --git a/sdk_v2/rust/tests/integration/model_test.rs b/sdk_v2/rust/tests/integration/model_test.rs new file mode 100644 index 00000000..cd11f881 --- /dev/null +++ b/sdk_v2/rust/tests/integration/model_test.rs @@ -0,0 +1,288 @@ +use std::sync::Arc; +use super::common; + +// ── Cached model verification ──────────────────────────────────────────────── + +#[tokio::test] +async fn should_verify_cached_models_from_test_data_shared() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let cached = catalog + .get_cached_models() + .await + .expect("get_cached_models failed"); + + let has_qwen = cached.iter().any(|m| m.alias() == common::TEST_MODEL_ALIAS); + assert!( + has_qwen, + "'{}' should be present in cached models", + common::TEST_MODEL_ALIAS + ); + + let has_whisper = cached + .iter() + .any(|m| m.alias() == common::WHISPER_MODEL_ALIAS); + assert!( + has_whisper, + "'{}' should be present in cached models", + common::WHISPER_MODEL_ALIAS + ); +} + +// ── Load / Unload ──────────────────────────────────────────────────────────── + +#[tokio::test] +async fn should_load_and_unload_model() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + model.load().await.expect("model.load() failed"); + assert!( + model.is_loaded().await.expect("is_loaded check failed"), + "Model should be loaded after load()" + ); + + model.unload().await.expect("model.unload() failed"); + assert!( + !model.is_loaded().await.expect("is_loaded check failed"), + "Model should not be loaded after unload()" + ); +} + +// ── Introspection ──────────────────────────────────────────────────────────── + +#[tokio::test] +async fn should_expose_alias() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + assert_eq!(model.alias(), common::TEST_MODEL_ALIAS); +} + +#[tokio::test] +async fn should_expose_non_empty_id() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + println!("Model id: {}", model.id()); + + assert!( + !model.id().is_empty(), + "Model id() should be a non-empty string" + ); +} + +#[tokio::test] +async fn should_have_at_least_one_variant() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let variants = model.variants(); + println!("Model has {} variant(s)", variants.len()); + + assert!( + !variants.is_empty(), + "Model should have at least one variant" + ); +} + +#[tokio::test] +async fn should_have_selected_variant_matching_id() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let selected = model.selected_variant(); + assert_eq!( + selected.id(), + model.id(), + "selected_variant().id() should match model.id()" + ); +} + +#[tokio::test] +async fn should_report_cached_model_as_cached() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let cached = model.is_cached().await.expect("is_cached() should succeed"); + assert!( + cached, + "Test model '{}' should be cached (from test-data-shared)", + common::TEST_MODEL_ALIAS + ); +} + +#[tokio::test] +async fn should_return_non_empty_path_for_cached_model() { + let manager = common::get_test_manager(); + let model = manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + + let path = model.path().await.expect("path() should succeed"); + println!("Model path: {}", path.display()); + + assert!( + !path.as_os_str().is_empty(), + "Cached model should have a non-empty path" + ); +} + +#[tokio::test] +async fn should_select_variant_by_id() { + let manager = common::get_test_manager(); + let mut model = (*manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed")) + .clone(); + + let first_variant_id = model.variants()[0].id().to_string(); + model + .select_variant(&first_variant_id) + .expect("select_variant should succeed"); + assert_eq!( + model.id(), + first_variant_id, + "After select_variant, id() should match the selected variant" + ); +} + +#[tokio::test] +async fn should_fail_to_select_unknown_variant() { + let manager = common::get_test_manager(); + let mut model = (*manager + .catalog() + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed")) + .clone(); + + let result = model.select_variant("nonexistent-variant-id"); + assert!( + result.is_err(), + "select_variant with unknown ID should fail" + ); + + let err_msg = result.unwrap_err().to_string(); + assert!( + err_msg.contains("not found"), + "Error should mention 'not found': {err_msg}" + ); +} + +// ── Load manager (core interop) ────────────────────────────────────────────── + +async fn get_test_model() -> Arc { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed") +} + +#[tokio::test] +async fn should_load_model_using_core_interop() { + let model = get_test_model().await; + model.load().await.expect("model.load() failed"); + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_unload_model_using_core_interop() { + let model = get_test_model().await; + model.load().await.expect("model.load() failed"); + model.unload().await.expect("model.unload() failed"); +} + +#[tokio::test] +async fn should_list_loaded_models_using_core_interop() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + + let loaded = catalog + .get_loaded_models() + .await + .expect("catalog.get_loaded_models() failed"); + + let _ = loaded; +} + +#[tokio::test] +#[ignore = "requires running web service"] +async fn should_load_and_unload_model_using_external_service() { + if common::is_running_in_ci() { + eprintln!("Skipping external-service test in CI"); + return; + } + + let manager = common::get_test_manager(); + let model = get_test_model().await; + + let _urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + + model + .load() + .await + .expect("load via external service failed"); + + model + .unload() + .await + .expect("unload via external service failed"); +} + +#[tokio::test] +#[ignore = "requires running web service"] +async fn should_list_loaded_models_using_external_service() { + if common::is_running_in_ci() { + eprintln!("Skipping external-service test in CI"); + return; + } + + let manager = common::get_test_manager(); + + let _urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + + let catalog = manager.catalog(); + let loaded = catalog + .get_loaded_models() + .await + .expect("get_loaded_models via external service failed"); + + let _ = loaded; +} diff --git a/sdk_v2/rust/tests/integration/web_service_test.rs b/sdk_v2/rust/tests/integration/web_service_test.rs new file mode 100644 index 00000000..41f04e49 --- /dev/null +++ b/sdk_v2/rust/tests/integration/web_service_test.rs @@ -0,0 +1,163 @@ +use super::common; +use serde_json::json; + +/// Start the web service, make a non-streaming POST to v1/chat/completions, +/// verify we get a valid response, then stop the service. +#[tokio::test] +async fn should_complete_chat_via_rest_api() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + let base_url = urls.first().expect("no URL returned").trim_end_matches('/'); + + let client = reqwest::Client::new(); + let resp = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&json!({ + "model": model.id(), + "messages": [ + { "role": "system", "content": "You are a helpful math assistant. Respond with just the answer." }, + { "role": "user", "content": "What is 7*6?" } + ], + "max_tokens": 500, + "temperature": 0.0, + "stream": false + })) + .send() + .await + .expect("HTTP request failed"); + + assert!( + resp.status().is_success(), + "Expected 2xx, got {}", + resp.status() + ); + + let body: serde_json::Value = resp.json().await.expect("failed to parse response JSON"); + let content = body + .pointer("/choices/0/message/content") + .and_then(|v| v.as_str()) + .unwrap_or(""); + + println!("REST response: {content}"); + + assert!( + content.contains("42"), + "Expected response to contain '42', got: {content}" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); + model.unload().await.expect("model.unload() failed"); +} + +/// Start the web service, make a streaming POST to v1/chat/completions, +/// collect SSE chunks, verify we get a valid streamed response. +#[tokio::test] +async fn should_stream_chat_via_rest_api() { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + let model = catalog + .get_model(common::TEST_MODEL_ALIAS) + .await + .expect("get_model failed"); + model.load().await.expect("model.load() failed"); + + let urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + let base_url = urls.first().expect("no URL returned").trim_end_matches('/'); + + let client = reqwest::Client::new(); + let mut response = client + .post(format!("{base_url}/v1/chat/completions")) + .json(&json!({ + "model": model.id(), + "messages": [ + { "role": "system", "content": "You are a helpful math assistant. Respond with just the answer." }, + { "role": "user", "content": "What is 7*6?" } + ], + "max_tokens": 500, + "temperature": 0.0, + "stream": true + })) + .send() + .await + .expect("HTTP request failed"); + + assert!( + response.status().is_success(), + "Expected 2xx, got {}", + response.status() + ); + + let mut full_text = String::new(); + while let Some(chunk) = response.chunk().await.expect("chunk read failed") { + let text = String::from_utf8_lossy(&chunk); + for line in text.lines() { + let line = line.trim(); + if let Some(data) = line.strip_prefix("data: ") { + if data == "[DONE]" { + break; + } + if let Ok(parsed) = serde_json::from_str::(data) { + if let Some(content) = parsed + .pointer("/choices/0/delta/content") + .and_then(|v| v.as_str()) + { + full_text.push_str(content); + } + } + } + } + } + + println!("REST streamed response: {full_text}"); + + assert!( + full_text.contains("42"), + "Expected streamed response to contain '42', got: {full_text}" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); + model.unload().await.expect("model.unload() failed"); +} + +/// urls() should return the listening addresses after start_web_service. +#[tokio::test] +async fn should_expose_urls_after_start() { + let manager = common::get_test_manager(); + + let urls = manager + .start_web_service() + .await + .expect("start_web_service failed"); + println!("Web service URLs: {urls:?}"); + assert!(!urls.is_empty(), "start_web_service should return URLs"); + + let cached_urls = manager.urls().expect("urls() should succeed"); + assert_eq!( + urls, cached_urls, + "urls() should match what start_web_service returned" + ); + + manager + .stop_web_service() + .await + .expect("stop_web_service failed"); +}