diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 013305399..95c180b28 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -48,7 +48,13 @@ ## Where to add new code or tests 🧭 -- SDK code: `nodejs/src`, `python/copilot`, `go`, `dotnet/src` -- Unit tests: `nodejs/test`, `python/*`, `go/*`, `dotnet/test` +- SDK code: `nodejs/src`, `python/copilot`, `go`, `dotnet/src`, `rust/src` +- Unit tests: `nodejs/test`, `python/*`, `go/*`, `dotnet/test`, `rust/tests` - E2E tests: `*/e2e/` folders that use the shared replay proxy and `test/snapshots/` - Generated types: update schema in `@github/copilot` then run `cd nodejs && npm run generate:session-types` and commit generated files in `src/generated` or language generated location. + +## Skills 🛠️ + +Repo-scoped skills live under `.github/skills//` and are auto-discovered by Copilot. Load the relevant skill before editing the matching file types. + +- **`rust-coding-skill`** (`.github/skills/rust-coding-skill/SKILL.md`) — load before editing any `*.rs` file in `rust/`. Covers error handling, async/concurrency, tracing, codegen workflow, and Rust SDK-specific trait patterns. diff --git a/.github/skills/rust-coding-skill/SKILL.md b/.github/skills/rust-coding-skill/SKILL.md new file mode 100644 index 000000000..25c54695e --- /dev/null +++ b/.github/skills/rust-coding-skill/SKILL.md @@ -0,0 +1,254 @@ +--- +name: rust-coding-skill +description: "Use this skill whenever editing `*.rs` files in the `rust/` SDK in order to write idiomatic, efficient, well-structured Rust code" +--- + +# Rust Coding Skill + +Opinionated Rust rules for the Copilot Rust SDK (`rust/`). Priority order: + +1. **Readable code** — every line should earn its place +2. **Correct code** — especially in concurrent/async contexts +3. **Performant code** — think about allocations, data structures, hot paths + +## Error handling + +The SDK's public error type is `crate::Error` (`rust/src/error.rs`). Add new +variants there rather than introducing parallel error enums per module — every +public failure mode is part of the API contract and should be expressible in one +type. Internal modules can use `thiserror` enums when a richer local taxonomy +helps; convert at the boundary. + +`anyhow` is reserved for binaries and example code. Library code never returns +`anyhow::Result` — callers can't pattern-match on `anyhow::Error`, so it would +prevent them from handling specific failures. + +In production code, prefer `?`, `let-else`, and `if let`. Reach for `expect("…")` +when an invariant cannot fail and the message would help debug a future +regression. `unwrap()` belongs in tests only — Clippy enforces this in the SDK +via `#![cfg_attr(test, allow(clippy::unwrap_used))]` in `lib.rs`. + +When you need to log on the way through, prefer +`.inspect_err(|e| warn!(error = ?e, "context"))?` over a `match` that logs and +re-wraps. It reads top-to-bottom and keeps the happy path uncluttered. + +## Async and concurrency + +The default for request-scoped I/O is `async fn` plus `.await` — futures +inherit cancellation from their parent task and can borrow local references. +Reach for `tokio::spawn` only when you genuinely need background work (an event +loop, a long-lived watcher) and track the `JoinHandle` so you can cancel or join +it on shutdown. Fire-and-forget spawns silently swallow panics and outlive the +session; don't. + +Blocking calls (filesystem, subprocess wait) belong in +`tokio::task::spawn_blocking`, *not* on the async runtime. The blocking pool is +bounded, so for genuinely long-lived workers (think: file watchers that run for +the lifetime of a session) prefer `std::thread::spawn` with a channel back into +async land. + +Lock choice matters. `tokio::sync::Mutex` is correct when you must hold the +guard across `.await`; `parking_lot::Mutex` (or `RwLock`) is faster on hot +synchronous paths and is what `session.rs` uses for capability state. +`std::sync::Mutex` is rarely the right answer in this crate — its poisoning +semantics buy us nothing and it's slower than `parking_lot`. Never hold a +`std::sync::Mutex` guard across an `.await`; Clippy will catch this, but the +fix is to move the await out, not silence the lint. + +For lazy statics use `std::sync::LazyLock`. The `once_cell` crate is no longer +needed. + +## Traits and conversions + +Plain functions on a type beat traits for navigability — IDE "Go to definition" +on an inherent method jumps directly to the implementation, while a trait method +hops to the trait declaration first. Use that as the default. + +There are four intentional exceptions where the SDK exposes a trait because it +*is* an extension point — code paths consumers must be able to plug behaviour +into: + +- **`SessionHandler`** (`rust/src/handler.rs`) — single `on_event()` dispatches + CLI events. Notification-triggered events (`permission.requested`, + `external_tool.requested`, `elicitation.requested`) are dispatched on spawned + tasks, so implementations must be safe for concurrent invocation. Use + `ApproveAllHandler` in tests and examples. +- **`SessionHooks`** (`rust/src/hooks.rs`) — optional lifecycle callbacks. The + SDK auto-enables hooks (`config.hooks = Some(true)`) when an impl is supplied + to `create_session` / `resume_session`. +- **`SystemMessageTransform`** (`rust/src/system_message.rs`) — declare + `section_ids()` and return content from `transform_section()`. +- **`ToolHandler`** (`rust/src/tool.rs`) — client-side tool implementations. + Dispatch by name via `ToolHandlerRouter`. + +Don't add new traits without a clear extension story. In particular, don't +implement `From`/`Into` for SDK-internal conversions: they can't take extra +parameters, can't return `Result`, and hide which conversion is happening at +call sites. Prefer named methods like `to_info(&self)` or +`MyType::from_record(record, ctx)`. + +Trivial field re-shaping ("flatten this struct into that one") is best inlined +at the call site. A free-standing `map_x_to_y(x) -> Y` adds a hop without +adding clarity. + +Closures should stay short — under ~10 lines is a good rule. Long anonymous +closures show up as opaque frames in stack traces. Extract them to named +functions when they grow. Visitor patterns are a closure-fest in disguise; +expose an `iter()` method instead and let the consumer drive the traversal. + +## Tracing — `#[tracing::instrument]` is banned + +Banned via `clippy.toml`. Use manual spans with `error_span!`: + +- **Almost always use `error_span!`**, not `info_span!`. Span level controls + the *minimum* filter at which the span appears. An `info_span` disappears when + the filter is `warn` or `error` — taking all child events with it, even + errors. `error_span!` ensures the span is always present. +- **Spawned tasks lose parent context.** Attach a span with `.instrument()` or + events inside won't correlate. +- **Never hold `span.enter()` guards across `.await`** — use `.instrument(span)` + instead (also enforced by Clippy). + +```rust +use tracing::Instrument; + +async fn send_message(&self, session_id: &str, prompt: &str) -> Result<(), Error> { + let span = tracing::error_span!("send_message", session_id = %session_id); + async { /* body */ }.instrument(span).await +} + +let span = tracing::error_span!("event_loop", session_id = %id); +tokio::spawn(async move { run_loop().await }.instrument(span)); +``` + +Log with structured fields: `info!(session_id = %id, "Session created")`. +Static messages stay greppable; dynamic data goes in named fields, not +interpolated into the message string. + +## Idioms that don't port from Go or Node + +The most common pitfall when adapting code from the Node and Go SDKs is the +event subscription pattern. Those SDKs expose `client.on(handler)` callback +registration; the Rust SDK uses typed channels (`tokio::sync::broadcast` for +fan-out, `tokio::sync::mpsc` for single-consumer streams). Don't try to +recreate observer-style callbacks — drop the consumer onto a channel and let +each subscriber `.recv()` on its own task. See `Session::events_subscribe()` for +the canonical example. + +Similarly, contexts and cancellation in Go/Node map to dropping a future or +calling `JoinHandle::abort()` — there is no `ctx.Done()` analogue to plumb +through every call site. Optional fields use `Option`, not nullable +pointers; defaults come from `Default` impls, not constructors that accept +zero values. JSON tag attributes become `#[serde(rename_all = "camelCase")]` at +the type level plus `#[serde(rename = "…")]` on the occasional outlier. + +## Code organization + +- **Public API:** every `pub` item in the crate is part of the SDK's contract. + Adding a field to a `pub struct` is a breaking change unless the struct is + `#[non_exhaustive]` or constructors hide field-by-field literals. Prefer + `Default + ..Default::default()` patterns and document new fields with + rustdoc. +- **Generated code lives in `rust/src/generated/`** and must not be + hand-edited. Regenerate with `cd scripts/codegen && npm run generate:rust`. + When a generated type lacks a field the schema doesn't yet describe (e.g. + `Tool::overrides_built_in_tool`), hand-author the user-facing type in + `rust/src/types.rs` and stop re-exporting the generated one. +- **`#[expect(dead_code)]`** instead of `#[allow(dead_code)]` on individual + fields — it forces a cleanup once the field gets used. +- **`..Default::default()`** — avoid in production code (be explicit about + which fields you're setting); prefer it in tests and doc examples to keep + the focus on the values that matter for the test. +- **Import grouping** — three blocks separated by blank lines: + (1) `std`/`core`/`alloc`, (2) external crates, (3) + `crate::`/`super::`/`self::`. Enforced by nightly `cargo fmt` via + `rust/.rustfmt.nightly.toml`. +- **`pub(crate)` vs `pub`** — most modules in `lib.rs` are private (`mod`), so + `pub` items inside them are already crate-private. Use `pub(crate)` only when + you want to be explicit that an item must not become part of the public API. + +## Testing + +- **No mock testing.** Depend on real implementations, spin up lightweight + versions (e.g. `MockServer` in tests), or restructure code so the logic + under test takes its dependency's output as input. +- `assert_eq!(actual, expected)` — actual first, for readable diffs. +- Tests at end of file: `#[cfg(test)] mod tests`. Never place production code + after the test module. +- Keep tests concurrent-safe — unique temp dirs (`tempfile::tempdir()`), + unique data, no global state. +- `ApproveAllHandler` is the standard test handler for sessions that don't + exercise permission logic — see `rust/src/handler.rs:174`. + +## Cross-platform + +The SDK ships on macOS, Windows, and Linux; CI exercises all three. Construct +paths with `Path::join` rather than string concatenation — `/` and `\` are not +interchangeable, and string equality breaks on Windows UNC paths. Log paths +with `path.display()`; serialize with `to_string_lossy()` only when you need a +`String`. + +Process spawning needs care. The SDK applies `CREATE_NO_WINDOW` on Windows +when launching the CLI (see `Client::build_command`); preserve that if you +touch process spawning. Subprocess stdout often contains `\r` on Windows — strip +or split on `\r?\n` rather than assuming `\n`. + +Tests must use `tempfile::tempdir()`, never hardcoded `/tmp/`, and any test +that asserts on a path string needs to normalize separators or use +`std::path::MAIN_SEPARATOR`. + +## Build speed + +Specify Tokio features explicitly — never `features = ["full"]`. Iterate with +`cargo check`; reach for `cargo build` only when you need the binary. Audit +new dependency feature flags with `cargo tree` before committing. + +## Comments + +Explain **why**, never **what**. No comments that restate code. No decorative +banners (`// ── Section ────────`). + +**Never compare to other SDKs in code comments or rustdoc.** Don't write +"Mirrors Node's `Foo`", "Like Go's `Bar`", "Unlike Python's `Baz`", or include +file/line citations into other SDKs (`nodejs/src/types.ts:1592`, `go/types.go:14`). +The Rust SDK seeks parity with the Node, Python, Go, and .NET SDKs, and that +fact is stated once at the top of `rust/README.md`. Intentional divergences +live in the README's "Differences From Other SDKs" section. Repeating the +relationship per-symbol is unscalable, drifts as the other SDKs evolve, and +adds noise to consumer-facing rustdoc — Rust users care about the Rust API, +not its lineage. Self-references within the Rust crate (e.g. "Mirrors +[`from_streams`] but adds…") are fine. + +## Toolchain + +The SDK is pinned to `rust 1.94.0` via `rust/rust-toolchain.toml`. Formatting +uses nightly (`nightly-2026-04-14`) so unstable rustfmt options like grouped +imports work — see `rust/.rustfmt.nightly.toml`. CI runs: + +```bash +cd rust +cargo +nightly-2026-04-14 fmt --check +cargo clippy --all-features --all-targets -- -D warnings +cargo test --all-features +``` + +Match those exact commands locally before pushing. + +## Codegen + +JSON-RPC and session-event types are generated from the Copilot CLI schema: + +| Source | Output | +|---|---| +| `nodejs/node_modules/@github/copilot/schemas/api.schema.json` | `rust/src/generated/api_types.rs` | +| `nodejs/node_modules/@github/copilot/schemas/session-events.schema.json` | `rust/src/generated/session_events.rs` | + +Regenerate with: + +```bash +cd scripts/codegen && npm run generate:rust +``` + +Never hand-edit files under `rust/src/generated/`. If a generated type needs a +field the schema lacks, hand-author the user-facing type in `rust/src/types.rs` +and stop re-exporting the generated one. diff --git a/.github/skills/rust-coding-skill/examples.md b/.github/skills/rust-coding-skill/examples.md new file mode 100644 index 000000000..ef4d7b1a1 --- /dev/null +++ b/.github/skills/rust-coding-skill/examples.md @@ -0,0 +1,170 @@ +# Rust Coding Skill — Examples + +Patterns specific to the Rust SDK in this repo (`rust/`) that aren't obvious +from general Rust knowledge. + +## Defining a tool + +### Anti-pattern — building the wire payload by hand + +```rust +let raw = serde_json::json!({ + "name": "get_weather", + "description": "...", + "parameters": { "type": "object", ... }, +}); +config.tools = Some(vec![serde_json::from_value(raw)?]); +``` + +### Preferred — implement `ToolHandler`, route via `ToolHandlerRouter` + +```rust +use copilot::tool::{Tool, ToolHandler, ToolHandlerRouter, ToolInvocation, ToolResult}; +use copilot::Error; + +struct GetWeatherTool; + +#[async_trait::async_trait] +impl ToolHandler for GetWeatherTool { + fn tool(&self) -> Tool { + Tool { + name: "get_weather".to_string(), + description: "Get the current weather for a city.".to_string(), + // ..Default::default() — leaves namespaced_name, instructions, + // overrides_built_in_tool, skip_permission at their defaults. + ..Default::default() + } + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + // ... + Ok(ToolResult::Text("...".into())) + } +} + +use copilot::handler::ApproveAllHandler; +use std::sync::Arc; + +let router = ToolHandlerRouter::new( + vec![Box::new(GetWeatherTool)], + Arc::new(ApproveAllHandler), +); +``` + +## Spans for spawned event loops + +The session event loop is spawned per session. Always attach a span so events +emitted inside it correlate. + +### Anti-pattern — losing parent context + +```rust +tokio::spawn(async move { + while let Some(event) = rx.recv().await { + info!("event {:?}", event); // No span — can't filter by session + } +}); +``` + +### Preferred — `error_span!` + `.instrument()` + +```rust +use tracing::Instrument; + +let span = tracing::error_span!("session_event_loop", session_id = %id); +tokio::spawn(async move { + while let Some(event) = rx.recv().await { + info!(event_type = ?event.kind, "session event"); + } +}.instrument(span)); +``` + +## Concurrent permission handlers + +`HandlerEvent::PermissionRequest` and `HandlerEvent::ExternalTool` are dispatched +on spawned tasks (see `rust/src/session.rs:973` and `:1022`). Implementations +must be safe for concurrent invocation. + +### Anti-pattern — non-`Send` mutable state in the handler + +```rust +struct MyHandler { + last_request: std::cell::RefCell>, // not thread-safe +} +``` + +### Preferred — `parking_lot::Mutex` or atomics + +```rust +struct MyHandler { + last_request: parking_lot::Mutex>, +} +``` + +## Adding a field to a public struct + +Adding a field to a public, non-exhaustive struct is a breaking change because +existing callers' struct literals stop compiling. Two patterns soften this: + +### Pattern 1 — `Default` + `..Default::default()` in docs + +```rust +#[derive(Default)] +pub struct Tool { + pub name: String, + pub description: String, + // new field + pub overrides_built_in_tool: bool, +} + +// In docs and examples: +let t = Tool { + name: "x".into(), + description: "y".into(), + ..Default::default() +}; +``` + +### Pattern 2 — `#[non_exhaustive]` for types callers shouldn't construct + +Use sparingly — only for types that are *only* meant to be received from the +SDK, never built by users. + +```rust +#[non_exhaustive] +pub struct CreateSessionResult { + pub session_id: SessionId, + // ... +} +``` + +## Test handler for non-permission scenarios + +When a test doesn't exercise the permission flow, use the SDK's built-in +`ApproveAllHandler` instead of writing a custom one: + +```rust +use copilot::handler::ApproveAllHandler; +use copilot::types::SessionConfig; +use std::sync::Arc; + +let session = client + .create_session(SessionConfig::default().with_handler(Arc::new(ApproveAllHandler))) + .await?; +``` + +## Regenerating types after a schema bump + +```bash +# 1. Update schema (usually arrives with @github/copilot package update) +cd nodejs && npm install @github/copilot@latest && cd .. + +# 2. Regenerate Rust types +cd scripts/codegen && npm run generate:rust + +# 3. Verify +cd ../../rust && cargo check --all-features +``` + +If a generated type changes shape, hand-fix any user-facing wrappers in +`rust/src/types.rs` rather than monkey-patching the generated file. diff --git a/.github/workflows/codegen-check.yml b/.github/workflows/codegen-check.yml index 9fd7f0542..d48b6a491 100644 --- a/.github/workflows/codegen-check.yml +++ b/.github/workflows/codegen-check.yml @@ -13,6 +13,7 @@ on: - 'python/copilot/generated/**' - 'go/generated_*.go' - 'go/rpc/**' + - 'rust/src/generated/**' - '.github/workflows/codegen-check.yml' workflow_dispatch: @@ -34,6 +35,24 @@ jobs: with: go-version: '1.22' + # Rust generator runs `cargo fmt` on the output, so we need a toolchain with rustfmt. + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + components: rustfmt + + # Nightly rustfmt for unstable format options (group_imports, + # imports_granularity, reorder_impl_items) — pinned in + # `rust/.rustfmt.nightly.toml`. The Rust generator emits unconsolidated + # imports under stable rustfmt; nightly fmt consolidates them to match + # the canonical committed form. + - name: Install nightly rustfmt + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2026-04-14 + components: rustfmt + - name: Install nodejs SDK dependencies working-directory: ./nodejs run: npm ci @@ -46,6 +65,10 @@ jobs: working-directory: ./scripts/codegen run: npm run generate + - name: Apply nightly rustfmt to generated Rust output + working-directory: ./rust + run: cargo +nightly-2026-04-14 fmt --all -- --config-path .rustfmt.nightly.toml + - name: Check for uncommitted changes run: | if [ -n "$(git status --porcelain)" ]; then diff --git a/.github/workflows/rust-publish-release.yml b/.github/workflows/rust-publish-release.yml new file mode 100644 index 000000000..348d2acf0 --- /dev/null +++ b/.github/workflows/rust-publish-release.yml @@ -0,0 +1,54 @@ +name: "Rust SDK: Publish Release" + +# Publishes the `copilot-sdk` crate to crates.io when a release-plz +# version-bump PR is merged to `main`. See rust/RELEASING.md for the +# full release process and one-time setup (CARGO_REGISTRY_TOKEN, etc). + +on: + push: + branches: + - main + paths: + - 'rust/Cargo.toml' + - 'rust/Cargo.lock' + - 'rust/release-plz.toml' + workflow_dispatch: + +permissions: + contents: write + +concurrency: + group: rust-release-plz-publish + cancel-in-progress: false + +jobs: + publish: + name: Publish to crates.io + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./rust + steps: + - uses: actions/checkout@v6.0.2 + with: + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: "rust" + + - name: Run release-plz release + uses: release-plz/action@v0.5 + with: + command: release + manifest_path: rust/Cargo.toml + config: rust/release-plz.toml + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + CARGO_REGISTRY_TOKEN: ${{ secrets.CARGO_REGISTRY_TOKEN }} diff --git a/.github/workflows/rust-release-pr.yml b/.github/workflows/rust-release-pr.yml new file mode 100644 index 000000000..41420f3e4 --- /dev/null +++ b/.github/workflows/rust-release-pr.yml @@ -0,0 +1,56 @@ +name: "Rust SDK: Create Release PR" + +# release-plz opens a PR that bumps the `copilot-sdk` version in +# `rust/Cargo.toml` and updates `rust/CHANGELOG.md` based on +# conventional-commit history since the last `rust-vX.Y.Z` tag. +# +# Review and merge that PR on the maintainer's schedule. Publishing to +# crates.io happens separately in `rust-publish-release.yml` once the +# version bump lands on `main`. +# +# Runs manually only — we don't want a PR to race with every push. + +on: + workflow_dispatch: + +permissions: + contents: write + pull-requests: write + +concurrency: + group: rust-release-plz-pr + cancel-in-progress: false + +jobs: + release-pr: + name: Create Release PR + runs-on: ubuntu-latest + defaults: + run: + working-directory: ./rust + steps: + - uses: actions/checkout@v6.0.2 + with: + fetch-depth: 0 + token: ${{ secrets.GITHUB_TOKEN }} + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: "rust" + + - name: Run release-plz release-pr + uses: release-plz/action@v0.5 + with: + command: release-pr + manifest_path: rust/Cargo.toml + config: rust/release-plz.toml + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + # CARGO_REGISTRY_TOKEN is not required for release-pr (no publish), + # but release-plz inspects the crate on crates.io to compute the + # next version. Public crate inspection doesn't need auth. diff --git a/.github/workflows/rust-sdk-tests.yml b/.github/workflows/rust-sdk-tests.yml new file mode 100644 index 000000000..201841784 --- /dev/null +++ b/.github/workflows/rust-sdk-tests.yml @@ -0,0 +1,170 @@ +name: "Rust SDK Tests" + +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + paths: + - 'rust/**' + - 'test/**' + - 'nodejs/package.json' + - '.github/workflows/rust-sdk-tests.yml' + - '.github/actions/setup-copilot/**' + - '!**/*.md' + - '!**/LICENSE*' + - '!**/.gitignore' + - '!**/.editorconfig' + - '!**/*.png' + - '!**/*.jpg' + - '!**/*.jpeg' + - '!**/*.gif' + - '!**/*.svg' + workflow_dispatch: + merge_group: + +permissions: + contents: read + +jobs: + test: + name: "Rust SDK Tests" + env: + POWERSHELL_UPDATECHECK: Off + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash + working-directory: ./rust + steps: + - uses: actions/checkout@v6.0.2 + + - uses: ./.github/actions/setup-copilot + id: setup-copilot + + # rust-toolchain.toml in rust/ pins the stable channel + components. + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + components: rustfmt, clippy + + # Nightly rustfmt for unstable format options (group_imports, + # imports_granularity, reorder_impl_items) — pinned in + # `.rustfmt.nightly.toml`. + - name: Install nightly rustfmt + if: runner.os == 'Linux' + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2026-04-14 + components: rustfmt + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: "rust" + + - name: cargo fmt --check (nightly) + if: runner.os == 'Linux' + run: cargo +nightly-2026-04-14 fmt --all -- --config-path .rustfmt.nightly.toml --check + + - name: cargo clippy + if: runner.os == 'Linux' + run: cargo clippy --all-targets --features test-support -- --no-deps -D warnings -D clippy::unwrap_used -D clippy::disallowed_macros -D clippy::await_holding_invalid_type + + - name: cargo doc + if: runner.os == 'Linux' + env: + RUSTDOCFLAGS: "-D warnings" + run: cargo doc --no-deps --all-features + + - name: Install test harness dependencies + working-directory: ./test/harness + run: npm ci --ignore-scripts + + - name: Warm up PowerShell + if: runner.os == 'Windows' + run: pwsh.exe -Command "Write-Host 'PowerShell ready'" + + - name: cargo test + env: + COPILOT_HMAC_KEY: ${{ secrets.COPILOT_DEVELOPER_CLI_INTEGRATION_HMAC_KEY }} + COPILOT_CLI_PATH: ${{ steps.setup-copilot.outputs.cli-path }} + run: cargo test --features test-support + + # Detects accidental public-API breakage against the crate's last + # published version on crates.io. Non-blocking until the crate has + # a first published release — once a 0.1.0 ships, flip + # `continue-on-error` to `false` to enforce SemVer. + - name: cargo semver-checks + if: runner.os == 'Linux' + continue-on-error: true + uses: obi1kenobi/cargo-semver-checks-action@v2 + with: + package: github-copilot-sdk + manifest-path: rust/Cargo.toml + + # Validates the `embedded-cli` build path on all three supported + # platforms. This is the only place `build.rs` actually runs (the + # default `cargo test` job above has `COPILOT_CLI_VERSION` unset, so + # `build.rs` returns immediately). Catches regressions in the + # download / verify / extract / embed pipeline before they ship to + # crates.io and before bundling consumers (e.g. github-app's + # bundled-CLI release pipeline) hit them downstream. + bundle: + name: "Rust SDK Bundled CLI Build" + env: + CARGO_TERM_COLOR: always + RUST_BACKTRACE: 1 + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash + working-directory: ./rust + steps: + - uses: actions/checkout@v6.0.2 + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: "rust" + key: bundled-cli + + - name: Read pinned @github/copilot CLI version + id: cli-version + working-directory: ./nodejs + run: | + version=$(node -p "require('./package.json').dependencies['@github/copilot'].replace(/^[\^~]/, '')") + echo "version=$version" >> "$GITHUB_OUTPUT" + echo "Pinned CLI version: $version" + + # Cache the downloaded archive across runs so we don't refetch + # ~130 MB on every CI invocation. Keyed by OS + CLI version; on + # cache miss the bundle job exercises the full ureq download + + # SHA-256 + retry path, which is exactly the regression surface + # we want validated. + - name: Cache bundled CLI tarball + uses: actions/cache@v4 + with: + path: ./rust/.bundled-cli-cache + key: bundled-cli-${{ matrix.os }}-${{ steps.cli-version.outputs.version }} + + - name: cargo build --features embedded-cli + env: + COPILOT_CLI_VERSION: ${{ steps.cli-version.outputs.version }} + BUNDLED_CLI_CACHE_DIR: ${{ github.workspace }}/rust/.bundled-cli-cache + run: cargo build --features embedded-cli diff --git a/.github/workflows/scenario-builds.yml b/.github/workflows/scenario-builds.yml index ae368075c..923560aba 100644 --- a/.github/workflows/scenario-builds.yml +++ b/.github/workflows/scenario-builds.yml @@ -9,6 +9,8 @@ on: - "python/copilot/**" - "go/**/*.go" - "dotnet/src/**" + - "rust/src/**" + - "rust/Cargo.toml" - ".github/workflows/scenario-builds.yml" push: branches: @@ -185,3 +187,46 @@ jobs: echo -e "Failures:$FAILURES" exit 1 fi + + # ── Rust ──────────────────────────────────────────────────────────── + build-rust: + name: "Rust scenarios" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - uses: dtolnay/rust-toolchain@1.94.0 + + - uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + test/scenarios/**/rust/target + key: ${{ runner.os }}-cargo-scenarios-${{ hashFiles('rust/Cargo.toml', 'test/scenarios/**/rust/Cargo.toml') }} + restore-keys: | + ${{ runner.os }}-cargo-scenarios- + + - name: Build all Rust scenarios + run: | + PASS=0; FAIL=0; FAILURES="" + for manifest in $(find test/scenarios -path '*/rust/Cargo.toml' | sort); do + dir=$(dirname "$manifest") + scenario="${dir#test/scenarios/}" + echo "::group::$scenario" + if (cd "$dir" && cargo build --quiet 2>&1); then + echo "✅ $scenario" + PASS=$((PASS + 1)) + else + echo "❌ $scenario" + FAIL=$((FAIL + 1)) + FAILURES="$FAILURES\n $scenario" + fi + echo "::endgroup::" + done + echo "" + echo "Rust builds: $PASS passed, $FAIL failed" + if [ "$FAIL" -gt 0 ]; then + echo -e "Failures:$FAILURES" + exit 1 + fi diff --git a/.github/workflows/update-copilot-dependency.yml b/.github/workflows/update-copilot-dependency.yml index a39d0575e..05833bf73 100644 --- a/.github/workflows/update-copilot-dependency.yml +++ b/.github/workflows/update-copilot-dependency.yml @@ -40,6 +40,22 @@ jobs: with: dotnet-version: "10.0.x" + # Rust generator runs `cargo fmt` on its output under stable rustfmt; + # nightly rustfmt is needed for unstable format options (group_imports, + # imports_granularity, reorder_impl_items) pinned in + # `rust/.rustfmt.nightly.toml`. See codegen-check.yml for the same step. + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@stable + with: + toolchain: "1.94.0" + components: rustfmt + + - name: Install nightly rustfmt + uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2026-04-14 + components: rustfmt + - name: Update @github/copilot in nodejs env: VERSION: ${{ inputs.version }} @@ -68,6 +84,7 @@ jobs: run: | cd nodejs && npx prettier --write "src/generated/**/*.ts" cd ../dotnet && dotnet format src/GitHub.Copilot.SDK.csproj + cd ../rust && cargo +nightly-2026-04-14 fmt --all -- --config-path .rustfmt.nightly.toml - name: Create pull request env: diff --git a/.gitignore b/.gitignore index a445051c6..ba3ebfcd0 100644 --- a/.gitignore +++ b/.gitignore @@ -3,5 +3,9 @@ docs/.validation/ .DS_Store +# Rust scenario build artifacts +test/scenarios/**/rust/target/ +test/scenarios/**/rust/Cargo.lock + # Visual Studio .vs/ diff --git a/justfile b/justfile index 5bb0ce0fa..ab97c1d3d 100644 --- a/justfile +++ b/justfile @@ -3,13 +3,13 @@ default: @just --list # Format all code across all languages -format: format-go format-python format-nodejs format-dotnet +format: format-go format-python format-nodejs format-dotnet format-rust # Lint all code across all languages -lint: lint-go lint-python lint-nodejs lint-dotnet +lint: lint-go lint-python lint-nodejs lint-dotnet lint-rust # Run tests for all languages -test: test-go test-python test-nodejs test-dotnet test-corrections +test: test-go test-python test-nodejs test-dotnet test-rust test-corrections # Format Go code format-go: @@ -71,6 +71,27 @@ test-dotnet: @echo "=== Testing .NET code ===" @cd dotnet && dotnet test test/GitHub.Copilot.SDK.Test.csproj +# Format Rust code (uses nightly for unstable formatting options) +format-rust: + @echo "=== Formatting Rust code ===" + @cd rust && cargo +nightly-2026-04-14 fmt --all -- --config-path .rustfmt.nightly.toml + +# Lint Rust code +lint-rust: + @echo "=== Linting Rust code ===" + @cd rust && cargo +nightly-2026-04-14 fmt --all -- --config-path .rustfmt.nightly.toml --check + @cd rust && cargo clippy --all-targets --features test-support -- --no-deps -D warnings -D clippy::unwrap_used -D clippy::disallowed_macros -D clippy::await_holding_invalid_type + +# Test Rust code +test-rust: + @echo "=== Testing Rust code ===" + @cd rust && cargo test --features test-support + +# Generate Rust types from JSON schemas +generate-rust: + @echo "=== Generating Rust types ===" + @cd scripts/codegen && npm run generate:rust + # Test correction collection scripts test-corrections: @echo "=== Testing correction scripts ===" diff --git a/nodejs/scripts/update-protocol-version.ts b/nodejs/scripts/update-protocol-version.ts index a18a560c7..ef3ac9a2f 100644 --- a/nodejs/scripts/update-protocol-version.ts +++ b/nodejs/scripts/update-protocol-version.ts @@ -117,4 +117,22 @@ internal static class SdkProtocolVersion fs.writeFileSync(path.join(rootDir, "dotnet", "src", "SdkProtocolVersion.cs"), csharpCode); console.log(" ✓ dotnet/src/SdkProtocolVersion.cs"); +// Generate Rust +const rustCode = `// Code generated by update-protocol-version.ts. DO NOT EDIT. + +//! The SDK protocol version. Must match the version expected by the +//! copilot-agent-runtime server. + +/// The SDK protocol version. +pub const SDK_PROTOCOL_VERSION: u32 = ${version}; + +/// Returns the SDK protocol version. +#[must_use] +pub const fn get_sdk_protocol_version() -> u32 { + SDK_PROTOCOL_VERSION +} +`; +fs.writeFileSync(path.join(rootDir, "rust", "src", "sdk_protocol_version.rs"), rustCode); +console.log(" ✓ rust/src/sdk_protocol_version.rs"); + console.log("Done!"); diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 000000000..c17da7f58 --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock.bak diff --git a/rust/.rustfmt.nightly.toml b/rust/.rustfmt.nightly.toml new file mode 100644 index 000000000..677b79658 --- /dev/null +++ b/rust/.rustfmt.nightly.toml @@ -0,0 +1,7 @@ +# These options are only available in nightly, but it should be fine to use nightly for just formatting. +group_imports = "StdExternalCrate" +imports_granularity = "Module" +reorder_impl_items = true + +# stable options +edition = "2024" diff --git a/rust/.rustfmt.toml b/rust/.rustfmt.toml new file mode 100644 index 000000000..f3fb29261 --- /dev/null +++ b/rust/.rustfmt.toml @@ -0,0 +1,15 @@ +# This is not yet in stable, so we should keep an eye on it and enable it when it is. +# https://rust-lang.github.io/rustfmt/?version=v1.4.32&search=#group_imports +# In the mean time it is commented out because it will cause warnings. +#group_imports = "StdExternalCrate" + +# This is not yet in stable, so we should keep an eye on it and enable it when it is. +# https://rust-lang.github.io/rustfmt/?version=v1.5.1&search=#imports_granularity +# In the mean time it is commented out because it will cause warnings. +#imports_granularity = "Module" + +# This is not yet in stable, so we should keep an eye on it and enable it when it is. +# https://rust-lang.github.io/rustfmt/?version=v1.4.36&search=order#reorder_impl_items +# In the mean time it is commented out because it will cause warnings. +#reorder_impl_items = true +edition = "2024" diff --git a/rust/CHANGELOG.md b/rust/CHANGELOG.md new file mode 100644 index 000000000..d9f8439aa --- /dev/null +++ b/rust/CHANGELOG.md @@ -0,0 +1,480 @@ +# Changelog + +All notable changes to the `github-copilot-sdk` crate will be documented in this file. + +The format follows [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +After 0.1.0 ships, [release-plz](https://release-plz.dev/) will prepend new +entries from conventional-commit history. The Unreleased entry below is +hand-curated so that crates.io readers get a usable summary of the public +surface on first publish, not a flat list of merge commits — release-plz +will rename `[Unreleased]` to `[0.1.0] - ` and add a fresh empty +`[Unreleased]` above it when it cuts the first release PR. + +## [Unreleased] + +Initial public release. Programmatic Rust access to the GitHub Copilot CLI +over JSON-RPC 2.0 (stdio or TCP), with handler-based event dispatch, typed +tool/permission/elicitation helpers, and runtime session management. + +This is a **technical preview**. The crate is pre-1.0 and the public API may +change in breaking ways before 1.0. The rendered docs on +[docs.rs](https://docs.rs/github-copilot-sdk) are the canonical reference for the +public surface. + +### Added + +#### Client lifecycle +- `Client::start` — spawn and manage a GitHub Copilot CLI child process. +- `Client::from_streams` — connect to a CLI server over caller-supplied + `AsyncRead`/`AsyncWrite` (testing, custom transports). +- `Client::stop` / `Client::force_stop` — graceful and immediate shutdown. +- `Client::state` returning `ConnectionState` (`Connecting`, `Connected`, + `Disconnecting`, `Disconnected`). +- `Client::subscribe_lifecycle` returning a `LifecycleSubscription` for + runtime observation of created / destroyed / foreground / background + events. Implements `tokio_stream::Stream` and offers an inherent + `recv()`; drop the value to unsubscribe. +- `Client::ping(message)` returning typed `PingResponse` and + `Client::verify_protocol_version` for handshake validation. +- `Client::list_sessions`, `get_session_metadata`, `delete_session`, + `get_last_session_id`, `get_foreground_session_id`, + `set_foreground_session_id`. +- `Client::list_models`, `get_status` (typed `GetStatusResponse`), + `get_auth_status` (typed `GetAuthStatusResponse`), `get_quota`, + `send_telemetry`. + +#### Sessions +- `Client::create_session` and `Client::resume_session` accepting + `SessionConfig` with handler, capabilities, system message, mode, model, + permission policy, working directory, and resume parameters. +- `Session::send` returning the assigned message ID for + correlation with later events. +- `Session::send_and_wait` for synchronous prompt → final-event flows. +- `Session::subscribe` returning an `EventSubscription` for observe-only + access to the session's event stream. Implements `tokio_stream::Stream` + and offers an inherent `recv()`; drop the value to unsubscribe. +- Mode + model controls: `get_mode` / `set_mode`, `get_model` / + `set_model(model, SetModelOptions)` with `reasoning_effort` and + `model_capabilities` overrides. +- Plan helpers: `read_plan`, `delete_plan`. +- Workspace helpers: `list_workspace_files`, `read_workspace_file`, + `create_workspace_file`, `cwd`, `remote_url`. +- UI primitives: `session.ui().elicitation()`, `confirm()`, `select()`, + `input()` — grouped under a `SessionUi` sub-API to mirror .NET / Python / + Go. +- `Session::log(message, LogOptions)` with optional severity and + ephemeral flag. +- `Session::send_telemetry`, `start_fleet`, `abort`, + `set_approve_all_permissions`, `set_name`. +- `Session::disconnect` (canonical) and `Session::destroy` (alias) + preserve on-disk session state for later resume. +- `Session::stop_event_loop` for shutting down the per-session loop. + +#### Handlers + helpers +- `SessionHandler` trait with default fallback impls for each event + (permissions, external tools, elicitation, plan-mode prompts). +- `ApproveAllHandler` / `DenyAllHandler` reference handlers. +- Permission policy helpers: `permission::approve_all`, + `permission::deny_all`, `permission::approve_if`, plus chainable + builders on `SessionConfig` (`approve_all_permissions`, + `deny_all_permissions`, `approve_if`). +- `PermissionResult` is `#[non_exhaustive]` and supports `Approved`, + `Denied`, `Deferred` (handler will resolve via + `handlePendingPermissionRequest` itself — notification path only; + direct RPC falls back to `Approved`), and + `Custom(serde_json::Value)` for response shapes beyond + `{ "kind": "approve-once" | "reject" }` (e.g. allowlist payloads). +- All extension-point and protocol-evolving public enums are + `#[non_exhaustive]` so future variants are additive (non-breaking): + `Error`, `ProtocolError`, `SessionError`, `Transport`, `Attachment`, + `ToolResult`, `ElicitationMode`, `InputFormat`, `GitHubReferenceType`, + `SessionLifecycleEventType`, plus the handler/hook event/response enums. + Closed taxonomies (`LogLevel`, `ConnectionState`, `CliProgram`) remain + exhaustive so callers benefit from compile-time exhaustiveness checks. +- Tool helpers: `tool::DefineTool`, `tool::tool_schema_for`, + `tool::ToolHandlerRouter`, derive support via `derive` feature. + `ToolHandlerRouter` overrides each `SessionHandler` per-event method + directly, so callers can use the narrow-typed entry points (e.g. + `router.on_external_tool(invocation).await -> ToolResult`) instead of + unwrapping a `HandlerResponse` from `on_event`. The default `on_event` + still routes correctly through the per-event methods, so legacy + callers are unaffected. +- Hooks API for instrumenting send/receive flows (`github_copilot_sdk::hooks`). +- `SessionHandler::on_auto_mode_switch` — typed handler for the CLI's + rate-limit-recovery prompt (`autoModeSwitch.request` JSON-RPC + callback, added in copilot-agent-runtime PR #7024). Returns a typed + [`AutoModeSwitchResponse`] enum with `Yes`, `YesAlways`, `No` + variants (`#[serde(rename_all = "snake_case")]`, wire values byte- + identical to the runtime's `"yes" | "yes_always" | "no"` schema). + Default impl declines (`No`); override only if your application + surfaces a UX for the prompt. `SessionConfig::request_auto_mode_switch` + and `ResumeSessionConfig::request_auto_mode_switch` default to + `Some(true)` so the CLI advertises the callback to the SDK out of the + box. **Cross-SDK divergence:** typed handler is Rust-only as of 0.1.0. + Node, Python, Go, and .NET observe the request as a raw JSON-RPC + callback today; parity ports for those SDKs are post-release follow-up + work. +- New session-event fields surfaced by the `@github/copilot ^1.0.39` + schema bump: + - `SessionErrorData.eligible_for_auto_switch: Option` — set on + `errorType: "rate_limit"` to signal the runtime will follow with an + `auto_mode_switch.requested` event. UI clients can suppress + duplicate rendering of the rate-limit error when they show their + own auto-mode-switch prompt. + - `SessionErrorData.error_code: Option` — fine-grained + upstream provider error code (e.g. + `"user_weekly_rate_limited"`, `"integration_rate_limited"`). + - `SessionModelChangeData.cause: Option` — + `"rate_limit_auto_switch"` for changes triggered by the + auto-mode-switch recovery path. Lets UI render contextual copy. + - `AutoModeSwitchRequestedData.retry_after_seconds: Option` — + seconds until the rate limit resets, when known. Clients can + render a humanized reset time alongside the prompt. (The request- + callback path's `retry_after_seconds` parameter on + [`SessionHandler::on_auto_mode_switch`](crate::handler::SessionHandler::on_auto_mode_switch) + uses `Option` for HTTP `Retry-After` `delta-seconds` + semantics.) + +#### Types +- Newtype `SessionId`, plus generated RPC types under `github_copilot_sdk::generated`. +- `LogLevel`, `LogOptions`, `SetModelOptions`, `PingResponse`, + `SessionLifecycleEvent`, `SessionLifecycleEventType`, `ConnectionState`, + `SessionTelemetryEvent`, `ServerTelemetryEvent`, `SystemMessageConfig`, + `MessageOptions`, `SectionOverride`, `Attachment`, + `InputFormat`, `InputOptions`. +- Strongly-typed `Error` and `ProtocolError` with `is_transport_failure` + classifier and `error_codes` constants. + +#### Typed RPC namespace +- `Client::rpc()` and `Session::rpc()` accessors exposing a generated, typed + view over the full GitHub Copilot CLI JSON-RPC API. Sub-namespaces mirror the + schema (e.g. `client.rpc().models().list()`, `session.rpc().workspaces() + .list_files()`, `session.rpc().agent().list()`, + `session.rpc().tasks().list()`). +- All hand-authored helpers (`list_workspace_files`, `read_plan`, `set_mode`, + `list_models`, `get_quota`, etc.) are now thin one-line delegations over + this namespace. Wire-method strings exist in exactly one place + (`generated/rpc.rs`), making typo bugs like the `session.workspace.*` + → `session.workspaces.*` regression structurally impossible. Public + helper signatures are unchanged. + +#### Configuration parity +- All remaining public configuration types are now `#[non_exhaustive]` + for forward-compatibility — adding fields post-1.0 is non-breaking on + consumers that construct via `Default::default()` plus field + assignment or the `with_*` builders. Affected: `SessionConfig`, + `ResumeSessionConfig`, `ClientOptions`, `ProviderConfig`, + `McpServerConfig`, `Tool`, `CustomAgentConfig`, + `InfiniteSessionConfig`, `SystemMessageConfig`, `ConnectionState`. + (`HookEvent`, `HookOutput`, `MessageOptions`, `TelemetryConfig`, + `SessionFsConfig`, `FsError`, `FileInfo`, `DirEntry`, `ToolInvocation`, + `Error`, `Transport`, `DeliveryMode` were already marked.) Callers + using exhaustive struct literals must switch to + `let mut x = Type::default(); x.field = ...;` or the available `with_*` + builders; `..Default::default()` no longer compiles for these types + outside the defining crate. +- `MessageOptions::mode` is now typed `Option` (was + `Option`). `DeliveryMode` is `#[non_exhaustive]` and serializes + to the wire strings `"enqueue"` (default) and `"immediate"`. The prior + rustdoc incorrectly described this field as a permission mode; the + field controls how the prompt is delivered relative to in-flight work. + `MessageOptions::with_mode` now takes `DeliveryMode` directly. Callers + that previously passed `"agent"` or `"autopilot"` were already silently + no-ops at the CLI level — switch to a `DeliveryMode` variant or omit + the field entirely. +- `SessionConfig::default()` and `ResumeSessionConfig::new()` now set the + four permission-flow flags (`request_user_input`, `request_permission`, + `request_exit_plan_mode`, `request_elicitation`) to `Some(true)` instead + of `None`. Mirrors Node's `client.ts` behavior of always advertising the + permission surface and deriving handler presence from the + `SessionHandler` impl. The default `DenyAllHandler` refuses all + permission requests so the wire surface is safe out-of-the-box; callers + that want the wire surface fully disabled set the flags explicitly to + `Some(false)`. +- `SessionListFilter` — typed filter for `Client::list_sessions` covering + `cwd`, `git_root`, `repository`, and `branch`. Replaces the prior + `Option` parameter. +- `McpServerConfig` tagged enum (`Stdio` / `Http` / `Sse`) with + `McpStdioServerConfig` and `McpHttpServerConfig` payload structs. + `SessionConfig::mcp_servers`, `ResumeSessionConfig::mcp_servers`, and + `CustomAgentConfig::mcp_servers` are now `Option>` instead of typeless `Value` maps. Stdio configurations + serialized by older callers (no explicit `type`, or `type: "local"`) are + accepted on the deserialize path. +- `PermissionRequestData` gains typed `kind: Option` + and `tool_call_id: Option` fields covering the eight CLI + permission categories (`shell`, `write`, `read`, `url`, `mcp`, + `custom-tool`, `memory`, `hook`); unknown values fall through to + `PermissionRequestKind::Unknown` for forward compatibility. The original + params object is still available via the existing `extra: Value` flatten. +- `PermissionResult` gains `UserNotAvailable` (sent as + `{ "kind": "user-not-available" }`) and `NoResult` (sent as + `{ "kind": "no-result" }`) variants for headless agents and explicit + fall-through-to-CLI-default responses. +- `Client::stop` cooperatively shuts down active sessions before killing + the CLI child: walks every session still registered with the client, + sends `session.destroy` for each, then kills the child. Errors from + per-session destroys and the terminal child-kill are collected into a + new `StopErrors` aggregate (`Result<(), StopErrors>`) instead of + short-circuiting on the first failure, mirroring the Node SDK's + `Error[]` return shape. `StopErrors` implements `std::error::Error` + and exposes `errors()` / `into_errors()` for inspection. Callers that + previously used `client.stop().await?` should switch to + `client.stop().await.ok();` (best-effort) or match on the aggregate. +- `ResumeSessionConfig::disable_resume: Option` — force-fail resume + if the session does not exist on disk, instead of silently starting a + new session. +- `SessionConfig` and `ResumeSessionConfig` gain six configuration knobs + matching the Node SDK shape (Bucket B.1): + - `session_id: Option` (SessionConfig only — required on + resume, where it remains `SessionId`) — supply a custom session ID + instead of letting the CLI generate one. + - `working_directory: Option` — per-session cwd override, + independent of [`ClientOptions::cwd`](crate::ClientOptions::cwd). + - `config_dir: Option` — override the default configuration + directory location for this session. + - `model_capabilities: Option` — per-property + overrides for model capabilities, deep-merged over runtime defaults. + The same type was previously available only on + `SetModelOptions::model_capabilities`. + - `github_token: Option` — per-session GitHub token. Distinct + from [`ClientOptions::github_token`], which authenticates the CLI + process; this token determines the GitHub identity used for content + exclusion, model routing, and quota checks for this session. The + field is redacted from the `Debug` output. + - `include_sub_agent_streaming_events: Option` — forward streaming + delta events from sub-agents to this connection (Node default: true). +- `ClientOptions` gains the simple subset of Node's + `CopilotClientOptions` knobs (Bucket B.2): + - `log_level: Option` — typed enum (`None`, `Error`, `Warning`, + `Info`, `Debug`, `All`) replacing the previously hard-coded + `--log-level info` argument. When unset, the SDK still passes + `--log-level info` for parity with prior behavior. + - `session_idle_timeout_seconds: Option` — server-wide idle + timeout for sessions in seconds. When `Some(n)` with `n > 0`, the + SDK passes `--session-idle-timeout `. `None` or `Some(0)` leaves + sessions running indefinitely (the CLI default). + - The Node knob `isChildProcess` (sub-CLI parent-stdio mode) and + `autoStart` (lazy-init pattern) are intentionally **not** ported — + `isChildProcess` requires a transport variant the Rust SDK does not + yet support; `autoStart` does not apply because [`Client::start`] is + a single explicit constructor rather than a deferred-init pattern. + - `on_list_models: Option>` — BYOK escape + hatch matching Node's `onListModels`. When set, [`Client::list_models`] + returns the handler's result without making a `models.list` RPC. + `ListModelsHandler` is a new public `async_trait` (mirrors the shape + of `SessionHandler` / `SessionHooks`) with a single + `async fn list_models(&self) -> Result, Error>` method. + `ClientOptions` switched from `#[derive(Debug)]` to a manual `Debug` + impl that prints the handler as `` / `None` (same precedent as + `SessionConfig::handler` and `github_token`). +- `MessageOptions` gains `request_headers: Option>` + with a corresponding [`MessageOptions::with_request_headers`] builder + method, matching Node's `MessageOptions.requestHeaders` and Go's + `MessageOptions.RequestHeaders`. Custom HTTP headers are forwarded to + the CLI via the `requestHeaders` field on `session.send`. The field is + omitted from the wire when `None` or empty (matches Node's + `omitempty` semantics). +- Slash command registration: new [`CommandHandler`] async trait, + [`CommandDefinition`] (with `new`/`with_description` builders), and + [`CommandContext`] (`session_id`, `command`, `command_name`, `args`) + hand-authored in `crate::types`. `SessionConfig::commands` and + `ResumeSessionConfig::commands` accept a `Vec` via + the new `with_commands` builder, matching Node's + `SessionConfig.commands`, Python's `SessionConfig.commands`, and Go's + `SessionConfig.Commands`. The SDK serializes only `{name, description?}` + on the wire (handlers stay client-side), and dispatches incoming + `command.execute` events to the registered handler — acking with no + error on success, `error: ` on `Err`, and + `error: "Unknown command: "` when the name is unregistered. + `CommandContext` and `CommandDefinition` are `#[non_exhaustive]` so + forward-compatible fields (e.g. aliases, completion providers) can land + without breaking callers. +- Custom session filesystem: new [`SessionFsProvider`] async trait, + [`SessionFsConfig`], [`FsError`], [`FileInfo`], [`DirEntry`], + [`DirEntryKind`], and [`SessionFsConventions`] in `crate::session_fs` + (also re-exported from `crate::types`). When [`ClientOptions::session_fs`] + is set, [`Client::start`] calls `sessionFs.setProvider` on the CLI to + delegate per-session filesystem operations to a provider supplied via + [`SessionConfig::with_session_fs_provider`] / + [`ResumeSessionConfig::with_session_fs_provider`]. Inbound `sessionFs.*` + requests dispatch to the provider; `FsError::NotFound` maps to the wire + `ENOENT` code and other `FsError` values map to `UNKNOWN`. + `From` is provided so handlers backed by `std::fs` / + `tokio::fs` can propagate errors with `?`. All trait methods have + default implementations returning `Err(FsError::Other("not supported"))`, + so providers only override the methods they need and forward-compatible + schema additions land without breaking existing implementations. + Diverges from Node/Python/Go's factory-closure pattern in favor of + direct `Arc` registration. +- W3C Trace Context propagation: new [`TraceContext`] struct and + [`TraceContextProvider`] async trait in `crate::trace_context` (also + re-exported from `crate::types`). Hybrid shape combines Node's + callback-based `onGetTraceContext` and Go's per-turn + `MessageOptions.Traceparent` / `Tracestate`: + [`ClientOptions::on_get_trace_context`] supplies an ambient provider that + injects `traceparent` / `tracestate` on `session.create`, + `session.resume`, and `session.send`, while + [`MessageOptions::with_traceparent`], [`MessageOptions::with_tracestate`], + and [`MessageOptions::with_trace_context`] override per-turn (override + wins; provider is not invoked when MessageOptions carries trace headers). + [`ToolInvocation`] is now `#[non_exhaustive]` and exposes inbound + `traceparent` / `tracestate` populated from `external_tool.requested` + events, plus a [`ToolInvocation::trace_context`] helper. Wire fields are + omitted when unset (matches Node/Go `omitempty` semantics). +- `ToolInvocation` and `SessionId` now derive `Default`. Production code + never constructs `ToolInvocation` literals (it's a CLI-emitted read-only + type), but downstream test scaffolding can now use + `ToolInvocation { tool_name: "...".into(), ..Default::default() }` and + absorb future `#[non_exhaustive]` field additions automatically. +- OpenTelemetry env-var passthrough: new [`TelemetryConfig`] struct and + [`OtelExporterType`] enum (both `#[non_exhaustive]`), wired on + [`ClientOptions::telemetry`]. When `Some(...)`, the SDK injects + `COPILOT_OTEL_ENABLED=true` plus `OTEL_EXPORTER_OTLP_ENDPOINT`, + `COPILOT_OTEL_FILE_EXPORTER_PATH`, `COPILOT_OTEL_EXPORTER_TYPE`, + `COPILOT_OTEL_SOURCE_NAME`, and + `OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT` into the spawned CLI + process — verbatim env-var names matching Node/Python/Go. Pure + passthrough: no `opentelemetry-rust` dependency; the CLI itself owns the + exporter. `exporter_type` is a typed enum (`OtlpHttp` / `File`) following + the [`LogLevel`](LogLevel) precedent for finite, enumerated CLI knobs; + serialized verbatim as `"otlp-http"` / `"file"`. User-supplied + `ClientOptions::env` continues to win over telemetry-injected values. + +### Documentation +- `README.md` with quickstart, architecture diagram, and feature matrix. +- Examples under `examples/`: `chat`, `hooks`, `tool_server`, + `lifecycle_observer`. +- `RELEASING.md` operational runbook for maintainers. + +#### Builder ergonomics +- `ClientOptions::new()` plus a chainable `with_*` builder per public + field (`with_program`, `with_prefix_args`, `with_cwd`, `with_env`, + `with_env_remove`, `with_extra_args`, `with_transport`, + `with_github_token`, `with_use_logged_in_user`, `with_log_level`, + `with_session_idle_timeout_seconds`, `with_list_models_handler`, + `with_session_fs`, `with_trace_context_provider`, `with_telemetry`). + Mirrors the existing [`MessageOptions::new`] / `with_*` shape and + closes the cross-crate ergonomics gap on `#[non_exhaustive]` — + external callers no longer need to write + `let mut opts = ClientOptions::default(); opts.field = ...;` for + every field they touch. Existing `ClientOptions::default()` and + mut-let-and-assign continue to work unchanged. +- `Tool::new(name)` plus `with_namespaced_name`, `with_description`, + `with_instructions`, `with_parameters`, `with_overrides_built_in_tool`, + `with_skip_permission` for tool definitions. Same rationale — + `Tool` is the most-instantiated `#[non_exhaustive]` type at consumer + call sites in real-world consumer code, where the + builder shape replaces the per-consumer `make_tool(name, desc, + params)` helper that consumers were writing to smooth over the + mut-let pattern. +- Per-field `with_*` builder methods on `SessionConfig` and + `ResumeSessionConfig` covering every public scalar, vector, and + optional-struct field (~30 new methods on each). Mirrors the + `ClientOptions` / `Tool` shape; existing closure-installing + chains (`with_handler`, `with_hooks`, `with_transform`, + `with_commands`, `with_session_fs_provider`, + `approve_all_permissions`, etc.) continue to work unchanged. The + primary win: external session-construction sites collapse from + `let mut cfg = ResumeSessionConfig::new(id); cfg.client_name = + Some("...".into()); cfg.streaming = Some(true); ...` (10-15 + lines per site) to a single fluent chain. +- Round out builder coverage on the remaining consumer-facing + config structs: `CustomAgentConfig::new(name, prompt)` plus + `with_display_name`, `with_description`, `with_tools`, + `with_mcp_servers`, `with_infer`, `with_skills`; + `InfiniteSessionConfig::new()` plus `with_enabled`, + `with_background_compaction_threshold`, + `with_buffer_exhaustion_threshold`; + `ProviderConfig::new(base_url)` plus `with_provider_type`, + `with_wire_api`, `with_api_key`, `with_bearer_token`, + `with_azure`, `with_headers`; `SystemMessageConfig::new()` plus + `with_mode`, `with_content`, `with_sections`; + `TelemetryConfig::new()` plus `with_otlp_endpoint`, + `with_file_path`, `with_exporter_type`, `with_source_name`, + `with_capture_content`. `TraceContext` also gains a symmetric + `new()` + `with_traceparent` pair alongside the existing + `from_traceparent` shorthand. +- Documented the direct-field-assignment escape hatch on + `SessionConfig` and `ResumeSessionConfig` for callers forwarding + `Option` values from upstream code (matches the + `http::request::Parts` / `hyper::Body::Builder` convention; per- + field `with_*_opt` setters intentionally omitted to keep the + primary API surface small). + +#### Build infrastructure +- `build.rs` no longer shells out to `curl` for the bundled-CLI + download. The `embedded-cli` feature now downloads the + `SHA256SUMS.txt` and platform tarball through `ureq` (rustls TLS, + pure-Rust, no system dependencies). Removes the implicit `curl`- + on-PATH requirement that previously broke the build on minimal + Windows / container environments. Includes bounded retries with + exponential backoff (1s/2s/4s) on transient failures (5xx, + connect/read timeouts, transport errors) — 4xx responses still + fail fast as before. + +### Fixed +- `Session::user_input` no longer double-dispatches when the CLI sends + both a `user_input.requested` notification (for observers) and a + `userInput.request` JSON-RPC call (the actual prompt) for the same + prompt. The notification path is now a no-op; the JSON-RPC path + remains authoritative. Matches Python / Go / .NET / Node SDK + behavior, all of which only register the JSON-RPC handler. Fixes + github/github-app#4249, where consumers saw duplicate `ask_user` + and `exit_plan` widgets on every prompt. +- `SessionUi::elicitation` (and the `confirm` / `select` / `input` + convenience helpers that delegate through it) now sends the user-supplied + JSON Schema as `requestedSchema` on the wire, matching the + `session.ui.elicitation` request shape that all other SDKs ship and that + this crate's own generated `UIElicitationRequest` type expects. The + hand-authored convenience layer was sending it as `schema`, so every UI + helper call was effectively dead — the CLI saw a missing required + `requestedSchema` field. The mock-server test for elicitation + round-tripped through the same misnamed field, so the bug slipped past + unit tests; the test now asserts on `requestedSchema` and explicitly + rejects a stray `schema` key. +- `Client::list_sessions` now wraps the optional filter under `params.filter` + on the wire, matching the `session.list` request shape that Node, Python, + Go, and .NET ship. The hand-authored implementation was flattening the + filter fields directly onto `params`, which the runtime silently ignored + — so `list_sessions(Some(filter))` was functionally equivalent to + `list_sessions(None)` in 0.0.x. Same class of bug as the elicitation + wire fix above: the existing mock-server test asserted on the flat shape + it observed rather than the schema's wrapped shape, so the bug + round-tripped through both ends. The test now asserts the wrapped path + (`params.filter.repository`) and explicitly rejects the flattened + fallback (`params.repository`). +- `Client::get_status` and `Client::get_auth_status` now use the + correct wire method names (`status.get` and `auth.getStatus`) + matching Node, Go, Python, and .NET. The hand-authored + implementation was sending `getStatus` and `getAuthStatus` — names + that aren't registered on the CLI runtime — so both calls would + have returned a "method not found" error (or a misleading no-such- + method log) instead of the expected status payload. Same class of + bug as the elicitation `requestedSchema` and `list_sessions` + filter-wrapping fixes above: the mock-server test for these + methods asserted on the wrong-name strings the implementation + used, so the bugs round-tripped through both ends. The test now + asserts on the canonical wire names AND explicitly rejects the + hand-authored aliases (`assert_ne!(request["method"], "getStatus")` + / `"getAuthStatus"`). + +### Notes +- Minimum supported Rust version (MSRV): 1.94.0 (pinned via + `rust-toolchain.toml`). +- No `Client::actual_port` accessor — this SDK is strictly stream-based, + so the concept doesn't apply. See `Client::from_streams` rustdoc. +- `cargo semver-checks` runs in `continue-on-error` mode for 0.1.0; will + flip to blocking once 0.1.0 is published and serves as the baseline. +- `infinite_sessions: Option` is wired on both + `SessionConfig` and `ResumeSessionConfig` and follows the same + default-omit-on-the-wire semantics as Node/Go: when `None`, the field + is skipped and the CLI applies its own default. No behavioral + divergence from the other SDKs. +- `Client::stop` returns `Result<(), StopErrors>` and now cooperatively + shuts down each active session via `session.destroy` before killing + the CLI child, aggregating all per-session and child-kill errors into + the returned `StopErrors`. See the entry under "Configuration parity" + above for the migration note. diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 000000000..6f12279b8 --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,1775 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "adler2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" +dependencies = [ + "derive_arbitrary", +] + +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + +[[package]] +name = "bitflags" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" + +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bumpalo" +version = "3.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cc" +version = "1.2.61" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +dependencies = [ + "find-msvc-tools", + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "cpufeatures" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + +[[package]] +name = "crypto-common" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" +dependencies = [ + "generic-array", + "typenum", +] + +[[package]] +name = "derive_arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", +] + +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + +[[package]] +name = "displaydoc" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + +[[package]] +name = "equivalent" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f1f227452a390804cdb637b74a86990f2a7d7ba4b7d5693aac9b4dd6defd8d6" + +[[package]] +name = "filetime" +version = "0.2.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98844151eee8917efc50bd9e8318cb963ae8b297431495d3f758616ea5c57db" +dependencies = [ + "cfg-if", + "libc", + "libredox", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "flate2" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + +[[package]] +name = "form_urlencoded" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "futures-core" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" + +[[package]] +name = "futures-executor" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf29c38818342a3b26b5b923639e7b1f4a61fc5e76102d4b1981c6dc7a7579d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-sink" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" + +[[package]] +name = "futures-task" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" + +[[package]] +name = "futures-util" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" +dependencies = [ + "futures-core", + "futures-task", + "pin-project-lite", + "slab", +] + +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi 5.3.0", + "wasip2", +] + +[[package]] +name = "getrandom" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0de51e6874e94e7bf76d726fc5d13ba782deca734ff60d5bb2fb2607c7406555" +dependencies = [ + "cfg-if", + "libc", + "r-efi 6.0.0", + "wasip2", + "wasip3", +] + +[[package]] +name = "github-copilot-sdk" +version = "0.1.0" +dependencies = [ + "async-trait", + "dirs", + "flate2", + "parking_lot", + "regex", + "schemars", + "serde", + "serde_json", + "serial_test", + "sha2", + "tar", + "tempfile", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tracing", + "ureq", + "zip", + "zstd", +] + +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + +[[package]] +name = "hashbrown" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "icu_collections" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2984d1cd16c883d7935b9e07e44071dca8d917fd52ecc02c04d5fa0b5a3f191c" +dependencies = [ + "displaydoc", + "potential_utf", + "utf8_iter", + "yoke", + "zerofrom", + "zerovec", +] + +[[package]] +name = "icu_locale_core" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92219b62b3e2b4d88ac5119f8904c10f8f61bf7e95b640d25ba3075e6cac2c29" +dependencies = [ + "displaydoc", + "litemap", + "tinystr", + "writeable", + "zerovec", +] + +[[package]] +name = "icu_normalizer" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c56e5ee99d6e3d33bd91c5d85458b6005a22140021cc324cea84dd0e72cff3b4" +dependencies = [ + "icu_collections", + "icu_normalizer_data", + "icu_properties", + "icu_provider", + "smallvec", + "zerovec", +] + +[[package]] +name = "icu_normalizer_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da3be0ae77ea334f4da67c12f149704f19f81d1adf7c51cf482943e84a2bad38" + +[[package]] +name = "icu_properties" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bee3b67d0ea5c2cca5003417989af8996f8604e34fb9ddf96208a033901e70de" +dependencies = [ + "icu_collections", + "icu_locale_core", + "icu_properties_data", + "icu_provider", + "zerotrie", + "zerovec", +] + +[[package]] +name = "icu_properties_data" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e2bbb201e0c04f7b4b3e14382af113e17ba4f63e2c9d2ee626b720cbce54a14" + +[[package]] +name = "icu_provider" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139c4cf31c8b5f33d7e199446eff9c1e02decfc2f0eec2c8d71f65befa45b421" +dependencies = [ + "displaydoc", + "icu_locale_core", + "writeable", + "yoke", + "zerofrom", + "zerotrie", + "zerovec", +] + +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + +[[package]] +name = "idna" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" +dependencies = [ + "idna_adapter", + "smallvec", + "utf8_iter", +] + +[[package]] +name = "idna_adapter" +version = "1.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb68373c0d6620ef8105e855e7745e18b0d00d3bdb07fb532e434244cdb9a714" +dependencies = [ + "icu_normalizer", + "icu_properties", +] + +[[package]] +name = "indexmap" +version = "2.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" +dependencies = [ + "equivalent", + "hashbrown 0.17.0", + "serde", + "serde_core", +] + +[[package]] +name = "itoa" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" + +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + +[[package]] +name = "libc" +version = "0.2.186" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" + +[[package]] +name = "libredox" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e02f3bb43d335493c96bf3fd3a321600bf6bd07ed34bc64118e9293bdffea46c" +dependencies = [ + "bitflags", + "libc", + "plain", + "redox_syscall 0.7.4", +] + +[[package]] +name = "linux-raw-sys" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" + +[[package]] +name = "litemap" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92daf443525c4cce67b150400bc2316076100ce0b3686209eb8cf3c31612e6f0" + +[[package]] +name = "lock_api" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" +dependencies = [ + "scopeguard", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "miniz_oxide" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" +dependencies = [ + "adler2", + "simd-adler32", +] + +[[package]] +name = "mio" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "once_cell" +version = "1.21.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + +[[package]] +name = "parking_lot" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93857453250e3077bd71ff98b6a65ea6621a19bb0f559a85248955ac12c45a1a" +dependencies = [ + "lock_api", + "parking_lot_core", +] + +[[package]] +name = "parking_lot_core" +version = "0.9.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.5.18", + "smallvec", + "windows-link", +] + +[[package]] +name = "percent-encoding" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" + +[[package]] +name = "pin-project-lite" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a89322df9ebe1c1578d689c92318e070967d1042b512afbe49518723f4e6d5cd" + +[[package]] +name = "pkg-config" +version = "0.3.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19f132c84eca552bf34cab8ec81f1c1dcc229b811638f9d283dceabe58c5569e" + +[[package]] +name = "plain" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" + +[[package]] +name = "potential_utf" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0103b1cef7ec0cf76490e969665504990193874ea05c85ff9bab8b911d0a0564" +dependencies = [ + "zerovec", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41f2619966050689382d2b44f664f4bc593e129785a36d6ee376ddf37259b924" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "r-efi" +version = "6.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8dcc9c7d52a811697d2151c701e0d08956f92b0e24136cf4cf27b57a6a0d9bf" + +[[package]] +name = "redox_syscall" +version = "0.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_syscall" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f450ad9c3b1da563fb6948a8e0fb0fb9269711c9c73d9ea1de5058c79c8d643a" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_users" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" +dependencies = [ + "getrandom 0.2.17", + "libredox", + "thiserror 1.0.69", +] + +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.17", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustix" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef86cd5876211988985292b91c96a8f2d298df24e75989a43a3c73f2d4d8168b" +dependencies = [ + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30a7197ae7eb376e574fe940d068c30fe0462554a3ddbe4eca7838e049c937a9" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + +[[package]] +name = "scc" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46e6f046b7fef48e2660c57ed794263155d713de679057f2d0c169bfc6e756cc" +dependencies = [ + "sdd", +] + +[[package]] +name = "schemars" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +dependencies = [ + "dyn-clone", + "ref-cast", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d115b50f4aaeea07e79c1912f645c7513d81715d0420f8bc77a18c6260b307f" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + +[[package]] +name = "scopeguard" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" + +[[package]] +name = "sdd" +version = "3.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490dcfcbfef26be6800d11870ff2df8774fa6e86d047e3e8c8a76b25655e41ca" + +[[package]] +name = "semver" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a7852d02fc848982e0c167ef163aaff9cd91dc640ba85e263cb1ce46fae51cd" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "serial_test" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "911bd979bf1070a3f3aa7b691a3b3e9968f339ceeec89e08c280a8a22207a32f" +dependencies = [ + "futures-executor", + "futures-util", + "log", + "once_cell", + "parking_lot", + "scc", + "serial_test_derive", +] + +[[package]] +name = "serial_test_derive" +version = "3.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a7d91949b85b0d2fb687445e448b40d322b6b3e4af6b44a29b21d9a5f33e6d9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sha2" +version = "0.10.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "simd-adler32" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "703d5c7ef118737c72f1af64ad2f6f8c5e1921f818cdcb97b8fe6fc69bf66214" + +[[package]] +name = "slab" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c790de23124f9ab44544d7ac05d60440adc586479ce501c1d6d7da3cd8c9cf5" + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" + +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tar" +version = "0.4.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22692a6476a21fa75fdfc11d452fda482af402c008cdbaf3476414e122040973" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "tempfile" +version = "3.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" +dependencies = [ + "fastrand", + "getrandom 0.4.2", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52" +dependencies = [ + "thiserror-impl 1.0.69", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl 2.0.18", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.69" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinystr" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8323304221c2a851516f22236c5722a72eaa19749016521d6dff0824447d96d" +dependencies = [ + "displaydoc", + "zerovec", +] + +[[package]] +name = "tokio" +version = "1.52.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "385a6cb71ab9ab790c5fe8d67f1645e6c450a7ce006a33de03daa956cf70a496" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", + "tokio-util", +] + +[[package]] +name = "tokio-util" +version = "0.7.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ae9cec805b01e8fc3fd2fe289f89149a9b66dd16786abd8b19cfa7b48cb0098" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "typenum" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "url", + "webpki-roots 0.26.11", +] + +[[package]] +name = "url" +version = "2.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff67a8a4397373c3ef660812acab3268222035010ab8680ec4215f38ba3d0eed" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", + "serde", +] + +[[package]] +name = "utf8_iter" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.3+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20064672db26d7cdc89c7798c48a0fdfac8213434a1186e5ef29fd560ae223d6" +dependencies = [ + "wit-bindgen 0.57.1", +] + +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen 0.51.0", +] + +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.7", +] + +[[package]] +name = "webpki-roots" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52f5ee44c96cf55f1b349600768e3ece3a8f26010c05265ab73f945bb1a2eb9d" +dependencies = [ + "rustls-pki-types", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen" +version = "0.57.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ebf944e87a7c253233ad6766e082e3cd714b5d03812acc24c318f549614536e" + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck", + "indexmap", + "prettyplease", + "syn", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] + +[[package]] +name = "writeable" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ffae5123b2d3fc086436f8834ae3ab053a283cfac8fe0a0b8eaae044768a4c4" + +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix", +] + +[[package]] +name = "yoke" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abe8c5fda708d9ca3df187cae8bfb9ceda00dd96231bed36e445a1a48e66f9ca" +dependencies = [ + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de844c262c8848816172cef550288e7dc6c7b7814b4ee56b3e1553f275f1858e" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerofrom" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11532158c46691caf0f2593ea8358fed6bbf68a0315e80aae9bd41fbade684a1" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zerotrie" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f9152d31db0792fa83f70fb2f83148effb5c1f5b8c7686c3459e361d9bc20bf" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", +] + +[[package]] +name = "zerovec" +version = "0.11.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90f911cbc359ab6af17377d242225f4d75119aec87ea711a880987b18cd7b239" +dependencies = [ + "yoke", + "zerofrom", + "zerovec-derive", +] + +[[package]] +name = "zerovec-derive" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "625dc425cab0dca6dc3c3319506e6593dcb08a9f387ea3b284dbd52a92c40555" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zip" +version = "2.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabe6324e908f85a1c52063ce7aa26b68dcb7eb6dbc83a2d148403c9bc3eba50" +dependencies = [ + "arbitrary", + "crc32fast", + "crossbeam-utils", + "displaydoc", + "flate2", + "indexmap", + "memchr", + "thiserror 2.0.18", + "zopfli", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" + +[[package]] +name = "zopfli" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05cd8797d63865425ff89b5c4a48804f35ba0ce8d125800027ad6017d2b5249" +dependencies = [ + "bumpalo", + "crc32fast", + "log", + "simd-adler32", +] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 000000000..217a87cb7 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,69 @@ +[package] +name = "github-copilot-sdk" +version = "0.1.0" +edition = "2024" +rust-version = "1.94.0" +description = "Rust SDK for programmatic control of the GitHub Copilot CLI via JSON-RPC. Technical preview, pre-1.0." +keywords = ["copilot", "github", "ai", "json-rpc", "sdk"] +categories = ["api-bindings", "development-tools"] +repository = "https://github.com/github/copilot-sdk" +homepage = "https://github.com/github/copilot-sdk" +documentation = "https://docs.rs/github-copilot-sdk" +readme = "README.md" +license = "MIT" +exclude = [ + "RELEASING.md", + "release-plz.toml", + "rust-toolchain.toml", + ".rustfmt.toml", + ".rustfmt.nightly.toml", + "clippy.toml", + ".gitignore", +] + +[lib] +name = "github_copilot_sdk" + +[features] +default = [] +embedded-cli = ["dep:sha2", "dep:zstd"] +derive = ["dep:schemars"] +test-support = [] + +# Build docs.rs documentation with all features so feature-gated APIs +# (e.g. `define_tool`, `schema_for`) appear and intra-doc links resolve. +# Mirror this locally with: `cargo doc --no-deps --all-features`. +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[dependencies] +async-trait = "0.1" +schemars = { version = "1", optional = true } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "2" +tokio = { version = "1", features = ["io-util", "sync", "rt", "process", "net", "time", "macros"] } +tokio-stream = { version = "0.1", features = ["sync"] } +tracing = "0.1" +dirs = "5" +parking_lot = "0.12" +regex = "1" +sha2 = { version = "0.10", optional = true } +zstd = { version = "0.13", optional = true } + +[dev-dependencies] +schemars = "1" +serial_test = "3" +tempfile = "3" +sha2 = "0.10" +tokio = { version = "1", features = ["rt-multi-thread"] } +zstd = "0.13" + +[build-dependencies] +flate2 = "1" +sha2 = "0.10" +tar = "0.4" +ureq = { version = "2", default-features = false, features = ["tls"] } +zip = { version = "2", default-features = false, features = ["deflate"] } +zstd = "0.13" diff --git a/rust/LICENSE b/rust/LICENSE new file mode 120000 index 000000000..ea5b60640 --- /dev/null +++ b/rust/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/rust/README.md b/rust/README.md new file mode 100644 index 000000000..8222f8630 --- /dev/null +++ b/rust/README.md @@ -0,0 +1,800 @@ +# GitHub Copilot CLI SDK for Rust + +A Rust SDK for programmatic access to the GitHub Copilot CLI. + +> **Note:** This SDK is in technical preview and may change in breaking ways. + +See [github/copilot-sdk](https://github.com/github/copilot-sdk) for the equivalent SDKs in TypeScript, Python, Go, and .NET. The Rust SDK seeks parity with those SDKs; see [Differences From Other SDKs](#differences-from-other-sdks) below for the small set of intentional divergences. + +## Quick Start + +```rust,no_run +use std::sync::Arc; +use github_copilot_sdk::{Client, ClientOptions, SessionConfig}; +use github_copilot_sdk::handler::ApproveAllHandler; + +# async fn example() -> Result<(), github_copilot_sdk::Error> { +let client = Client::start(ClientOptions::default()).await?; +let session = client.create_session( + SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)), +).await?; +let _message_id = session.send("Hello!").await?; +session.disconnect().await?; +client.stop().await.ok(); +# Ok(()) +# } +``` + +## Architecture + +```text +Your Application + ↓ + github_copilot_sdk::Client (manages CLI process lifecycle) + ↓ + github_copilot_sdk::Session (per-session event loop + handler dispatch) + ↓ JSON-RPC over stdio or TCP + copilot --server --stdio +``` + +The SDK manages the CLI process lifecycle: spawning, health-checking, and graceful shutdown. Communication uses [JSON-RPC 2.0](https://www.jsonrpc.org/specification) over stdin/stdout with `Content-Length` framing (the same protocol used by LSP). TCP transport is also supported. + +## API Reference + +### Client + +```rust,ignore +// Start a client (spawns CLI process) +let client = Client::start(options).await?; + +// Create a new session +let session = client.create_session(config.with_handler(handler)).await?; + +// Resume an existing session +let session = client.resume_session(config.with_handler(handler)).await?; + +// Low-level RPC +let result = client.call("method.name", Some(params)).await?; +let response = client.send_request("method.name", Some(params)).await?; + +// Health check (echoes message back, returns typed PingResponse) +let pong = client.ping("hello").await?; + +// Shutdown +client.stop().await?; +``` + +**`ClientOptions`:** + +| Field | Type | Description | +|---|---|---| +| `program` | `CliProgram` | `Resolve` (default: auto-detect) or `Path(PathBuf)` (explicit) | +| `prefix_args` | `Vec` | Args before `--server` (e.g. script path for node) | +| `cwd` | `PathBuf` | Working directory for CLI process | +| `env` | `Vec<(OsString, OsString)>` | Environment variables for CLI process | +| `env_remove` | `Vec` | Environment variables to remove | +| `extra_args` | `Vec` | Extra CLI flags | +| `transport` | `Transport` | `Stdio` (default), `Tcp { port }`, or `External { host, port }` | + +With the default `CliProgram::Resolve`, `Client::start()` automatically resolves the binary via `github_copilot_sdk::resolve::copilot_binary()` — checking `COPILOT_CLI_PATH`, the [embedded CLI](#embedded-cli), and then the system PATH. Use `CliProgram::Path(path)` to skip resolution. + +### Session + +Created via `Client::create_session` or `Client::resume_session`. Owns an internal event loop that dispatches events to the `SessionHandler`. + +```rust,ignore +use github_copilot_sdk::MessageOptions; + +// Simple send — &str / String convert into MessageOptions automatically. +// Returns the assigned message ID for correlation with later events. +let _id = session.send("Fix the bug in auth.rs").await?; + +// Send with mode and attachments +let _id = session + .send( + MessageOptions::new("What's in this image?") + .with_mode("autopilot") + .with_attachments(attachments), + ) + .await?; + +// Message history +let messages = session.get_messages().await?; + +// Abort the current agent turn +session.abort().await?; + +// Model management +let model = session.get_model().await?; +session.set_model("claude-sonnet-4.5", None).await?; + +// Mode management (interactive, plan, autopilot) +let mode = session.get_mode().await?; +session.set_mode("autopilot").await?; + +// Workspace files +let files = session.list_workspace_files().await?; +let content = session.read_workspace_file("plan.md").await?; + +// Plan management +let (exists, content) = session.read_plan().await?; +session.update_plan("Updated plan content").await?; + +// Fleet (sub-agents) +session.start_fleet(Some("Implement the auth module")).await?; + +// Cleanup (preserves on-disk session state for later resume) +session.disconnect().await?; +``` + +#### Typed RPC namespace + +The ergonomic helpers above are convenience wrappers over a fully-typed +JSON-RPC namespace generated from the GitHub Copilot CLI schema. `Client::rpc()` +and `Session::rpc()` give direct access to every method on the wire, +including ones with no helper today, with strongly-typed request and +response structs. + +```rust,ignore +// Methods with helpers — wire strings live in one generated place. +let files = session.rpc().workspaces().list_files().await?.files; +let models = client.rpc().models().list().await?.models; + +// Methods with no helper — full schema-typed access. +let agents = session.rpc().agent().list().await?.agents; +let tasks = session.rpc().tasks().list().await?.tasks; +let forked = client + .rpc() + .sessions() + .fork(github_copilot_sdk::generated::api_types::SessionsForkRequest { + session_id: "session-id".to_string(), + from_message_id: None, + }) + .await?; +``` + +New RPCs land in the namespace immediately as the schema regenerates; +helpers are added on top only when an ergonomic story is worth the +maintenance. + +### SessionHandler + +Implement this trait to control how a session responds to CLI events. Two styles are supported: + +**1. Per-event methods (recommended).** Override only the callbacks you care about; every method has a safe default (permission → deny, user input → none, external tool → "no handler", elicitation → cancel, exit plan → default). This is the `serenity::EventHandler` pattern. + +```rust,ignore +use async_trait::async_trait; +use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionId}; + +struct MyHandler; + +#[async_trait] +impl SessionHandler for MyHandler { + async fn on_permission_request( + &self, + _sid: SessionId, + _rid: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + if data.extra.get("tool").and_then(|v| v.as_str()) == Some("view") { + PermissionResult::Approved + } else { + PermissionResult::Denied + } + } + + async fn on_session_event(&self, sid: SessionId, event: github_copilot_sdk::types::SessionEvent) { + println!("[{sid}] {}", event.event_type); + } +} +``` + +**2. Single `on_event` method.** Override `on_event` directly and `match` on `HandlerEvent` — useful for logging middleware, custom routing, or when you want one exhaustive dispatch point. + +```rust,ignore +use github_copilot_sdk::handler::*; +use async_trait::async_trait; + +#[async_trait] +impl SessionHandler for MyRouter { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::SessionEvent { session_id, event } => { + println!("[{session_id}] {}", event.event_type); + HandlerResponse::Ok + } + HandlerEvent::PermissionRequest { .. } => { + HandlerResponse::Permission(PermissionResult::Approved) + } + HandlerEvent::UserInput { question, .. } => { + HandlerResponse::UserInput(Some(UserInputResponse { + answer: prompt_user(&question), + was_freeform: true, + })) + } + _ => HandlerResponse::Ok, + } + } +} +``` + +The default `on_event` dispatches to the per-event methods, so overriding `on_event` short-circuits them entirely — pick one style per handler. + +Events are processed serially per session — blocking in a handler method pauses that session's event loop (which is correct, since the CLI is also waiting for the response). Other sessions are unaffected. + +> **Note:** Notification-triggered events (`PermissionRequest` via `permission.requested`, `ExternalTool` via `external_tool.requested`) are dispatched on spawned tasks and may run concurrently with the serial event loop. See the trait-level docs on `SessionHandler` for details. + +### SessionConfig + +```rust,ignore +let config = SessionConfig { + model: Some("gpt-5".into()), + system_message: Some(SystemMessageConfig { + content: Some("Always explain your reasoning.".into()), + ..Default::default() + }), + request_elicitation: Some(true), // enable elicitation provider + ..Default::default() +}; +let session = client.create_session(config.with_handler(handler)).await?; +``` + +### Session Hooks + +Hooks intercept CLI behavior at lifecycle points — tool use, prompt submission, session start/end, and errors. Install a `SessionHooks` impl with [`SessionConfig::with_hooks`] — the SDK auto-enables `hooks` in `SessionConfig` when one is set. + +```rust,ignore +use std::sync::Arc; +use github_copilot_sdk::hooks::*; +use async_trait::async_trait; + +struct MyHooks; + +#[async_trait] +impl SessionHooks for MyHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::PreToolUse { input, ctx } => { + if input.tool_name == "dangerous_tool" { + HookOutput::PreToolUse(PreToolUseOutput { + permission_decision: Some("deny".to_string()), + permission_decision_reason: Some("blocked by policy".to_string()), + ..Default::default() + }) + } else { + HookOutput::None // pass through + } + } + HookEvent::SessionStart { input, .. } => { + HookOutput::SessionStart(SessionStartOutput { + additional_context: Some("Extra system context".to_string()), + ..Default::default() + }) + } + _ => HookOutput::None, + } + } +} + +let session = client + .create_session( + config + .with_handler(handler) + .with_hooks(Arc::new(MyHooks)), + ) + .await?; +``` + +**Hook events:** `PreToolUse`, `PostToolUse`, `UserPromptSubmitted`, `SessionStart`, `SessionEnd`, `ErrorOccurred`. Each carries typed input/output structs. Return `HookOutput::None` for events you don't handle. + +### System Message Transforms + +Transforms customize system message sections during session creation. The SDK injects `action: "transform"` entries for each section ID your transform handles. + +```rust,ignore +use github_copilot_sdk::transforms::*; +use async_trait::async_trait; + +struct MyTransform; + +#[async_trait] +impl SystemMessageTransform for MyTransform { + fn section_ids(&self) -> Vec { + vec!["instructions".to_string()] + } + + async fn transform_section( + &self, + _section_id: &str, + content: &str, + _ctx: TransformContext, + ) -> Option { + Some(format!("{content}\n\nAlways be concise.")) + } +} + +let session = client + .create_session( + config + .with_handler(handler) + .with_transform(Arc::new(MyTransform)), + ) + .await?; +``` + +### Tool Registration + +Define client-side tools as named types with `ToolHandler`, then route them with `ToolHandlerRouter`. Enable the `derive` feature for `schema_for::()` — it generates JSON Schema from Rust types via `schemars`. + +```rust,ignore +use std::sync::Arc; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ + schema_for, tool_parameters, JsonSchema, ToolHandler, ToolHandlerRouter, +}; +use github_copilot_sdk::{Error, SessionConfig, Tool, ToolInvocation, ToolResult}; +use serde::Deserialize; +use async_trait::async_trait; + +#[derive(Deserialize, JsonSchema)] +struct GetWeatherParams { + /// City name + city: String, + /// Temperature unit + unit: Option, +} + +struct GetWeatherTool; + +#[async_trait] +impl ToolHandler for GetWeatherTool { + fn tool(&self) -> Tool { + Tool { + name: "get_weather".to_string(), + namespaced_name: None, + description: "Get weather for a city".to_string(), + parameters: tool_parameters(schema_for::()), + instructions: None, + } + } + + async fn call(&self, inv: ToolInvocation) -> Result { + let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; + Ok(ToolResult::Text(format!("Weather in {}: sunny", params.city))) + } +} + +// Build a router that dispatches tool calls by name +let router = ToolHandlerRouter::new( + vec![Box::new(GetWeatherTool)], + Arc::new(ApproveAllHandler), +); + +let config = SessionConfig { + tools: Some(router.tools()), + ..Default::default() +} +.with_handler(Arc::new(router)); +let session = client.create_session(config).await?; +``` + +Tools are named types (not closures) — visible in stack traces and navigable via "go to definition". The router implements `SessionHandler`, forwarding unrecognized tools and non-tool events to the inner handler. + +For trivial tools that don't need a named type, [`define_tool`](crate::tool::define_tool) collapses the definition to a single expression: + +```rust,ignore +use github_copilot_sdk::tool::{define_tool, JsonSchema, ToolHandlerRouter}; +use github_copilot_sdk::ToolResult; +use serde::Deserialize; + +#[derive(Deserialize, JsonSchema)] +struct GetWeatherParams { city: String } + +let router = ToolHandlerRouter::new( + vec![define_tool( + "get_weather", + "Get weather for a city", + |_inv, params: GetWeatherParams| async move { + Ok(ToolResult::Text(format!("Sunny in {}", params.city))) + }, + )], + Arc::new(ApproveAllHandler), +); +``` + +The closure receives the full [`ToolInvocation`](crate::types::ToolInvocation) alongside the deserialized parameters, so handlers that need `inv.session_id` or `inv.tool_call_id` for telemetry, streaming updates, or scoped lookups can use them directly. Use `_inv` when you don't need the metadata. + +Reach for the `ToolHandler` trait directly when you need shared state across multiple methods or want a named type that shows up by name in stack traces. + +### Permission Policies + +Set a permission policy directly on `SessionConfig` with the chainable builders. They wrap whatever handler you've installed (defaulting to `DenyAllHandler` if none) so only permission requests are intercepted; every other event flows through unchanged. + +```rust,ignore +let session = client + .create_session( + SessionConfig::default() + .with_handler(Arc::new(my_handler)) + .approve_all_permissions(), + // or .deny_all_permissions() + // or .approve_permissions_if(|data| { + // data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + // }) + ) + .await?; +``` + +> Call the policy method **after** `with_handler` — `with_handler` overwrites the handler field, so `approve_all_permissions().with_handler(...)` discards the wrap. + +For composing a policy onto a handler outside the builder chain (e.g. when wrapping a `ToolHandlerRouter` you've built elsewhere), the `permission` module exposes the same primitives as free functions: + +```rust,ignore +use github_copilot_sdk::permission; + +let router = ToolHandlerRouter::new(tools, Arc::new(MyHandler)); +let handler = permission::approve_all(Arc::new(router)); +// or permission::deny_all(...) / permission::approve_if(..., predicate) + +let session = client.create_session(config.with_handler(handler)).await?; +``` + +### Capabilities & Elicitation + +The SDK negotiates capabilities with the CLI after session creation. Enable elicitation to let the agent present structured UI dialogs (forms, URL prompts) to the user. + +```rust,ignore +let config = SessionConfig { + request_elicitation: Some(true), + ..Default::default() +}; +``` + +The handler receives `HandlerEvent::ElicitationRequest` with a message, optional JSON Schema for form fields, and an optional mode. Known modes include `Form` and `Url`, but the mode may be absent or an unknown future value. Return `HandlerResponse::Elicitation(result)`. + +### User Input Requests + +Some sessions ask the user free-form questions (or multiple-choice prompts) outside the elicitation flow. Implement `SessionHandler::on_user_input` and the SDK will forward `userInput.request` callbacks: + +```rust,ignore +async fn on_user_input( + &self, + _session_id: SessionId, + question: String, + choices: Option>, + _allow_freeform: Option, +) -> Option { + // Render `question` + `choices` to your UI, then: + Some(UserInputResponse { + answer: "Yes".to_string(), + was_freeform: false, + }) +} +``` + +Return `None` to signal "no answer available" (the CLI falls back to its own prompt). Enable via `SessionConfig::request_user_input` (defaults to `Some(true)`). + +### Slash Commands + +Register named commands so users can invoke them as `/name args` from the TUI: + +```rust,ignore +use github_copilot_sdk::types::{CommandContext, CommandDefinition, CommandHandler}; +use async_trait::async_trait; + +struct DeployCommand; + +#[async_trait] +impl CommandHandler for DeployCommand { + async fn on_command(&self, ctx: CommandContext) -> Result<(), github_copilot_sdk::Error> { + println!("deploy {}", ctx.args); + Ok(()) + } +} + +let mut config = SessionConfig::default(); +config.commands = Some(vec![ + CommandDefinition::new("deploy", Arc::new(DeployCommand)) + .with_description("Deploy the application"), +]); +``` + +Only `name` and `description` are sent over the wire; the handler stays in your process. Returning `Err(_)` surfaces the message back through the TUI. + +### Streaming + +Set `streaming: true` to receive incremental delta events alongside finalized messages: + +```rust,ignore +let mut config = SessionConfig::default(); +config.streaming = Some(true); + +let mut events = session.subscribe(); +while let Ok(event) = events.recv().await { + match event.event_type.as_str() { + "assistant.message_delta" | "assistant.reasoning_delta" => { + if let Some(d) = event.data.get("delta").and_then(|v| v.as_str()) { + print!("{d}"); + } + } + "assistant.message" => println!(), // final + _ => {} + } +} +``` + +When streaming is off (the default), only the final `assistant.message` and `assistant.reasoning` events fire. Delta events arrive in order; concatenating their `delta` text payloads reproduces the final message. + +### Infinite Sessions + +Enable the SDK's session-store integration so conversations persist across CLI restarts and grow beyond the model's context window via automatic compaction: + +```rust,ignore +use github_copilot_sdk::types::InfiniteSessionConfig; + +let mut infinite = InfiniteSessionConfig::default(); +infinite.workspace_path = Some("/path/to/workspace".into()); + +let mut config = SessionConfig::default(); +config.infinite_sessions = Some(infinite); +``` + +The CLI emits `session.compaction_start` / `session.compaction_complete` events around each compaction. The session id remains stable across compactions; resume with `Client::resume_session` to pick up a prior conversation. Workspace state lives under `~/.copilot/session-state/{sessionId}` by default — override with `workspace_path` to relocate. + +### Custom Providers (BYOK) + +Route model traffic through your own inference endpoint instead of GitHub's hosted models: + +```rust,ignore +use github_copilot_sdk::types::ProviderConfig; + +let mut provider = ProviderConfig::default(); +provider.provider_type = Some("openai".to_string()); +provider.base_url = "https://my-proxy.example.com/v1".to_string(); +provider.bearer_token = Some(std::env::var("OPENAI_API_KEY")?); + +let mut config = SessionConfig::default(); +config.provider = Some(provider); +``` + +Provider types include `"openai"`, `"azure"`, and `"anthropic"`. Set `wire_api` to `"completions"` or `"responses"` (OpenAI/Azure only). Custom headers go in `provider.headers`. The SDK forwards the configuration to the CLI verbatim — the CLI handles the upstream call, including authentication. + +### Telemetry + +Forward OpenTelemetry signals from the spawned CLI process to your collector: + +```rust,ignore +use github_copilot_sdk::{ClientOptions, OtelExporterType, TelemetryConfig}; + +let mut telem = TelemetryConfig::default(); +telem.exporter_type = Some(OtelExporterType::OtlpHttp); +telem.otlp_endpoint = Some("http://localhost:4318".to_string()); +telem.source_name = Some("my-app".to_string()); + +let mut opts = ClientOptions::default(); +opts.telemetry = Some(telem); +let client = Client::start(opts).await?; +``` + +The SDK injects the appropriate environment variables (`COPILOT_OTEL_EXPORTER_TYPE`, `OTEL_EXPORTER_OTLP_ENDPOINT`, ...) into the spawned CLI process. The SDK takes no OpenTelemetry dependency; the CLI itself owns the exporter pipeline. Caller-supplied `ClientOptions::env` entries override telemetry-injected values. + +### Progress Reporting (`send_and_wait`) + +For fire-and-forget messaging where you need to block until the agent finishes: + +```rust,ignore +use std::time::Duration; +use github_copilot_sdk::MessageOptions; + +// Sends a message and blocks until session.idle or session.error +session + .send_and_wait( + MessageOptions::new("Fix the bug").with_wait_timeout(Duration::from_secs(120)), + ) + .await?; +``` + +Default timeout is 60 seconds. Only one `send_and_wait` can be active per session — concurrent calls return an error. + +### Newtypes + +**`SessionId`** — a newtype wrapper around `String` that prevents accidentally passing workspace IDs or request IDs where session IDs are expected. Transparent serialization (`#[serde(transparent)]`), zero-cost `Deref`, and ergonomic comparisons with `&str` and `String`. + +```rust,ignore +use github_copilot_sdk::SessionId; + +let id = SessionId::new("sess-abc123"); +assert_eq!(id, "sess-abc123"); // compare with &str +let raw: String = id.into_inner(); // unwrap when needed +``` + +### Error Handling + +The SDK uses a typed error enum: + +```rust,ignore +pub enum Error { + Protocol(ProtocolError), // JSON-RPC framing, CLI startup, version mismatch + Rpc { code: i32, message: String }, // CLI returned an error response + Session(SessionError), // Session not found, agent error, timeout, conflicts + Io(std::io::Error), // Transport I/O error + Json(serde_json::Error), // Serialization error + BinaryNotFound { name, hint }, // CLI binary not found +} + +// Check if the transport is broken (caller should discard the client) +if err.is_transport_failure() { + client = Client::start(options).await?; +} +``` + +## Differences From Other SDKs + +The Rust SDK aligns closely with the Node, Python, and Go SDKs but diverges +in a few places where Rust idiom or the type system gives a clearly better +shape, and exposes a small additional surface where the language affords +ergonomics the dynamically-typed SDKs don't. + +### Shape divergence + +- **`SessionFsProvider` registration is direct, not factory-closure.** Where + Node/Python/Go accept a closure that the runtime calls on each + session-create to build a fresh provider, the Rust SDK takes + `Arc` directly via + [`SessionConfig::with_session_fs_provider`]. The factory pattern doesn't + cleanly express in Rust at the session-config call site — there is no + `Session` value to thread in, and the SDK already prefers traits over + boxed closures for handler-shaped APIs (`SessionHandler`, `SessionHooks`, + `ToolHandler`). + +```rust,ignore +use std::sync::Arc; +use github_copilot_sdk::session_fs::{SessionFsConfig, SessionFsConventions}; + +let mut options = ClientOptions::default(); +options.session_fs = Some(SessionFsConfig::new( + "/workspace", + "/workspace/.copilot", + SessionFsConventions::Posix, +)); +let client = Client::start(options).await?; + +let session = client + .create_session( + SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .with_session_fs_provider(Arc::new(MyProvider::new())), + ) + .await?; +``` + +See [`examples/session_fs.rs`](examples/session_fs.rs) for a complete +in-memory provider implementation. + +### Rust-only API + +A handful of conveniences exist only on the Rust SDK as of 0.1.0. These +are surface areas where Rust idiom (newtypes, enums, trait objects) +gives a clearly nicer shape than Node/Python/Go currently expose. Rust +gets to be Rust here — cross-SDK parity for these is a post-release +conversation, not a release blocker. None of these are deprecated and +none of them are scheduled for removal. + +- **`Client::get_quota`** — top-level convenience wrapper for fetching + account-level request quota snapshots. Rust-only as of 0.1.0; the other + SDKs do not expose a client-level shortcut. The underlying + `account.getQuota` JSON-RPC endpoint itself is available cross-SDK via + each SDK's typed `rpc()` namespace (Node + `client.rpc().account().getQuota()`, Python + `client.rpc().account.get_quota()`, Go + `client.Rpc().Account().GetQuota()`, .NET + `client.Rpc().Account().GetQuotaAsync()`), including in Rust at + `client.rpc().account().get_quota()`. +- **First-class `Session` convenience methods** — `set_mode` / `get_mode`, + `set_name` / `get_name`, `get_model`, `read_plan` / `update_plan` / + `delete_plan`, `start_fleet`, `list_workspace_files` / + `read_workspace_file` / `create_workspace_file`. The other SDKs require + the consumer to drive the typed JSON-RPC namespace directly for these. +- **`Client::send_telemetry` / `Session::send_telemetry`** — top-level + and session-scoped telemetry emission via `sendTelemetry` / + `session.sendTelemetry`. Other SDKs do not currently expose these RPC + endpoints in their public APIs (not even via the typed namespace). +- **Typed newtypes** — `SessionId` and `RequestId` are `#[serde(transparent)]` + newtypes around `String`, so the type system distinguishes a session + identifier from an arbitrary `String` at compile time. Node/Python/Go + use bare strings. +- **Permission policy builders** — `permission::approve_all`, + `permission::deny_all`, and `permission::approve_if(handler, predicate)` + in `crate::permission` provide composable, no-handler-needed permission + shortcuts that wrap an existing `SessionHandler`. Other SDKs require a + full handler implementation for these patterns. +- **`Client::from_streams`** — connect to a CLI server over arbitrary + caller-supplied `AsyncRead` / `AsyncWrite`. Useful for testing, + in-process embedding, or custom transports. Other SDKs are spawn-only + or fixed-stdio. +- **`enum Transport { Stdio, Tcp, External }`** — explicit, exhaustive + transport selector on `ClientOptions::transport`. Node/Python/Go rely + on conditional config field combinations instead. +- **Split `prefix_args` / `extra_args`** on `ClientOptions` — separate + arg vectors for "prepend before subcommand" vs "append after the + built-in flags", giving precise control over CLI invocation order + without string-splicing. +- **`SessionHandler::on_auto_mode_switch`** — typed handler for the CLI's + rate-limit-recovery prompt (CLI's `autoModeSwitch.request` callback, + added in copilot-agent-runtime PR #7024). Returns + `AutoModeSwitchResponse::{Yes, YesAlways, No}`. Default impl declines. + Cross-SDK parity is post-release follow-up — Node / Python / Go / .NET + consumers currently observe the request as a raw event and must drive + the wire response themselves. + +## Layout + +| File | Description | +|---|---| +| `lib.rs` | `Client`, `ClientOptions`, `CliProgram`, `Transport`, `Error` | +| `session.rs` | `Session` struct, event loop, `send`/`send_and_wait`, `Client::create_session`/`resume_session` | +| `subscription.rs` | `EventSubscription` / `LifecycleSubscription` (`Stream`-able observer handles for `subscribe()` / `subscribe_lifecycle()`) | +| `handler.rs` | `SessionHandler` trait, `HandlerEvent`/`HandlerResponse` enums, `ApproveAllHandler` | +| `hooks.rs` | `SessionHooks` trait, `HookEvent`/`HookOutput` enums, typed hook inputs/outputs | +| `transforms.rs` | `SystemMessageTransform` trait, section-level system message customization | +| `tool.rs` | `ToolHandler` trait, `ToolHandlerRouter`, `schema_for::()` (with `derive` feature) | +| `types.rs` | CLI protocol types (`SessionId`, `SessionEvent`, `SessionConfig`, `Tool`, etc.) | +| `resolve.rs` | Binary resolution (`copilot_binary`, `node_binary`, `extended_path`) | +| `embeddedcli.rs` | Embedded CLI extraction (`embedded-cli` feature) | +| `router.rs` | Internal per-session event demux | +| `jsonrpc.rs` | Internal Content-Length framed JSON-RPC transport | + +## Embedded CLI + +By default, `copilot_binary()` searches `COPILOT_CLI_PATH`, the system PATH, and common install locations. To **ship with a specific CLI version** embedded in the binary, set `COPILOT_CLI_VERSION` at build time: + +```bash +COPILOT_CLI_VERSION=1.0.15 cargo build +``` + +### How it works + +1. **Build time:** The SDK's `build.rs` detects `COPILOT_CLI_VERSION`, downloads the platform-appropriate archive from the [`github/copilot-cli` GitHub Releases](https://github.com/github/copilot-cli/releases) (`copilot-{platform}.tar.gz` on macOS/Linux, `.zip` on Windows), verifies the archive's SHA-256 against the release's `SHA256SUMS.txt`, extracts the `copilot` binary, compresses it with zstd, and embeds via `include_bytes!()`. No extra steps or tools needed — just the env var. + +2. **Runtime:** On the first call to `github_copilot_sdk::resolve::copilot_binary()`, the embedded binary is lazily extracted to `~/.cache/github-copilot-sdk-{version}/copilot` (or `copilot.exe` on Windows), SHA-256 verified, and cached. Subsequent calls return the cached path. + +3. **Dev builds:** Without the env var, `build.rs` does nothing. The binary is resolved from PATH as usual — zero friction. + +### Resolution priority + +`copilot_binary()` checks these sources in order: + +1. `COPILOT_CLI_PATH` environment variable +2. Embedded CLI (build-time, via `COPILOT_CLI_VERSION`) +3. System PATH + common install locations + +### Platforms + +Supported: `darwin-arm64`, `darwin-x64`, `linux-x64`, `linux-arm64`, `win32-x64`, `win32-arm64`. The target platform is auto-detected from `CARGO_CFG_TARGET_OS` and `CARGO_CFG_TARGET_ARCH` (cross-compilation works). + +## Features + +No features are enabled by default — the bare SDK resolves the CLI from `COPILOT_CLI_PATH` or the system PATH without pulling in additional feature-gated dependencies. + +| Feature | Default | Description | +|---|---|---| +| `embedded-cli` | — | Build-time CLI embedding via `COPILOT_CLI_VERSION` (adds `sha2`, `zstd`). Enable when you need to ship a self-contained binary with a pinned CLI version. | +| `derive` | — | `schema_for::()` for generating JSON Schema from Rust types (adds `schemars`). Enable when defining [tool parameters](#tool-registration). | + +```toml +# These examples use registry syntax for illustration; until the crate is +# published, use a path or git dependency instead. + +# Minimal — resolve CLI from PATH +github-copilot-sdk = "0.1" + +# Ship a pinned CLI version in your binary +github-copilot-sdk = { version = "0.1", features = ["embedded-cli"] } + +# Derive JSON Schema for tool parameters +github-copilot-sdk = { version = "0.1", features = ["derive"] } + +# Both +github-copilot-sdk = { version = "0.1", features = ["embedded-cli", "derive"] } +``` diff --git a/rust/RELEASING.md b/rust/RELEASING.md new file mode 100644 index 000000000..5361591d2 --- /dev/null +++ b/rust/RELEASING.md @@ -0,0 +1,192 @@ +# Releasing `github-copilot-sdk` + +This document describes how to cut a release of the `github-copilot-sdk` Rust crate +and publish it to [crates.io]. It is the operational counterpart to the +workflow files under `../.github/workflows/rust-*.yml` (which run the actual +mechanics). + +If you are adding code to the SDK, you do not need to read this. This is for +maintainers cutting a release. + +[crates.io]: https://crates.io/crates/github-copilot-sdk + +--- + +## TL;DR + +1. Land your changes on `main` using conventional-commit messages. +2. Trigger the **Rust SDK: Create Release PR** workflow manually + (`workflow_dispatch`). +3. Review and merge the PR that release-plz opens. +4. The **Rust SDK: Publish Release** workflow runs automatically when that + PR merges, publishes to crates.io, tags `rust-vX.Y.Z`, and creates a + GitHub Release. + +The first 0.1.0 publish requires a one-time `CARGO_REGISTRY_TOKEN` secret +setup — see [First-time setup](#first-time-setup) below. + +--- + +## How releases are cut + +The crate uses [release-plz] in a two-PR workflow. Both PRs run unattended +through GitHub Actions; the only manual step is reviewing and merging. + +[release-plz]: https://release-plz.dev/ + +### Step 1 — `release-plz release-pr` + +Workflow: `.github/workflows/rust-release-pr.yml` (`workflow_dispatch` only). + +When you trigger it, release-plz: + +- Reads conventional-commit history since the last `rust-vX.Y.Z` tag. +- Decides the next version (patch / minor / major) per SemVer rules. +- Bumps `rust/Cargo.toml`'s `version` field. +- Renames `## [Unreleased]` in `rust/CHANGELOG.md` to `## [X.Y.Z] - + ` and prepends a fresh empty `## [Unreleased]` above it. +- Opens a PR with those changes. + +Review the PR. The CHANGELOG entry is the one users see on crates.io and on +the GitHub Release page, so make sure it reads well. Edit the PR directly if +the auto-generated entry needs tweaking. + +> **First-publish note.** The hand-curated 0.1.0 entry currently lives +> under `## [Unreleased]` so release-plz will rename it cleanly on the +> first run. If release-plz instead generates a *second* entry from +> conventional commits and prepends it above the curated one (depends on +> the configured `body` template), delete the auto-generated stub in the +> PR and keep the curated entry — you only want one 0.1.0 section. + +### Step 2 — `release-plz release` (publish) + +Workflow: `.github/workflows/rust-publish-release.yml` (auto-runs on push +to `main` when `rust/Cargo.toml`, `rust/Cargo.lock`, or `rust/release-plz.toml` +changes). + +When the release-PR from step 1 merges, this workflow detects that +`rust/Cargo.toml`'s version is newer than the latest `rust-vX.Y.Z` tag and: + +- Runs `cargo publish` to upload to crates.io. +- Creates a `rust-vX.Y.Z` git tag. +- Creates a GitHub Release with the CHANGELOG entry as the body. + +The workflow is a no-op on non-release commits, so it's safe to run on every +push. + +--- + +## First-time setup + +Before the first 0.1.0 publish, complete this checklist exactly once: + +1. **Reserve the crate name.** Have a maintainer with crates.io 2FA log in + to crates.io and run `cargo publish` for an empty stub OR claim the name + via the "New Crate" form. The owner account should be a service account + (preferred) or a senior maintainer. +2. **Generate a scoped API token.** crates.io → Account Settings → API + Tokens → New Token. Scope it to publish `github-copilot-sdk` *only* — do not + issue an unscoped token. +3. **Add the secret.** GitHub repo Settings → Secrets and variables → + Actions → New repository secret named `CARGO_REGISTRY_TOKEN`, value = + the token from step 2. +4. **Rotation.** Rotate the token annually and whenever the maintainer set + changes. There's no automated reminder for this — set a calendar event. + +Until this checklist is complete, `cargo publish` in the workflow will fail. +That's intentional: it keeps accidental publishes from happening before the +repo is ready. + +--- + +## Versioning policy + +The crate follows [SemVer]. Pre-1.0 we treat **0.x.0** as breaking and +**0.x.y** as additive — same as the Rust ecosystem convention. + +[SemVer]: https://semver.org/ + +Two CI checks defend the API surface: + +- **`cargo semver-checks`** (`.github/workflows/rust-sdk-tests.yml`) — + detects breaking changes against the latest *published* version on + crates.io. Currently `continue-on-error: true` because there's no + baseline yet. **Flip it to `false` after 0.1.0 ships** to make SemVer + enforcement blocking. + +For ad-hoc public-surface inspection, `cargo public-api -sss --features +derive,test-support` is handy — but the surface is not snapshotted in the +repo. The rendered docs on [docs.rs](https://docs.rs/github-copilot-sdk) are the +canonical reference; `cargo-semver-checks` is the gate. + +For 0.x → 1.0, do an explicit API review pass (compare against the +language siblings under `../{nodejs,python,go,dotnet}/`), +remove anything `#[doc(hidden)]` you don't intend to keep public, and +write out the 1.0 commitment in the CHANGELOG. + +--- + +## Public-disclosure gate + +The Rust SDK release-prep work happens on `tclem/rust-sdk-release-prep` +and is held *unpushed* until product/comms gives explicit OK. Do not push +the branch, open a PR, or otherwise expose the work without that signal — +even if CI looks ready. + +Ways to keep moving without pushing: + +- Land work in local commits on the prep branch. +- Use `cargo publish --dry-run --allow-dirty` to validate package contents. +- Use `cargo public-api -sss --features derive,test-support` for ad-hoc + surface inspection. + +When the gate opens: + +1. Push `tclem/rust-sdk-release-prep`. +2. Open a PR titled "Rust SDK: prepare for 0.1.0 release" (or similar). +3. Once it merges, trigger the **Rust SDK: Create Release PR** workflow and + proceed with the publish flow above. + +--- + +## Manual publish (emergency only) + +If GitHub Actions is unavailable, a maintainer with crates.io credentials +can publish locally: + +```sh +cd rust + +# Verify the package contents first. +cargo publish --dry-run + +# Publish for real. +cargo publish + +# Tag and push. +git tag rust-v$(cargo metadata --no-deps --format-version=1 \ + | jq -r '.packages[] | select(.name=="github-copilot-sdk") | .version' | head -1) +git push origin --tags +``` + +Manual publishes skip the release-PR review step, so write the CHANGELOG +entry by hand before publishing and commit it on `main` first. + +--- + +## Yanking a release + +If a published version contains a critical bug (security, data loss, panic +on common input), yank it from crates.io to prevent new installs: + +```sh +cargo yank --version X.Y.Z github-copilot-sdk +``` + +Yanking does *not* delete the version — existing `Cargo.lock` files keep +working — but it stops new resolutions from picking it. Follow up with a +patch release that fixes the bug, and add a note to the yanked version's +GitHub Release explaining why. + +Reverse with `cargo yank --undo --version X.Y.Z github-copilot-sdk` if the yank +was a mistake. diff --git a/rust/build.rs b/rust/build.rs new file mode 100644 index 000000000..22463c9a9 --- /dev/null +++ b/rust/build.rs @@ -0,0 +1,340 @@ +use std::io::Read; +use std::path::Path; +use std::time::Duration; + +use sha2::Digest; + +fn main() { + println!("cargo:rerun-if-env-changed=COPILOT_CLI_VERSION"); + println!("cargo:rerun-if-env-changed=BUNDLED_CLI_CACHE_DIR"); + println!("cargo::rustc-check-cfg=cfg(has_bundled_cli)"); + + let Ok(version) = std::env::var("COPILOT_CLI_VERSION") else { + return; + }; + + let Some(platform) = target_platform() else { + println!( + "cargo:warning=COPILOT_CLI_VERSION set but unsupported target platform, skipping CLI bundling" + ); + return; + }; + + let out_dir = std::env::var("OUT_DIR").expect("OUT_DIR is always set by cargo"); + let out = Path::new(&out_dir); + + let base_url = format!("https://github.com/github/copilot-cli/releases/download/v{version}"); + let cache_dir = std::env::var("BUNDLED_CLI_CACHE_DIR") + .ok() + .map(std::path::PathBuf::from); + + // Download SHA256SUMS and find the expected hash for our platform's tarball. + let asset_name = platform.asset_name; + println!("cargo:warning=Bundling GitHub Copilot CLI v{version} ({asset_name})"); + // Download checksums and find the expected hash for our platform's archive. + let checksums_url = format!("{base_url}/SHA256SUMS.txt"); + let checksums = download_with_retry(&checksums_url); + let checksums_text = + std::str::from_utf8(&checksums).expect("checksums file is not valid UTF-8"); + let expected_hash = find_sha256_for_asset(checksums_text, asset_name); + + // Use a versioned cache key since copilot asset names don't include the version. + let cache_key = format!("v{version}-{asset_name}"); + + // Download the archive (or read from cache) and verify integrity. + let archive = cached_download( + &format!("{base_url}/{asset_name}"), + &cache_key, + &expected_hash, + &cache_dir, + ); + println!("cargo:warning=SHA-256 verified ({expected_hash})"); + + // Extract the binary from the archive. + let binary = extract_binary(&archive, platform.binary_name, platform.is_zip); + println!( + "cargo:warning=Extracted {} ({} bytes)", + platform.binary_name, + binary.len() + ); + + // Compress and embed. + let hash = sha256(&binary); + let compressed = zstd::encode_all(&binary[..], 19).expect("zstd compression failed"); + println!( + "cargo:warning=Compressed to {} bytes ({:.1}%)", + compressed.len(), + compressed.len() as f64 / binary.len() as f64 * 100.0 + ); + + std::fs::write(out.join("copilot_cli.zst"), &compressed) + .expect("failed to write copilot_cli.zst"); + + let hash_tokens: Vec = hash.iter().map(|b| format!("0x{b:02x}")).collect(); + let generated = format!( + r#"// Auto-generated by github-copilot-sdk build.rs. Do not edit. +pub(super) static CLI_BYTES: &[u8] = include_bytes!("copilot_cli.zst"); +pub(super) static CLI_HASH: [u8; 32] = [{}]; +pub(super) static CLI_VERSION: &str = "{version}"; +"#, + hash_tokens.join(", ") + ); + + std::fs::write(out.join("bundled_cli.rs"), generated).expect("failed to write bundled_cli.rs"); + + println!("cargo:rustc-cfg=has_bundled_cli"); +} + +struct Platform { + asset_name: &'static str, + binary_name: &'static str, + is_zip: bool, +} + +fn target_platform() -> Option { + let os = std::env::var("CARGO_CFG_TARGET_OS").ok()?; + let arch = std::env::var("CARGO_CFG_TARGET_ARCH").ok()?; + + match (os.as_str(), arch.as_str()) { + ("macos", "aarch64") => Some(Platform { + asset_name: "copilot-darwin-arm64.tar.gz", + binary_name: "copilot", + is_zip: false, + }), + ("macos", "x86_64") => Some(Platform { + asset_name: "copilot-darwin-x64.tar.gz", + binary_name: "copilot", + is_zip: false, + }), + ("linux", "x86_64") => Some(Platform { + asset_name: "copilot-linux-x64.tar.gz", + binary_name: "copilot", + is_zip: false, + }), + ("linux", "aarch64") => Some(Platform { + asset_name: "copilot-linux-arm64.tar.gz", + binary_name: "copilot", + is_zip: false, + }), + ("windows", "x86_64") => Some(Platform { + asset_name: "copilot-win32-x64.zip", + binary_name: "copilot.exe", + is_zip: true, + }), + ("windows", "aarch64") => Some(Platform { + asset_name: "copilot-win32-arm64.zip", + binary_name: "copilot.exe", + is_zip: true, + }), + _ => None, + } +} + +/// Read a file from the download cache, or download it (with retries) and save +/// to cache. Verifies SHA-256 on every path. Evicts stale/corrupt cache entries +/// automatically. Cache I/O failures are treated as cache misses — they never +/// break the build. +fn cached_download( + url: &str, + cache_key: &str, + expected_hash: &str, + cache_dir: &Option, +) -> Vec { + if let Some(dir) = cache_dir { + let cached_path = dir.join(cache_key); + if cached_path.is_file() { + match std::fs::read(&cached_path) { + Ok(data) if hex_sha256(&data) == expected_hash => { + println!( + "cargo:warning=Using cached archive: {}", + cached_path.display() + ); + return data; + } + Ok(_) => { + println!("cargo:warning=Cached archive hash mismatch, re-downloading"); + let _ = std::fs::remove_file(&cached_path); + } + Err(e) => { + println!( + "cargo:warning=Failed to read cache {}, re-downloading: {e}", + cached_path.display() + ); + } + } + } + } + + let data = download_with_retry(url); + let actual_hash = hex_sha256(&data); + if actual_hash != expected_hash { + panic!( + "Archive integrity check failed for {url}!\n expected: {expected_hash}\n actual: {actual_hash}\n \ + This could indicate a corrupted download or a supply-chain attack." + ); + } + + if let Some(dir) = cache_dir { + if let Err(e) = std::fs::create_dir_all(dir) { + println!( + "cargo:warning=Failed to create cache directory {}: {e}", + dir.display() + ); + } else { + let cached_path = dir.join(cache_key); + if let Err(e) = std::fs::write(&cached_path, &data) { + println!( + "cargo:warning=Failed to write cache file {}: {e}", + cached_path.display() + ); + } else { + println!("cargo:warning=Cached archive to: {}", cached_path.display()); + } + } + } + + data +} + +/// Maximum number of HTTP attempts (one initial + this many retries on transient errors). +const MAX_RETRIES: u32 = 3; + +/// Download `url` with bounded retries on transient network errors. Backoff is +/// exponential starting at 1s. 4xx responses fail fast; 5xx and connect/read +/// errors are retried. +fn download_with_retry(url: &str) -> Vec { + let mut attempt = 0u32; + loop { + attempt += 1; + match try_download(url) { + Ok(bytes) => return bytes, + Err(err) if err.transient && attempt <= MAX_RETRIES => { + let backoff = Duration::from_secs(1u64 << (attempt - 1)); + println!( + "cargo:warning=Transient download failure for {url} (attempt {attempt}/{}): {} — retrying in {}s", + MAX_RETRIES + 1, + err.message, + backoff.as_secs(), + ); + std::thread::sleep(backoff); + } + Err(err) => panic!("Failed to download {url}: {}", err.message), + } + } +} + +struct DownloadError { + message: String, + transient: bool, +} + +fn try_download(url: &str) -> Result, DownloadError> { + let agent = ureq::AgentBuilder::new() + .timeout_connect(Duration::from_secs(30)) + .timeout_read(Duration::from_secs(120)) + .build(); + + match agent.get(url).call() { + Ok(response) => { + let mut bytes = Vec::new(); + response + .into_reader() + .read_to_end(&mut bytes) + .map_err(|e| DownloadError { + message: format!("read error: {e}"), + transient: true, + })?; + Ok(bytes) + } + // 5xx — server-side, treat as transient. + Err(ureq::Error::Status(code, response)) if (500..600).contains(&code) => { + Err(DownloadError { + message: format!("HTTP {code} {}", response.status_text()), + transient: true, + }) + } + // 4xx — client-side, fail fast. + Err(ureq::Error::Status(code, response)) => Err(DownloadError { + message: format!("HTTP {code} {}", response.status_text()), + transient: false, + }), + // Transport-layer (DNS, connect, TLS, read timeout) — treat as transient. + Err(ureq::Error::Transport(t)) => Err(DownloadError { + message: format!("transport error: {t}"), + transient: true, + }), + } +} + +fn find_sha256_for_asset(sums: &str, asset_name: &str) -> String { + for line in sums.lines() { + // Format: " " (two spaces) + if let Some((hash, name)) = line.split_once(" ") + && name.trim() == asset_name + { + return hash.trim().to_string(); + } + } + panic!("SHA256SUMS.txt does not contain an entry for {asset_name}"); +} + +fn extract_binary(archive_bytes: &[u8], binary_name: &str, is_zip: bool) -> Vec { + if is_zip { + extract_from_zip(archive_bytes, binary_name) + } else { + extract_from_tarball(archive_bytes, binary_name) + } +} + +fn extract_from_tarball(tarball: &[u8], binary_name: &str) -> Vec { + let gz = flate2::read::GzDecoder::new(tarball); + let mut archive = tar::Archive::new(gz); + + for entry in archive.entries().expect("failed to read tarball entries") { + let mut entry = entry.expect("failed to read tarball entry"); + let path = entry + .path() + .expect("entry has no path") + .to_string_lossy() + .to_string(); + if path == binary_name || path.ends_with(&format!("/{binary_name}")) { + let mut bytes = Vec::new(); + entry + .read_to_end(&mut bytes) + .expect("failed to read binary from tarball"); + return bytes; + } + } + + panic!("'{binary_name}' not found in tarball"); +} + +fn extract_from_zip(zip_bytes: &[u8], binary_name: &str) -> Vec { + // Minimal zip extraction — find the binary by name. + // The Windows assets are .zip files with just copilot.exe at the root. + let cursor = std::io::Cursor::new(zip_bytes); + let mut archive = zip::ZipArchive::new(cursor).expect("failed to read zip archive"); + + for i in 0..archive.len() { + let mut file = archive.by_index(i).expect("failed to read zip entry"); + let name = file.name().to_string(); + if name == binary_name || name.ends_with(&format!("/{binary_name}")) { + let mut bytes = Vec::new(); + file.read_to_end(&mut bytes) + .expect("failed to read binary from zip"); + return bytes; + } + } + + panic!("'{binary_name}' not found in zip"); +} + +fn sha256(data: &[u8]) -> [u8; 32] { + let mut hasher = sha2::Sha256::new(); + hasher.update(data); + hasher.finalize().into() +} + +fn hex_sha256(data: &[u8]) -> String { + sha256(data).iter().map(|b| format!("{b:02x}")).collect() +} diff --git a/rust/clippy.toml b/rust/clippy.toml new file mode 100644 index 000000000..22781c472 --- /dev/null +++ b/rust/clippy.toml @@ -0,0 +1,8 @@ +await-holding-invalid-types = [ + { path = "tracing::span::Entered", reason = "generates incorrect spans when held across 'await' points" }, + { path = "tracing::span::EnteredSpan", reason = "generates incorrect spans when held across 'await' points" }, +] + +disallowed-macros = [ + { path = "tracing::instrument", reason = "tracing::instrument is error-prone. Use tracing::error_span! in the method body instead." }, +] diff --git a/rust/examples/chat.rs b/rust/examples/chat.rs new file mode 100644 index 000000000..37293c6bc --- /dev/null +++ b/rust/examples/chat.rs @@ -0,0 +1,122 @@ +//! Interactive chat with GitHub Copilot. +//! +//! Starts a GitHub Copilot CLI server, creates a session, and enters a read-eval-print +//! loop where each line you type is sent to the agent. Streaming is enabled so +//! response tokens print to stdout incrementally as they arrive. +//! +//! ```sh +//! cargo run -p github-copilot-sdk --example chat +//! ``` + +use std::io::{self, BufRead, Write}; +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{ + HandlerEvent, HandlerResponse, PermissionResult, SessionHandler, UserInputResponse, +}; +use github_copilot_sdk::types::{MessageOptions, SessionConfig, SessionEvent}; +use github_copilot_sdk::{Client, ClientOptions}; + +/// Handler that prints assistant message deltas as they stream in +/// and auto-approves permissions. +struct ChatHandler; + +#[async_trait] +impl SessionHandler for ChatHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::SessionEvent { event, .. } => { + print_event(&event); + HandlerResponse::Ok + } + HandlerEvent::PermissionRequest { .. } => { + HandlerResponse::Permission(PermissionResult::Approved) + } + HandlerEvent::UserInput { question, .. } => { + // Prompt the user on behalf of the agent. + print!("\n[agent asks] {question}\n> "); + io::stdout().flush().ok(); + let answer = read_line().unwrap_or_default(); + HandlerResponse::UserInput(Some(UserInputResponse { + answer, + was_freeform: true, + })) + } + _ => HandlerResponse::Ok, + } + } +} + +fn print_event(event: &SessionEvent) { + match event.event_type.as_str() { + "assistant.message_delta" => { + let text = event + .data + .get("deltaContent") + .and_then(|c| c.as_str()) + .unwrap_or(""); + print!("{text}"); + io::stdout().flush().ok(); + } + "assistant.message" => { + // Final message — print a newline to terminate the streamed output. + println!(); + } + "session.error" => { + let msg = event + .data + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("unknown error"); + eprintln!("\n[error] {msg}"); + } + _ => {} + } +} + +fn read_line() -> Option { + let stdin = io::stdin(); + let mut line = String::new(); + stdin.lock().read_line(&mut line).ok()?; + if line.is_empty() { + return None; // EOF + } + Some(line.trim_end_matches(&['\n', '\r'][..]).to_string()) +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let client = Client::start(ClientOptions::default()).await?; + + let config = { + let mut cfg = SessionConfig::default(); + cfg.streaming = Some(true); + cfg.with_handler(Arc::new(ChatHandler)) + }; + let session = client.create_session(config).await?; + + println!( + "Session {} started. Type a message (Ctrl-D to quit).\n", + session.id() + ); + + loop { + print!("> "); + io::stdout().flush().ok(); + + let Some(line) = read_line() else { break }; + if line.is_empty() { + continue; + } + + session + .send_and_wait(MessageOptions::new(line).with_wait_timeout(Duration::from_secs(120))) + .await?; + } + + println!("\nGoodbye."); + session.destroy().await?; + Ok(()) +} diff --git a/rust/examples/hooks.rs b/rust/examples/hooks.rs new file mode 100644 index 000000000..86f6ceadc --- /dev/null +++ b/rust/examples/hooks.rs @@ -0,0 +1,133 @@ +//! Session hooks for logging and auditing. +//! +//! Demonstrates `SessionHooks` to intercept lifecycle events — logging every +//! tool invocation, summarizing prompts, and recording session start/end +//! for audit purposes. +//! +//! ```sh +//! cargo run -p github-copilot-sdk --example hooks +//! ``` + +use std::sync::Arc; +use std::time::Duration; + +use async_trait::async_trait; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::hooks::{ + HookEvent, HookOutput, PostToolUseOutput, PreToolUseOutput, SessionEndOutput, SessionHooks, + SessionStartOutput, +}; +use github_copilot_sdk::types::{MessageOptions, SessionConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +/// Hooks implementation that logs lifecycle events to stdout. +struct AuditHooks; + +#[async_trait] +impl SessionHooks for AuditHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::SessionStart { input, ctx } => { + println!( + "[audit] session {} started (source={}, cwd={})", + ctx.session_id, + input.source, + input.cwd.display(), + ); + HookOutput::SessionStart(SessionStartOutput { + additional_context: Some("You are being audited. Be concise.".to_string()), + ..Default::default() + }) + } + + HookEvent::PreToolUse { input, ctx } => { + println!( + "[audit] session {} — pre tool use: {} (args: {})", + ctx.session_id, input.tool_name, input.tool_args, + ); + // Example: deny a specific tool by name. + if input.tool_name == "dangerous_tool" { + return HookOutput::PreToolUse(PreToolUseOutput { + permission_decision: Some("deny".to_string()), + permission_decision_reason: Some("blocked by audit policy".to_string()), + ..Default::default() + }); + } + HookOutput::None + } + + HookEvent::PostToolUse { input, ctx } => { + println!( + "[audit] session {} — post tool use: {} (result: {})", + ctx.session_id, input.tool_name, input.tool_result, + ); + HookOutput::PostToolUse(PostToolUseOutput::default()) + } + + HookEvent::UserPromptSubmitted { input, ctx } => { + println!( + "[audit] session {} — user prompt ({} chars)", + ctx.session_id, + input.prompt.len(), + ); + HookOutput::None + } + + HookEvent::SessionEnd { input, ctx } => { + println!( + "[audit] session {} ended (reason={})", + ctx.session_id, input.reason, + ); + HookOutput::SessionEnd(SessionEndOutput { + session_summary: Some("Audited session complete.".to_string()), + ..Default::default() + }) + } + + HookEvent::ErrorOccurred { input, ctx } => { + eprintln!( + "[audit] session {} — error in {}: {} (recoverable={})", + ctx.session_id, input.error_context, input.error, input.recoverable, + ); + HookOutput::None + } + + _ => HookOutput::None, + } + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let client = Client::start(ClientOptions::default()).await?; + + // hooks: true is set automatically when a hooks handler is provided. + let config = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .with_hooks(Arc::new(AuditHooks)); + let session = client.create_session(config).await?; + + println!( + "Session {} with audit hooks. Sending a message...\n", + session.id() + ); + + let response = session + .send_and_wait( + MessageOptions::new("Say hello in three languages.") + .with_wait_timeout(Duration::from_secs(60)), + ) + .await?; + + if let Some(event) = response { + let text = event + .data + .get("content") + .and_then(|c| c.as_str()) + .unwrap_or(""); + println!("\n{text}"); + } + + session.destroy().await?; + Ok(()) +} diff --git a/rust/examples/lifecycle_observer.rs b/rust/examples/lifecycle_observer.rs new file mode 100644 index 000000000..612792073 --- /dev/null +++ b/rust/examples/lifecycle_observer.rs @@ -0,0 +1,120 @@ +//! Observe lifecycle and event traffic without owning permission decisions. +//! +//! Demonstrates the channel-based observer APIs: +//! +//! - [`Client::subscribe_lifecycle`] — `tokio::sync::broadcast::Receiver` of +//! every `session.lifecycle` notification (created / destroyed / errored / +//! foreground / background). Filter by matching on `event.event_type` in +//! the consumer. +//! - [`Session::subscribe`] — receiver for the per-session `session.event` +//! stream (assistant messages, tool calls, permission prompts, etc.). +//! Observe-only — the constructor handler still owns permission decisions. +//! - [`Client::state`] — current connection state without polling. +//! - [`Client::get_session_metadata`] — inspect a session without resuming +//! it. +//! - [`Client::force_stop`] — synchronous shutdown for cleanup paths. +//! +//! Drop the receiver to unsubscribe — there is no separate cancel handle. +//! Slow consumers receive `RecvError::Lagged(n)` and resync on the next +//! event; they do not block the producer. +//! +//! ```sh +//! cargo run -p github-copilot-sdk --example lifecycle_observer +//! ``` +//! +//! [`Client::subscribe_lifecycle`]: github_copilot_sdk::Client::subscribe_lifecycle +//! [`Session::subscribe`]: github_copilot_sdk::session::Session::subscribe +//! [`Client::state`]: github_copilot_sdk::Client::state +//! [`Client::get_session_metadata`]: github_copilot_sdk::Client::get_session_metadata +//! [`Client::force_stop`]: github_copilot_sdk::Client::force_stop + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{MessageOptions, SessionConfig, SessionLifecycleEventType}; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let client = Client::start(ClientOptions::default()).await?; + println!("[client] state: {:?}", client.state()); + + // Wildcard lifecycle subscriber: see every session.lifecycle event, + // counting deletions inline by filtering on event_type. + let mut lifecycle_rx = client.subscribe_lifecycle(); + let deleted = Arc::new(AtomicUsize::new(0)); + let deleted_clone = Arc::clone(&deleted); + let lifecycle_task = tokio::spawn(async move { + while let Ok(event) = lifecycle_rx.recv().await { + let summary = event + .metadata + .as_ref() + .and_then(|m| m.summary.as_deref()) + .unwrap_or(""); + println!( + "[lifecycle:*] {:?} session={} summary={}", + event.event_type, event.session_id, summary, + ); + if event.event_type == SessionLifecycleEventType::Deleted { + deleted_clone.fetch_add(1, Ordering::Relaxed); + } + } + }); + + let config = SessionConfig::default().with_handler(Arc::new(ApproveAllHandler)); + let session = client.create_session(config).await?; + println!("[client] state after create: {:?}", client.state()); + + // Per-session observer: see every assistant message, tool call, etc. + // Subscribers fire alongside the constructor handler; they're great for + // logging or metrics that should run regardless of how the handler + // decides to respond. + let mut session_rx = session.subscribe(); + let session_events = Arc::new(AtomicUsize::new(0)); + let session_events_clone = Arc::clone(&session_events); + let session_task = tokio::spawn(async move { + while let Ok(event) = session_rx.recv().await { + session_events_clone.fetch_add(1, Ordering::Relaxed); + println!("[session-event] {}", event.event_type); + } + }); + + if let Some(metadata) = client.get_session_metadata(session.id()).await? { + println!( + "[metadata] id={} modified={} summary={}", + metadata.session_id, + metadata.modified_time, + metadata.summary.as_deref().unwrap_or(""), + ); + } + + session + .send_and_wait( + MessageOptions::new("Say hello in five words or fewer.") + .with_wait_timeout(Duration::from_secs(60)), + ) + .await?; + + session.destroy().await?; + + // Synchronous shutdown — useful in panicking-cleanup paths or tests + // where you don't have an async runtime available to await `stop()`. + // For graceful shutdown in normal flow, prefer `client.stop().await`. + client.force_stop(); + println!("[client] state after force_stop: {:?}", client.state()); + + // Stopping the client closes the broadcast senders, so the consumer + // tasks observe `RecvError::Closed` and exit cleanly. + let _ = lifecycle_task.await; + let _ = session_task.await; + + println!( + "\n[summary] session_events={} sessions_deleted={}", + session_events.load(Ordering::Relaxed), + deleted.load(Ordering::Relaxed), + ); + + Ok(()) +} diff --git a/rust/examples/session_fs.rs b/rust/examples/session_fs.rs new file mode 100644 index 000000000..0dbbb3414 --- /dev/null +++ b/rust/examples/session_fs.rs @@ -0,0 +1,139 @@ +//! Custom `SessionFsProvider` backed by an in-memory map. +//! +//! Demonstrates registering a [`SessionFsProvider`] so the CLI delegates all +//! per-session filesystem operations to your code. Useful for sandboxed +//! sessions, projecting files into virtual storage, or applying permission +//! policies before bytes are read or written. +//! +//! ```sh +//! cargo run -p github-copilot-sdk --example session_fs +//! ``` + +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::session_fs::{ + DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConfig, SessionFsConventions, + SessionFsProvider, +}; +use github_copilot_sdk::types::{MessageOptions, SessionConfig}; +use github_copilot_sdk::{Client, ClientOptions}; +use parking_lot::Mutex; + +struct InMemoryProvider { + files: Mutex>, +} + +impl InMemoryProvider { + fn new() -> Self { + let mut seed = HashMap::new(); + seed.insert( + "/workspace/README.md".to_string(), + "# Demo project\n\nThis file lives in memory.\n".to_string(), + ); + Self { + files: Mutex::new(seed), + } + } +} + +#[async_trait] +impl SessionFsProvider for InMemoryProvider { + async fn read_file(&self, path: &str) -> Result { + self.files + .lock() + .get(path) + .cloned() + .ok_or_else(|| FsError::NotFound(path.to_string())) + } + + async fn write_file( + &self, + path: &str, + content: &str, + _mode: Option, + ) -> Result<(), FsError> { + self.files + .lock() + .insert(path.to_string(), content.to_string()); + Ok(()) + } + + async fn exists(&self, path: &str) -> Result { + Ok(self.files.lock().contains_key(path)) + } + + async fn stat(&self, path: &str) -> Result { + let files = self.files.lock(); + let content = files + .get(path) + .ok_or_else(|| FsError::NotFound(path.to_string()))?; + Ok(FileInfo::new( + true, + false, + content.len() as i64, + "2025-01-01T00:00:00Z", + "2025-01-01T00:00:00Z", + )) + } + + async fn readdir_with_types(&self, path: &str) -> Result, FsError> { + let prefix = if path.ends_with('/') { + path.to_string() + } else { + format!("{path}/") + }; + let names: Vec = self + .files + .lock() + .keys() + .filter_map(|k| k.strip_prefix(&prefix)) + .filter(|rest| !rest.is_empty()) + .map(|rest| { + let name = rest.split('/').next().unwrap_or(rest); + DirEntry::new(name, DirEntryKind::File) + }) + .collect(); + Ok(names) + } + + async fn rm(&self, path: &str, _recursive: bool, force: bool) -> Result<(), FsError> { + if self.files.lock().remove(path).is_none() && !force { + return Err(FsError::NotFound(path.to_string())); + } + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let provider: Arc = Arc::new(InMemoryProvider::new()); + + let options = { + let mut opts = ClientOptions::default(); + opts.session_fs = Some(SessionFsConfig::new( + "/workspace", + "/workspace/.copilot", + SessionFsConventions::Posix, + )); + opts + }; + + let client = Client::start(options).await?; + let session = client + .create_session( + SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .with_session_fs_provider(provider), + ) + .await?; + + let response = session + .send(MessageOptions::new("Summarize README.md.")) + .await?; + println!("Assistant: {response}"); + + Ok(()) +} diff --git a/rust/examples/tool_server.rs b/rust/examples/tool_server.rs new file mode 100644 index 000000000..55bacbbe6 --- /dev/null +++ b/rust/examples/tool_server.rs @@ -0,0 +1,187 @@ +//! Define custom tools and expose them to the Copilot agent. +//! +//! Registers two tools — `get_weather` (typed params via schemars) and +//! `roll_dice` (manual schema) — then asks the agent a question that +//! triggers tool use. +//! +//! Requires the `derive` feature for typed parameter schemas: +//! +//! ```sh +//! cargo run -p github-copilot-sdk --example tool_server --features derive +//! ``` + +// Gate the entire example behind the `derive` feature so it compiles +// (as a stub that prints the required feature flag) when clippy/check +// runs without the feature. +#[cfg(not(feature = "derive"))] +fn main() { + eprintln!("This example requires the `derive` feature:"); + eprintln!(" cargo run -p github-copilot-sdk --example tool_server --features derive"); + std::process::exit(1); +} + +#[cfg(feature = "derive")] +use std::sync::Arc; +#[cfg(feature = "derive")] +use std::time::Duration; + +#[cfg(feature = "derive")] +use async_trait::async_trait; +#[cfg(feature = "derive")] +use github_copilot_sdk::handler::ApproveAllHandler; +#[cfg(feature = "derive")] +use github_copilot_sdk::tool::{ + JsonSchema, ToolHandler, ToolHandlerRouter, schema_for, tool_parameters, +}; +#[cfg(feature = "derive")] +use github_copilot_sdk::types::{MessageOptions, SessionConfig, Tool, ToolInvocation, ToolResult}; +#[cfg(feature = "derive")] +use github_copilot_sdk::{Client, ClientOptions, Error}; +#[cfg(feature = "derive")] +use serde::Deserialize; + +// --------------------------------------------------------------------------- +// Tool 1: get_weather — typed parameters derived from a Rust struct +// --------------------------------------------------------------------------- + +#[cfg(feature = "derive")] +#[derive(Deserialize, JsonSchema)] +struct GetWeatherParams { + /// City name (e.g. "Seattle"). + city: String, + /// Temperature unit: "celsius" or "fahrenheit". + unit: Option, +} + +#[cfg(feature = "derive")] +struct GetWeatherTool; + +#[cfg(feature = "derive")] +#[async_trait] +impl ToolHandler for GetWeatherTool { + fn tool(&self) -> Tool { + let mut tool = Tool::default(); + tool.name = "get_weather".to_string(); + tool.description = "Get the current weather for a city.".to_string(); + tool.parameters = tool_parameters(schema_for::()); + tool + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let params: GetWeatherParams = serde_json::from_value(invocation.arguments)?; + let unit = params.unit.as_deref().unwrap_or("celsius"); + // Stub response — a real implementation would call a weather API. + let reply = format!( + "Weather in {}: 18°{}, partly cloudy", + params.city, + if unit == "fahrenheit" { "F" } else { "C" }, + ); + Ok(ToolResult::Text(reply)) + } +} + +// --------------------------------------------------------------------------- +// Tool 2: roll_dice — manual JSON Schema +// --------------------------------------------------------------------------- + +#[cfg(feature = "derive")] +struct RollDiceTool; + +#[cfg(feature = "derive")] +#[async_trait] +impl ToolHandler for RollDiceTool { + fn tool(&self) -> Tool { + let mut tool = Tool::default(); + tool.name = "roll_dice".to_string(); + tool.description = "Roll one or more dice and return the total.".to_string(); + tool.parameters = tool_parameters(serde_json::json!({ + "type": "object", + "properties": { + "sides": { "type": "integer", "description": "Number of sides per die (default 6, max 1000)." }, + "count": { "type": "integer", "description": "Number of dice to roll (default 1, max 100)." } + } + })); + tool + } + + async fn call(&self, invocation: ToolInvocation) -> Result { + let sides = invocation + .arguments + .get("sides") + .and_then(|v| v.as_u64()) + .unwrap_or(6) + .clamp(1, 1000) as u32; + let count = invocation + .arguments + .get("count") + .and_then(|v| v.as_u64()) + .unwrap_or(1) + .clamp(1, 100) as u32; + + let mut total = 0u32; + let mut rolls = Vec::with_capacity(count as usize); + for _ in 0..count { + // Simple deterministic "random" for the example. + let roll = (std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .subsec_nanos() + % sides) + + 1; + rolls.push(roll); + total += roll; + } + + Ok(ToolResult::Text(format!( + "Rolled {count}d{sides}: {rolls:?} = {total}" + ))) + } +} + +// --------------------------------------------------------------------------- +// Main +// --------------------------------------------------------------------------- + +#[cfg(feature = "derive")] +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let router = ToolHandlerRouter::new( + vec![Box::new(GetWeatherTool), Box::new(RollDiceTool)], + Arc::new(ApproveAllHandler), + ); + let tools = router.tools(); + let handler = Arc::new(router); + + let client = Client::start(ClientOptions::default()).await?; + + let config = { + let mut cfg = SessionConfig::default(); + cfg.tools = Some(tools); + cfg.with_handler(handler) + }; + let session = client.create_session(config).await?; + + println!( + "Session {} — asking about weather + dice...\n", + session.id() + ); + + let response = session + .send_and_wait( + MessageOptions::new("What's the weather in Seattle? Also roll 3d20 for me.") + .with_wait_timeout(Duration::from_secs(60)), + ) + .await?; + + if let Some(event) = response { + let text = event + .data + .get("content") + .and_then(|c| c.as_str()) + .unwrap_or(""); + println!("{text}"); + } + + session.destroy().await?; + Ok(()) +} diff --git a/rust/release-plz.toml b/rust/release-plz.toml new file mode 100644 index 000000000..82c38ffd7 --- /dev/null +++ b/rust/release-plz.toml @@ -0,0 +1,35 @@ +[workspace] +# release-plz config for the Rust github-copilot-sdk crate. +# +# The crate lives in the `rust/` subdirectory of the monorepo, so +# invoke release-plz from this directory (via the release-plz workflows +# under `.github/workflows/`). release-plz will: +# +# 1. `release-plz release-pr`: open a PR updating `rust/Cargo.toml`'s +# version and `rust/CHANGELOG.md` based on conventional-commit +# history on `tclem/rust-sdk-release-prep`-style branches. +# 2. `release-plz release`: after that PR is merged to main, publish +# the tagged version to crates.io and create a `rust-vX.Y.Z` git +# tag. +# +# Publishing requires a `CARGO_REGISTRY_TOKEN` repository secret scoped +# to the `github-copilot-sdk` crate owner account. See +# `.github/workflows/rust-publish-release.yml` for the setup checklist. +# +# Reference: https://release-plz.dev/docs/config +changelog_update = true +dependencies_update = false +git_release_enable = true +# Prefix crate git tags so they don't collide with the monorepo's +# top-level `vX.Y.Z` tags used by the other SDKs. +git_tag_name = "rust-v{{ version }}" +git_release_name = "rust-v{{ version }}" + +[[package]] +name = "github-copilot-sdk" +changelog_path = "CHANGELOG.md" +# Mark pre-1.0 publishes as prereleases on the GitHub release page so +# consumers don't pick them up as "stable" by default. Maintainers +# should flip this (or remove it) when cutting 1.0. +git_release_type = "auto" + diff --git a/rust/rust-toolchain.toml b/rust/rust-toolchain.toml new file mode 100644 index 000000000..2259b2c8a --- /dev/null +++ b/rust/rust-toolchain.toml @@ -0,0 +1,4 @@ +[toolchain] +channel = "1.94.0" +components = ["clippy", "rust-analyzer", "rustfmt"] +profile = "default" diff --git a/rust/src/embeddedcli.rs b/rust/src/embeddedcli.rs new file mode 100644 index 000000000..d0e5ea9ff --- /dev/null +++ b/rust/src/embeddedcli.rs @@ -0,0 +1,278 @@ +#[cfg(any(has_bundled_cli, test))] +use std::fs; +#[cfg(any(has_bundled_cli, test))] +use std::io::{self, Read, Write}; +#[cfg(any(has_bundled_cli, test))] +use std::path::Path; +use std::path::PathBuf; +use std::sync::OnceLock; + +#[cfg(has_bundled_cli)] +use tracing::{info, warn}; + +// When the SDK is built with COPILOT_CLI_VERSION set, build.rs generates +// bundled_cli.rs with the compressed binary bytes, hash, and version. +#[cfg(has_bundled_cli)] +mod build_time { + include!(concat!(env!("OUT_DIR"), "/bundled_cli.rs")); +} + +static INSTALLED_PATH: OnceLock> = OnceLock::new(); + +/// Returns the bundled CLI version string, if one was embedded at build time. +pub fn bundled_version() -> Option<&'static str> { + #[cfg(has_bundled_cli)] + { + Some(build_time::CLI_VERSION) + } + #[cfg(not(has_bundled_cli))] + { + None + } +} + +/// Returns the path to the installed CLI binary, lazily extracting on first call. +/// +/// When the SDK was built with `COPILOT_CLI_VERSION` set, this extracts the +/// embedded binary to `~/.cache/github-copilot-sdk-{version}/copilot` (or +/// `copilot.exe` on Windows), verifies the SHA-256 hash, and returns the +/// path. Subsequent calls return the cached result. +/// +/// Returns `None` if no CLI was embedded at build time. +pub fn path() -> Option { + INSTALLED_PATH + .get_or_init(|| { + #[cfg(has_bundled_cli)] + { + match install( + build_time::CLI_BYTES, + build_time::CLI_HASH, + build_time::CLI_VERSION, + ) { + Ok(path) => { + info!(path = %path.display(), version = build_time::CLI_VERSION, "embedded CLI installed"); + return Some(path); + } + Err(e) => { + warn!(error = %e, "embedded CLI installation failed"); + } + } + } + None + }) + .clone() +} + +#[cfg(has_bundled_cli)] +fn install( + compressed: &[u8], + expected_hash: [u8; 32], + version: &str, +) -> Result { + let verbose = std::env::var("COPILOT_CLI_INSTALL_VERBOSE").ok().as_deref() == Some("1"); + + let cache = dirs::cache_dir().unwrap_or_else(std::env::temp_dir); + // Use a versioned directory so multiple versions can coexist, + // but keep the binary named `copilot` — the CLI checks argv[0] + // for this exact name. + let install_dir = if version.is_empty() { + cache.join("github-copilot-sdk") + } else { + cache.join(format!("github-copilot-sdk-{}", sanitize_version(version))) + }; + fs::create_dir_all(&install_dir).map_err(EmbeddedCliError::CreateDir)?; + + let binary_name = binary_name(); + let final_path = install_dir.join(&binary_name); + + // If the binary already exists and hash matches, skip extraction. + if final_path.is_file() { + let existing_hash = hash_file(&final_path)?; + if existing_hash == expected_hash { + if verbose { + eprintln!("embedded CLI already installed at {}", final_path.display()); + } + return Ok(final_path); + } + if verbose { + eprintln!("embedded CLI hash mismatch, reinstalling"); + } + } + + let start = std::time::Instant::now(); + let decompressed = decompress(compressed)?; + + let actual_hash = sha256(&decompressed); + if actual_hash != expected_hash { + return Err(EmbeddedCliError::HashMismatch); + } + + write_binary(&final_path, &decompressed)?; + + if verbose { + eprintln!( + "embedded CLI installed at {} in {:?}", + final_path.display(), + start.elapsed() + ); + } + + Ok(final_path) +} + +#[cfg(any(has_bundled_cli, test))] +fn binary_name() -> String { + if cfg!(target_os = "windows") { + "copilot.exe".to_string() + } else { + "copilot".to_string() + } +} + +#[cfg(has_bundled_cli)] +fn sanitize_version(version: &str) -> String { + version + .chars() + .map(|c| match c { + 'a'..='z' | 'A'..='Z' | '0'..='9' | '.' | '-' | '_' => c, + _ => '_', + }) + .collect() +} + +#[cfg(any(has_bundled_cli, test))] +fn decompress(data: &[u8]) -> Result, EmbeddedCliError> { + let mut decoder = zstd::Decoder::new(data).map_err(EmbeddedCliError::Decompress)?; + let mut out = Vec::new(); + decoder + .read_to_end(&mut out) + .map_err(EmbeddedCliError::Decompress)?; + Ok(out) +} + +#[cfg(any(has_bundled_cli, test))] +fn sha256(data: &[u8]) -> [u8; 32] { + use sha2::Digest; + let mut hasher = sha2::Sha256::new(); + hasher.update(data); + hasher.finalize().into() +} + +#[cfg(has_bundled_cli)] +fn hash_file(path: &Path) -> Result<[u8; 32], EmbeddedCliError> { + use sha2::Digest; + let mut file = fs::File::open(path).map_err(EmbeddedCliError::Io)?; + let mut hasher = sha2::Sha256::new(); + let mut buf = [0u8; 8192]; + loop { + let n = file.read(&mut buf).map_err(EmbeddedCliError::Io)?; + if n == 0 { + break; + } + hasher.update(&buf[..n]); + } + Ok(hasher.finalize().into()) +} + +#[cfg(any(has_bundled_cli, test))] +fn write_binary(path: &Path, data: &[u8]) -> Result<(), EmbeddedCliError> { + let mut file = fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .map_err(EmbeddedCliError::Io)?; + + file.write_all(data).map_err(EmbeddedCliError::Io)?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + fs::set_permissions(path, fs::Permissions::from_mode(0o755)) + .map_err(EmbeddedCliError::Io)?; + } + + Ok(()) +} + +#[cfg(any(has_bundled_cli, test))] +#[derive(Debug, thiserror::Error)] +#[allow(dead_code)] +enum EmbeddedCliError { + #[error("failed to create install directory: {0}")] + CreateDir(io::Error), + + #[error("decompression failed: {0}")] + Decompress(io::Error), + + #[error("SHA-256 hash of decompressed binary does not match expected hash")] + HashMismatch, + + #[error("I/O error: {0}")] + Io(io::Error), +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn install_extracts_to_cache_dir() { + let temp = tempfile::tempdir().expect("should create temp dir"); + let original = b"fake copilot binary"; + let hash = sha256(original); + let compressed = zstd::encode_all(&original[..], 3).expect("compression should succeed"); + + // Override cache dir via env for test isolation. + let path = install_to_dir(&temp, &compressed, hash); + let expected_name = binary_name(); + assert!(path.is_file()); + assert_eq!( + path.file_name().and_then(|s| s.to_str()), + Some(expected_name.as_str()) + ); + + let installed_content = fs::read(&path).expect("should read installed binary"); + assert_eq!(installed_content, original); + + // Second install should be idempotent (hash matches, skips extraction). + let path2 = install_to_dir(&temp, &compressed, hash); + assert_eq!(path, path2); + } + + #[test] + fn install_rejects_hash_mismatch() { + let temp = tempfile::tempdir().expect("should create temp dir"); + let original = b"fake copilot binary"; + let wrong_hash = [0u8; 32]; + let compressed = zstd::encode_all(&original[..], 3).expect("compression should succeed"); + + let result = install_to_dir_result(&temp, &compressed, wrong_hash); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("SHA-256"),); + } + + // Test helpers that install to a specific directory instead of the global cache. + fn install_to_dir(temp: &tempfile::TempDir, compressed: &[u8], hash: [u8; 32]) -> PathBuf { + install_to_dir_result(temp, compressed, hash).expect("install should succeed") + } + + fn install_to_dir_result( + temp: &tempfile::TempDir, + compressed: &[u8], + hash: [u8; 32], + ) -> Result { + let install_dir = temp.path().to_path_buf(); + fs::create_dir_all(&install_dir).expect("create dir"); + let binary_name = binary_name(); + let final_path = install_dir.join(&binary_name); + + let decompressed = decompress(compressed)?; + let actual_hash = sha256(&decompressed); + if actual_hash != hash { + return Err(EmbeddedCliError::HashMismatch); + } + write_binary(&final_path, &decompressed)?; + Ok(final_path) + } +} diff --git a/rust/src/generated/api_types.rs b/rust/src/generated/api_types.rs new file mode 100644 index 000000000..1b5eb433f --- /dev/null +++ b/rust/src/generated/api_types.rs @@ -0,0 +1,3175 @@ +//! Auto-generated from api.schema.json — do not edit manually. + +#![allow(clippy::large_enum_variant)] + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::types::{RequestId, SessionId}; + +/// JSON-RPC method name constants. +pub mod rpc_methods { + /// `ping` + pub const PING: &str = "ping"; + /// `models.list` + pub const MODELS_LIST: &str = "models.list"; + /// `tools.list` + pub const TOOLS_LIST: &str = "tools.list"; + /// `account.getQuota` + pub const ACCOUNT_GETQUOTA: &str = "account.getQuota"; + /// `mcp.config.list` + pub const MCP_CONFIG_LIST: &str = "mcp.config.list"; + /// `mcp.config.add` + pub const MCP_CONFIG_ADD: &str = "mcp.config.add"; + /// `mcp.config.update` + pub const MCP_CONFIG_UPDATE: &str = "mcp.config.update"; + /// `mcp.config.remove` + pub const MCP_CONFIG_REMOVE: &str = "mcp.config.remove"; + /// `mcp.config.enable` + pub const MCP_CONFIG_ENABLE: &str = "mcp.config.enable"; + /// `mcp.config.disable` + pub const MCP_CONFIG_DISABLE: &str = "mcp.config.disable"; + /// `mcp.discover` + pub const MCP_DISCOVER: &str = "mcp.discover"; + /// `skills.config.setDisabledSkills` + pub const SKILLS_CONFIG_SETDISABLEDSKILLS: &str = "skills.config.setDisabledSkills"; + /// `skills.discover` + pub const SKILLS_DISCOVER: &str = "skills.discover"; + /// `sessionFs.setProvider` + pub const SESSIONFS_SETPROVIDER: &str = "sessionFs.setProvider"; + /// `sessions.fork` + pub const SESSIONS_FORK: &str = "sessions.fork"; + /// `session.auth.getStatus` + pub const SESSION_AUTH_GETSTATUS: &str = "session.auth.getStatus"; + /// `session.model.getCurrent` + pub const SESSION_MODEL_GETCURRENT: &str = "session.model.getCurrent"; + /// `session.model.switchTo` + pub const SESSION_MODEL_SWITCHTO: &str = "session.model.switchTo"; + /// `session.mode.get` + pub const SESSION_MODE_GET: &str = "session.mode.get"; + /// `session.mode.set` + pub const SESSION_MODE_SET: &str = "session.mode.set"; + /// `session.name.get` + pub const SESSION_NAME_GET: &str = "session.name.get"; + /// `session.name.set` + pub const SESSION_NAME_SET: &str = "session.name.set"; + /// `session.plan.read` + pub const SESSION_PLAN_READ: &str = "session.plan.read"; + /// `session.plan.update` + pub const SESSION_PLAN_UPDATE: &str = "session.plan.update"; + /// `session.plan.delete` + pub const SESSION_PLAN_DELETE: &str = "session.plan.delete"; + /// `session.workspaces.getWorkspace` + pub const SESSION_WORKSPACES_GETWORKSPACE: &str = "session.workspaces.getWorkspace"; + /// `session.workspaces.listFiles` + pub const SESSION_WORKSPACES_LISTFILES: &str = "session.workspaces.listFiles"; + /// `session.workspaces.readFile` + pub const SESSION_WORKSPACES_READFILE: &str = "session.workspaces.readFile"; + /// `session.workspaces.createFile` + pub const SESSION_WORKSPACES_CREATEFILE: &str = "session.workspaces.createFile"; + /// `session.instructions.getSources` + pub const SESSION_INSTRUCTIONS_GETSOURCES: &str = "session.instructions.getSources"; + /// `session.fleet.start` + pub const SESSION_FLEET_START: &str = "session.fleet.start"; + /// `session.agent.list` + pub const SESSION_AGENT_LIST: &str = "session.agent.list"; + /// `session.agent.getCurrent` + pub const SESSION_AGENT_GETCURRENT: &str = "session.agent.getCurrent"; + /// `session.agent.select` + pub const SESSION_AGENT_SELECT: &str = "session.agent.select"; + /// `session.agent.deselect` + pub const SESSION_AGENT_DESELECT: &str = "session.agent.deselect"; + /// `session.agent.reload` + pub const SESSION_AGENT_RELOAD: &str = "session.agent.reload"; + /// `session.tasks.startAgent` + pub const SESSION_TASKS_STARTAGENT: &str = "session.tasks.startAgent"; + /// `session.tasks.list` + pub const SESSION_TASKS_LIST: &str = "session.tasks.list"; + /// `session.tasks.promoteToBackground` + pub const SESSION_TASKS_PROMOTETOBACKGROUND: &str = "session.tasks.promoteToBackground"; + /// `session.tasks.cancel` + pub const SESSION_TASKS_CANCEL: &str = "session.tasks.cancel"; + /// `session.tasks.remove` + pub const SESSION_TASKS_REMOVE: &str = "session.tasks.remove"; + /// `session.skills.list` + pub const SESSION_SKILLS_LIST: &str = "session.skills.list"; + /// `session.skills.enable` + pub const SESSION_SKILLS_ENABLE: &str = "session.skills.enable"; + /// `session.skills.disable` + pub const SESSION_SKILLS_DISABLE: &str = "session.skills.disable"; + /// `session.skills.reload` + pub const SESSION_SKILLS_RELOAD: &str = "session.skills.reload"; + /// `session.mcp.list` + pub const SESSION_MCP_LIST: &str = "session.mcp.list"; + /// `session.mcp.enable` + pub const SESSION_MCP_ENABLE: &str = "session.mcp.enable"; + /// `session.mcp.disable` + pub const SESSION_MCP_DISABLE: &str = "session.mcp.disable"; + /// `session.mcp.reload` + pub const SESSION_MCP_RELOAD: &str = "session.mcp.reload"; + /// `session.mcp.oauth.login` + pub const SESSION_MCP_OAUTH_LOGIN: &str = "session.mcp.oauth.login"; + /// `session.plugins.list` + pub const SESSION_PLUGINS_LIST: &str = "session.plugins.list"; + /// `session.extensions.list` + pub const SESSION_EXTENSIONS_LIST: &str = "session.extensions.list"; + /// `session.extensions.enable` + pub const SESSION_EXTENSIONS_ENABLE: &str = "session.extensions.enable"; + /// `session.extensions.disable` + pub const SESSION_EXTENSIONS_DISABLE: &str = "session.extensions.disable"; + /// `session.extensions.reload` + pub const SESSION_EXTENSIONS_RELOAD: &str = "session.extensions.reload"; + /// `session.tools.handlePendingToolCall` + pub const SESSION_TOOLS_HANDLEPENDINGTOOLCALL: &str = "session.tools.handlePendingToolCall"; + /// `session.commands.handlePendingCommand` + pub const SESSION_COMMANDS_HANDLEPENDINGCOMMAND: &str = "session.commands.handlePendingCommand"; + /// `session.ui.elicitation` + pub const SESSION_UI_ELICITATION: &str = "session.ui.elicitation"; + /// `session.ui.handlePendingElicitation` + pub const SESSION_UI_HANDLEPENDINGELICITATION: &str = "session.ui.handlePendingElicitation"; + /// `session.permissions.handlePendingPermissionRequest` + pub const SESSION_PERMISSIONS_HANDLEPENDINGPERMISSIONREQUEST: &str = + "session.permissions.handlePendingPermissionRequest"; + /// `session.permissions.setApproveAll` + pub const SESSION_PERMISSIONS_SETAPPROVEALL: &str = "session.permissions.setApproveAll"; + /// `session.permissions.resetSessionApprovals` + pub const SESSION_PERMISSIONS_RESETSESSIONAPPROVALS: &str = + "session.permissions.resetSessionApprovals"; + /// `session.log` + pub const SESSION_LOG: &str = "session.log"; + /// `session.shell.exec` + pub const SESSION_SHELL_EXEC: &str = "session.shell.exec"; + /// `session.shell.kill` + pub const SESSION_SHELL_KILL: &str = "session.shell.kill"; + /// `session.history.compact` + pub const SESSION_HISTORY_COMPACT: &str = "session.history.compact"; + /// `session.history.truncate` + pub const SESSION_HISTORY_TRUNCATE: &str = "session.history.truncate"; + /// `session.usage.getMetrics` + pub const SESSION_USAGE_GETMETRICS: &str = "session.usage.getMetrics"; + /// `sessionFs.readFile` + pub const SESSIONFS_READFILE: &str = "sessionFs.readFile"; + /// `sessionFs.writeFile` + pub const SESSIONFS_WRITEFILE: &str = "sessionFs.writeFile"; + /// `sessionFs.appendFile` + pub const SESSIONFS_APPENDFILE: &str = "sessionFs.appendFile"; + /// `sessionFs.exists` + pub const SESSIONFS_EXISTS: &str = "sessionFs.exists"; + /// `sessionFs.stat` + pub const SESSIONFS_STAT: &str = "sessionFs.stat"; + /// `sessionFs.mkdir` + pub const SESSIONFS_MKDIR: &str = "sessionFs.mkdir"; + /// `sessionFs.readdir` + pub const SESSIONFS_READDIR: &str = "sessionFs.readdir"; + /// `sessionFs.readdirWithTypes` + pub const SESSIONFS_READDIRWITHTYPES: &str = "sessionFs.readdirWithTypes"; + /// `sessionFs.rm` + pub const SESSIONFS_RM: &str = "sessionFs.rm"; + /// `sessionFs.rename` + pub const SESSIONFS_RENAME: &str = "sessionFs.rename"; +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AccountGetQuotaRequest { + /// GitHub token for per-user quota lookup. When provided, resolves this token to determine the user's quota instead of using the global auth. + #[serde(skip_serializing_if = "Option::is_none")] + pub git_hub_token: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AccountQuotaSnapshot { + /// Number of requests included in the entitlement + pub entitlement_requests: i64, + /// Whether the user has an unlimited usage entitlement + pub is_unlimited_entitlement: bool, + /// Number of overage requests made this period + pub overage: f64, + /// Whether overage is allowed when quota is exhausted + pub overage_allowed_with_exhausted_quota: bool, + /// Percentage of entitlement remaining + pub remaining_percentage: f64, + /// Date when the quota resets (ISO 8601 string) + #[serde(skip_serializing_if = "Option::is_none")] + pub reset_date: Option, + /// Whether usage is still permitted after quota exhaustion + pub usage_allowed_with_exhausted_quota: bool, + /// Number of requests used so far this period + pub used_requests: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AccountGetQuotaResult { + /// Quota snapshots keyed by type (e.g., chat, completions, premium_interactions) + pub quota_snapshots: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentInfo { + /// Description of the agent's purpose + pub description: String, + /// Human-readable display name + pub display_name: String, + /// Unique identifier of the custom agent + pub name: String, + /// Absolute local file path of the agent definition. Only set for file-based agents loaded from disk; remote agents do not have a path. + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentGetCurrentResult { + /// Currently selected custom agent, or null if using the default agent + pub agent: AgentInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentList { + /// Available custom agents + pub agents: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentReloadResult { + /// Reloaded custom agents + pub agents: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentSelectRequest { + /// Name of the custom agent to select + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AgentSelectResult { + /// The newly selected custom agent + pub agent: AgentInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandsHandlePendingCommandRequest { + /// Error message if the command handler failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Request ID from the command invocation event + pub request_id: RequestId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandsHandlePendingCommandResult { + /// Whether the command was handled successfully + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CurrentModel { + /// Currently active model identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DiscoveredMcpServer { + /// Whether the server is enabled (not in the disabled list) + pub enabled: bool, + /// Server name (config key) + pub name: String, + /// Configuration source + pub source: DiscoveredMcpServerSource, + /// Server transport type: stdio, http, sse, or memory (local configs are normalized to stdio) + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Extension { + /// Source-qualified ID (e.g., 'project:my-ext', 'user:auth-helper') + pub id: String, + /// Extension name (directory name) + pub name: String, + /// Process ID if the extension is running + #[serde(skip_serializing_if = "Option::is_none")] + pub pid: Option, + /// Discovery source: project (.github/extensions/) or user (~/.copilot/extensions/) + pub source: ExtensionSource, + /// Current status: running, disabled, failed, or starting + pub status: ExtensionStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExtensionList { + /// Discovered extensions and their current status + pub extensions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExtensionsDisableRequest { + /// Source-qualified extension ID to disable + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExtensionsEnableRequest { + /// Source-qualified extension ID to enable + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FleetStartRequest { + /// Optional user prompt to combine with fleet instructions + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FleetStartResult { + /// Whether fleet mode was successfully activated + pub started: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HandleToolCallResult { + /// Whether the tool call result was handled successfully + pub success: bool, +} + +/// Post-compaction context window usage breakdown +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HistoryCompactContextWindow { + /// Token count from non-system messages (user, assistant, tool) + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_tokens: Option, + /// Current total tokens in the context window (system + conversation + tool definitions) + pub current_tokens: i64, + /// Current number of messages in the conversation + pub messages_length: i64, + /// Token count from system message(s) + #[serde(skip_serializing_if = "Option::is_none")] + pub system_tokens: Option, + /// Maximum token count for the model's context window + pub token_limit: i64, + /// Token count from tool definitions + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_definitions_tokens: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HistoryCompactResult { + /// Post-compaction context window usage breakdown + #[serde(skip_serializing_if = "Option::is_none")] + pub context_window: Option, + /// Number of messages removed during compaction + pub messages_removed: i64, + /// Whether compaction completed successfully + pub success: bool, + /// Number of tokens freed by compaction + pub tokens_removed: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HistoryTruncateRequest { + /// Event ID to truncate to. This event and all events after it are removed from the session. + pub event_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HistoryTruncateResult { + /// Number of events that were removed + pub events_removed: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InstructionsSources { + /// Glob pattern from frontmatter — when set, this instruction applies only to matching files + #[serde(skip_serializing_if = "Option::is_none")] + pub apply_to: Option, + /// Raw content of the instruction file + pub content: String, + /// Short description (body after frontmatter) for use in instruction tables + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Unique identifier for this source (used for toggling) + pub id: String, + /// Human-readable label + pub label: String, + /// Where this source lives — used for UI grouping + pub location: InstructionsSourcesLocation, + /// File path relative to repo or absolute for home + pub source_path: String, + /// Category of instruction source — used for merge logic + pub r#type: InstructionsSourcesType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct InstructionsGetSourcesResult { + /// Instruction sources for the session + pub sources: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LogRequest { + /// When true, the message is transient and not persisted to the session event log on disk + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral: Option, + /// Log severity level. Determines how the message is displayed in the timeline. Defaults to "info". + #[serde(skip_serializing_if = "Option::is_none")] + pub level: Option, + /// Human-readable message + pub message: String, + /// Optional URL the user can open in their browser for more details + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LogResult { + /// The unique identifier of the emitted session event + pub event_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigAddRequest { + /// MCP server configuration (local/stdio or remote/http) + pub config: serde_json::Value, + /// Unique name for the MCP server + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigDisableRequest { + /// Names of MCP servers to disable. Each server is added to the persisted disabled list so new sessions skip it. Already-disabled names are ignored. Active sessions keep their current connections until they end. + pub names: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigEnableRequest { + /// Names of MCP servers to enable. Each server is removed from the persisted disabled list so new sessions spawn it. Unknown or already-enabled names are ignored. + pub names: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigList { + /// All MCP servers from user config, keyed by name + pub servers: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigRemoveRequest { + /// Name of the MCP server to remove + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigUpdateRequest { + /// MCP server configuration (local/stdio or remote/http) + pub config: serde_json::Value, + /// Name of the MCP server to update + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpDisableRequest { + /// Name of the MCP server to disable + pub server_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpDiscoverRequest { + /// Working directory used as context for discovery (e.g., plugin resolution) + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpDiscoverResult { + /// MCP servers discovered from all sources + pub servers: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpEnableRequest { + /// Name of the MCP server to enable + pub server_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthLoginRequest { + /// Optional override for the body text shown on the OAuth loopback callback success page. When omitted, the runtime applies a neutral fallback; callers driving interactive auth should pass surface-specific copy telling the user where to return. + #[serde(skip_serializing_if = "Option::is_none")] + pub callback_success_message: Option, + /// Optional override for the OAuth client display name shown on the consent screen. Applies to newly registered dynamic clients only — existing registrations keep the name they were created with. When omitted, the runtime applies a neutral fallback; callers driving interactive auth should pass their own surface-specific label so the consent screen matches the product the user sees. + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + /// When true, clears any cached OAuth token for the server and runs a full new authorization. Use when the user explicitly wants to switch accounts or believes their session is stuck. + #[serde(skip_serializing_if = "Option::is_none")] + pub force_reauth: Option, + /// Name of the remote MCP server to authenticate + pub server_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthLoginResult { + /// URL the caller should open in a browser to complete OAuth. Omitted when cached tokens were still valid and no browser interaction was needed — the server is already reconnected in that case. When present, the runtime starts the callback listener before returning and continues the flow in the background; completion is signaled via session.mcp_server_status_changed. + #[serde(skip_serializing_if = "Option::is_none")] + pub authorization_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServer { + /// Error message if the server failed to connect + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Server name (config key) + pub name: String, + /// Configuration source: user, workspace, plugin, or builtin + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + /// Connection status: connected, failed, needs-auth, pending, disabled, or not_configured + pub status: McpServerStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServerConfigHttp { + #[serde(skip_serializing_if = "Option::is_none")] + pub filter_mapping: Option, + #[serde(default)] + pub headers: HashMap, + #[serde(skip_serializing_if = "Option::is_none")] + pub is_default_server: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_client_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub oauth_public_client: Option, + /// Timeout in milliseconds for tool calls to this server. + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, + /// Tools to include. Defaults to all tools if not specified. + #[serde(default)] + pub tools: Vec, + /// Remote transport type. Defaults to "http" when omitted. + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, + pub url: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServerConfigLocal { + pub args: Vec, + pub command: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(default)] + pub env: HashMap, + #[serde(skip_serializing_if = "Option::is_none")] + pub filter_mapping: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub is_default_server: Option, + /// Timeout in milliseconds for tool calls to this server. + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, + /// Tools to include. Defaults to all tools if not specified. + #[serde(default)] + pub tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServerList { + /// Configured MCP servers + pub servers: Vec, +} + +/// Billing information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelBilling { + /// Billing cost multiplier relative to the base rate + pub multiplier: f64, +} + +/// Vision-specific limits +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesLimitsVision { + /// Maximum image size in bytes + #[serde(rename = "max_prompt_image_size")] + pub max_prompt_image_size: i64, + /// Maximum number of images per prompt + #[serde(rename = "max_prompt_images")] + pub max_prompt_images: i64, + /// MIME types the model accepts + #[serde(rename = "supported_media_types")] + pub supported_media_types: Vec, +} + +/// Token limits for prompts, outputs, and context window +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesLimits { + /// Maximum total context window size in tokens + #[serde( + rename = "max_context_window_tokens", + skip_serializing_if = "Option::is_none" + )] + pub max_context_window_tokens: Option, + /// Maximum number of output/completion tokens + #[serde(rename = "max_output_tokens", skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + /// Maximum number of prompt/input tokens + #[serde(rename = "max_prompt_tokens", skip_serializing_if = "Option::is_none")] + pub max_prompt_tokens: Option, + /// Vision-specific limits + #[serde(skip_serializing_if = "Option::is_none")] + pub vision: Option, +} + +/// Feature flags indicating what the model supports +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesSupports { + /// Whether this model supports reasoning effort configuration + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Whether this model supports vision/image input + #[serde(skip_serializing_if = "Option::is_none")] + pub vision: Option, +} + +/// Model capabilities and limits +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilities { + /// Token limits for prompts, outputs, and context window + #[serde(skip_serializing_if = "Option::is_none")] + pub limits: Option, + /// Feature flags indicating what the model supports + #[serde(skip_serializing_if = "Option::is_none")] + pub supports: Option, +} + +/// Policy state (if applicable) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelPolicy { + /// Current policy state for this model + pub state: String, + /// Usage terms or conditions for this model + #[serde(skip_serializing_if = "Option::is_none")] + pub terms: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Model { + /// Billing information + #[serde(skip_serializing_if = "Option::is_none")] + pub billing: Option, + /// Model capabilities and limits + pub capabilities: ModelCapabilities, + /// Default reasoning effort level (only present if model supports reasoning effort) + #[serde(skip_serializing_if = "Option::is_none")] + pub default_reasoning_effort: Option, + /// Model identifier (e.g., "claude-sonnet-4.5") + pub id: String, + /// Display name + pub name: String, + /// Policy state (if applicable) + #[serde(skip_serializing_if = "Option::is_none")] + pub policy: Option, + /// Supported reasoning effort levels (only present if model supports reasoning effort) + #[serde(default)] + pub supported_reasoning_efforts: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesOverrideLimitsVision { + /// Maximum image size in bytes + #[serde( + rename = "max_prompt_image_size", + skip_serializing_if = "Option::is_none" + )] + pub max_prompt_image_size: Option, + /// Maximum number of images per prompt + #[serde(rename = "max_prompt_images", skip_serializing_if = "Option::is_none")] + pub max_prompt_images: Option, + /// MIME types the model accepts + #[serde(rename = "supported_media_types", default)] + pub supported_media_types: Vec, +} + +/// Token limits for prompts, outputs, and context window +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesOverrideLimits { + /// Maximum total context window size in tokens + #[serde( + rename = "max_context_window_tokens", + skip_serializing_if = "Option::is_none" + )] + pub max_context_window_tokens: Option, + #[serde(rename = "max_output_tokens", skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + #[serde(rename = "max_prompt_tokens", skip_serializing_if = "Option::is_none")] + pub max_prompt_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub vision: Option, +} + +/// Feature flags indicating what the model supports +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesOverrideSupports { + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub vision: Option, +} + +/// Override individual model capabilities resolved by the runtime +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCapabilitiesOverride { + /// Token limits for prompts, outputs, and context window + #[serde(skip_serializing_if = "Option::is_none")] + pub limits: Option, + /// Feature flags indicating what the model supports + #[serde(skip_serializing_if = "Option::is_none")] + pub supports: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelList { + /// List of available models with full metadata + pub models: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelsListRequest { + /// GitHub token for per-user model listing. When provided, resolves this token to determine the user's Copilot plan and available models instead of using the global auth. + #[serde(skip_serializing_if = "Option::is_none")] + pub git_hub_token: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelSwitchToRequest { + /// Override individual model capabilities resolved by the runtime + #[serde(skip_serializing_if = "Option::is_none")] + pub model_capabilities: Option, + /// Model identifier to switch to + pub model_id: String, + /// Reasoning effort level to use for the model + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelSwitchToResult { + /// Currently active model identifier after the switch + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModeSetRequest { + /// The agent mode. Valid values: "interactive", "plan", "autopilot". + pub mode: SessionMode, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NameGetResult { + /// The session name (user-set or auto-generated), or null if not yet set + pub name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct NameSetRequest { + /// New session name (1–100 characters, trimmed of leading/trailing whitespace) + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalCommands { + pub command_identifiers: Vec, + pub kind: PermissionDecisionApproveForLocationApprovalCommandsKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalRead { + pub kind: PermissionDecisionApproveForLocationApprovalReadKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalWrite { + pub kind: PermissionDecisionApproveForLocationApprovalWriteKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalMcp { + pub kind: PermissionDecisionApproveForLocationApprovalMcpKind, + pub server_name: String, + pub tool_name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalMcpSampling { + pub kind: PermissionDecisionApproveForLocationApprovalMcpSamplingKind, + pub server_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalMemory { + pub kind: PermissionDecisionApproveForLocationApprovalMemoryKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocationApprovalCustomTool { + pub kind: PermissionDecisionApproveForLocationApprovalCustomToolKind, + pub tool_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForLocation { + /// The approval to persist for this location + pub approval: PermissionDecisionApproveForLocationApproval, + /// Approved and persisted for this project location + pub kind: PermissionDecisionApproveForLocationKind, + /// The location key (git root or cwd) to persist the approval to + pub location_key: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalCommands { + pub command_identifiers: Vec, + pub kind: PermissionDecisionApproveForSessionApprovalCommandsKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalRead { + pub kind: PermissionDecisionApproveForSessionApprovalReadKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalWrite { + pub kind: PermissionDecisionApproveForSessionApprovalWriteKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalMcp { + pub kind: PermissionDecisionApproveForSessionApprovalMcpKind, + pub server_name: String, + pub tool_name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalMcpSampling { + pub kind: PermissionDecisionApproveForSessionApprovalMcpSamplingKind, + pub server_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalMemory { + pub kind: PermissionDecisionApproveForSessionApprovalMemoryKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSessionApprovalCustomTool { + pub kind: PermissionDecisionApproveForSessionApprovalCustomToolKind, + pub tool_name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveForSession { + /// The approval to add as a session-scoped rule + pub approval: PermissionDecisionApproveForSessionApproval, + /// Approved and remembered for the rest of the session + pub kind: PermissionDecisionApproveForSessionKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionApproveOnce { + /// The permission request was approved for this one instance + pub kind: PermissionDecisionApproveOnceKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionReject { + /// Optional feedback from the user explaining the denial + #[serde(skip_serializing_if = "Option::is_none")] + pub feedback: Option, + /// Denied by the user during an interactive prompt + pub kind: PermissionDecisionRejectKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionUserNotAvailable { + /// Denied because user confirmation was unavailable + pub kind: PermissionDecisionUserNotAvailableKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionDecisionRequest { + /// Request ID of the pending permission request + pub request_id: RequestId, + pub result: PermissionDecision, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestResult { + /// Whether the permission request was handled successfully + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionsResetSessionApprovalsRequest {} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionsResetSessionApprovalsResult { + /// Whether the operation succeeded + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionsSetApproveAllRequest { + /// Whether to auto-approve all tool permission requests + pub enabled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionsSetApproveAllResult { + /// Whether the operation succeeded + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PingRequest { + /// Optional message to echo back + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PingResult { + /// Echoed message (or default greeting) + pub message: String, + /// Server protocol version number + pub protocol_version: i64, + /// Server timestamp in milliseconds + pub timestamp: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PlanReadResult { + /// The content of the plan file, or null if it does not exist + pub content: Option, + /// Whether the plan file exists in the workspace + pub exists: bool, + /// Absolute file path of the plan file, or null if workspace is not enabled + pub path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PlanUpdateRequest { + /// The new content for the plan file + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Plugin { + /// Whether the plugin is currently enabled + pub enabled: bool, + /// Marketplace the plugin came from + pub marketplace: String, + /// Plugin name + pub name: String, + /// Installed version + #[serde(skip_serializing_if = "Option::is_none")] + pub version: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PluginList { + /// Installed plugins + pub plugins: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerSkill { + /// Description of what the skill does + pub description: String, + /// Whether the skill is currently enabled (based on global config) + pub enabled: bool, + /// Unique identifier for the skill + pub name: String, + /// Absolute path to the skill file + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// The project path this skill belongs to (only for project/inherited skills) + #[serde(skip_serializing_if = "Option::is_none")] + pub project_path: Option, + /// Source location type (e.g., project, personal-copilot, plugin, builtin) + pub source: String, + /// Whether the skill can be invoked by the user as a slash command + pub user_invocable: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerSkillList { + /// All discovered skills across all sources + pub skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAuthStatus { + /// Authentication type + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_type: Option, + /// Copilot plan tier (e.g., individual_pro, business) + #[serde(skip_serializing_if = "Option::is_none")] + pub copilot_plan: Option, + /// Authentication host URL + #[serde(skip_serializing_if = "Option::is_none")] + pub host: Option, + /// Whether the session has resolved authentication + pub is_authenticated: bool, + /// Authenticated login/username, if available + #[serde(skip_serializing_if = "Option::is_none")] + pub login: Option, + /// Human-readable authentication status description + #[serde(skip_serializing_if = "Option::is_none")] + pub status_message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsAppendFileRequest { + /// Content to append + pub content: String, + /// Optional POSIX-style mode for newly created files + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// Path using SessionFs conventions + pub path: String, +} + +/// Describes a filesystem error. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsError { + /// Error classification + pub code: SessionFsErrorCode, + /// Free-form detail about the error, for logging/diagnostics + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsExistsRequest { + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsExistsResult { + /// Whether the path exists + pub exists: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsMkdirRequest { + /// Optional POSIX-style mode for newly created directories + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// Path using SessionFs conventions + pub path: String, + /// Create parent directories as needed + #[serde(skip_serializing_if = "Option::is_none")] + pub recursive: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReaddirRequest { + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReaddirResult { + /// Entry names in the directory + pub entries: Vec, + /// Describes a filesystem error. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReaddirWithTypesEntry { + /// Entry name + pub name: String, + /// Entry type + pub r#type: SessionFsReaddirWithTypesEntryType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReaddirWithTypesRequest { + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReaddirWithTypesResult { + /// Directory entries with type information + pub entries: Vec, + /// Describes a filesystem error. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReadFileRequest { + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsReadFileResult { + /// File content as UTF-8 string + pub content: String, + /// Describes a filesystem error. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsRenameRequest { + /// Destination path using SessionFs conventions + pub dest: String, + /// Source path using SessionFs conventions + pub src: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsRmRequest { + /// Ignore errors if the path does not exist + #[serde(skip_serializing_if = "Option::is_none")] + pub force: Option, + /// Path using SessionFs conventions + pub path: String, + /// Remove directories and their contents recursively + #[serde(skip_serializing_if = "Option::is_none")] + pub recursive: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsSetProviderRequest { + /// Path conventions used by this filesystem + pub conventions: SessionFsSetProviderConventions, + /// Initial working directory for sessions + pub initial_cwd: String, + /// Path within each session's SessionFs where the runtime stores files for that session + pub session_state_path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsSetProviderResult { + /// Whether the provider was set successfully + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsStatRequest { + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsStatResult { + /// ISO 8601 timestamp of creation + pub birthtime: String, + /// Describes a filesystem error. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Whether the path is a directory + pub is_directory: bool, + /// Whether the path is a file + pub is_file: bool, + /// ISO 8601 timestamp of last modification + pub mtime: String, + /// File size in bytes + pub size: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFsWriteFileRequest { + /// Content to write + pub content: String, + /// Optional POSIX-style mode for newly created files + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// Path using SessionFs conventions + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionsForkRequest { + /// Source session ID to fork from + pub session_id: SessionId, + /// Optional event ID boundary. When provided, the fork includes only events before this ID (exclusive). When omitted, all events are included. + #[serde(skip_serializing_if = "Option::is_none")] + pub to_event_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionsForkResult { + /// The new forked session's ID + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShellExecRequest { + /// Shell command to execute + pub command: String, + /// Working directory (defaults to session working directory) + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + /// Timeout in milliseconds (default: 30000) + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShellExecResult { + /// Unique identifier for tracking streamed output + pub process_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShellKillRequest { + /// Process identifier returned by shell.exec + pub process_id: String, + /// Signal to send (default: SIGTERM) + #[serde(skip_serializing_if = "Option::is_none")] + pub signal: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShellKillResult { + /// Whether the signal was sent successfully + pub killed: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Skill { + /// Description of what the skill does + pub description: String, + /// Whether the skill is currently enabled + pub enabled: bool, + /// Unique identifier for the skill + pub name: String, + /// Absolute path to the skill file + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// Source location type (e.g., project, personal, plugin) + pub source: String, + /// Whether the skill can be invoked by the user as a slash command + pub user_invocable: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillList { + /// Available skills + pub skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsConfigSetDisabledSkillsRequest { + /// List of skill names to disable + pub disabled_skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsDisableRequest { + /// Name of the skill to disable + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsDiscoverRequest { + /// Optional list of project directory paths to scan for project-scoped skills + #[serde(default)] + pub project_paths: Vec, + /// Optional list of additional skill directory paths to include + #[serde(default)] + pub skill_directories: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsEnableRequest { + /// Name of the skill to enable + pub name: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskAgentInfo { + /// ISO 8601 timestamp when the current active period began + #[serde(skip_serializing_if = "Option::is_none")] + pub active_started_at: Option, + /// Accumulated active execution time in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub active_time_ms: Option, + /// Type of agent running this task + pub agent_type: String, + /// Whether the task is currently in the original sync wait and can be moved to background mode. False once it is already backgrounded, idle, finished, or no longer has a promotable sync waiter. + #[serde(skip_serializing_if = "Option::is_none")] + pub can_promote_to_background: Option, + /// ISO 8601 timestamp when the task finished + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, + /// Short description of the task + pub description: String, + /// Error message when the task failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// How the agent is currently being managed by the runtime + #[serde(skip_serializing_if = "Option::is_none")] + pub execution_mode: Option, + /// Unique task identifier + pub id: String, + /// ISO 8601 timestamp when the agent entered idle state + #[serde(skip_serializing_if = "Option::is_none")] + pub idle_since: Option, + /// Most recent response text from the agent + #[serde(skip_serializing_if = "Option::is_none")] + pub latest_response: Option, + /// Model used for the task when specified + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Prompt passed to the agent + pub prompt: String, + /// Result text from the task when available + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + /// ISO 8601 timestamp when the task was started + pub started_at: String, + /// Current lifecycle status of the task + pub status: TaskAgentInfoStatus, + /// Tool call ID associated with this agent task + pub tool_call_id: String, + /// Task kind + pub r#type: TaskAgentInfoType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskList { + /// Currently tracked tasks + pub tasks: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksCancelRequest { + /// Task identifier + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksCancelResult { + /// Whether the task was successfully cancelled + pub cancelled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TaskShellInfo { + /// Whether the shell runs inside a managed PTY session or as an independent background process + pub attachment_mode: TaskShellInfoAttachmentMode, + /// Whether this shell task can be promoted to background mode + #[serde(skip_serializing_if = "Option::is_none")] + pub can_promote_to_background: Option, + /// Command being executed + pub command: String, + /// ISO 8601 timestamp when the task finished + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, + /// Short description of the task + pub description: String, + /// Whether the shell command is currently sync-waited or background-managed + #[serde(skip_serializing_if = "Option::is_none")] + pub execution_mode: Option, + /// Unique task identifier + pub id: String, + /// Path to the detached shell log, when available + #[serde(skip_serializing_if = "Option::is_none")] + pub log_path: Option, + /// Process ID when available + #[serde(skip_serializing_if = "Option::is_none")] + pub pid: Option, + /// ISO 8601 timestamp when the task was started + pub started_at: String, + /// Current lifecycle status of the task + pub status: TaskShellInfoStatus, + /// Task kind + pub r#type: TaskShellInfoType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksPromoteToBackgroundRequest { + /// Task identifier + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksPromoteToBackgroundResult { + /// Whether the task was successfully promoted to background mode + pub promoted: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksRemoveRequest { + /// Task identifier + pub id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksRemoveResult { + /// Whether the task was removed. Returns false if the task does not exist or is still running/idle (cancel it first). + pub removed: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksStartAgentRequest { + /// Type of agent to start (e.g., 'explore', 'task', 'general-purpose') + pub agent_type: String, + /// Short description of the task + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Optional model override + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Short name for the agent, used to generate a human-readable ID + pub name: String, + /// Task prompt for the agent + pub prompt: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TasksStartAgentResult { + /// Generated agent ID for the background task + pub agent_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Tool { + /// Description of what the tool does + pub description: String, + /// Optional instructions for how to use this tool effectively + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + /// Tool identifier (e.g., "bash", "grep", "str_replace_editor") + pub name: String, + /// Optional namespaced name for declarative filtering (e.g., "playwright/navigate" for MCP tools) + #[serde(skip_serializing_if = "Option::is_none")] + pub namespaced_name: Option, + /// JSON Schema for the tool's input parameters + #[serde(default)] + pub parameters: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolCallResult { + /// Error message if the tool call failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Type of the tool result + #[serde(skip_serializing_if = "Option::is_none")] + pub result_type: Option, + /// Text result to send back to the LLM + pub text_result_for_llm: String, + /// Telemetry data from tool execution + #[serde(default)] + pub tool_telemetry: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolList { + /// List of available built-in tools with metadata + pub tools: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolsHandlePendingToolCallRequest { + /// Error message if the tool call failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Request ID of the pending tool call + pub request_id: RequestId, + /// Tool call result (string or expanded result object) + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolsListRequest { + /// Optional model ID — when provided, the returned tool list reflects model-specific overrides + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationArrayAnyOfFieldItemsAnyOf { + pub r#const: String, + pub title: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationArrayAnyOfFieldItems { + pub any_of: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationArrayAnyOfField { + #[serde(default)] + pub default: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub items: UIElicitationArrayAnyOfFieldItems, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_items: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_items: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationArrayAnyOfFieldType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationArrayEnumFieldItems { + pub r#enum: Vec, + pub r#type: UIElicitationArrayEnumFieldItemsType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationArrayEnumField { + #[serde(default)] + pub default: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub items: UIElicitationArrayEnumFieldItems, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_items: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_items: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationArrayEnumFieldType, +} + +/// JSON Schema describing the form fields to present to the user +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationSchema { + /// Form field definitions, keyed by field name + pub properties: HashMap, + /// List of required field names + #[serde(default)] + pub required: Vec, + /// Schema type indicator (always 'object') + pub r#type: UIElicitationSchemaType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationRequest { + /// Message describing what information is needed from the user + pub message: String, + /// JSON Schema describing the form fields to present to the user + pub requested_schema: UIElicitationSchema, +} + +/// The elicitation response (accept with form values, decline, or cancel) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationResponse { + /// The user's response: accept (submitted), decline (rejected), or cancel (dismissed) + pub action: UIElicitationResponseAction, + /// The form values submitted by the user (present when action is 'accept') + #[serde(default)] + pub content: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationResult { + /// Whether the response was accepted. False if the request was already resolved by another client. + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationSchemaPropertyBoolean { + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationSchemaPropertyBooleanType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationSchemaPropertyNumber { + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub maximum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub minimum: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationSchemaPropertyNumberType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationSchemaPropertyString { + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_length: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_length: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationSchemaPropertyStringType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationStringEnumField { + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub r#enum: Vec, + #[serde(default)] + pub enum_names: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationStringEnumFieldType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationStringOneOfFieldOneOf { + pub r#const: String, + pub title: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIElicitationStringOneOfField { + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub one_of: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, + pub r#type: UIElicitationStringOneOfFieldType, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UIHandlePendingElicitationRequest { + /// The unique request ID from the elicitation.requested event + pub request_id: RequestId, + /// The elicitation response (accept with form values, decline, or cancel) + pub result: UIElicitationResponse, +} + +/// Aggregated code change metrics +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetricsCodeChanges { + /// Number of distinct files modified + pub files_modified_count: i64, + /// Total lines of code added + pub lines_added: i64, + /// Total lines of code removed + pub lines_removed: i64, +} + +/// Request count and cost metrics for this model +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetricsModelMetricRequests { + /// User-initiated premium request cost (with multiplier applied) + pub cost: f64, + /// Number of API requests made with this model + pub count: i64, +} + +/// Token usage metrics for this model +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetricsModelMetricUsage { + /// Total tokens read from prompt cache + pub cache_read_tokens: i64, + /// Total tokens written to prompt cache + pub cache_write_tokens: i64, + /// Total input tokens consumed + pub input_tokens: i64, + /// Total output tokens produced + pub output_tokens: i64, + /// Total output tokens used for reasoning + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetricsModelMetric { + /// Request count and cost metrics for this model + pub requests: UsageMetricsModelMetricRequests, + /// Token usage metrics for this model + pub usage: UsageMetricsModelMetricUsage, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageGetMetricsResult { + /// Aggregated code change metrics + pub code_changes: UsageMetricsCodeChanges, + /// Currently active model identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub current_model: Option, + /// Input tokens from the most recent main-agent API call + pub last_call_input_tokens: i64, + /// Output tokens from the most recent main-agent API call + pub last_call_output_tokens: i64, + /// Per-model token and request metrics, keyed by model identifier + pub model_metrics: HashMap, + /// Session start timestamp (epoch milliseconds) + pub session_start_time: i64, + /// Total time spent in model API calls (milliseconds) + pub total_api_duration_ms: f64, + /// Total user-initiated premium request cost across all models (may be fractional due to multipliers) + pub total_premium_request_cost: f64, + /// Raw count of user-initiated API requests + pub total_user_requests: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesCreateFileRequest { + /// File content to write as a UTF-8 string + pub content: String, + /// Relative path within the workspace files directory + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesGetWorkspaceResultWorkspace { + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + #[serde( + rename = "chronicle_sync_dismissed", + skip_serializing_if = "Option::is_none" + )] + pub chronicle_sync_dismissed: Option, + #[serde(rename = "created_at", skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(rename = "git_root", skip_serializing_if = "Option::is_none")] + pub git_root: Option, + #[serde(rename = "host_type", skip_serializing_if = "Option::is_none")] + pub host_type: Option, + pub id: String, + #[serde(rename = "mc_last_event_id", skip_serializing_if = "Option::is_none")] + pub mc_last_event_id: Option, + #[serde(rename = "mc_session_id", skip_serializing_if = "Option::is_none")] + pub mc_session_id: Option, + #[serde(rename = "mc_task_id", skip_serializing_if = "Option::is_none")] + pub mc_task_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(rename = "remote_steerable", skip_serializing_if = "Option::is_none")] + pub remote_steerable: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + #[serde(rename = "session_sync_level", skip_serializing_if = "Option::is_none")] + pub session_sync_level: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, + #[serde(rename = "summary_count", skip_serializing_if = "Option::is_none")] + pub summary_count: Option, + #[serde(rename = "updated_at", skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + #[serde(rename = "user_named", skip_serializing_if = "Option::is_none")] + pub user_named: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesGetWorkspaceResult { + /// Current workspace metadata, or null if not available + pub workspace: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesListFilesResult { + /// Relative file paths in the workspace files directory + pub files: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesReadFileRequest { + /// Relative path within the workspace files directory + pub path: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkspacesReadFileResult { + /// File content as a UTF-8 string + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelsListResult { + /// List of available models with full metadata + pub models: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolsListResult { + /// List of available built-in tools with metadata + pub tools: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpConfigListResult { + /// All MCP servers from user config, keyed by name + pub servers: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsDiscoverResult { + /// All discovered skills across all sources + pub skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAuthGetStatusParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAuthGetStatusResult { + /// Authentication type + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_type: Option, + /// Copilot plan tier (e.g., individual_pro, business) + #[serde(skip_serializing_if = "Option::is_none")] + pub copilot_plan: Option, + /// Authentication host URL + #[serde(skip_serializing_if = "Option::is_none")] + pub host: Option, + /// Whether the session has resolved authentication + pub is_authenticated: bool, + /// Authenticated login/username, if available + #[serde(skip_serializing_if = "Option::is_none")] + pub login: Option, + /// Human-readable authentication status description + #[serde(skip_serializing_if = "Option::is_none")] + pub status_message: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModelGetCurrentParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModelGetCurrentResult { + /// Currently active model identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModelSwitchToResult { + /// Currently active model identifier after the switch + #[serde(skip_serializing_if = "Option::is_none")] + pub model_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModeGetParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionNameGetParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionNameGetResult { + /// The session name (user-set or auto-generated), or null if not yet set + pub name: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPlanReadParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPlanReadResult { + /// The content of the plan file, or null if it does not exist + pub content: Option, + /// Whether the plan file exists in the workspace + pub exists: bool, + /// Absolute file path of the plan file, or null if workspace is not enabled + pub path: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPlanDeleteParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesGetWorkspaceParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesGetWorkspaceResultWorkspace { + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + #[serde( + rename = "chronicle_sync_dismissed", + skip_serializing_if = "Option::is_none" + )] + pub chronicle_sync_dismissed: Option, + #[serde(rename = "created_at", skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, + #[serde(rename = "git_root", skip_serializing_if = "Option::is_none")] + pub git_root: Option, + #[serde(rename = "host_type", skip_serializing_if = "Option::is_none")] + pub host_type: Option, + pub id: String, + #[serde(rename = "mc_last_event_id", skip_serializing_if = "Option::is_none")] + pub mc_last_event_id: Option, + #[serde(rename = "mc_session_id", skip_serializing_if = "Option::is_none")] + pub mc_session_id: Option, + #[serde(rename = "mc_task_id", skip_serializing_if = "Option::is_none")] + pub mc_task_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(rename = "remote_steerable", skip_serializing_if = "Option::is_none")] + pub remote_steerable: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + #[serde(rename = "session_sync_level", skip_serializing_if = "Option::is_none")] + pub session_sync_level: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, + #[serde(rename = "summary_count", skip_serializing_if = "Option::is_none")] + pub summary_count: Option, + #[serde(rename = "updated_at", skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + #[serde(rename = "user_named", skip_serializing_if = "Option::is_none")] + pub user_named: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesGetWorkspaceResult { + /// Current workspace metadata, or null if not available + pub workspace: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesListFilesParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesListFilesResult { + /// Relative file paths in the workspace files directory + pub files: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspacesReadFileResult { + /// File content as a UTF-8 string + pub content: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionInstructionsGetSourcesParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionInstructionsGetSourcesResult { + /// Instruction sources for the session + pub sources: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionFleetStartResult { + /// Whether fleet mode was successfully activated + pub started: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentListResult { + /// Available custom agents + pub agents: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentGetCurrentParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentGetCurrentResult { + /// Currently selected custom agent, or null if using the default agent + pub agent: AgentInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentSelectResult { + /// The newly selected custom agent + pub agent: AgentInfo, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentDeselectParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentReloadParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionAgentReloadResult { + /// Reloaded custom agents + pub agents: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksStartAgentResult { + /// Generated agent ID for the background task + pub agent_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksListResult { + /// Currently tracked tasks + pub tasks: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksPromoteToBackgroundResult { + /// Whether the task was successfully promoted to background mode + pub promoted: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksCancelResult { + /// Whether the task was successfully cancelled + pub cancelled: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTasksRemoveResult { + /// Whether the task was removed. Returns false if the task does not exist or is still running/idle (cancel it first). + pub removed: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSkillsListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSkillsListResult { + /// Available skills + pub skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSkillsReloadParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpListResult { + /// Configured MCP servers + pub servers: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpReloadParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpOauthLoginResult { + /// URL the caller should open in a browser to complete OAuth. Omitted when cached tokens were still valid and no browser interaction was needed — the server is already reconnected in that case. When present, the runtime starts the callback listener before returning and continues the flow in the background; completion is signaled via session.mcp_server_status_changed. + #[serde(skip_serializing_if = "Option::is_none")] + pub authorization_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPluginsListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPluginsListResult { + /// Installed plugins + pub plugins: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionExtensionsListParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionExtensionsListResult { + /// Discovered extensions and their current status + pub extensions: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionExtensionsReloadParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionToolsHandlePendingToolCallResult { + /// Whether the tool call result was handled successfully + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCommandsHandlePendingCommandResult { + /// Whether the command was handled successfully + pub success: bool, +} + +/// The elicitation response (accept with form values, decline, or cancel) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionUiElicitationResult { + /// The user's response: accept (submitted), decline (rejected), or cancel (dismissed) + pub action: UIElicitationResponseAction, + /// The form values submitted by the user (present when action is 'accept') + #[serde(default)] + pub content: HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionUiHandlePendingElicitationResult { + /// Whether the response was accepted. False if the request was already resolved by another client. + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPermissionsHandlePendingPermissionRequestResult { + /// Whether the permission request was handled successfully + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPermissionsSetApproveAllResult { + /// Whether the operation succeeded + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPermissionsResetSessionApprovalsResult { + /// Whether the operation succeeded + pub success: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionLogResult { + /// The unique identifier of the emitted session event + pub event_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionShellExecResult { + /// Unique identifier for tracking streamed output + pub process_id: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionShellKillResult { + /// Whether the signal was sent successfully + pub killed: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionHistoryCompactParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionHistoryCompactResult { + /// Post-compaction context window usage breakdown + #[serde(skip_serializing_if = "Option::is_none")] + pub context_window: Option, + /// Number of messages removed during compaction + pub messages_removed: i64, + /// Whether compaction completed successfully + pub success: bool, + /// Number of tokens freed by compaction + pub tokens_removed: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionHistoryTruncateResult { + /// Number of events that were removed + pub events_removed: i64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionUsageGetMetricsParams { + /// Target session identifier + pub session_id: SessionId, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionUsageGetMetricsResult { + /// Aggregated code change metrics + pub code_changes: UsageMetricsCodeChanges, + /// Currently active model identifier + #[serde(skip_serializing_if = "Option::is_none")] + pub current_model: Option, + /// Input tokens from the most recent main-agent API call + pub last_call_input_tokens: i64, + /// Output tokens from the most recent main-agent API call + pub last_call_output_tokens: i64, + /// Per-model token and request metrics, keyed by model identifier + pub model_metrics: HashMap, + /// Session start timestamp (epoch milliseconds) + pub session_start_time: i64, + /// Total time spent in model API calls (milliseconds) + pub total_api_duration_ms: f64, + /// Total user-initiated premium request cost across all models (may be fractional due to multipliers) + pub total_premium_request_cost: f64, + /// Raw count of user-initiated API requests + pub total_user_requests: i64, +} + +/// Authentication type +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AuthInfoType { + #[serde(rename = "hmac")] + Hmac, + #[serde(rename = "env")] + Env, + #[serde(rename = "user")] + User, + #[serde(rename = "gh-cli")] + GhCli, + #[serde(rename = "api-key")] + ApiKey, + #[serde(rename = "token")] + Token, + #[serde(rename = "copilot-api-token")] + CopilotApiToken, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Configuration source +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DiscoveredMcpServerSource { + #[serde(rename = "user")] + User, + #[serde(rename = "workspace")] + Workspace, + #[serde(rename = "plugin")] + Plugin, + #[serde(rename = "builtin")] + Builtin, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Server transport type: stdio, http, sse, or memory (local configs are normalized to stdio) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum DiscoveredMcpServerType { + #[serde(rename = "stdio")] + Stdio, + #[serde(rename = "http")] + Http, + #[serde(rename = "sse")] + Sse, + #[serde(rename = "memory")] + Memory, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Discovery source: project (.github/extensions/) or user (~/.copilot/extensions/) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExtensionSource { + #[serde(rename = "project")] + Project, + #[serde(rename = "user")] + User, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Current status: running, disabled, failed, or starting +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExtensionStatus { + #[serde(rename = "running")] + Running, + #[serde(rename = "disabled")] + Disabled, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "starting")] + Starting, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum FilterMappingString { + #[serde(rename = "none")] + None, + #[serde(rename = "markdown")] + Markdown, + #[serde(rename = "hidden_characters")] + HiddenCharacters, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum FilterMappingValue { + #[serde(rename = "none")] + None, + #[serde(rename = "markdown")] + Markdown, + #[serde(rename = "hidden_characters")] + HiddenCharacters, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Where this source lives — used for UI grouping +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum InstructionsSourcesLocation { + #[serde(rename = "user")] + User, + #[serde(rename = "repository")] + Repository, + #[serde(rename = "working-directory")] + WorkingDirectory, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Category of instruction source — used for merge logic +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum InstructionsSourcesType { + #[serde(rename = "home")] + Home, + #[serde(rename = "repo")] + Repo, + #[serde(rename = "model")] + Model, + #[serde(rename = "vscode")] + Vscode, + #[serde(rename = "nested-agents")] + NestedAgents, + #[serde(rename = "child-instructions")] + ChildInstructions, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Log severity level. Determines how the message is displayed in the timeline. Defaults to "info". +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionLogLevel { + #[serde(rename = "info")] + Info, + #[serde(rename = "warning")] + Warning, + #[serde(rename = "error")] + Error, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Configuration source: user, workspace, plugin, or builtin +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerSource { + #[serde(rename = "user")] + User, + #[serde(rename = "workspace")] + Workspace, + #[serde(rename = "plugin")] + Plugin, + #[serde(rename = "builtin")] + Builtin, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Connection status: connected, failed, needs-auth, pending, disabled, or not_configured +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerStatus { + #[serde(rename = "connected")] + Connected, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "needs-auth")] + NeedsAuth, + #[serde(rename = "pending")] + Pending, + #[serde(rename = "disabled")] + Disabled, + #[serde(rename = "not_configured")] + NotConfigured, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Remote transport type. Defaults to "http" when omitted. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerConfigHttpType { + #[serde(rename = "http")] + Http, + #[serde(rename = "sse")] + Sse, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerConfigLocalType { + #[serde(rename = "local")] + Local, + #[serde(rename = "stdio")] + Stdio, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// The agent mode. Valid values: "interactive", "plan", "autopilot". +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionMode { + #[serde(rename = "interactive")] + Interactive, + #[serde(rename = "plan")] + Plan, + #[serde(rename = "autopilot")] + Autopilot, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalCommandsKind { + #[serde(rename = "commands")] + Commands, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalReadKind { + #[serde(rename = "read")] + Read, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalWriteKind { + #[serde(rename = "write")] + Write, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalMcpKind { + #[serde(rename = "mcp")] + Mcp, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalMcpSamplingKind { + #[serde(rename = "mcp-sampling")] + McpSampling, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalMemoryKind { + #[serde(rename = "memory")] + Memory, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationApprovalCustomToolKind { + #[serde(rename = "custom-tool")] + CustomTool, +} + +/// The approval to persist for this location +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionDecisionApproveForLocationApproval { + Commands(PermissionDecisionApproveForLocationApprovalCommands), + Read(PermissionDecisionApproveForLocationApprovalRead), + Write(PermissionDecisionApproveForLocationApprovalWrite), + Mcp(PermissionDecisionApproveForLocationApprovalMcp), + McpSampling(PermissionDecisionApproveForLocationApprovalMcpSampling), + Memory(PermissionDecisionApproveForLocationApprovalMemory), + CustomTool(PermissionDecisionApproveForLocationApprovalCustomTool), +} + +/// Approved and persisted for this project location +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForLocationKind { + #[serde(rename = "approve-for-location")] + ApproveForLocation, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalCommandsKind { + #[serde(rename = "commands")] + Commands, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalReadKind { + #[serde(rename = "read")] + Read, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalWriteKind { + #[serde(rename = "write")] + Write, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalMcpKind { + #[serde(rename = "mcp")] + Mcp, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalMcpSamplingKind { + #[serde(rename = "mcp-sampling")] + McpSampling, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalMemoryKind { + #[serde(rename = "memory")] + Memory, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionApprovalCustomToolKind { + #[serde(rename = "custom-tool")] + CustomTool, +} + +/// The approval to add as a session-scoped rule +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionDecisionApproveForSessionApproval { + Commands(PermissionDecisionApproveForSessionApprovalCommands), + Read(PermissionDecisionApproveForSessionApprovalRead), + Write(PermissionDecisionApproveForSessionApprovalWrite), + Mcp(PermissionDecisionApproveForSessionApprovalMcp), + McpSampling(PermissionDecisionApproveForSessionApprovalMcpSampling), + Memory(PermissionDecisionApproveForSessionApprovalMemory), + CustomTool(PermissionDecisionApproveForSessionApprovalCustomTool), +} + +/// Approved and remembered for the rest of the session +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveForSessionKind { + #[serde(rename = "approve-for-session")] + ApproveForSession, +} + +/// The permission request was approved for this one instance +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionApproveOnceKind { + #[serde(rename = "approve-once")] + ApproveOnce, +} + +/// Denied by the user during an interactive prompt +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionRejectKind { + #[serde(rename = "reject")] + Reject, +} + +/// Denied because user confirmation was unavailable +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionDecisionUserNotAvailableKind { + #[serde(rename = "user-not-available")] + UserNotAvailable, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionDecision { + ApproveOnce(PermissionDecisionApproveOnce), + ApproveForSession(PermissionDecisionApproveForSession), + ApproveForLocation(PermissionDecisionApproveForLocation), + Reject(PermissionDecisionReject), + UserNotAvailable(PermissionDecisionUserNotAvailable), +} + +/// Error classification +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionFsErrorCode { + ENOENT, + UNKNOWN, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Entry type +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionFsReaddirWithTypesEntryType { + #[serde(rename = "file")] + File, + #[serde(rename = "directory")] + Directory, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Path conventions used by this filesystem +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionFsSetProviderConventions { + #[serde(rename = "windows")] + Windows, + #[serde(rename = "posix")] + Posix, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Signal to send (default: SIGTERM) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ShellKillSignal { + SIGTERM, + SIGKILL, + SIGINT, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// How the agent is currently being managed by the runtime +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskAgentInfoExecutionMode { + #[serde(rename = "sync")] + Sync, + #[serde(rename = "background")] + Background, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Current lifecycle status of the task +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskAgentInfoStatus { + #[serde(rename = "running")] + Running, + #[serde(rename = "idle")] + Idle, + #[serde(rename = "completed")] + Completed, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "cancelled")] + Cancelled, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Task kind +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskAgentInfoType { + #[serde(rename = "agent")] + Agent, +} + +/// Whether the shell runs inside a managed PTY session or as an independent background process +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskShellInfoAttachmentMode { + #[serde(rename = "attached")] + Attached, + #[serde(rename = "detached")] + Detached, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Whether the shell command is currently sync-waited or background-managed +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskShellInfoExecutionMode { + #[serde(rename = "sync")] + Sync, + #[serde(rename = "background")] + Background, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Current lifecycle status of the task +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskShellInfoStatus { + #[serde(rename = "running")] + Running, + #[serde(rename = "idle")] + Idle, + #[serde(rename = "completed")] + Completed, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "cancelled")] + Cancelled, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Task kind +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum TaskShellInfoType { + #[serde(rename = "shell")] + Shell, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationArrayAnyOfFieldType { + #[serde(rename = "array")] + Array, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationArrayEnumFieldItemsType { + #[serde(rename = "string")] + String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationArrayEnumFieldType { + #[serde(rename = "array")] + Array, +} + +/// Schema type indicator (always 'object') +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationSchemaType { + #[serde(rename = "object")] + Object, +} + +/// The user's response: accept (submitted), decline (rejected), or cancel (dismissed) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationResponseAction { + #[serde(rename = "accept")] + Accept, + #[serde(rename = "decline")] + Decline, + #[serde(rename = "cancel")] + Cancel, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationSchemaPropertyBooleanType { + #[serde(rename = "boolean")] + Boolean, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationSchemaPropertyNumberType { + #[serde(rename = "number")] + Number, + #[serde(rename = "integer")] + Integer, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationSchemaPropertyStringFormat { + #[serde(rename = "email")] + Email, + #[serde(rename = "uri")] + Uri, + #[serde(rename = "date")] + Date, + #[serde(rename = "date-time")] + DateTime, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationSchemaPropertyStringType { + #[serde(rename = "string")] + String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationStringEnumFieldType { + #[serde(rename = "string")] + String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UIElicitationStringOneOfFieldType { + #[serde(rename = "string")] + String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum WorkspacesGetWorkspaceResultWorkspaceHostType { + #[serde(rename = "github")] + Github, + #[serde(rename = "ado")] + Ado, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum WorkspacesGetWorkspaceResultWorkspaceSessionSyncLevel { + #[serde(rename = "local")] + Local, + #[serde(rename = "user")] + User, + #[serde(rename = "repo_and_user")] + RepoAndUser, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionWorkspacesGetWorkspaceResultWorkspaceHostType { + #[serde(rename = "github")] + Github, + #[serde(rename = "ado")] + Ado, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionWorkspacesGetWorkspaceResultWorkspaceSessionSyncLevel { + #[serde(rename = "local")] + Local, + #[serde(rename = "user")] + User, + #[serde(rename = "repo_and_user")] + RepoAndUser, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} diff --git a/rust/src/generated/mod.rs b/rust/src/generated/mod.rs new file mode 100644 index 000000000..5466a5e35 --- /dev/null +++ b/rust/src/generated/mod.rs @@ -0,0 +1,15 @@ +//! Auto-generated protocol types — do not edit manually. +//! +//! Generated from the Copilot protocol JSON Schemas by `scripts/codegen/rust.ts`. +#![allow(missing_docs)] +#![allow(rustdoc::bare_urls)] + +pub mod api_types; +pub mod rpc; +pub mod session_events; + +// Re-export session event types at the module root — no conflicts with +// hand-written types. API types are kept namespaced under `api_types::` +// because some names (Tool, ModelCapabilities, etc.) overlap with the +// hand-written SDK API types in `types.rs`. +pub use session_events::*; diff --git a/rust/src/generated/rpc.rs b/rust/src/generated/rpc.rs new file mode 100644 index 000000000..6f42c73da --- /dev/null +++ b/rust/src/generated/rpc.rs @@ -0,0 +1,1359 @@ +//! Auto-generated typed JSON-RPC namespace — do not edit manually. +//! +//! Generated from `api.schema.json` by `scripts/codegen/rust.ts`. The +//! [`ClientRpc`] and [`SessionRpc`] view structs let callers reach every +//! protocol method through a typed namespace tree, so wire method names +//! and request/response shapes live in exactly one place — this file. + +#![allow(missing_docs)] +#![allow(clippy::too_many_arguments)] + +use super::api_types::{rpc_methods, *}; +use crate::session::Session; +use crate::{Client, Error}; + +/// Typed view over the [`Client`]'s server-level RPC namespace. +#[derive(Clone, Copy)] +pub struct ClientRpc<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpc<'a> { + /// `account.*` sub-namespace. + pub fn account(&self) -> ClientRpcAccount<'a> { + ClientRpcAccount { + client: self.client, + } + } + + /// `mcp.*` sub-namespace. + pub fn mcp(&self) -> ClientRpcMcp<'a> { + ClientRpcMcp { + client: self.client, + } + } + + /// `models.*` sub-namespace. + pub fn models(&self) -> ClientRpcModels<'a> { + ClientRpcModels { + client: self.client, + } + } + + /// `sessionFs.*` sub-namespace. + pub fn session_fs(&self) -> ClientRpcSessionFs<'a> { + ClientRpcSessionFs { + client: self.client, + } + } + + /// `sessions.*` sub-namespace. + pub fn sessions(&self) -> ClientRpcSessions<'a> { + ClientRpcSessions { + client: self.client, + } + } + + /// `skills.*` sub-namespace. + pub fn skills(&self) -> ClientRpcSkills<'a> { + ClientRpcSkills { + client: self.client, + } + } + + /// `tools.*` sub-namespace. + pub fn tools(&self) -> ClientRpcTools<'a> { + ClientRpcTools { + client: self.client, + } + } + + /// Wire method: `ping`. + pub async fn ping(&self, params: PingRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::PING, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `account.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcAccount<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcAccount<'a> { + /// Wire method: `account.getQuota`. + pub async fn get_quota(&self) -> Result { + let wire_params = serde_json::json!({}); + let _value = self + .client + .call(rpc_methods::ACCOUNT_GETQUOTA, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `mcp.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcMcp<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcMcp<'a> { + /// `mcp.config.*` sub-namespace. + pub fn config(&self) -> ClientRpcMcpConfig<'a> { + ClientRpcMcpConfig { + client: self.client, + } + } + + /// Wire method: `mcp.discover`. + pub async fn discover(&self, params: McpDiscoverRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_DISCOVER, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `mcp.config.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcMcpConfig<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcMcpConfig<'a> { + /// Wire method: `mcp.config.list`. + pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({}); + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `mcp.config.add`. + pub async fn add(&self, params: McpConfigAddRequest) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_ADD, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `mcp.config.update`. + pub async fn update(&self, params: McpConfigUpdateRequest) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_UPDATE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `mcp.config.remove`. + pub async fn remove(&self, params: McpConfigRemoveRequest) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_REMOVE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `mcp.config.enable`. + pub async fn enable(&self, params: McpConfigEnableRequest) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_ENABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `mcp.config.disable`. + pub async fn disable(&self, params: McpConfigDisableRequest) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::MCP_CONFIG_DISABLE, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `models.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcModels<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcModels<'a> { + /// Wire method: `models.list`. + pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({}); + let _value = self + .client + .call(rpc_methods::MODELS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `sessionFs.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcSessionFs<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcSessionFs<'a> { + /// Wire method: `sessionFs.setProvider`. + pub async fn set_provider( + &self, + params: SessionFsSetProviderRequest, + ) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::SESSIONFS_SETPROVIDER, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `sessions.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcSessions<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcSessions<'a> { + /// Wire method: `sessions.fork`. + /// Stability: `experimental`. + pub async fn fork(&self, params: SessionsForkRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::SESSIONS_FORK, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `skills.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcSkills<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcSkills<'a> { + /// `skills.config.*` sub-namespace. + pub fn config(&self) -> ClientRpcSkillsConfig<'a> { + ClientRpcSkillsConfig { + client: self.client, + } + } + + /// Wire method: `skills.discover`. + pub async fn discover(&self, params: SkillsDiscoverRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::SKILLS_DISCOVER, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `skills.config.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcSkillsConfig<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcSkillsConfig<'a> { + /// Wire method: `skills.config.setDisabledSkills`. + pub async fn set_disabled_skills( + &self, + params: SkillsConfigSetDisabledSkillsRequest, + ) -> Result<(), Error> { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call( + rpc_methods::SKILLS_CONFIG_SETDISABLEDSKILLS, + Some(wire_params), + ) + .await?; + Ok(()) + } +} + +/// `tools.*` RPCs. +#[derive(Clone, Copy)] +pub struct ClientRpcTools<'a> { + pub(crate) client: &'a Client, +} + +impl<'a> ClientRpcTools<'a> { + /// Wire method: `tools.list`. + pub async fn list(&self, params: ToolsListRequest) -> Result { + let wire_params = serde_json::to_value(params)?; + let _value = self + .client + .call(rpc_methods::TOOLS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// Typed view over a [`Session`]'s RPC namespace. +#[derive(Clone, Copy)] +pub struct SessionRpc<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpc<'a> { + /// `session.agent.*` sub-namespace. + pub fn agent(&self) -> SessionRpcAgent<'a> { + SessionRpcAgent { + session: self.session, + } + } + + /// `session.auth.*` sub-namespace. + pub fn auth(&self) -> SessionRpcAuth<'a> { + SessionRpcAuth { + session: self.session, + } + } + + /// `session.commands.*` sub-namespace. + pub fn commands(&self) -> SessionRpcCommands<'a> { + SessionRpcCommands { + session: self.session, + } + } + + /// `session.extensions.*` sub-namespace. + pub fn extensions(&self) -> SessionRpcExtensions<'a> { + SessionRpcExtensions { + session: self.session, + } + } + + /// `session.fleet.*` sub-namespace. + pub fn fleet(&self) -> SessionRpcFleet<'a> { + SessionRpcFleet { + session: self.session, + } + } + + /// `session.history.*` sub-namespace. + pub fn history(&self) -> SessionRpcHistory<'a> { + SessionRpcHistory { + session: self.session, + } + } + + /// `session.instructions.*` sub-namespace. + pub fn instructions(&self) -> SessionRpcInstructions<'a> { + SessionRpcInstructions { + session: self.session, + } + } + + /// `session.mcp.*` sub-namespace. + pub fn mcp(&self) -> SessionRpcMcp<'a> { + SessionRpcMcp { + session: self.session, + } + } + + /// `session.mode.*` sub-namespace. + pub fn mode(&self) -> SessionRpcMode<'a> { + SessionRpcMode { + session: self.session, + } + } + + /// `session.model.*` sub-namespace. + pub fn model(&self) -> SessionRpcModel<'a> { + SessionRpcModel { + session: self.session, + } + } + + /// `session.name.*` sub-namespace. + pub fn name(&self) -> SessionRpcName<'a> { + SessionRpcName { + session: self.session, + } + } + + /// `session.permissions.*` sub-namespace. + pub fn permissions(&self) -> SessionRpcPermissions<'a> { + SessionRpcPermissions { + session: self.session, + } + } + + /// `session.plan.*` sub-namespace. + pub fn plan(&self) -> SessionRpcPlan<'a> { + SessionRpcPlan { + session: self.session, + } + } + + /// `session.plugins.*` sub-namespace. + pub fn plugins(&self) -> SessionRpcPlugins<'a> { + SessionRpcPlugins { + session: self.session, + } + } + + /// `session.shell.*` sub-namespace. + pub fn shell(&self) -> SessionRpcShell<'a> { + SessionRpcShell { + session: self.session, + } + } + + /// `session.skills.*` sub-namespace. + pub fn skills(&self) -> SessionRpcSkills<'a> { + SessionRpcSkills { + session: self.session, + } + } + + /// `session.tasks.*` sub-namespace. + pub fn tasks(&self) -> SessionRpcTasks<'a> { + SessionRpcTasks { + session: self.session, + } + } + + /// `session.tools.*` sub-namespace. + pub fn tools(&self) -> SessionRpcTools<'a> { + SessionRpcTools { + session: self.session, + } + } + + /// `session.ui.*` sub-namespace. + pub fn ui(&self) -> SessionRpcUi<'a> { + SessionRpcUi { + session: self.session, + } + } + + /// `session.usage.*` sub-namespace. + pub fn usage(&self) -> SessionRpcUsage<'a> { + SessionRpcUsage { + session: self.session, + } + } + + /// `session.workspaces.*` sub-namespace. + pub fn workspaces(&self) -> SessionRpcWorkspaces<'a> { + SessionRpcWorkspaces { + session: self.session, + } + } + + /// Wire method: `session.log`. + pub async fn log(&self, params: LogRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_LOG, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.agent.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcAgent<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcAgent<'a> { + /// Wire method: `session.agent.list`. + /// Stability: `experimental`. + pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AGENT_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.agent.getCurrent`. + /// Stability: `experimental`. + pub async fn get_current(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AGENT_GETCURRENT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.agent.select`. + /// Stability: `experimental`. + pub async fn select(&self, params: AgentSelectRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AGENT_SELECT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.agent.deselect`. + /// Stability: `experimental`. + pub async fn deselect(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AGENT_DESELECT, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.agent.reload`. + /// Stability: `experimental`. + pub async fn reload(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AGENT_RELOAD, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.auth.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcAuth<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcAuth<'a> { + /// Wire method: `session.auth.getStatus`. + pub async fn get_status(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_AUTH_GETSTATUS, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.commands.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcCommands<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcCommands<'a> { + /// Wire method: `session.commands.handlePendingCommand`. + pub async fn handle_pending_command( + &self, + params: CommandsHandlePendingCommandRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_COMMANDS_HANDLEPENDINGCOMMAND, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.extensions.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcExtensions<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcExtensions<'a> { + /// Wire method: `session.extensions.list`. + /// Stability: `experimental`. + pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_EXTENSIONS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.extensions.enable`. + /// Stability: `experimental`. + pub async fn enable(&self, params: ExtensionsEnableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_EXTENSIONS_ENABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.extensions.disable`. + /// Stability: `experimental`. + pub async fn disable(&self, params: ExtensionsDisableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_EXTENSIONS_DISABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.extensions.reload`. + /// Stability: `experimental`. + pub async fn reload(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_EXTENSIONS_RELOAD, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.fleet.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcFleet<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcFleet<'a> { + /// Wire method: `session.fleet.start`. + /// Stability: `experimental`. + pub async fn start(&self, params: FleetStartRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_FLEET_START, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.history.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcHistory<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcHistory<'a> { + /// Wire method: `session.history.compact`. + /// Stability: `experimental`. + pub async fn compact(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_HISTORY_COMPACT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.history.truncate`. + /// Stability: `experimental`. + pub async fn truncate( + &self, + params: HistoryTruncateRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_HISTORY_TRUNCATE, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.instructions.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcInstructions<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcInstructions<'a> { + /// Wire method: `session.instructions.getSources`. + pub async fn get_sources(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_INSTRUCTIONS_GETSOURCES, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.mcp.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcMcp<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcMcp<'a> { + /// `session.mcp.oauth.*` sub-namespace. + pub fn oauth(&self) -> SessionRpcMcpOauth<'a> { + SessionRpcMcpOauth { + session: self.session, + } + } + + /// Wire method: `session.mcp.list`. + /// Stability: `experimental`. + pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MCP_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.mcp.enable`. + /// Stability: `experimental`. + pub async fn enable(&self, params: McpEnableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MCP_ENABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.mcp.disable`. + /// Stability: `experimental`. + pub async fn disable(&self, params: McpDisableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MCP_DISABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.mcp.reload`. + /// Stability: `experimental`. + pub async fn reload(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MCP_RELOAD, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.mcp.oauth.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcMcpOauth<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcMcpOauth<'a> { + /// Wire method: `session.mcp.oauth.login`. + /// Stability: `experimental`. + pub async fn login(&self, params: McpOauthLoginRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MCP_OAUTH_LOGIN, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.mode.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcMode<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcMode<'a> { + /// Wire method: `session.mode.get`. + pub async fn get(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MODE_GET, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.mode.set`. + pub async fn set(&self, params: ModeSetRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MODE_SET, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.model.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcModel<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcModel<'a> { + /// Wire method: `session.model.getCurrent`. + pub async fn get_current(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MODEL_GETCURRENT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.model.switchTo`. + pub async fn switch_to( + &self, + params: ModelSwitchToRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_MODEL_SWITCHTO, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.name.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcName<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcName<'a> { + /// Wire method: `session.name.get`. + pub async fn get(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_NAME_GET, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.name.set`. + pub async fn set(&self, params: NameSetRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_NAME_SET, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.permissions.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcPermissions<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcPermissions<'a> { + /// Wire method: `session.permissions.handlePendingPermissionRequest`. + pub async fn handle_pending_permission_request( + &self, + params: PermissionDecisionRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_PERMISSIONS_HANDLEPENDINGPERMISSIONREQUEST, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.permissions.setApproveAll`. + pub async fn set_approve_all( + &self, + params: PermissionsSetApproveAllRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_PERMISSIONS_SETAPPROVEALL, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.permissions.resetSessionApprovals`. + pub async fn reset_session_approvals( + &self, + ) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_PERMISSIONS_RESETSESSIONAPPROVALS, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.plan.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcPlan<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcPlan<'a> { + /// Wire method: `session.plan.read`. + pub async fn read(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_PLAN_READ, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.plan.update`. + pub async fn update(&self, params: PlanUpdateRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_PLAN_UPDATE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.plan.delete`. + pub async fn delete(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_PLAN_DELETE, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.plugins.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcPlugins<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcPlugins<'a> { + /// Wire method: `session.plugins.list`. + /// Stability: `experimental`. + pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_PLUGINS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.shell.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcShell<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcShell<'a> { + /// Wire method: `session.shell.exec`. + pub async fn exec(&self, params: ShellExecRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SHELL_EXEC, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.shell.kill`. + pub async fn kill(&self, params: ShellKillRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SHELL_KILL, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.skills.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcSkills<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcSkills<'a> { + /// Wire method: `session.skills.list`. + /// Stability: `experimental`. + pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SKILLS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.skills.enable`. + /// Stability: `experimental`. + pub async fn enable(&self, params: SkillsEnableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SKILLS_ENABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.skills.disable`. + /// Stability: `experimental`. + pub async fn disable(&self, params: SkillsDisableRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SKILLS_DISABLE, Some(wire_params)) + .await?; + Ok(()) + } + + /// Wire method: `session.skills.reload`. + /// Stability: `experimental`. + pub async fn reload(&self) -> Result<(), Error> { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_SKILLS_RELOAD, Some(wire_params)) + .await?; + Ok(()) + } +} + +/// `session.tasks.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcTasks<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcTasks<'a> { + /// Wire method: `session.tasks.startAgent`. + /// Stability: `experimental`. + pub async fn start_agent( + &self, + params: TasksStartAgentRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_TASKS_STARTAGENT, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.tasks.list`. + /// Stability: `experimental`. + pub async fn list(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_TASKS_LIST, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.tasks.promoteToBackground`. + /// Stability: `experimental`. + pub async fn promote_to_background( + &self, + params: TasksPromoteToBackgroundRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_TASKS_PROMOTETOBACKGROUND, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.tasks.cancel`. + /// Stability: `experimental`. + pub async fn cancel(&self, params: TasksCancelRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_TASKS_CANCEL, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.tasks.remove`. + /// Stability: `experimental`. + pub async fn remove(&self, params: TasksRemoveRequest) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_TASKS_REMOVE, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.tools.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcTools<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcTools<'a> { + /// Wire method: `session.tools.handlePendingToolCall`. + pub async fn handle_pending_tool_call( + &self, + params: ToolsHandlePendingToolCallRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_TOOLS_HANDLEPENDINGTOOLCALL, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.ui.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcUi<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcUi<'a> { + /// Wire method: `session.ui.elicitation`. + pub async fn elicitation( + &self, + params: UIElicitationRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_UI_ELICITATION, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.ui.handlePendingElicitation`. + pub async fn handle_pending_elicitation( + &self, + params: UIHandlePendingElicitationRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_UI_HANDLEPENDINGELICITATION, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.usage.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcUsage<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcUsage<'a> { + /// Wire method: `session.usage.getMetrics`. + /// Stability: `experimental`. + pub async fn get_metrics(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_USAGE_GETMETRICS, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } +} + +/// `session.workspaces.*` RPCs. +#[derive(Clone, Copy)] +pub struct SessionRpcWorkspaces<'a> { + pub(crate) session: &'a Session, +} + +impl<'a> SessionRpcWorkspaces<'a> { + /// Wire method: `session.workspaces.getWorkspace`. + pub async fn get_workspace(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_WORKSPACES_GETWORKSPACE, + Some(wire_params), + ) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.workspaces.listFiles`. + pub async fn list_files(&self) -> Result { + let wire_params = serde_json::json!({ "sessionId": self.session.id() }); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_WORKSPACES_LISTFILES, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.workspaces.readFile`. + pub async fn read_file( + &self, + params: WorkspacesReadFileRequest, + ) -> Result { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call(rpc_methods::SESSION_WORKSPACES_READFILE, Some(wire_params)) + .await?; + Ok(serde_json::from_value(_value)?) + } + + /// Wire method: `session.workspaces.createFile`. + pub async fn create_file(&self, params: WorkspacesCreateFileRequest) -> Result<(), Error> { + let mut wire_params = serde_json::to_value(params)?; + wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string()); + let _value = self + .session + .client() + .call( + rpc_methods::SESSION_WORKSPACES_CREATEFILE, + Some(wire_params), + ) + .await?; + Ok(()) + } +} diff --git a/rust/src/generated/session_events.rs b/rust/src/generated/session_events.rs new file mode 100644 index 000000000..d8ad34f2f --- /dev/null +++ b/rust/src/generated/session_events.rs @@ -0,0 +1,2767 @@ +//! Auto-generated from session-events.schema.json — do not edit manually. + +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::types::{RequestId, SessionId}; + +/// Identifies the kind of session event. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub enum SessionEventType { + #[serde(rename = "session.start")] + SessionStart, + #[serde(rename = "session.resume")] + SessionResume, + #[serde(rename = "session.remote_steerable_changed")] + SessionRemoteSteerableChanged, + #[serde(rename = "session.error")] + SessionError, + #[serde(rename = "session.idle")] + SessionIdle, + #[serde(rename = "session.title_changed")] + SessionTitleChanged, + #[serde(rename = "session.info")] + SessionInfo, + #[serde(rename = "session.warning")] + SessionWarning, + #[serde(rename = "session.model_change")] + SessionModelChange, + #[serde(rename = "session.mode_changed")] + SessionModeChanged, + #[serde(rename = "session.plan_changed")] + SessionPlanChanged, + #[serde(rename = "session.workspace_file_changed")] + SessionWorkspaceFileChanged, + #[serde(rename = "session.handoff")] + SessionHandoff, + #[serde(rename = "session.truncation")] + SessionTruncation, + #[serde(rename = "session.snapshot_rewind")] + SessionSnapshotRewind, + #[serde(rename = "session.shutdown")] + SessionShutdown, + #[serde(rename = "session.context_changed")] + SessionContextChanged, + #[serde(rename = "session.usage_info")] + SessionUsageInfo, + #[serde(rename = "session.compaction_start")] + SessionCompactionStart, + #[serde(rename = "session.compaction_complete")] + SessionCompactionComplete, + #[serde(rename = "session.task_complete")] + SessionTaskComplete, + #[serde(rename = "user.message")] + UserMessage, + #[serde(rename = "pending_messages.modified")] + PendingMessagesModified, + #[serde(rename = "assistant.turn_start")] + AssistantTurnStart, + #[serde(rename = "assistant.intent")] + AssistantIntent, + #[serde(rename = "assistant.reasoning")] + AssistantReasoning, + #[serde(rename = "assistant.reasoning_delta")] + AssistantReasoningDelta, + #[serde(rename = "assistant.streaming_delta")] + AssistantStreamingDelta, + #[serde(rename = "assistant.message")] + AssistantMessage, + #[serde(rename = "assistant.message_delta")] + AssistantMessageDelta, + #[serde(rename = "assistant.turn_end")] + AssistantTurnEnd, + #[serde(rename = "assistant.usage")] + AssistantUsage, + #[serde(rename = "model.call_failure")] + ModelCallFailure, + #[serde(rename = "abort")] + Abort, + #[serde(rename = "tool.user_requested")] + ToolUserRequested, + #[serde(rename = "tool.execution_start")] + ToolExecutionStart, + #[serde(rename = "tool.execution_partial_result")] + ToolExecutionPartialResult, + #[serde(rename = "tool.execution_progress")] + ToolExecutionProgress, + #[serde(rename = "tool.execution_complete")] + ToolExecutionComplete, + #[serde(rename = "skill.invoked")] + SkillInvoked, + #[serde(rename = "subagent.started")] + SubagentStarted, + #[serde(rename = "subagent.completed")] + SubagentCompleted, + #[serde(rename = "subagent.failed")] + SubagentFailed, + #[serde(rename = "subagent.selected")] + SubagentSelected, + #[serde(rename = "subagent.deselected")] + SubagentDeselected, + #[serde(rename = "hook.start")] + HookStart, + #[serde(rename = "hook.end")] + HookEnd, + #[serde(rename = "system.message")] + SystemMessage, + #[serde(rename = "system.notification")] + SystemNotification, + #[serde(rename = "permission.requested")] + PermissionRequested, + #[serde(rename = "permission.completed")] + PermissionCompleted, + #[serde(rename = "user_input.requested")] + UserInputRequested, + #[serde(rename = "user_input.completed")] + UserInputCompleted, + #[serde(rename = "elicitation.requested")] + ElicitationRequested, + #[serde(rename = "elicitation.completed")] + ElicitationCompleted, + #[serde(rename = "sampling.requested")] + SamplingRequested, + #[serde(rename = "sampling.completed")] + SamplingCompleted, + #[serde(rename = "mcp.oauth_required")] + McpOauthRequired, + #[serde(rename = "mcp.oauth_completed")] + McpOauthCompleted, + #[serde(rename = "external_tool.requested")] + ExternalToolRequested, + #[serde(rename = "external_tool.completed")] + ExternalToolCompleted, + #[serde(rename = "command.queued")] + CommandQueued, + #[serde(rename = "command.execute")] + CommandExecute, + #[serde(rename = "command.completed")] + CommandCompleted, + #[serde(rename = "auto_mode_switch.requested")] + AutoModeSwitchRequested, + #[serde(rename = "auto_mode_switch.completed")] + AutoModeSwitchCompleted, + #[serde(rename = "commands.changed")] + CommandsChanged, + #[serde(rename = "capabilities.changed")] + CapabilitiesChanged, + #[serde(rename = "exit_plan_mode.requested")] + ExitPlanModeRequested, + #[serde(rename = "exit_plan_mode.completed")] + ExitPlanModeCompleted, + #[serde(rename = "session.tools_updated")] + SessionToolsUpdated, + #[serde(rename = "session.background_tasks_changed")] + SessionBackgroundTasksChanged, + #[serde(rename = "session.skills_loaded")] + SessionSkillsLoaded, + #[serde(rename = "session.custom_agents_updated")] + SessionCustomAgentsUpdated, + #[serde(rename = "session.mcp_servers_loaded")] + SessionMcpServersLoaded, + #[serde(rename = "session.mcp_server_status_changed")] + SessionMcpServerStatusChanged, + #[serde(rename = "session.extensions_loaded")] + SessionExtensionsLoaded, + /// Unknown event type for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Typed session event data, discriminated by the event `type` field. +/// +/// Use with [`TypedSessionEvent`] for fully typed event handling. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", content = "data")] +pub enum SessionEventData { + #[serde(rename = "session.start")] + SessionStart(SessionStartData), + #[serde(rename = "session.resume")] + SessionResume(SessionResumeData), + #[serde(rename = "session.remote_steerable_changed")] + SessionRemoteSteerableChanged(SessionRemoteSteerableChangedData), + #[serde(rename = "session.error")] + SessionError(SessionErrorData), + #[serde(rename = "session.idle")] + SessionIdle(SessionIdleData), + #[serde(rename = "session.title_changed")] + SessionTitleChanged(SessionTitleChangedData), + #[serde(rename = "session.info")] + SessionInfo(SessionInfoData), + #[serde(rename = "session.warning")] + SessionWarning(SessionWarningData), + #[serde(rename = "session.model_change")] + SessionModelChange(SessionModelChangeData), + #[serde(rename = "session.mode_changed")] + SessionModeChanged(SessionModeChangedData), + #[serde(rename = "session.plan_changed")] + SessionPlanChanged(SessionPlanChangedData), + #[serde(rename = "session.workspace_file_changed")] + SessionWorkspaceFileChanged(SessionWorkspaceFileChangedData), + #[serde(rename = "session.handoff")] + SessionHandoff(SessionHandoffData), + #[serde(rename = "session.truncation")] + SessionTruncation(SessionTruncationData), + #[serde(rename = "session.snapshot_rewind")] + SessionSnapshotRewind(SessionSnapshotRewindData), + #[serde(rename = "session.shutdown")] + SessionShutdown(SessionShutdownData), + #[serde(rename = "session.context_changed")] + SessionContextChanged(SessionContextChangedData), + #[serde(rename = "session.usage_info")] + SessionUsageInfo(SessionUsageInfoData), + #[serde(rename = "session.compaction_start")] + SessionCompactionStart(SessionCompactionStartData), + #[serde(rename = "session.compaction_complete")] + SessionCompactionComplete(SessionCompactionCompleteData), + #[serde(rename = "session.task_complete")] + SessionTaskComplete(SessionTaskCompleteData), + #[serde(rename = "user.message")] + UserMessage(UserMessageData), + #[serde(rename = "pending_messages.modified")] + PendingMessagesModified(PendingMessagesModifiedData), + #[serde(rename = "assistant.turn_start")] + AssistantTurnStart(AssistantTurnStartData), + #[serde(rename = "assistant.intent")] + AssistantIntent(AssistantIntentData), + #[serde(rename = "assistant.reasoning")] + AssistantReasoning(AssistantReasoningData), + #[serde(rename = "assistant.reasoning_delta")] + AssistantReasoningDelta(AssistantReasoningDeltaData), + #[serde(rename = "assistant.streaming_delta")] + AssistantStreamingDelta(AssistantStreamingDeltaData), + #[serde(rename = "assistant.message")] + AssistantMessage(AssistantMessageData), + #[serde(rename = "assistant.message_delta")] + AssistantMessageDelta(AssistantMessageDeltaData), + #[serde(rename = "assistant.turn_end")] + AssistantTurnEnd(AssistantTurnEndData), + #[serde(rename = "assistant.usage")] + AssistantUsage(AssistantUsageData), + #[serde(rename = "model.call_failure")] + ModelCallFailure(ModelCallFailureData), + #[serde(rename = "abort")] + Abort(AbortData), + #[serde(rename = "tool.user_requested")] + ToolUserRequested(ToolUserRequestedData), + #[serde(rename = "tool.execution_start")] + ToolExecutionStart(ToolExecutionStartData), + #[serde(rename = "tool.execution_partial_result")] + ToolExecutionPartialResult(ToolExecutionPartialResultData), + #[serde(rename = "tool.execution_progress")] + ToolExecutionProgress(ToolExecutionProgressData), + #[serde(rename = "tool.execution_complete")] + ToolExecutionComplete(ToolExecutionCompleteData), + #[serde(rename = "skill.invoked")] + SkillInvoked(SkillInvokedData), + #[serde(rename = "subagent.started")] + SubagentStarted(SubagentStartedData), + #[serde(rename = "subagent.completed")] + SubagentCompleted(SubagentCompletedData), + #[serde(rename = "subagent.failed")] + SubagentFailed(SubagentFailedData), + #[serde(rename = "subagent.selected")] + SubagentSelected(SubagentSelectedData), + #[serde(rename = "subagent.deselected")] + SubagentDeselected(SubagentDeselectedData), + #[serde(rename = "hook.start")] + HookStart(HookStartData), + #[serde(rename = "hook.end")] + HookEnd(HookEndData), + #[serde(rename = "system.message")] + SystemMessage(SystemMessageData), + #[serde(rename = "system.notification")] + SystemNotification(SystemNotificationData), + #[serde(rename = "permission.requested")] + PermissionRequested(PermissionRequestedData), + #[serde(rename = "permission.completed")] + PermissionCompleted(PermissionCompletedData), + #[serde(rename = "user_input.requested")] + UserInputRequested(UserInputRequestedData), + #[serde(rename = "user_input.completed")] + UserInputCompleted(UserInputCompletedData), + #[serde(rename = "elicitation.requested")] + ElicitationRequested(ElicitationRequestedData), + #[serde(rename = "elicitation.completed")] + ElicitationCompleted(ElicitationCompletedData), + #[serde(rename = "sampling.requested")] + SamplingRequested(SamplingRequestedData), + #[serde(rename = "sampling.completed")] + SamplingCompleted(SamplingCompletedData), + #[serde(rename = "mcp.oauth_required")] + McpOauthRequired(McpOauthRequiredData), + #[serde(rename = "mcp.oauth_completed")] + McpOauthCompleted(McpOauthCompletedData), + #[serde(rename = "external_tool.requested")] + ExternalToolRequested(ExternalToolRequestedData), + #[serde(rename = "external_tool.completed")] + ExternalToolCompleted(ExternalToolCompletedData), + #[serde(rename = "command.queued")] + CommandQueued(CommandQueuedData), + #[serde(rename = "command.execute")] + CommandExecute(CommandExecuteData), + #[serde(rename = "command.completed")] + CommandCompleted(CommandCompletedData), + #[serde(rename = "auto_mode_switch.requested")] + AutoModeSwitchRequested(AutoModeSwitchRequestedData), + #[serde(rename = "auto_mode_switch.completed")] + AutoModeSwitchCompleted(AutoModeSwitchCompletedData), + #[serde(rename = "commands.changed")] + CommandsChanged(CommandsChangedData), + #[serde(rename = "capabilities.changed")] + CapabilitiesChanged(CapabilitiesChangedData), + #[serde(rename = "exit_plan_mode.requested")] + ExitPlanModeRequested(ExitPlanModeRequestedData), + #[serde(rename = "exit_plan_mode.completed")] + ExitPlanModeCompleted(ExitPlanModeCompletedData), + #[serde(rename = "session.tools_updated")] + SessionToolsUpdated(SessionToolsUpdatedData), + #[serde(rename = "session.background_tasks_changed")] + SessionBackgroundTasksChanged(SessionBackgroundTasksChangedData), + #[serde(rename = "session.skills_loaded")] + SessionSkillsLoaded(SessionSkillsLoadedData), + #[serde(rename = "session.custom_agents_updated")] + SessionCustomAgentsUpdated(SessionCustomAgentsUpdatedData), + #[serde(rename = "session.mcp_servers_loaded")] + SessionMcpServersLoaded(SessionMcpServersLoadedData), + #[serde(rename = "session.mcp_server_status_changed")] + SessionMcpServerStatusChanged(SessionMcpServerStatusChangedData), + #[serde(rename = "session.extensions_loaded")] + SessionExtensionsLoaded(SessionExtensionsLoadedData), +} + +/// A session event with typed data payload. +/// +/// The common event fields (id, timestamp, parentId, ephemeral) are +/// available directly. The event-specific data is in the `payload` field +/// as a [`SessionEventData`] enum. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct TypedSessionEvent { + /// Unique event identifier (UUID v4). + pub id: String, + /// ISO 8601 timestamp when the event was created. + pub timestamp: String, + /// ID of the preceding event in the chain. + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_id: Option, + /// When true, the event is transient and not persisted. + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral: Option, + /// The typed event payload (discriminated by event type). + #[serde(flatten)] + pub payload: SessionEventData, +} + +/// Working directory and git context at session start +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct WorkingDirectoryContext { + /// Base commit of current git branch at session start time + #[serde(skip_serializing_if = "Option::is_none")] + pub base_commit: Option, + /// Current git branch name + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + /// Current working directory path + pub cwd: String, + /// Root directory of the git repository, resolved via git rev-parse + #[serde(skip_serializing_if = "Option::is_none")] + pub git_root: Option, + /// Head commit of current git branch at session start time + #[serde(skip_serializing_if = "Option::is_none")] + pub head_commit: Option, + /// Hosting platform type of the repository (github or ado) + #[serde(skip_serializing_if = "Option::is_none")] + pub host_type: Option, + /// Repository identifier derived from the git remote URL ("owner/name" for GitHub, "org/project/repo" for Azure DevOps) + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + /// Raw host string from the git remote URL (e.g. "github.com", "mycompany.ghe.com", "dev.azure.com") + #[serde(skip_serializing_if = "Option::is_none")] + pub repository_host: Option, +} + +/// Session initialization metadata including context and configuration +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionStartData { + /// Whether the session was already in use by another client at start time + #[serde(skip_serializing_if = "Option::is_none")] + pub already_in_use: Option, + /// Working directory and git context at session start + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, + /// Version string of the Copilot application + pub copilot_version: String, + /// Identifier of the software producing the events (e.g., "copilot-agent") + pub producer: String, + /// Reasoning effort level used for model calls, if applicable (e.g. "low", "medium", "high", "xhigh") + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Whether this session supports remote steering via Mission Control + #[serde(skip_serializing_if = "Option::is_none")] + pub remote_steerable: Option, + /// Model selected at session creation time, if any + #[serde(skip_serializing_if = "Option::is_none")] + pub selected_model: Option, + /// Unique identifier for the session + pub session_id: SessionId, + /// ISO 8601 timestamp when the session was created + pub start_time: String, + /// Schema version number for the session event format + pub version: f64, +} + +/// Session resume metadata including current context and event count +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionResumeData { + /// Whether the session was already in use by another client at resume time + #[serde(skip_serializing_if = "Option::is_none")] + pub already_in_use: Option, + /// Updated working directory and git context at resume time + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, + /// Total number of persisted events in the session at the time of resume + pub event_count: f64, + /// Reasoning effort level used for model calls, if applicable (e.g. "low", "medium", "high", "xhigh") + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Whether this session supports remote steering via Mission Control + #[serde(skip_serializing_if = "Option::is_none")] + pub remote_steerable: Option, + /// ISO 8601 timestamp when the session was resumed + pub resume_time: String, + /// Model currently selected at resume time + #[serde(skip_serializing_if = "Option::is_none")] + pub selected_model: Option, +} + +/// Notifies Mission Control that the session's remote steering capability has changed +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionRemoteSteerableChangedData { + /// Whether this session now supports remote steering via Mission Control + pub remote_steerable: bool, +} + +/// Error details for timeline display including message and optional diagnostic information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionErrorData { + /// Only set on `errorType: "rate_limit"`. When `true`, the runtime will follow this error with an `auto_mode_switch.requested` event (or silently switch if `continueOnAutoMode` is enabled). UI clients can use this flag to suppress duplicate rendering of the rate-limit error when they show their own auto-mode-switch prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub eligible_for_auto_switch: Option, + /// Fine-grained error code from the upstream provider, when available. For `errorType: "rate_limit"`, this is one of the `RateLimitErrorCode` values (e.g., `"user_weekly_rate_limited"`, `"user_global_rate_limited"`, `"rate_limited"`, `"user_model_rate_limited"`, `"integration_rate_limited"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub error_code: Option, + /// Category of error (e.g., "authentication", "authorization", "quota", "rate_limit", "context_limit", "query") + pub error_type: String, + /// Human-readable error message + pub message: String, + /// GitHub request tracing ID (x-github-request-id header) for correlating with server-side logs + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_call_id: Option, + /// Error stack trace, when available + #[serde(skip_serializing_if = "Option::is_none")] + pub stack: Option, + /// HTTP status code from the upstream request, if applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub status_code: Option, + /// Optional URL associated with this error that the user can open in a browser + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +/// Payload indicating the session is idle with no background agents in flight +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionIdleData { + /// True when the preceding agentic loop was cancelled via abort signal + #[serde(skip_serializing_if = "Option::is_none")] + pub aborted: Option, +} + +/// Session title change payload containing the new display title +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTitleChangedData { + /// The new display title for the session + pub title: String, +} + +/// Informational message for timeline display with categorization +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionInfoData { + /// Category of informational message (e.g., "notification", "timing", "context_window", "mcp", "snapshot", "configuration", "authentication", "model") + pub info_type: String, + /// Human-readable informational message for display in the timeline + pub message: String, + /// Optional actionable tip displayed with this message + #[serde(skip_serializing_if = "Option::is_none")] + pub tip: Option, + /// Optional URL associated with this message that the user can open in a browser + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +/// Warning message for timeline display with categorization +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWarningData { + /// Human-readable warning message for display in the timeline + pub message: String, + /// Optional URL associated with this warning that the user can open in a browser + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, + /// Category of warning (e.g., "subscription", "policy", "mcp") + pub warning_type: String, +} + +/// Model change details including previous and new model identifiers +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModelChangeData { + /// Reason the change happened, when not user-initiated. Currently `"rate_limit_auto_switch"` for changes triggered by the auto-mode-switch rate-limit recovery path. UI clients can use this to render contextual copy. + #[serde(skip_serializing_if = "Option::is_none")] + pub cause: Option, + /// Newly selected model identifier + pub new_model: String, + /// Model that was previously selected, if any + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_model: Option, + /// Reasoning effort level before the model change, if applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_reasoning_effort: Option, + /// Reasoning effort level after the model change, if applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, +} + +/// Agent mode change details including previous and new modes +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionModeChangedData { + /// Agent mode after the change (e.g., "interactive", "plan", "autopilot") + pub new_mode: String, + /// Agent mode before the change (e.g., "interactive", "plan", "autopilot") + pub previous_mode: String, +} + +/// Plan file operation details indicating what changed +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionPlanChangedData { + /// The type of operation performed on the plan file + pub operation: PlanChangedOperation, +} + +/// Workspace file change details including path and operation type +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionWorkspaceFileChangedData { + /// Whether the file was newly created or updated + pub operation: WorkspaceFileChangedOperation, + /// Relative path within the session workspace files directory + pub path: String, +} + +/// Repository context for the handed-off session +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HandoffRepository { + /// Git branch name, if applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + /// Repository name + pub name: String, + /// Repository owner (user or organization) + pub owner: String, +} + +/// Session handoff metadata including source, context, and repository information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionHandoffData { + /// Additional context information for the handoff + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, + /// ISO 8601 timestamp when the handoff occurred + pub handoff_time: String, + /// GitHub host URL for the source session (e.g., https://github.com or https://tenant.ghe.com) + #[serde(skip_serializing_if = "Option::is_none")] + pub host: Option, + /// Session ID of the remote session being handed off + #[serde(skip_serializing_if = "Option::is_none")] + pub remote_session_id: Option, + /// Repository context for the handed-off session + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + /// Origin type of the session being handed off + pub source_type: HandoffSourceType, + /// Summary of the work done in the source session + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +/// Conversation truncation statistics including token counts and removed content metrics +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTruncationData { + /// Number of messages removed by truncation + pub messages_removed_during_truncation: f64, + /// Identifier of the component that performed truncation (e.g., "BasicTruncator") + pub performed_by: String, + /// Number of conversation messages after truncation + pub post_truncation_messages_length: f64, + /// Total tokens in conversation messages after truncation + pub post_truncation_tokens_in_messages: f64, + /// Number of conversation messages before truncation + pub pre_truncation_messages_length: f64, + /// Total tokens in conversation messages before truncation + pub pre_truncation_tokens_in_messages: f64, + /// Maximum token count for the model's context window + pub token_limit: f64, + /// Number of tokens removed by truncation + pub tokens_removed_during_truncation: f64, +} + +/// Session rewind details including target event and count of removed events +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSnapshotRewindData { + /// Number of events that were removed by the rewind + pub events_removed: f64, + /// Event ID that was rewound to; this event and all after it were removed + pub up_to_event_id: String, +} + +/// Aggregate code change metrics for the session +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownCodeChanges { + /// List of file paths that were modified during the session + pub files_modified: Vec, + /// Total number of lines added during the session + pub lines_added: f64, + /// Total number of lines removed during the session + pub lines_removed: f64, +} + +/// Request count and cost metrics +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownModelMetricRequests { + /// Cumulative cost multiplier for requests to this model + pub cost: f64, + /// Total number of API requests made to this model + pub count: f64, +} + +/// Token usage breakdown +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownModelMetricUsage { + /// Total tokens read from prompt cache across all requests + pub cache_read_tokens: f64, + /// Total tokens written to prompt cache across all requests + pub cache_write_tokens: f64, + /// Total input tokens consumed across all requests to this model + pub input_tokens: f64, + /// Total output tokens produced across all requests to this model + pub output_tokens: f64, + /// Total reasoning tokens produced across all requests to this model + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ShutdownModelMetric { + /// Request count and cost metrics + pub requests: ShutdownModelMetricRequests, + /// Token usage breakdown + pub usage: ShutdownModelMetricUsage, +} + +/// Session termination metrics including usage statistics, code changes, and shutdown reason +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionShutdownData { + /// Aggregate code change metrics for the session + pub code_changes: ShutdownCodeChanges, + /// Non-system message token count at shutdown + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_tokens: Option, + /// Model that was selected at the time of shutdown + #[serde(skip_serializing_if = "Option::is_none")] + pub current_model: Option, + /// Total tokens in context window at shutdown + #[serde(skip_serializing_if = "Option::is_none")] + pub current_tokens: Option, + /// Error description when shutdownType is "error" + #[serde(skip_serializing_if = "Option::is_none")] + pub error_reason: Option, + /// Per-model usage breakdown, keyed by model identifier + pub model_metrics: HashMap, + /// Unix timestamp (milliseconds) when the session started + pub session_start_time: f64, + /// Whether the session ended normally ("routine") or due to a crash/fatal error ("error") + pub shutdown_type: ShutdownType, + /// System message token count at shutdown + #[serde(skip_serializing_if = "Option::is_none")] + pub system_tokens: Option, + /// Tool definitions token count at shutdown + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_definitions_tokens: Option, + /// Cumulative time spent in API calls during the session, in milliseconds + pub total_api_duration_ms: f64, + /// Total number of premium API requests used during the session + pub total_premium_requests: f64, +} + +/// Working directory and git context at session start +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionContextChangedData { + /// Base commit of current git branch at session start time + #[serde(skip_serializing_if = "Option::is_none")] + pub base_commit: Option, + /// Current git branch name + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + /// Current working directory path + pub cwd: String, + /// Root directory of the git repository, resolved via git rev-parse + #[serde(skip_serializing_if = "Option::is_none")] + pub git_root: Option, + /// Head commit of current git branch at session start time + #[serde(skip_serializing_if = "Option::is_none")] + pub head_commit: Option, + /// Hosting platform type of the repository (github or ado) + #[serde(skip_serializing_if = "Option::is_none")] + pub host_type: Option, + /// Repository identifier derived from the git remote URL ("owner/name" for GitHub, "org/project/repo" for Azure DevOps) + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + /// Raw host string from the git remote URL (e.g. "github.com", "mycompany.ghe.com", "dev.azure.com") + #[serde(skip_serializing_if = "Option::is_none")] + pub repository_host: Option, +} + +/// Current context window usage statistics including token and message counts +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionUsageInfoData { + /// Token count from non-system messages (user, assistant, tool) + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_tokens: Option, + /// Current number of tokens in the context window + pub current_tokens: f64, + /// Whether this is the first usage_info event emitted in this session + #[serde(skip_serializing_if = "Option::is_none")] + pub is_initial: Option, + /// Current number of messages in the conversation + pub messages_length: f64, + /// Token count from system message(s) + #[serde(skip_serializing_if = "Option::is_none")] + pub system_tokens: Option, + /// Maximum token count for the model's context window + pub token_limit: f64, + /// Token count from tool definitions + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_definitions_tokens: Option, +} + +/// Context window breakdown at the start of LLM-powered conversation compaction +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCompactionStartData { + /// Token count from non-system messages (user, assistant, tool) at compaction start + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_tokens: Option, + /// Token count from system message(s) at compaction start + #[serde(skip_serializing_if = "Option::is_none")] + pub system_tokens: Option, + /// Token count from tool definitions at compaction start + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_definitions_tokens: Option, +} + +/// Token usage detail for a single billing category +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CompactionCompleteCompactionTokensUsedCopilotUsageTokenDetail { + /// Number of tokens in this billing batch + pub batch_size: f64, + /// Cost per batch of tokens + pub cost_per_batch: f64, + /// Total token count for this entry + pub token_count: f64, + /// Token category (e.g., "input", "output") + pub token_type: String, +} + +/// Per-request cost and usage data from the CAPI copilot_usage response field +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CompactionCompleteCompactionTokensUsedCopilotUsage { + /// Itemized token usage breakdown + pub token_details: Vec, + /// Total cost in nano-AIU (AI Units) for this request + pub total_nano_aiu: f64, +} + +/// Token usage breakdown for the compaction LLM call (aligned with assistant.usage format) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CompactionCompleteCompactionTokensUsed { + /// Cached input tokens reused in the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_read_tokens: Option, + /// Tokens written to prompt cache in the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_write_tokens: Option, + /// Per-request cost and usage data from the CAPI copilot_usage response field + #[serde(skip_serializing_if = "Option::is_none")] + pub copilot_usage: Option, + /// Duration of the compaction LLM call in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration: Option, + /// Input tokens consumed by the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + /// Model identifier used for the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Output tokens produced by the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, +} + +/// Conversation compaction results including success status, metrics, and optional error details +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCompactionCompleteData { + /// Checkpoint snapshot number created for recovery + #[serde(skip_serializing_if = "Option::is_none")] + pub checkpoint_number: Option, + /// File path where the checkpoint was stored + #[serde(skip_serializing_if = "Option::is_none")] + pub checkpoint_path: Option, + /// Token usage breakdown for the compaction LLM call (aligned with assistant.usage format) + #[serde(skip_serializing_if = "Option::is_none")] + pub compaction_tokens_used: Option, + /// Token count from non-system messages (user, assistant, tool) after compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub conversation_tokens: Option, + /// Error message if compaction failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Number of messages removed during compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub messages_removed: Option, + /// Total tokens in conversation after compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub post_compaction_tokens: Option, + /// Number of messages before compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub pre_compaction_messages_length: Option, + /// Total tokens in conversation before compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub pre_compaction_tokens: Option, + /// GitHub request tracing ID (x-github-request-id header) for the compaction LLM call + #[serde(skip_serializing_if = "Option::is_none")] + pub request_id: Option, + /// Whether compaction completed successfully + pub success: bool, + /// LLM-generated summary of the compacted conversation history + #[serde(skip_serializing_if = "Option::is_none")] + pub summary_content: Option, + /// Token count from system message(s) after compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub system_tokens: Option, + /// Number of tokens removed during compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub tokens_removed: Option, + /// Token count from tool definitions after compaction + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_definitions_tokens: Option, +} + +/// Task completion notification with summary from the agent +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTaskCompleteData { + /// Whether the tool call succeeded. False when validation failed (e.g., invalid arguments) + #[serde(skip_serializing_if = "Option::is_none")] + pub success: Option, + /// Summary of the completed task, provided by the agent + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserMessageData { + /// The agent mode that was active when this message was sent + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_mode: Option, + /// Files, selections, or GitHub references attached to the message + #[serde(default)] + pub attachments: Vec, + /// The user's message text as displayed in the timeline + pub content: String, + /// CAPI interaction ID for correlating this user message with its turn + #[serde(skip_serializing_if = "Option::is_none")] + pub interaction_id: Option, + /// Path-backed native document attachments that stayed on the tagged_files path flow because native upload would exceed the request size limit + #[serde(default)] + pub native_document_path_fallback_paths: Vec, + /// Origin of this message, used for timeline filtering (e.g., "skill-pdf" for skill-injected messages that should be hidden from the user) + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + /// Normalized document MIME types that were sent natively instead of through tagged_files XML + #[serde(default)] + pub supported_native_document_mime_types: Vec, + /// Transformed version of the message sent to the model, with XML wrapping, timestamps, and other augmentations for prompt caching + #[serde(skip_serializing_if = "Option::is_none")] + pub transformed_content: Option, +} + +/// Empty payload; the event signals that the pending message queue has changed +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PendingMessagesModifiedData {} + +/// Turn initialization metadata including identifier and interaction tracking +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantTurnStartData { + /// CAPI interaction ID for correlating this turn with upstream telemetry + #[serde(skip_serializing_if = "Option::is_none")] + pub interaction_id: Option, + /// Identifier for this turn within the agentic loop, typically a stringified turn number + pub turn_id: String, +} + +/// Agent intent description for current activity or plan +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantIntentData { + /// Short description of what the agent is currently doing or planning to do + pub intent: String, +} + +/// Assistant reasoning content for timeline display with complete thinking text +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantReasoningData { + /// The complete extended thinking text from the model + pub content: String, + /// Unique identifier for this reasoning block + pub reasoning_id: String, +} + +/// Streaming reasoning delta for incremental extended thinking updates +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantReasoningDeltaData { + /// Incremental text chunk to append to the reasoning content + pub delta_content: String, + /// Reasoning block ID this delta belongs to, matching the corresponding assistant.reasoning event + pub reasoning_id: String, +} + +/// Streaming response progress with cumulative byte count +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantStreamingDeltaData { + /// Cumulative total bytes received from the streaming response so far + pub total_response_size_bytes: f64, +} + +/// A tool invocation request from the assistant +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantMessageToolRequest { + /// Arguments to pass to the tool, format depends on the tool + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + /// Resolved intention summary describing what this specific call does + #[serde(skip_serializing_if = "Option::is_none")] + pub intention_summary: Option, + /// Name of the MCP server hosting this tool, when the tool is an MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_server_name: Option, + /// Name of the tool being invoked + pub name: String, + /// Unique identifier for this tool call + pub tool_call_id: String, + /// Human-readable display title for the tool + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_title: Option, + /// Tool call type: "function" for standard tool calls, "custom" for grammar-based tool calls. Defaults to "function" when absent. + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, +} + +/// Assistant response containing text content, optional tool requests, and interaction metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantMessageData { + /// The assistant's text response content + pub content: String, + /// Encrypted reasoning content from OpenAI models. Session-bound and stripped on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub encrypted_content: Option, + /// CAPI interaction ID for correlating this message with upstream telemetry + #[serde(skip_serializing_if = "Option::is_none")] + pub interaction_id: Option, + /// Unique identifier for this assistant message + pub message_id: String, + /// Actual output token count from the API response (completion_tokens), used for accurate token accounting + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, + /// Tool call ID of the parent tool invocation when this event originates from a sub-agent + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, + /// Generation phase for phased-output models (e.g., thinking vs. response phases) + #[serde(skip_serializing_if = "Option::is_none")] + pub phase: Option, + /// Opaque/encrypted extended thinking data from Anthropic models. Session-bound and stripped on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_opaque: Option, + /// Readable reasoning text from the model's extended thinking + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_text: Option, + /// GitHub request tracing ID (x-github-request-id header) for correlating with server-side logs + #[serde(skip_serializing_if = "Option::is_none")] + pub request_id: Option, + /// Tool invocations requested by the assistant in this message + #[serde(default)] + pub tool_requests: Vec, +} + +/// Streaming assistant message delta for incremental response updates +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantMessageDeltaData { + /// Incremental text chunk to append to the message content + pub delta_content: String, + /// Message ID this delta belongs to, matching the corresponding assistant.message event + pub message_id: String, + /// Tool call ID of the parent tool invocation when this event originates from a sub-agent + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, +} + +/// Turn completion metadata including the turn identifier +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantTurnEndData { + /// Identifier of the turn that has ended, matching the corresponding assistant.turn_start event + pub turn_id: String, +} + +/// Token usage detail for a single billing category +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantUsageCopilotUsageTokenDetail { + /// Number of tokens in this billing batch + pub batch_size: f64, + /// Cost per batch of tokens + pub cost_per_batch: f64, + /// Total token count for this entry + pub token_count: f64, + /// Token category (e.g., "input", "output") + pub token_type: String, +} + +/// Per-request cost and usage data from the CAPI copilot_usage response field +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantUsageCopilotUsage { + /// Itemized token usage breakdown + pub token_details: Vec, + /// Total cost in nano-AIU (AI Units) for this request + pub total_nano_aiu: f64, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantUsageQuotaSnapshot { + /// Total requests allowed by the entitlement + pub entitlement_requests: f64, + /// Whether the user has an unlimited usage entitlement + pub is_unlimited_entitlement: bool, + /// Number of requests over the entitlement limit + pub overage: f64, + /// Whether overage is allowed when quota is exhausted + pub overage_allowed_with_exhausted_quota: bool, + /// Percentage of quota remaining (0.0 to 1.0) + pub remaining_percentage: f64, + /// Date when the quota resets + #[serde(skip_serializing_if = "Option::is_none")] + pub reset_date: Option, + /// Whether usage is still permitted after quota exhaustion + pub usage_allowed_with_exhausted_quota: bool, + /// Number of requests already consumed + pub used_requests: f64, +} + +/// LLM API call usage metrics including tokens, costs, quotas, and billing information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AssistantUsageData { + /// Completion ID from the model provider (e.g., chatcmpl-abc123) + #[serde(skip_serializing_if = "Option::is_none")] + pub api_call_id: Option, + /// Number of tokens read from prompt cache + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_read_tokens: Option, + /// Number of tokens written to prompt cache + #[serde(skip_serializing_if = "Option::is_none")] + pub cache_write_tokens: Option, + /// Per-request cost and usage data from the CAPI copilot_usage response field + #[serde(skip_serializing_if = "Option::is_none")] + pub copilot_usage: Option, + /// Model multiplier cost for billing purposes + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, + /// Duration of the API call in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration: Option, + /// What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls + #[serde(skip_serializing_if = "Option::is_none")] + pub initiator: Option, + /// Number of input tokens consumed + #[serde(skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + /// Average inter-token latency in milliseconds. Only available for streaming requests + #[serde(skip_serializing_if = "Option::is_none")] + pub inter_token_latency_ms: Option, + /// Model identifier used for this API call + pub model: String, + /// Number of output tokens produced + #[serde(skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, + /// Parent tool call ID when this usage originates from a sub-agent + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, + /// GitHub request tracing ID (x-github-request-id header) for server-side log correlation + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_call_id: Option, + /// Per-quota resource usage snapshots, keyed by quota identifier + #[serde(default)] + pub quota_snapshots: HashMap, + /// Reasoning effort level used for model calls, if applicable (e.g. "low", "medium", "high", "xhigh") + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Number of output tokens used for reasoning (e.g., chain-of-thought) + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_tokens: Option, + /// Time to first token in milliseconds. Only available for streaming requests + #[serde(skip_serializing_if = "Option::is_none")] + pub ttft_ms: Option, +} + +/// Failed LLM API call metadata for telemetry +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ModelCallFailureData { + /// Completion ID from the model provider (e.g., chatcmpl-abc123) + #[serde(skip_serializing_if = "Option::is_none")] + pub api_call_id: Option, + /// Duration of the failed API call in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_ms: Option, + /// Raw provider/runtime error message for restricted telemetry + #[serde(skip_serializing_if = "Option::is_none")] + pub error_message: Option, + /// What initiated this API call (e.g., "sub-agent", "mcp-sampling"); absent for user-initiated calls + #[serde(skip_serializing_if = "Option::is_none")] + pub initiator: Option, + /// Model identifier used for the failed API call + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// GitHub request tracing ID (x-github-request-id header) for server-side log correlation + #[serde(skip_serializing_if = "Option::is_none")] + pub provider_call_id: Option, + /// Where the failed model call originated + pub source: ModelCallFailureSource, + /// HTTP status code from the failed request + #[serde(skip_serializing_if = "Option::is_none")] + pub status_code: Option, +} + +/// Turn abort information including the reason for termination +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AbortData { + /// Reason the current turn was aborted (e.g., "user initiated") + pub reason: String, +} + +/// User-initiated tool invocation request with tool name and arguments +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolUserRequestedData { + /// Arguments for the tool invocation + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + /// Unique identifier for this tool call + pub tool_call_id: String, + /// Name of the tool the user wants to invoke + pub tool_name: String, +} + +/// Tool execution startup details including MCP server information when applicable +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionStartData { + /// Arguments passed to the tool + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + /// Name of the MCP server hosting this tool, when the tool is an MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_server_name: Option, + /// Original tool name on the MCP server, when the tool is an MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_tool_name: Option, + /// Tool call ID of the parent tool invocation when this event originates from a sub-agent + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, + /// Unique identifier for this tool call + pub tool_call_id: String, + /// Name of the tool being executed + pub tool_name: String, +} + +/// Streaming tool execution output for incremental result display +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionPartialResultData { + /// Incremental output chunk from the running tool + pub partial_output: String, + /// Tool call ID this partial result belongs to + pub tool_call_id: String, +} + +/// Tool execution progress notification with status message +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionProgressData { + /// Human-readable progress status message (e.g., from an MCP server) + pub progress_message: String, + /// Tool call ID this progress notification belongs to + pub tool_call_id: String, +} + +/// Error details when the tool execution failed +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionCompleteError { + /// Machine-readable error code + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + /// Human-readable error message + pub message: String, +} + +/// Tool execution result on success +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionCompleteResult { + /// Concise tool result text sent to the LLM for chat completion, potentially truncated for token efficiency + pub content: String, + /// Structured content blocks (text, images, audio, resources) returned by the tool in their native format + #[serde(default)] + pub contents: Vec, + /// Full detailed tool result for UI/timeline display, preserving complete content such as diffs. Falls back to content when absent. + #[serde(skip_serializing_if = "Option::is_none")] + pub detailed_content: Option, +} + +/// Tool execution completion results including success status, detailed output, and error information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolExecutionCompleteData { + /// Error details when the tool execution failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// CAPI interaction ID for correlating this tool execution with upstream telemetry + #[serde(skip_serializing_if = "Option::is_none")] + pub interaction_id: Option, + /// Whether this tool call was explicitly requested by the user rather than the assistant + #[serde(skip_serializing_if = "Option::is_none")] + pub is_user_requested: Option, + /// Model identifier that generated this tool call + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Tool call ID of the parent tool invocation when this event originates from a sub-agent + #[deprecated] + #[serde(skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, + /// Tool execution result on success + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + /// Whether the tool execution completed successfully + pub success: bool, + /// Unique identifier for the completed tool call + pub tool_call_id: String, + /// Tool-specific telemetry data (e.g., CodeQL check counts, grep match counts) + #[serde(default)] + pub tool_telemetry: HashMap, +} + +/// Skill invocation details including content, allowed tools, and plugin metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillInvokedData { + /// Tool names that should be auto-approved when this skill is active + #[serde(default)] + pub allowed_tools: Vec, + /// Full content of the skill file, injected into the conversation for the model + pub content: String, + /// Description of the skill from its SKILL.md frontmatter + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Name of the invoked skill + pub name: String, + /// File path to the SKILL.md definition + pub path: String, + /// Name of the plugin this skill originated from, when applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub plugin_name: Option, + /// Version of the plugin this skill originated from, when applicable + #[serde(skip_serializing_if = "Option::is_none")] + pub plugin_version: Option, +} + +/// Sub-agent startup details including parent tool call and agent information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubagentStartedData { + /// Description of what the sub-agent does + pub agent_description: String, + /// Human-readable display name of the sub-agent + pub agent_display_name: String, + /// Internal name of the sub-agent + pub agent_name: String, + /// Tool call ID of the parent tool invocation that spawned this sub-agent + pub tool_call_id: String, +} + +/// Sub-agent completion details for successful execution +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubagentCompletedData { + /// Human-readable display name of the sub-agent + pub agent_display_name: String, + /// Internal name of the sub-agent + pub agent_name: String, + /// Wall-clock duration of the sub-agent execution in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_ms: Option, + /// Model used by the sub-agent + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Tool call ID of the parent tool invocation that spawned this sub-agent + pub tool_call_id: String, + /// Total tokens (input + output) consumed by the sub-agent + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tokens: Option, + /// Total number of tool calls made by the sub-agent + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tool_calls: Option, +} + +/// Sub-agent failure details including error message and agent information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubagentFailedData { + /// Human-readable display name of the sub-agent + pub agent_display_name: String, + /// Internal name of the sub-agent + pub agent_name: String, + /// Wall-clock duration of the sub-agent execution in milliseconds + #[serde(skip_serializing_if = "Option::is_none")] + pub duration_ms: Option, + /// Error message describing why the sub-agent failed + pub error: String, + /// Model used by the sub-agent (if any model calls succeeded before failure) + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Tool call ID of the parent tool invocation that spawned this sub-agent + pub tool_call_id: String, + /// Total tokens (input + output) consumed before the sub-agent failed + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tokens: Option, + /// Total number of tool calls made before the sub-agent failed + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tool_calls: Option, +} + +/// Custom agent selection details including name and available tools +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubagentSelectedData { + /// Human-readable display name of the selected custom agent + pub agent_display_name: String, + /// Internal name of the selected custom agent + pub agent_name: String, + /// List of tool names available to this agent, or null for all tools + pub tools: Vec, +} + +/// Empty payload; the event signals that the custom agent was deselected, returning to the default agent +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SubagentDeselectedData {} + +/// Hook invocation start details including type and input data +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HookStartData { + /// Unique identifier for this hook invocation + pub hook_invocation_id: String, + /// Type of hook being invoked (e.g., "preToolUse", "postToolUse", "sessionStart") + pub hook_type: String, + /// Input data passed to the hook + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, +} + +/// Error details when the hook failed +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HookEndError { + /// Human-readable error message + pub message: String, + /// Error stack trace, when available + #[serde(skip_serializing_if = "Option::is_none")] + pub stack: Option, +} + +/// Hook invocation completion details including output, success status, and error information +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct HookEndData { + /// Error details when the hook failed + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Identifier matching the corresponding hook.start event + pub hook_invocation_id: String, + /// Type of hook that was invoked (e.g., "preToolUse", "postToolUse", "sessionStart") + pub hook_type: String, + /// Output data produced by the hook + #[serde(skip_serializing_if = "Option::is_none")] + pub output: Option, + /// Whether the hook completed successfully + pub success: bool, +} + +/// Metadata about the prompt template and its construction +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SystemMessageMetadata { + /// Version identifier of the prompt template used + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_version: Option, + /// Template variables used when constructing the prompt + #[serde(default)] + pub variables: HashMap, +} + +/// System/developer instruction content with role and optional template metadata +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SystemMessageData { + /// The system or developer prompt text sent as model input + pub content: String, + /// Metadata about the prompt template and its construction + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + /// Optional name identifier for the message source + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + /// Message role: "system" for system prompts, "developer" for developer-injected instructions + pub role: SystemMessageRole, +} + +/// System-generated notification for runtime events like background task completion +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SystemNotificationData { + /// The notification text, typically wrapped in XML tags + pub content: String, + /// Structured metadata identifying what triggered this notification + pub kind: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestShellCommand { + /// Command identifier (e.g., executable name) + pub identifier: String, + /// Whether this command is read-only (no side effects) + pub read_only: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestShellPossibleUrl { + /// URL that may be accessed by the command + pub url: String, +} + +/// Shell command permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestShell { + /// Whether the UI can offer session-wide approval for this command pattern + pub can_offer_session_approval: bool, + /// Parsed command identifiers found in the command text + pub commands: Vec, + /// The complete shell command text to be executed + pub full_command_text: String, + /// Whether the command includes a file write redirection (e.g., > or >>) + pub has_write_file_redirection: bool, + /// Human-readable description of what the command intends to do + pub intention: String, + /// Permission kind discriminator + pub kind: PermissionRequestShellKind, + /// File paths that may be read or written by the command + pub possible_paths: Vec, + /// URLs that may be accessed by the command + pub possible_urls: Vec, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Optional warning message about risks of running this command + #[serde(skip_serializing_if = "Option::is_none")] + pub warning: Option, +} + +/// File write permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestWrite { + /// Whether the UI can offer session-wide approval for file write operations + pub can_offer_session_approval: bool, + /// Unified diff showing the proposed changes + pub diff: String, + /// Path of the file being written to + pub file_name: String, + /// Human-readable description of the intended file change + pub intention: String, + /// Permission kind discriminator + pub kind: PermissionRequestWriteKind, + /// Complete new file contents for newly created files + #[serde(skip_serializing_if = "Option::is_none")] + pub new_file_contents: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// File or directory read permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestRead { + /// Human-readable description of why the file is being read + pub intention: String, + /// Permission kind discriminator + pub kind: PermissionRequestReadKind, + /// Path of the file or directory being read + pub path: String, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// MCP tool invocation permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestMcp { + /// Arguments to pass to the MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option, + /// Permission kind discriminator + pub kind: PermissionRequestMcpKind, + /// Whether this MCP tool is read-only (no side effects) + pub read_only: bool, + /// Name of the MCP server providing the tool + pub server_name: String, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Internal name of the MCP tool + pub tool_name: String, + /// Human-readable title of the MCP tool + pub tool_title: String, +} + +/// URL access permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestUrl { + /// Human-readable description of why the URL is being accessed + pub intention: String, + /// Permission kind discriminator + pub kind: PermissionRequestUrlKind, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// URL to be fetched + pub url: String, +} + +/// Memory operation permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestMemory { + /// Whether this is a store or vote memory operation + #[serde(skip_serializing_if = "Option::is_none")] + pub action: Option, + /// Source references for the stored fact (store only) + #[serde(skip_serializing_if = "Option::is_none")] + pub citations: Option, + /// Vote direction (vote only) + #[serde(skip_serializing_if = "Option::is_none")] + pub direction: Option, + /// The fact being stored or voted on + pub fact: String, + /// Permission kind discriminator + pub kind: PermissionRequestMemoryKind, + /// Reason for the vote (vote only) + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, + /// Topic or subject of the memory (store only) + #[serde(skip_serializing_if = "Option::is_none")] + pub subject: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// Custom tool invocation permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestCustomTool { + /// Arguments to pass to the custom tool + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option, + /// Permission kind discriminator + pub kind: PermissionRequestCustomToolKind, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Description of what the custom tool does + pub tool_description: String, + /// Name of the custom tool + pub tool_name: String, +} + +/// Hook confirmation permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestHook { + /// Optional message from the hook explaining why confirmation is needed + #[serde(skip_serializing_if = "Option::is_none")] + pub hook_message: Option, + /// Permission kind discriminator + pub kind: PermissionRequestHookKind, + /// Arguments of the tool call being gated + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_args: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Name of the tool the hook is gating + pub tool_name: String, +} + +/// Shell command permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestCommands { + /// Whether the UI can offer session-wide approval for this command pattern + pub can_offer_session_approval: bool, + /// Command identifiers covered by this approval prompt + pub command_identifiers: Vec, + /// The complete shell command text to be executed + pub full_command_text: String, + /// Human-readable description of what the command intends to do + pub intention: String, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestCommandsKind, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Optional warning message about risks of running this command + #[serde(skip_serializing_if = "Option::is_none")] + pub warning: Option, +} + +/// File write permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestWrite { + /// Whether the UI can offer session-wide approval for file write operations + pub can_offer_session_approval: bool, + /// Unified diff showing the proposed changes + pub diff: String, + /// Path of the file being written to + pub file_name: String, + /// Human-readable description of the intended file change + pub intention: String, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestWriteKind, + /// Complete new file contents for newly created files + #[serde(skip_serializing_if = "Option::is_none")] + pub new_file_contents: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// File read permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestRead { + /// Human-readable description of why the file is being read + pub intention: String, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestReadKind, + /// Path of the file or directory being read + pub path: String, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// MCP tool invocation permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestMcp { + /// Arguments to pass to the MCP tool + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestMcpKind, + /// Name of the MCP server providing the tool + pub server_name: String, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Internal name of the MCP tool + pub tool_name: String, + /// Human-readable title of the MCP tool + pub tool_title: String, +} + +/// URL access permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestUrl { + /// Human-readable description of why the URL is being accessed + pub intention: String, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestUrlKind, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// URL to be fetched + pub url: String, +} + +/// Memory operation permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestMemory { + /// Whether this is a store or vote memory operation + #[serde(skip_serializing_if = "Option::is_none")] + pub action: Option, + /// Source references for the stored fact (store only) + #[serde(skip_serializing_if = "Option::is_none")] + pub citations: Option, + /// Vote direction (vote only) + #[serde(skip_serializing_if = "Option::is_none")] + pub direction: Option, + /// The fact being stored or voted on + pub fact: String, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestMemoryKind, + /// Reason for the vote (vote only) + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, + /// Topic or subject of the memory (store only) + #[serde(skip_serializing_if = "Option::is_none")] + pub subject: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// Custom tool invocation permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestCustomTool { + /// Arguments to pass to the custom tool + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestCustomToolKind, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Description of what the custom tool does + pub tool_description: String, + /// Name of the custom tool + pub tool_name: String, +} + +/// Path access permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestPath { + /// Underlying permission kind that needs path approval + pub access_kind: PermissionPromptRequestPathAccessKind, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestPathKind, + /// File paths that require explicit approval + pub paths: Vec, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// Hook confirmation permission prompt +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionPromptRequestHook { + /// Optional message from the hook explaining why confirmation is needed + #[serde(skip_serializing_if = "Option::is_none")] + pub hook_message: Option, + /// Prompt kind discriminator + pub kind: PermissionPromptRequestHookKind, + /// Arguments of the tool call being gated + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_args: Option, + /// Tool call ID that triggered this permission request + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Name of the tool the hook is gating + pub tool_name: String, +} + +/// Permission request notification requiring client approval with request details +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestedData { + /// Details of the permission being requested + pub permission_request: PermissionRequest, + /// Derived user-facing permission prompt details for UI consumers + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_request: Option, + /// Unique identifier for this permission request; used to respond via session.respondToPermission() + pub request_id: RequestId, + /// When true, this permission was already resolved by a permissionRequest hook and requires no client action + #[serde(skip_serializing_if = "Option::is_none")] + pub resolved_by_hook: Option, +} + +/// The result of the permission request +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionCompletedResult { + /// The outcome of the permission request + pub kind: PermissionCompletedKind, +} + +/// Permission request completion notification signaling UI dismissal +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionCompletedData { + /// Request ID of the resolved permission request; clients should dismiss any UI for this request + pub request_id: RequestId, + /// The result of the permission request + pub result: PermissionCompletedResult, + /// Optional tool call ID associated with this permission prompt; clients may use it to correlate UI created from tool-scoped prompts + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// User input request notification with question and optional predefined choices +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserInputRequestedData { + /// Whether the user can provide a free-form text response in addition to predefined choices + #[serde(skip_serializing_if = "Option::is_none")] + pub allow_freeform: Option, + /// Predefined choices for the user to select from, if applicable + #[serde(default)] + pub choices: Vec, + /// The question or prompt to present to the user + pub question: String, + /// Unique identifier for this input request; used to respond via session.respondToUserInput() + pub request_id: RequestId, + /// The LLM-assigned tool call ID that triggered this request; used by remote UIs to correlate responses + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +/// User input request completion with the user's response +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserInputCompletedData { + /// The user's answer to the input request + #[serde(skip_serializing_if = "Option::is_none")] + pub answer: Option, + /// Request ID of the resolved user input request; clients should dismiss any UI for this request + pub request_id: RequestId, + /// Whether the answer was typed as free-form text rather than selected from choices + #[serde(skip_serializing_if = "Option::is_none")] + pub was_freeform: Option, +} + +/// JSON Schema describing the form fields to present to the user (form mode only) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationRequestedSchema { + /// Form field definitions, keyed by field name + pub properties: HashMap, + /// List of required field names + #[serde(default)] + pub required: Vec, + /// Schema type indicator (always 'object') + pub r#type: ElicitationRequestedSchemaType, +} + +/// Elicitation request; may be form-based (structured input) or URL-based (browser redirect) +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationRequestedData { + /// The source that initiated the request (MCP server name, or absent for agent-initiated) + #[serde(skip_serializing_if = "Option::is_none")] + pub elicitation_source: Option, + /// Message describing what information is needed from the user + pub message: String, + /// Elicitation mode; "form" for structured input, "url" for browser-based. Defaults to "form" when absent. + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// JSON Schema describing the form fields to present to the user (form mode only) + #[serde(skip_serializing_if = "Option::is_none")] + pub requested_schema: Option, + /// Unique identifier for this elicitation request; used to respond via session.respondToElicitation() + pub request_id: RequestId, + /// Tool call ID from the LLM completion; used to correlate with CompletionChunk.toolCall.id for remote UIs + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// URL to open in the user's browser (url mode only) + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +/// Elicitation request completion with the user's response +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationCompletedData { + /// The user action: "accept" (submitted form), "decline" (explicitly refused), or "cancel" (dismissed) + #[serde(skip_serializing_if = "Option::is_none")] + pub action: Option, + /// The submitted form data when action is 'accept'; keys match the requested schema fields + #[serde(default)] + pub content: HashMap, + /// Request ID of the resolved elicitation request; clients should dismiss any UI for this request + pub request_id: RequestId, +} + +/// Sampling request from an MCP server; contains the server name and a requestId for correlation +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SamplingRequestedData { + /// The JSON-RPC request ID from the MCP protocol + pub mcp_request_id: serde_json::Value, + /// Unique identifier for this sampling request; used to respond via session.respondToSampling() + pub request_id: RequestId, + /// Name of the MCP server that initiated the sampling request + pub server_name: String, +} + +/// Sampling request completion notification signaling UI dismissal +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SamplingCompletedData { + /// Request ID of the resolved sampling request; clients should dismiss any UI for this request + pub request_id: RequestId, +} + +/// Static OAuth client configuration, if the server specifies one +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthRequiredStaticClientConfig { + /// OAuth client ID for the server + pub client_id: String, + /// Whether this is a public OAuth client + #[serde(skip_serializing_if = "Option::is_none")] + pub public_client: Option, +} + +/// OAuth authentication request for an MCP server +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthRequiredData { + /// Unique identifier for this OAuth request; used to respond via session.respondToMcpOAuth() + pub request_id: RequestId, + /// Display name of the MCP server that requires OAuth + pub server_name: String, + /// URL of the MCP server that requires OAuth + pub server_url: String, + /// Static OAuth client configuration, if the server specifies one + #[serde(skip_serializing_if = "Option::is_none")] + pub static_client_config: Option, +} + +/// MCP OAuth request completion notification +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpOauthCompletedData { + /// Request ID of the resolved OAuth request + pub request_id: RequestId, +} + +/// External tool invocation request for client-side tool execution +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolRequestedData { + /// Arguments to pass to the external tool + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + /// Unique identifier for this request; used to respond via session.respondToExternalTool() + pub request_id: RequestId, + /// Session ID that this external tool request belongs to + pub session_id: SessionId, + /// Tool call ID assigned to this external tool invocation + pub tool_call_id: String, + /// Name of the external tool to invoke + pub tool_name: String, + /// W3C Trace Context traceparent header for the execute_tool span + #[serde(skip_serializing_if = "Option::is_none")] + pub traceparent: Option, + /// W3C Trace Context tracestate header for the execute_tool span + #[serde(skip_serializing_if = "Option::is_none")] + pub tracestate: Option, +} + +/// External tool completion notification signaling UI dismissal +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExternalToolCompletedData { + /// Request ID of the resolved external tool request; clients should dismiss any UI for this request + pub request_id: RequestId, +} + +/// Queued slash command dispatch request for client execution +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandQueuedData { + /// The slash command text to be executed (e.g., /help, /clear) + pub command: String, + /// Unique identifier for this request; used to respond via session.respondToQueuedCommand() + pub request_id: RequestId, +} + +/// Registered command dispatch request routed to the owning client +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandExecuteData { + /// Raw argument string after the command name + pub args: String, + /// The full command text (e.g., /deploy production) + pub command: String, + /// Command name without leading / + pub command_name: String, + /// Unique identifier; used to respond via session.commands.handlePendingCommand() + pub request_id: RequestId, +} + +/// Queued command completion notification signaling UI dismissal +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandCompletedData { + /// Request ID of the resolved command request; clients should dismiss any UI for this request + pub request_id: RequestId, +} + +/// Auto mode switch request notification requiring user approval +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AutoModeSwitchRequestedData { + /// The rate limit error code that triggered this request + #[serde(skip_serializing_if = "Option::is_none")] + pub error_code: Option, + /// Unique identifier for this request; used to respond via session.respondToAutoModeSwitch() + pub request_id: RequestId, + /// Seconds until the rate limit resets, when known. Lets clients render a humanized reset time alongside the prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub retry_after_seconds: Option, +} + +/// Auto mode switch completion notification +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AutoModeSwitchCompletedData { + /// Request ID of the resolved request; clients should dismiss any UI for this request + pub request_id: RequestId, + /// The user's choice: 'yes', 'yes_always', or 'no' + pub response: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandsChangedCommand { + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub name: String, +} + +/// SDK command registration change notification +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CommandsChangedData { + /// Current list of registered SDK commands + pub commands: Vec, +} + +/// UI capability changes +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CapabilitiesChangedUI { + /// Whether elicitation is now supported + #[serde(skip_serializing_if = "Option::is_none")] + pub elicitation: Option, +} + +/// Session capability change notification +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CapabilitiesChangedData { + /// UI capability changes + #[serde(skip_serializing_if = "Option::is_none")] + pub ui: Option, +} + +/// Plan approval request with plan content and available user actions +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExitPlanModeRequestedData { + /// Available actions the user can take (e.g., approve, edit, reject) + pub actions: Vec, + /// Full content of the plan file + pub plan_content: String, + /// The recommended action for the user to take + pub recommended_action: String, + /// Unique identifier for this request; used to respond via session.respondToExitPlanMode() + pub request_id: RequestId, + /// Summary of the plan that was created + pub summary: String, +} + +/// Plan mode exit completion with the user's approval decision and optional feedback +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExitPlanModeCompletedData { + /// Whether the plan was approved by the user + #[serde(skip_serializing_if = "Option::is_none")] + pub approved: Option, + /// Whether edits should be auto-approved without confirmation + #[serde(skip_serializing_if = "Option::is_none")] + pub auto_approve_edits: Option, + /// Free-form feedback from the user if they requested changes to the plan + #[serde(skip_serializing_if = "Option::is_none")] + pub feedback: Option, + /// Request ID of the resolved exit plan mode request; clients should dismiss any UI for this request + pub request_id: RequestId, + /// Which action the user selected (e.g. 'autopilot', 'interactive', 'exit_only') + #[serde(skip_serializing_if = "Option::is_none")] + pub selected_action: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionToolsUpdatedData { + pub model: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionBackgroundTasksChangedData {} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SkillsLoadedSkill { + /// Description of what the skill does + pub description: String, + /// Whether the skill is currently enabled + pub enabled: bool, + /// Unique identifier for the skill + pub name: String, + /// Absolute path to the skill file, if available + #[serde(skip_serializing_if = "Option::is_none")] + pub path: Option, + /// Source location type of the skill (e.g., project, personal, plugin) + pub source: String, + /// Whether the skill can be invoked by the user as a slash command + pub user_invocable: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionSkillsLoadedData { + /// Array of resolved skill metadata + pub skills: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CustomAgentsUpdatedAgent { + /// Description of what the agent does + pub description: String, + /// Human-readable display name + pub display_name: String, + /// Unique identifier for the agent + pub id: String, + /// Model override for this agent, if set + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Internal name of the agent + pub name: String, + /// Source location: user, project, inherited, remote, or plugin + pub source: String, + /// List of tool names available to this agent + pub tools: Vec, + /// Whether the agent can be selected by the user + pub user_invocable: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCustomAgentsUpdatedData { + /// Array of loaded custom agent metadata + pub agents: Vec, + /// Fatal errors from agent loading + pub errors: Vec, + /// Non-fatal warnings from agent loading + pub warnings: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpServersLoadedServer { + /// Error message if the server failed to connect + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + /// Server name (config key) + pub name: String, + /// Configuration source: user, workspace, plugin, or builtin + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + /// Connection status: connected, failed, needs-auth, pending, disabled, or not_configured + pub status: McpServersLoadedServerStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpServersLoadedData { + /// Array of MCP server status summaries + pub servers: Vec, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMcpServerStatusChangedData { + /// Name of the MCP server whose status changed + pub server_name: String, + /// New connection status: connected, failed, needs-auth, pending, disabled, or not_configured + pub status: McpServerStatusChangedStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExtensionsLoadedExtension { + /// Source-qualified extension ID (e.g., 'project:my-ext', 'user:auth-helper') + pub id: String, + /// Extension name (directory name) + pub name: String, + /// Discovery source + pub source: ExtensionsLoadedExtensionSource, + /// Current status: running, disabled, failed, or starting + pub status: ExtensionsLoadedExtensionStatus, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionExtensionsLoadedData { + /// Array of discovered extensions and their status + pub extensions: Vec, +} + +/// Hosting platform type of the repository (github or ado) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum WorkingDirectoryContextHostType { + #[serde(rename = "github")] + Github, + #[serde(rename = "ado")] + Ado, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// The type of operation performed on the plan file +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PlanChangedOperation { + #[serde(rename = "create")] + Create, + #[serde(rename = "update")] + Update, + #[serde(rename = "delete")] + Delete, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Whether the file was newly created or updated +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum WorkspaceFileChangedOperation { + #[serde(rename = "create")] + Create, + #[serde(rename = "update")] + Update, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Origin type of the session being handed off +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum HandoffSourceType { + #[serde(rename = "remote")] + Remote, + #[serde(rename = "local")] + Local, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Whether the session ended normally ("routine") or due to a crash/fatal error ("error") +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ShutdownType { + #[serde(rename = "routine")] + Routine, + #[serde(rename = "error")] + Error, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// The agent mode that was active when this message was sent +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum UserMessageAgentMode { + #[serde(rename = "interactive")] + Interactive, + #[serde(rename = "plan")] + Plan, + #[serde(rename = "autopilot")] + Autopilot, + #[serde(rename = "shell")] + Shell, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Tool call type: "function" for standard tool calls, "custom" for grammar-based tool calls. Defaults to "function" when absent. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum AssistantMessageToolRequestType { + #[serde(rename = "function")] + Function, + #[serde(rename = "custom")] + Custom, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Where the failed model call originated +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ModelCallFailureSource { + #[serde(rename = "top_level")] + TopLevel, + #[serde(rename = "subagent")] + Subagent, + #[serde(rename = "mcp_sampling")] + McpSampling, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Message role: "system" for system prompts, "developer" for developer-injected instructions +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SystemMessageRole { + #[serde(rename = "system")] + System, + #[serde(rename = "developer")] + Developer, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestShellKind { + #[serde(rename = "shell")] + Shell, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestWriteKind { + #[serde(rename = "write")] + Write, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestReadKind { + #[serde(rename = "read")] + Read, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestMcpKind { + #[serde(rename = "mcp")] + Mcp, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestUrlKind { + #[serde(rename = "url")] + Url, +} + +/// Whether this is a store or vote memory operation +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestMemoryAction { + #[serde(rename = "store")] + Store, + #[serde(rename = "vote")] + Vote, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Vote direction (vote only) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestMemoryDirection { + #[serde(rename = "upvote")] + Upvote, + #[serde(rename = "downvote")] + Downvote, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestMemoryKind { + #[serde(rename = "memory")] + Memory, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestCustomToolKind { + #[serde(rename = "custom-tool")] + CustomTool, +} + +/// Permission kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionRequestHookKind { + #[serde(rename = "hook")] + Hook, +} + +/// Details of the permission being requested +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionRequest { + Shell(PermissionRequestShell), + Write(PermissionRequestWrite), + Read(PermissionRequestRead), + Mcp(PermissionRequestMcp), + Url(PermissionRequestUrl), + Memory(PermissionRequestMemory), + CustomTool(PermissionRequestCustomTool), + Hook(PermissionRequestHook), +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestCommandsKind { + #[serde(rename = "commands")] + Commands, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestWriteKind { + #[serde(rename = "write")] + Write, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestReadKind { + #[serde(rename = "read")] + Read, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestMcpKind { + #[serde(rename = "mcp")] + Mcp, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestUrlKind { + #[serde(rename = "url")] + Url, +} + +/// Whether this is a store or vote memory operation +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestMemoryAction { + #[serde(rename = "store")] + Store, + #[serde(rename = "vote")] + Vote, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Vote direction (vote only) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestMemoryDirection { + #[serde(rename = "upvote")] + Upvote, + #[serde(rename = "downvote")] + Downvote, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestMemoryKind { + #[serde(rename = "memory")] + Memory, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestCustomToolKind { + #[serde(rename = "custom-tool")] + CustomTool, +} + +/// Underlying permission kind that needs path approval +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestPathAccessKind { + #[serde(rename = "read")] + Read, + #[serde(rename = "shell")] + Shell, + #[serde(rename = "write")] + Write, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestPathKind { + #[serde(rename = "path")] + Path, +} + +/// Prompt kind discriminator +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionPromptRequestHookKind { + #[serde(rename = "hook")] + Hook, +} + +/// Derived user-facing permission prompt details for UI consumers +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum PermissionPromptRequest { + Commands(PermissionPromptRequestCommands), + Write(PermissionPromptRequestWrite), + Read(PermissionPromptRequestRead), + Mcp(PermissionPromptRequestMcp), + Url(PermissionPromptRequestUrl), + Memory(PermissionPromptRequestMemory), + CustomTool(PermissionPromptRequestCustomTool), + Path(PermissionPromptRequestPath), + Hook(PermissionPromptRequestHook), +} + +/// The outcome of the permission request +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum PermissionCompletedKind { + #[serde(rename = "approved")] + Approved, + #[serde(rename = "approved-for-session")] + ApprovedForSession, + #[serde(rename = "approved-for-location")] + ApprovedForLocation, + #[serde(rename = "denied-by-rules")] + DeniedByRules, + #[serde(rename = "denied-no-approval-rule-and-could-not-request-from-user")] + DeniedNoApprovalRuleAndCouldNotRequestFromUser, + #[serde(rename = "denied-interactively-by-user")] + DeniedInteractivelyByUser, + #[serde(rename = "denied-by-content-exclusion-policy")] + DeniedByContentExclusionPolicy, + #[serde(rename = "denied-by-permission-request-hook")] + DeniedByPermissionRequestHook, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Elicitation mode; "form" for structured input, "url" for browser-based. Defaults to "form" when absent. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ElicitationRequestedMode { + #[serde(rename = "form")] + Form, + #[serde(rename = "url")] + Url, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Schema type indicator (always 'object') +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ElicitationRequestedSchemaType { + #[serde(rename = "object")] + Object, +} + +/// The user action: "accept" (submitted form), "decline" (explicitly refused), or "cancel" (dismissed) +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ElicitationCompletedAction { + #[serde(rename = "accept")] + Accept, + #[serde(rename = "decline")] + Decline, + #[serde(rename = "cancel")] + Cancel, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Connection status: connected, failed, needs-auth, pending, disabled, or not_configured +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServersLoadedServerStatus { + #[serde(rename = "connected")] + Connected, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "needs-auth")] + NeedsAuth, + #[serde(rename = "pending")] + Pending, + #[serde(rename = "disabled")] + Disabled, + #[serde(rename = "not_configured")] + NotConfigured, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// New connection status: connected, failed, needs-auth, pending, disabled, or not_configured +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum McpServerStatusChangedStatus { + #[serde(rename = "connected")] + Connected, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "needs-auth")] + NeedsAuth, + #[serde(rename = "pending")] + Pending, + #[serde(rename = "disabled")] + Disabled, + #[serde(rename = "not_configured")] + NotConfigured, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Discovery source +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExtensionsLoadedExtensionSource { + #[serde(rename = "project")] + Project, + #[serde(rename = "user")] + User, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} + +/// Current status: running, disabled, failed, or starting +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum ExtensionsLoadedExtensionStatus { + #[serde(rename = "running")] + Running, + #[serde(rename = "disabled")] + Disabled, + #[serde(rename = "failed")] + Failed, + #[serde(rename = "starting")] + Starting, + /// Unknown variant for forward compatibility. + #[serde(other)] + Unknown, +} diff --git a/rust/src/handler.rs b/rust/src/handler.rs new file mode 100644 index 000000000..79c7d381d --- /dev/null +++ b/rust/src/handler.rs @@ -0,0 +1,608 @@ +//! Event handler traits for session lifecycle. +//! +//! The [`SessionHandler`](crate::handler::SessionHandler) trait is the primary extension point — implement +//! [`on_event`](crate::handler::SessionHandler::on_event) to control how sessions respond to +//! CLI events, permission requests, tool calls, and user input prompts. + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::types::{ + ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId, + SessionEvent, SessionId, ToolInvocation, ToolResult, +}; + +/// Events dispatched by the SDK session event loop to the handler. +/// +/// The handler returns a [`HandlerResponse`] indicating how the SDK should +/// respond to the CLI. For fire-and-forget events (`SessionEvent`), the +/// response is ignored. +#[non_exhaustive] +#[derive(Debug)] +pub enum HandlerEvent { + /// Informational session event from the timeline (e.g. assistant.message_delta, + /// session.idle, tool.execution_start). Fire-and-forget — return `HandlerResponse::Ok`. + SessionEvent { + /// The session that emitted this event. + session_id: SessionId, + /// The event payload. + event: SessionEvent, + }, + + /// The CLI requests permission for an action. Return `HandlerResponse::Permission(..)`. + PermissionRequest { + /// The requesting session. + session_id: SessionId, + /// Unique ID to correlate the response. + request_id: RequestId, + /// Permission request payload. + data: PermissionRequestData, + }, + + /// The CLI requests user input. Return `HandlerResponse::UserInput(..)`. + /// The handler may block (e.g. awaiting a UI dialog) — this is expected. + UserInput { + /// The requesting session. + session_id: SessionId, + /// The question text to present. + question: String, + /// Optional multiple-choice options. + choices: Option>, + /// Whether free-form text input is allowed. + allow_freeform: Option, + }, + + /// The CLI requests execution of a client-defined tool. + /// Return `HandlerResponse::ToolResult(..)`. + ExternalTool { + /// The tool call to execute. + invocation: ToolInvocation, + }, + + /// The CLI broadcasts an elicitation request for the provider to handle. + /// Return `HandlerResponse::Elicitation(..)`. + ElicitationRequest { + /// The requesting session. + session_id: SessionId, + /// Unique ID to correlate the response. + request_id: RequestId, + /// The elicitation request payload. + request: ElicitationRequest, + }, + + /// The CLI requests exiting plan mode. Return `HandlerResponse::ExitPlanMode(..)`. + ExitPlanMode { + /// The requesting session. + session_id: SessionId, + /// Plan mode exit payload. + data: ExitPlanModeData, + }, + + /// The CLI asks whether to switch to auto model when an eligible rate + /// limit is hit. Return [`HandlerResponse::AutoModeSwitch`]. + AutoModeSwitch { + /// The requesting session. + session_id: SessionId, + /// The specific rate-limit error code that triggered the request, + /// if known (e.g. `user_weekly_rate_limited`, `user_global_rate_limited`). + error_code: Option, + /// Seconds until the rate limit resets, when known. Per RFC 9110's + /// `Retry-After` `delta-seconds` form, this is an integer count of + /// seconds. Handlers can use it to render a humanized reset time + /// alongside the prompt. + retry_after_seconds: Option, + }, +} + +/// Response from the handler back to the SDK, used to construct the +/// JSON-RPC reply sent to the CLI. +#[non_exhaustive] +#[derive(Debug)] +pub enum HandlerResponse { + /// No response needed (used for fire-and-forget `SessionEvent`s). + Ok, + /// Permission decision. + Permission(PermissionResult), + /// User input response (or `None` to signal no input available). + UserInput(Option), + /// Result of a tool execution. + ToolResult(ToolResult), + /// Elicitation result (accept/decline/cancel with optional form data). + Elicitation(ElicitationResult), + /// Exit plan mode decision. + ExitPlanMode(ExitPlanModeResult), + /// Auto-mode-switch decision. + AutoModeSwitch(AutoModeSwitchResponse), +} + +/// Result of a permission request. +/// +/// `#[non_exhaustive]` so future variants can be added without a major +/// version bump. Match arms must include a `_` fallback. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum PermissionResult { + /// Permission granted. + Approved, + /// Permission denied. + Denied, + /// Defer the response. The handler will resolve this request itself + /// later — typically after a UI prompt — by calling + /// `session.permissions.handlePendingPermissionRequest` directly. The + /// SDK will not send a response for this request. + /// + /// **Notification path only** (`permission.requested`). On the direct + /// RPC path (`permission.request`), `Deferred` falls back to + /// [`Approved`](Self::Approved) because that path must return a value + /// to satisfy the JSON-RPC reply contract. + Deferred, + /// Provide the full response payload. The SDK passes the value as-is + /// in the `result` field of `handlePendingPermissionRequest` + /// (notification path) or as the JSON-RPC `result` directly (direct + /// RPC path). + /// + /// Use this for response shapes beyond `{ "kind": "approve-once" }` + /// or `{ "kind": "reject" }` — for example, "approve and remember" + /// with allowlist data. + Custom(serde_json::Value), + /// No user is available to respond — for example, headless agents + /// without an interactive session. Sent as + /// `{ "kind": "user-not-available" }`. + UserNotAvailable, + /// The handler has no result to provide and the CLI should fall back + /// to its default policy. Sent as `{ "kind": "no-result" }`. Distinct + /// from [`Deferred`](Self::Deferred), which suppresses the reply + /// entirely so the handler can resolve later out-of-band. + NoResult, +} + +/// Response to a user input request. +#[derive(Debug, Clone)] +pub struct UserInputResponse { + /// The user's answer text. + pub answer: String, + /// Whether the answer was free-form (not a preset choice). + pub was_freeform: bool, +} + +/// Result of an exit-plan-mode request. +#[derive(Debug, Clone)] +pub struct ExitPlanModeResult { + /// Whether the user approved exiting plan mode. + pub approved: bool, + /// The action the user selected (if any). + pub selected_action: Option, + /// Optional feedback text from the user. + pub feedback: Option, +} + +impl Default for ExitPlanModeResult { + fn default() -> Self { + Self { + approved: true, + selected_action: None, + feedback: None, + } + } +} + +/// Response to a [`HandlerEvent::AutoModeSwitch`] request. +/// +/// Wire serialization matches the CLI's `autoModeSwitch.request` response +/// schema: `"yes"`, `"yes_always"`, or `"no"`. +#[non_exhaustive] +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum AutoModeSwitchResponse { + /// Approve the auto-mode switch for this rate-limit cycle only. + Yes, + /// Approve and remember — auto-accept future auto-mode switches in this + /// session without prompting. + YesAlways, + /// Decline the auto-mode switch. The session stays on the current model + /// and surfaces the rate-limit error. + No, +} + +/// Callback trait for session events. +/// +/// Implement this trait to control how a session responds to CLI events, +/// permission requests, tool calls, user input prompts, elicitations, and +/// plan-mode exits. There are two styles of implementation — pick whichever +/// fits your use case: +/// +/// 1. **Per-event methods (recommended for most handlers).** Override the +/// specific `on_*` methods you care about; every method has a safe +/// default so you only write what you need. This is the pattern used by +/// [`serenity::EventHandler`][serenity], `lapin`, and most Rust SDKs +/// that dispatch broker/client callbacks. +/// 2. **Single [`on_event`](Self::on_event) method.** Override this one +/// method and `match` on [`HandlerEvent`] yourself. Useful for logging +/// middleware, custom routing, or when you want an exhaustiveness check +/// across all variants. +/// +/// When you override [`on_event`](Self::on_event) directly, the per-event methods are not +/// called — your implementation is entirely responsible for dispatch. The +/// default [`on_event`](Self::on_event) fans out to the per-event methods. +/// +/// [serenity]: https://docs.rs/serenity/latest/serenity/client/trait.EventHandler.html +/// +/// # Default behavior +/// +/// - Permission requests → **denied** (safe default). +/// - User input → `None` (no answer available). +/// - External tool calls → failure result with "no handler registered". +/// - Elicitation → `"cancel"`. +/// - Exit plan mode → [`ExitPlanModeResult::default`]. +/// - Auto-mode-switch → [`AutoModeSwitchResponse::No`] (decline by default; the +/// session stays on its current model and surfaces the rate-limit error). +/// - Session events → ignored (fire-and-forget). +/// +/// # Concurrency +/// +/// **Request-triggered events** (`UserInput`, `ExternalTool` via `tool.call`, +/// `ExitPlanMode`, `PermissionRequest` via `permission.request`) are awaited +/// inline in the event loop and therefore processed **serially** per session. +/// Blocking here pauses that session's event loop — which is correct, since +/// the CLI is also blocked waiting for the response. +/// +/// **Notification-triggered events** (`PermissionRequest` via +/// `permission.requested`, `ExternalTool` via `external_tool.requested`) are +/// dispatched on spawned tasks and may run **concurrently** with each other +/// and with the serial event loop. Implementations must be safe for +/// concurrent invocation. +/// +/// # Example +/// +/// ```no_run +/// use async_trait::async_trait; +/// use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +/// use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionId}; +/// +/// struct ApproveReadsOnly; +/// +/// #[async_trait] +/// impl SessionHandler for ApproveReadsOnly { +/// async fn on_permission_request( +/// &self, +/// _sid: SessionId, +/// _rid: RequestId, +/// data: PermissionRequestData, +/// ) -> PermissionResult { +/// match data.extra.get("tool").and_then(|v| v.as_str()) { +/// Some("view") | Some("ls") | Some("grep") => PermissionResult::Approved, +/// _ => PermissionResult::Denied, +/// } +/// } +/// } +/// ``` +#[async_trait] +pub trait SessionHandler: Send + Sync + 'static { + /// Handle an event from the session. + /// + /// The default implementation destructures `event` and calls the + /// matching per-event method (e.g. [`on_permission_request`](Self::on_permission_request) + /// for [`HandlerEvent::PermissionRequest`]). Override this method only + /// if you want a single dispatch point with exhaustive matching — most + /// handlers should override the per-event methods instead. + /// + /// See the [trait-level docs](SessionHandler#concurrency) for details on + /// which events may be dispatched concurrently. + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::SessionEvent { session_id, event } => { + self.on_session_event(session_id, event).await; + HandlerResponse::Ok + } + HandlerEvent::PermissionRequest { + session_id, + request_id, + data, + } => HandlerResponse::Permission( + self.on_permission_request(session_id, request_id, data) + .await, + ), + HandlerEvent::UserInput { + session_id, + question, + choices, + allow_freeform, + } => HandlerResponse::UserInput( + self.on_user_input(session_id, question, choices, allow_freeform) + .await, + ), + HandlerEvent::ExternalTool { invocation } => { + HandlerResponse::ToolResult(self.on_external_tool(invocation).await) + } + HandlerEvent::ElicitationRequest { + session_id, + request_id, + request, + } => HandlerResponse::Elicitation( + self.on_elicitation(session_id, request_id, request).await, + ), + HandlerEvent::ExitPlanMode { session_id, data } => { + HandlerResponse::ExitPlanMode(self.on_exit_plan_mode(session_id, data).await) + } + HandlerEvent::AutoModeSwitch { + session_id, + error_code, + retry_after_seconds, + } => HandlerResponse::AutoModeSwitch( + self.on_auto_mode_switch(session_id, error_code, retry_after_seconds) + .await, + ), + } + } + + /// Informational timeline event (assistant messages, tool execution + /// markers, session idle, etc.). Fire-and-forget — the return value is + /// ignored. + /// + /// Default: do nothing. + async fn on_session_event(&self, _session_id: SessionId, _event: SessionEvent) {} + + /// The CLI is asking whether the agent may perform a privileged action. + /// + /// Default: [`PermissionResult::Denied`]. The default-deny posture + /// matches the CLI's safety model; override to implement your own + /// policy (see the [`permission`](crate::permission) module for common + /// wrappers like `approve_all` / `approve_if`). + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Denied + } + + /// The CLI is asking the user a question (optionally with a list of + /// choices). + /// + /// Default: `None` — the CLI interprets this as "no answer available" + /// and falls back to its own prompt behavior. + async fn on_user_input( + &self, + _session_id: SessionId, + _question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + None + } + + /// The CLI wants to invoke a client-defined ("external") tool. + /// + /// Default: a failure [`ToolResult`] indicating no tool handler is + /// registered. Typical implementations route to a + /// [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) which + /// dispatches to tools registered via + /// [`define_tool`](crate::tool::define_tool) or custom + /// [`ToolHandler`](crate::tool::ToolHandler) impls. + async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { + let msg = format!("No handler registered for tool '{}'", invocation.tool_name); + ToolResult::Expanded(crate::types::ToolResultExpanded { + text_result_for_llm: msg.clone(), + result_type: "failure".to_string(), + session_log: None, + error: Some(msg), + }) + } + + /// The CLI is requesting an elicitation (structured form / URL prompt). + /// + /// Default: cancel. + async fn on_elicitation( + &self, + _session_id: SessionId, + _request_id: RequestId, + _request: ElicitationRequest, + ) -> ElicitationResult { + ElicitationResult { + action: "cancel".to_string(), + content: None, + } + } + + /// The CLI is asking the user whether to exit plan mode. + /// + /// Default: [`ExitPlanModeResult::default`] (approved with no action). + async fn on_exit_plan_mode( + &self, + _session_id: SessionId, + _data: ExitPlanModeData, + ) -> ExitPlanModeResult { + ExitPlanModeResult::default() + } + + /// The CLI is asking whether to switch to auto model after an eligible + /// rate limit. + /// + /// `retry_after_seconds`, when present, is the number of seconds until the + /// rate limit resets (RFC 9110 `Retry-After` `delta-seconds`). Handlers + /// can use it to render a humanized reset time alongside the prompt. + /// + /// Default: [`AutoModeSwitchResponse::No`] — decline. Override only if + /// your application surfaces a UX for the rate-limit-recovery prompt. + async fn on_auto_mode_switch( + &self, + _session_id: SessionId, + _error_code: Option, + _retry_after_seconds: Option, + ) -> AutoModeSwitchResponse { + AutoModeSwitchResponse::No + } +} + +/// A [`SessionHandler`] that auto-approves all permissions and ignores all events. +/// +/// Useful for CLI tools, scripts, and tests that don't need interactive +/// permission prompts or custom tool handling. +#[derive(Debug, Clone)] +pub struct ApproveAllHandler; + +#[async_trait] +impl SessionHandler for ApproveAllHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } +} + +/// A [`SessionHandler`] that denies all permission requests and otherwise +/// relies on the trait's default fallback responses for every other event +/// (e.g. tool invocations return "unhandled", elicitations cancel, plan-mode +/// prompts decline). This is the safe default used when no handler is set on +/// [`SessionConfig::handler`](crate::types::SessionConfig::handler) — sessions +/// will not stall on permission prompts (they're denied immediately) but no +/// privileged actions will be taken without an explicit opt-in. +#[derive(Debug, Clone)] +pub struct DenyAllHandler; + +#[async_trait] +impl SessionHandler for DenyAllHandler { + // All defaults are already safe: permissions deny, everything else is a + // sensible fallback. We just reuse them here for clarity. +} + +#[cfg(test)] +mod tests { + use serde_json::Value; + + use super::*; + use crate::types::{PermissionRequestData, RequestId, SessionId}; + + fn perm_data() -> PermissionRequestData { + PermissionRequestData::default() + } + + // A handler that overrides only `on_permission_request` (per-method style). + struct ApproveViaPerMethod; + + #[async_trait] + impl SessionHandler for ApproveViaPerMethod { + async fn on_permission_request( + &self, + _: SessionId, + _: RequestId, + _: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } + } + + // A handler that overrides `on_event` directly (legacy / routing style). + struct ApproveViaOnEvent; + + #[async_trait] + impl SessionHandler for ApproveViaOnEvent { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::PermissionRequest { .. } => { + HandlerResponse::Permission(PermissionResult::Approved) + } + _ => HandlerResponse::Ok, + } + } + } + + #[tokio::test] + async fn per_method_override_dispatches_via_default_on_event() { + let h = ApproveViaPerMethod; + let resp = h + .on_event(HandlerEvent::PermissionRequest { + session_id: SessionId::from("s1".to_string()), + request_id: RequestId::new("r1"), + data: perm_data(), + }) + .await; + assert!(matches!( + resp, + HandlerResponse::Permission(PermissionResult::Approved) + )); + } + + #[tokio::test] + async fn on_event_override_short_circuits_per_method_defaults() { + let h = ApproveViaOnEvent; + let resp = h + .on_event(HandlerEvent::PermissionRequest { + session_id: SessionId::from("s1".to_string()), + request_id: RequestId::new("r1"), + data: perm_data(), + }) + .await; + assert!(matches!( + resp, + HandlerResponse::Permission(PermissionResult::Approved) + )); + } + + #[tokio::test] + async fn deny_all_handler_uses_default_permission_deny() { + let h = DenyAllHandler; + let resp = h + .on_event(HandlerEvent::PermissionRequest { + session_id: SessionId::from("s1".to_string()), + request_id: RequestId::new("r1"), + data: perm_data(), + }) + .await; + assert!(matches!( + resp, + HandlerResponse::Permission(PermissionResult::Denied) + )); + } + + #[tokio::test] + async fn default_on_external_tool_returns_failure() { + let h = DenyAllHandler; + let resp = h + .on_event(HandlerEvent::ExternalTool { + invocation: crate::types::ToolInvocation { + session_id: SessionId::from("s1".to_string()), + tool_call_id: "tc1".to_string(), + tool_name: "missing".to_string(), + arguments: Value::Null, + traceparent: None, + tracestate: None, + }, + }) + .await; + match resp { + HandlerResponse::ToolResult(crate::types::ToolResult::Expanded(exp)) => { + assert_eq!(exp.result_type, "failure"); + assert!(exp.text_result_for_llm.contains("missing")); + assert_eq!(exp.error.as_deref(), Some(exp.text_result_for_llm.as_str())); + } + other => panic!("unexpected response: {other:?}"), + } + } + + #[tokio::test] + async fn default_on_elicitation_returns_cancel() { + let h = DenyAllHandler; + let resp = h + .on_event(HandlerEvent::ElicitationRequest { + session_id: SessionId::from("s1".to_string()), + request_id: RequestId::new("r1"), + request: crate::types::ElicitationRequest { + message: "test".to_string(), + requested_schema: None, + mode: Some(crate::types::ElicitationMode::Form), + elicitation_source: None, + url: None, + }, + }) + .await; + match resp { + HandlerResponse::Elicitation(r) => assert_eq!(r.action, "cancel"), + other => panic!("unexpected response: {other:?}"), + } + } +} diff --git a/rust/src/hooks.rs b/rust/src/hooks.rs new file mode 100644 index 000000000..ca755c6f9 --- /dev/null +++ b/rust/src/hooks.rs @@ -0,0 +1,715 @@ +//! Lifecycle hook callbacks invoked at key session points. +//! +//! Hooks let you intercept and modify CLI behavior — approve or deny tool +//! use, rewrite user prompts, inject context at session start, and handle +//! errors. Implement [`SessionHooks`](crate::hooks::SessionHooks) and pass it to +//! [`Client::create_session`](crate::Client::create_session). + +use std::path::PathBuf; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::types::SessionId; + +/// Context provided to every hook invocation. +#[derive(Debug, Clone)] +pub struct HookContext { + /// The session this hook was triggered in. + pub session_id: SessionId, +} + +/// Input for the `preToolUse` hook — received before a tool executes. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PreToolUseInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// Name of the tool about to execute. + pub tool_name: String, + /// Arguments passed to the tool. + pub tool_args: Value, +} + +/// Output for the `preToolUse` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PreToolUseOutput { + /// "allow" or "deny". + #[serde(skip_serializing_if = "Option::is_none")] + pub permission_decision: Option, + /// Reason for the decision (shown to the agent). + #[serde(skip_serializing_if = "Option::is_none")] + pub permission_decision_reason: Option, + /// Replacement arguments for the tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub modified_args: Option, + /// Extra context injected into the agent's prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_context: Option, + /// Suppress the hook's output from the session log. + #[serde(skip_serializing_if = "Option::is_none")] + pub suppress_output: Option, +} + +/// Input for the `postToolUse` hook — received after a tool executes. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostToolUseInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// Name of the tool that executed. + pub tool_name: String, + /// Arguments that were passed to the tool. + pub tool_args: Value, + /// Result returned by the tool. + pub tool_result: Value, +} + +/// Output for the `postToolUse` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PostToolUseOutput { + /// Replacement result for the tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub modified_result: Option, + /// Extra context injected into the agent's prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_context: Option, + /// Suppress the hook's output from the session log. + #[serde(skip_serializing_if = "Option::is_none")] + pub suppress_output: Option, +} + +/// Input for the `userPromptSubmitted` hook — received when the user sends a message. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UserPromptSubmittedInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// The user's message text. + pub prompt: String, +} + +/// Output for the `userPromptSubmitted` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct UserPromptSubmittedOutput { + /// Replacement prompt text. + #[serde(skip_serializing_if = "Option::is_none")] + pub modified_prompt: Option, + /// Extra context injected into the agent's prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_context: Option, + /// Suppress the hook's output from the session log. + #[serde(skip_serializing_if = "Option::is_none")] + pub suppress_output: Option, +} + +/// Input for the `sessionStart` hook. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionStartInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// How the session was started: `"startup"`, `"resume"`, or `"new"`. + pub source: String, + /// The first user message, if any. + #[serde(default)] + pub initial_prompt: Option, +} + +/// Output for the `sessionStart` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionStartOutput { + /// Extra context injected at session start. + #[serde(skip_serializing_if = "Option::is_none")] + pub additional_context: Option, + /// Config overrides applied to the session. + #[serde(skip_serializing_if = "Option::is_none")] + pub modified_config: Option, +} + +/// Input for the `sessionEnd` hook. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionEndInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// Why the session ended: `"complete"`, `"error"`, `"abort"`, `"timeout"`, `"user_exit"`. + pub reason: String, + /// The last assistant message. + #[serde(default)] + pub final_message: Option, + /// Error message, if the session ended due to an error. + #[serde(default)] + pub error: Option, +} + +/// Output for the `sessionEnd` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionEndOutput { + /// Suppress the hook's output from the session log. + #[serde(skip_serializing_if = "Option::is_none")] + pub suppress_output: Option, + /// Actions to run during cleanup. + #[serde(skip_serializing_if = "Option::is_none")] + pub cleanup_actions: Option>, + /// Summary text for the session. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_summary: Option, +} + +/// Input for the `errorOccurred` hook. +#[derive(Debug, Clone, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ErrorOccurredInput { + /// Unix timestamp (ms). + pub timestamp: i64, + /// Working directory. + pub cwd: PathBuf, + /// The error message. + pub error: String, + /// Context where the error occurred: `"model_call"`, `"tool_execution"`, `"system"`, `"user_input"`. + pub error_context: String, + /// Whether the error is recoverable. + pub recoverable: bool, +} + +/// Output for the `errorOccurred` hook. +#[derive(Debug, Clone, Default, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ErrorOccurredOutput { + /// Suppress the hook's output from the session log. + #[serde(skip_serializing_if = "Option::is_none")] + pub suppress_output: Option, + /// How to handle the error: `"retry"`, `"skip"`, or `"abort"`. + #[serde(skip_serializing_if = "Option::is_none")] + pub error_handling: Option, + /// Number of retries to attempt. + #[serde(skip_serializing_if = "Option::is_none")] + pub retry_count: Option, + /// Message to show the user. + #[serde(skip_serializing_if = "Option::is_none")] + pub user_notification: Option, +} + +/// Events dispatched to [`SessionHooks::on_hook`] at CLI lifecycle points. +/// +/// Each variant carries the typed input for that hook plus the shared +/// [`HookContext`]. The handler returns a matching [`HookOutput`] variant +/// (or [`HookOutput::None`] to signal "no hook registered"). +#[non_exhaustive] +#[derive(Debug)] +pub enum HookEvent { + /// Fired before a tool executes. + PreToolUse { + /// Typed input data. + input: PreToolUseInput, + /// Session context. + ctx: HookContext, + }, + /// Fired after a tool executes. + PostToolUse { + /// Typed input data. + input: PostToolUseInput, + /// Session context. + ctx: HookContext, + }, + /// Fired when the user sends a message. + UserPromptSubmitted { + /// Typed input data. + input: UserPromptSubmittedInput, + /// Session context. + ctx: HookContext, + }, + /// Fired at session creation or resume. + SessionStart { + /// Typed input data. + input: SessionStartInput, + /// Session context. + ctx: HookContext, + }, + /// Fired when the session ends. + SessionEnd { + /// Typed input data. + input: SessionEndInput, + /// Session context. + ctx: HookContext, + }, + /// Fired when an error occurs. + ErrorOccurred { + /// Typed input data. + input: ErrorOccurredInput, + /// Session context. + ctx: HookContext, + }, +} + +/// Response from [`SessionHooks::on_hook`] back to the SDK. +/// +/// Return the variant matching the [`HookEvent`] you received, or +/// [`HookOutput::None`] to indicate no hook is registered for that event. +#[non_exhaustive] +#[derive(Debug)] +pub enum HookOutput { + /// No hook registered — the SDK returns an empty output object to the CLI. + None, + /// Response for a pre-tool-use hook. + PreToolUse(PreToolUseOutput), + /// Response for a post-tool-use hook. + PostToolUse(PostToolUseOutput), + /// Response for a user-prompt-submitted hook. + UserPromptSubmitted(UserPromptSubmittedOutput), + /// Response for a session-start hook. + SessionStart(SessionStartOutput), + /// Response for a session-end hook. + SessionEnd(SessionEndOutput), + /// Response for an error-occurred hook. + ErrorOccurred(ErrorOccurredOutput), +} + +impl HookOutput { + fn variant_name(&self) -> &'static str { + match self { + Self::None => "None", + Self::PreToolUse(_) => "PreToolUse", + Self::PostToolUse(_) => "PostToolUse", + Self::UserPromptSubmitted(_) => "UserPromptSubmitted", + Self::SessionStart(_) => "SessionStart", + Self::SessionEnd(_) => "SessionEnd", + Self::ErrorOccurred(_) => "ErrorOccurred", + } + } +} + +/// Callback trait for session hooks — invoked by the CLI at key lifecycle +/// points (tool use, prompt submission, session start/end, errors). +/// +/// Implement this trait to intercept and modify CLI behavior at hook points. +/// There are two styles of implementation — pick whichever fits: +/// +/// 1. **Per-hook methods (recommended).** Override the specific `on_*` hook +/// methods you care about; every hook has a default that returns `None` +/// (meaning "no hook registered, use CLI default behavior"). +/// 2. **Single [`on_hook`](Self::on_hook) method.** Override this one and +/// `match` on [`HookEvent`] yourself — useful for logging middleware or +/// shared dispatch logic. +/// +/// Hooks only fire when hooks are enabled on the session (via +/// [`SessionConfig::hooks = Some(true)`](crate::types::SessionConfig::hooks), +/// which [`SessionConfig::with_hooks`](crate::types::SessionConfig::with_hooks) +/// sets automatically). +#[async_trait] +pub trait SessionHooks: Send + Sync + 'static { + /// Top-level dispatch. The default implementation fans out to the + /// per-hook methods below; override this only if you want a single + /// matching point across all hook types. + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::PreToolUse { input, ctx } => self + .on_pre_tool_use(input, ctx) + .await + .map(HookOutput::PreToolUse) + .unwrap_or(HookOutput::None), + HookEvent::PostToolUse { input, ctx } => self + .on_post_tool_use(input, ctx) + .await + .map(HookOutput::PostToolUse) + .unwrap_or(HookOutput::None), + HookEvent::UserPromptSubmitted { input, ctx } => self + .on_user_prompt_submitted(input, ctx) + .await + .map(HookOutput::UserPromptSubmitted) + .unwrap_or(HookOutput::None), + HookEvent::SessionStart { input, ctx } => self + .on_session_start(input, ctx) + .await + .map(HookOutput::SessionStart) + .unwrap_or(HookOutput::None), + HookEvent::SessionEnd { input, ctx } => self + .on_session_end(input, ctx) + .await + .map(HookOutput::SessionEnd) + .unwrap_or(HookOutput::None), + HookEvent::ErrorOccurred { input, ctx } => self + .on_error_occurred(input, ctx) + .await + .map(HookOutput::ErrorOccurred) + .unwrap_or(HookOutput::None), + } + } + + /// Called before a tool executes. Return `Some(output)` to approve/deny + /// or modify the call, or `None` (default) to pass through unchanged. + async fn on_pre_tool_use( + &self, + _input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + None + } + + /// Called after a tool executes. Return `Some(output)` to inject + /// additional context or signal post-processing decisions; `None` + /// (default) means no follow-up. + async fn on_post_tool_use( + &self, + _input: PostToolUseInput, + _ctx: HookContext, + ) -> Option { + None + } + + /// Called when the user submits a prompt. Return `Some(output)` to + /// rewrite the prompt or inject extra context; `None` (default) passes + /// through unchanged. + async fn on_user_prompt_submitted( + &self, + _input: UserPromptSubmittedInput, + _ctx: HookContext, + ) -> Option { + None + } + + /// Called at session creation or resume. Return `Some(output)` to + /// inject startup context. + async fn on_session_start( + &self, + _input: SessionStartInput, + _ctx: HookContext, + ) -> Option { + None + } + + /// Called when the session ends. Return `Some(output)` if your hook + /// needs to signal cleanup behavior. + async fn on_session_end( + &self, + _input: SessionEndInput, + _ctx: HookContext, + ) -> Option { + None + } + + /// Called when the CLI reports an error. Return `Some(output)` to + /// influence retry behavior or surface a user-facing notification. + async fn on_error_occurred( + &self, + _input: ErrorOccurredInput, + _ctx: HookContext, + ) -> Option { + None + } +} + +/// Dispatches a `hooks.invoke` request to [`SessionHooks::on_hook`]. +/// +/// Returns `Ok(Value)` shaped like `{ "output": ... }` on success. +/// If no hook is registered ([`HookOutput::None`]), the output is an empty +/// object: `{ "output": {} }`. +pub(crate) async fn dispatch_hook( + hooks: &dyn SessionHooks, + session_id: &SessionId, + hook_type: &str, + raw_input: Value, +) -> Result { + let ctx = HookContext { + session_id: session_id.clone(), + }; + + let event = match hook_type { + "preToolUse" => { + let input: PreToolUseInput = serde_json::from_value(raw_input)?; + HookEvent::PreToolUse { input, ctx } + } + "postToolUse" => { + let input: PostToolUseInput = serde_json::from_value(raw_input)?; + HookEvent::PostToolUse { input, ctx } + } + "userPromptSubmitted" => { + let input: UserPromptSubmittedInput = serde_json::from_value(raw_input)?; + HookEvent::UserPromptSubmitted { input, ctx } + } + "sessionStart" => { + let input: SessionStartInput = serde_json::from_value(raw_input)?; + HookEvent::SessionStart { input, ctx } + } + "sessionEnd" => { + let input: SessionEndInput = serde_json::from_value(raw_input)?; + HookEvent::SessionEnd { input, ctx } + } + "errorOccurred" => { + let input: ErrorOccurredInput = serde_json::from_value(raw_input)?; + HookEvent::ErrorOccurred { input, ctx } + } + _ => { + tracing::warn!( + hook_type = hook_type, + session_id = %session_id, + "unknown hook type" + ); + return Ok(serde_json::json!({ "output": {} })); + } + }; + + let output = hooks.on_hook(event).await; + + // Validate that the output variant matches the dispatched hook type. + // A mismatched return (e.g. HookOutput::SessionEnd for a preToolUse + // event) is treated as "no hook registered" to avoid sending the CLI + // a semantically wrong response. + let output_value = match (hook_type, &output) { + (_, HookOutput::None) => None, + ("preToolUse", HookOutput::PreToolUse(o)) => Some(serde_json::to_value(o)?), + ("postToolUse", HookOutput::PostToolUse(o)) => Some(serde_json::to_value(o)?), + ("userPromptSubmitted", HookOutput::UserPromptSubmitted(o)) => { + Some(serde_json::to_value(o)?) + } + ("sessionStart", HookOutput::SessionStart(o)) => Some(serde_json::to_value(o)?), + ("sessionEnd", HookOutput::SessionEnd(o)) => Some(serde_json::to_value(o)?), + ("errorOccurred", HookOutput::ErrorOccurred(o)) => Some(serde_json::to_value(o)?), + _ => { + tracing::warn!( + hook_type = hook_type, + session_id = %session_id, + output_variant = output.variant_name(), + "hook returned mismatched output variant, treating as unregistered" + ); + None + } + }; + + Ok(serde_json::json!({ "output": output_value.unwrap_or(Value::Object(Default::default())) })) +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestHooks; + + #[async_trait] + impl SessionHooks for TestHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::PreToolUse { input, .. } => { + if input.tool_name == "dangerous_tool" { + HookOutput::PreToolUse(PreToolUseOutput { + permission_decision: Some("deny".to_string()), + permission_decision_reason: Some("blocked by policy".to_string()), + ..Default::default() + }) + } else { + HookOutput::None + } + } + HookEvent::UserPromptSubmitted { input, .. } => { + HookOutput::UserPromptSubmitted(UserPromptSubmittedOutput { + modified_prompt: Some(format!("[prefixed] {}", input.prompt)), + ..Default::default() + }) + } + _ => HookOutput::None, + } + } + } + + #[tokio::test] + async fn dispatch_pre_tool_use_deny() { + let hooks = TestHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "toolName": "dangerous_tool", + "toolArgs": {} + }); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input) + .await + .unwrap(); + let output = &result["output"]; + assert_eq!(output["permissionDecision"], "deny"); + assert_eq!(output["permissionDecisionReason"], "blocked by policy"); + } + + #[tokio::test] + async fn dispatch_pre_tool_use_passthrough() { + let hooks = TestHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "toolName": "safe_tool", + "toolArgs": {"key": "value"} + }); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input) + .await + .unwrap(); + // No hook registered for this tool — output should be empty object + assert_eq!(result["output"], serde_json::json!({})); + } + + #[tokio::test] + async fn dispatch_user_prompt_submitted() { + let hooks = TestHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "prompt": "hello world" + }); + let result = dispatch_hook( + &hooks, + &SessionId::new("sess-1"), + "userPromptSubmitted", + input, + ) + .await + .unwrap(); + assert_eq!(result["output"]["modifiedPrompt"], "[prefixed] hello world"); + } + + #[tokio::test] + async fn dispatch_unregistered_hook_returns_empty() { + let hooks = TestHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "reason": "complete" + }); + // TestHooks doesn't handle SessionEnd + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionEnd", input) + .await + .unwrap(); + assert_eq!(result["output"], serde_json::json!({})); + } + + #[tokio::test] + async fn dispatch_unknown_hook_type() { + let hooks = TestHooks; + let input = serde_json::json!({}); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "unknownHook", input) + .await + .unwrap(); + assert_eq!(result["output"], serde_json::json!({})); + } + + #[tokio::test] + async fn dispatch_mismatched_output_returns_empty() { + struct MismatchHooks; + #[async_trait] + impl SessionHooks for MismatchHooks { + async fn on_hook(&self, _event: HookEvent) -> HookOutput { + // Always return SessionEnd output regardless of event type + HookOutput::SessionEnd(SessionEndOutput { + session_summary: Some("oops".to_string()), + ..Default::default() + }) + } + } + + let hooks = MismatchHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "toolName": "some_tool", + "toolArgs": {} + }); + // preToolUse event gets a SessionEnd output — should be treated as empty + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "preToolUse", input) + .await + .unwrap(); + assert_eq!(result["output"], serde_json::json!({})); + } + + #[tokio::test] + async fn dispatch_post_tool_use_default() { + let hooks = TestHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "toolName": "some_tool", + "toolArgs": {}, + "toolResult": "success" + }); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "postToolUse", input) + .await + .unwrap(); + assert_eq!(result["output"], serde_json::json!({})); + } + + #[tokio::test] + async fn dispatch_session_start() { + struct StartHooks; + #[async_trait] + impl SessionHooks for StartHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::SessionStart { .. } => { + HookOutput::SessionStart(SessionStartOutput { + additional_context: Some("extra context".to_string()), + ..Default::default() + }) + } + _ => HookOutput::None, + } + } + } + + let hooks = StartHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "source": "new" + }); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "sessionStart", input) + .await + .unwrap(); + assert_eq!(result["output"]["additionalContext"], "extra context"); + } + + #[tokio::test] + async fn dispatch_error_occurred() { + struct ErrorHooks; + #[async_trait] + impl SessionHooks for ErrorHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::ErrorOccurred { .. } => { + HookOutput::ErrorOccurred(ErrorOccurredOutput { + error_handling: Some("retry".to_string()), + retry_count: Some(3), + ..Default::default() + }) + } + _ => HookOutput::None, + } + } + } + + let hooks = ErrorHooks; + let input = serde_json::json!({ + "timestamp": 1234567890, + "cwd": "/tmp", + "error": "model timeout", + "errorContext": "model_call", + "recoverable": true + }); + let result = dispatch_hook(&hooks, &SessionId::new("sess-1"), "errorOccurred", input) + .await + .unwrap(); + assert_eq!(result["output"]["errorHandling"], "retry"); + assert_eq!(result["output"]["retryCount"], 3); + } +} diff --git a/rust/src/jsonrpc.rs b/rust/src/jsonrpc.rs new file mode 100644 index 000000000..5f6d95612 --- /dev/null +++ b/rust/src/jsonrpc.rs @@ -0,0 +1,549 @@ +use std::collections::HashMap; +use std::sync::Arc; +use std::sync::atomic::{AtomicU64, Ordering}; + +use parking_lot::RwLock; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader}; +use tokio::sync::{broadcast, mpsc, oneshot}; +use tracing::{Instrument, error, warn}; + +use crate::{Error, ProtocolError}; + +/// A JSON-RPC 2.0 request message. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct JsonRpcRequest { + /// Protocol version (always `"2.0"`). + pub jsonrpc: String, + /// Request ID for correlating responses. + pub id: u64, + /// RPC method name. + pub method: String, + /// Optional method parameters. + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +/// A JSON-RPC 2.0 response message. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct JsonRpcResponse { + /// Protocol version (always `"2.0"`). + pub jsonrpc: String, + /// Request ID this response correlates to. + pub id: u64, + /// Success payload (mutually exclusive with `error`). + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + /// Error payload (mutually exclusive with `result`). + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// A JSON-RPC 2.0 error object. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcError { + /// Numeric error code. + pub code: i32, + /// Human-readable error description. + pub message: String, + /// Optional structured error data. + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +/// Standard JSON-RPC 2.0 error codes. +pub mod error_codes { + /// Method not found (-32601). + pub const METHOD_NOT_FOUND: i32 = -32601; + /// Invalid method parameters (-32602). + pub const INVALID_PARAMS: i32 = -32602; + /// Internal server error (-32603). + #[allow(dead_code, reason = "standard JSON-RPC code, reserved for future use")] + pub const INTERNAL_ERROR: i32 = -32603; +} + +/// A JSON-RPC 2.0 notification (no `id`, no response expected). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct JsonRpcNotification { + /// Protocol version (always `"2.0"`). + pub jsonrpc: String, + /// Notification method name. + pub method: String, + /// Optional notification parameters. + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +/// A parsed JSON-RPC 2.0 message — request, response, or notification. +#[derive(Debug, Clone, Serialize)] +pub enum JsonRpcMessage { + /// An incoming or outgoing request. + Request(JsonRpcRequest), + /// A response to a previous request. + Response(JsonRpcResponse), + /// A fire-and-forget notification. + Notification(JsonRpcNotification), +} + +/// Custom deserializer that dispatches based on field presence instead of +/// `#[serde(untagged)]` which tries each variant sequentially (3× parse +/// attempts for Notification — the hot-path streaming variant). +/// +/// Dispatch logic: +/// - has `id` + has `method` → Request +/// - has `id` + no `method` → Response +/// - no `id` → Notification +impl<'de> Deserialize<'de> for JsonRpcMessage { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let value = Value::deserialize(deserializer)?; + let obj = value + .as_object() + .ok_or_else(|| serde::de::Error::custom("expected a JSON object"))?; + + let has_id = obj.contains_key("id"); + let has_method = obj.contains_key("method"); + + if has_id && has_method { + JsonRpcRequest::deserialize(value) + .map(JsonRpcMessage::Request) + .map_err(serde::de::Error::custom) + } else if has_id { + JsonRpcResponse::deserialize(value) + .map(JsonRpcMessage::Response) + .map_err(serde::de::Error::custom) + } else { + JsonRpcNotification::deserialize(value) + .map(JsonRpcMessage::Notification) + .map_err(serde::de::Error::custom) + } + } +} + +impl JsonRpcRequest { + /// Create a new JSON-RPC request with the given ID, method, and params. + pub fn new(id: u64, method: &str, params: Option) -> Self { + Self { + jsonrpc: "2.0".to_string(), + id, + method: method.to_string(), + params, + } + } +} + +impl JsonRpcResponse { + /// Returns `true` if this response contains an error. + #[allow(dead_code)] + pub fn is_error(&self) -> bool { + self.error.is_some() + } +} + +const CONTENT_LENGTH_HEADER: &str = "Content-Length: "; + +/// One framed JSON-RPC message handed to the writer actor. +/// +/// `frame` is the fully serialized bytes (header + body); the caller pays +/// the serde cost synchronously before enqueueing so the actor never sees a +/// `Result` from JSON encoding. `ack` resolves once the bytes have been +/// fully written and flushed (or the underlying I/O reports an error). If +/// the caller drops the `oneshot::Receiver`, the actor still completes the +/// frame — caller cancellation cannot desync the wire. +struct WriteCommand { + frame: Vec, + ack: oneshot::Sender>, +} + +/// Low-level JSON-RPC 2.0 client over Content-Length-framed streams. +/// +/// # Cancel safety +/// +/// All public methods (`write`, `send_request`) are **cancel-safe**: the +/// actual bytes hit the wire on a dedicated background actor task, so +/// dropping the caller's future after `await` returns `Pending` cannot +/// produce a partial frame on the wire. Frames either land atomically or +/// the underlying I/O fails. See `cancel-safety review` artifact for the +/// full RFD-400 reasoning. +pub struct JsonRpcClient { + request_id: AtomicU64, + /// Sender side of the writer actor's command queue. Public methods + /// pre-serialize their frames and enqueue here; the background actor + /// drains the queue and serializes writes onto the underlying + /// `AsyncWrite`. Unbounded by design — RFD 400 explicitly permits this + /// for cancel-safety, and JSON-RPC frames are small relative to the + /// natural request/response back-pressure of the wire. + write_tx: mpsc::UnboundedSender, + pending_requests: Arc>>>, + notification_tx: broadcast::Sender, + request_tx: mpsc::UnboundedSender, +} + +impl JsonRpcClient { + /// Create a new client from async read/write streams. + /// + /// Spawns two background tasks: a reader that dispatches incoming + /// messages to pending request channels, the notification broadcast, + /// or the request-forwarding channel; and a writer actor that owns the + /// underlying `AsyncWrite` and serializes frames atomically. + pub fn new( + writer: impl AsyncWrite + Unpin + Send + 'static, + reader: impl AsyncRead + Unpin + Send + 'static, + notification_tx: broadcast::Sender, + request_tx: mpsc::UnboundedSender, + ) -> Self { + let (write_tx, write_rx) = mpsc::unbounded_channel::(); + + let writer_span = tracing::error_span!("jsonrpc_write_loop"); + tokio::spawn(Self::write_loop(writer, write_rx).instrument(writer_span)); + + let client = Self { + request_id: AtomicU64::new(1), + write_tx, + pending_requests: Arc::new(RwLock::new(HashMap::new())), + notification_tx, + request_tx, + }; + + let pending_requests = client.pending_requests.clone(); + let notification_tx_clone = client.notification_tx.clone(); + let request_tx_clone = client.request_tx.clone(); + let reader_span = tracing::error_span!("jsonrpc_read_loop"); + + tokio::spawn( + async move { + Self::read_loop( + reader, + pending_requests, + notification_tx_clone, + request_tx_clone, + ) + .await; + } + .instrument(reader_span), + ); + + client + } + + /// Writer-actor task. Owns the `AsyncWrite`, drains the command queue, + /// and writes each frame atomically (header + body + flush) before + /// signaling the ack. + /// + /// Caller-side cancellation cannot interrupt a write in progress: + /// dropping the ack `oneshot::Receiver` does not cancel the in-flight + /// I/O. Once `WriteCommand` is enqueued the frame is committed to land + /// on the wire (or surface an `io::Error` to the ack receiver if the + /// transport is broken). + /// + /// Exits cleanly when all senders drop (channel closes), flushing any + /// final buffered bytes. + async fn write_loop( + mut writer: impl AsyncWrite + Unpin + Send + 'static, + mut rx: mpsc::UnboundedReceiver, + ) { + while let Some(WriteCommand { frame, ack }) = rx.recv().await { + let result = async { + writer.write_all(&frame).await?; + writer.flush().await?; + Ok::<_, std::io::Error>(()) + } + .await; + + // Caller may have dropped the ack receiver (e.g. their + // `await` was cancelled); that's fine — we still completed + // the write, which was the whole point. + let _ = ack.send(result); + } + } + + async fn read_loop( + reader: impl AsyncRead + Unpin + Send, + pending_requests: Arc>>>, + notification_tx: broadcast::Sender, + request_tx: mpsc::UnboundedSender, + ) { + let mut reader = BufReader::new(reader); + + loop { + match Self::read_message(&mut reader).await { + Ok(Some(message)) => match message { + JsonRpcMessage::Response(response) => { + let id = response.id; + let tx = pending_requests.write().remove(&id); + if let Some(tx) = tx { + if tx.send(response).is_err() { + warn!(request_id = %id, "failed to send response for request"); + } + } else { + warn!(request_id = %id, "received response for unknown request id"); + } + } + JsonRpcMessage::Notification(notification) => { + let _ = notification_tx.send(notification); + } + JsonRpcMessage::Request(request) => { + if request_tx.send(request).is_err() { + warn!("failed to forward JSON-RPC request, channel closed"); + } + } + }, + Ok(None) => { + break; + } + Err(e) => { + error!(error = %e, "error reading from CLI"); + break; + } + } + } + + // Drain in-flight requests so callers observe cancellation + // instead of hanging on a oneshot receiver. + let mut pending = pending_requests.write(); + if !pending.is_empty() { + warn!( + count = pending.len(), + "draining pending requests after read loop exit" + ); + pending.clear(); + } + } + + async fn read_message( + reader: &mut BufReader, + ) -> Result, Error> { + let mut line = String::new(); + let mut content_length = None; + + loop { + line.clear(); + if reader.read_line(&mut line).await? == 0 { + return Ok(None); + } + + let trimmed = line.trim(); + if trimmed.is_empty() { + break; + } + + if let Some(value) = trimmed.strip_prefix(CONTENT_LENGTH_HEADER) { + content_length = Some(value.trim().parse::().map_err(|_| { + Error::Protocol(ProtocolError::InvalidContentLength( + value.trim().to_string(), + )) + })?); + } + } + + let Some(length) = content_length else { + return Err(Error::Protocol(ProtocolError::MissingContentLength)); + }; + + let mut body = vec![0u8; length]; + reader.read_exact(&mut body).await?; + + let message: JsonRpcMessage = serde_json::from_slice(&body)?; + Ok(Some(message)) + } + + /// Send a JSON-RPC request and wait for the matching response. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** The frame is committed to the wire via the writer + /// actor before this future yields; cancelling the await drops the + /// response oneshot but does not desync the transport. The pending- + /// requests map is cleaned up automatically (the `PendingGuard` drop + /// removes the entry, and the read loop's response handling tolerates + /// a missing entry). + pub async fn send_request( + &self, + method: &str, + params: Option, + ) -> Result { + let id = self.request_id.fetch_add(1, Ordering::SeqCst); + let request = JsonRpcRequest::new(id, method, params); + + let (tx, rx) = oneshot::channel(); + self.pending_requests.write().insert(id, tx); + + // RAII guard that removes the pending entry if this future is + // dropped before the response arrives. Disarmed below before the + // success return so the read loop owns the cleanup on the happy + // path. + let mut guard = PendingGuard { + map: &self.pending_requests, + id, + armed: true, + }; + + // The PendingGuard's drop removes the entry on every error path + // and on cancellation; disarmed below before the success return so + // the read loop owns the cleanup on the happy path. + self.write(&request).await?; + + let response = rx + .await + .map_err(|_| Error::Protocol(ProtocolError::RequestCancelled))?; + guard.disarm(); + Ok(response) + } + + /// Write a Content-Length-framed JSON-RPC message to the transport. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** Pre-serializes the body, enqueues it on the writer + /// actor's command channel, and awaits an ack. Caller cancellation + /// drops the ack receiver; the actor still completes the frame and + /// flushes. A partial frame can never appear on the wire. + pub async fn write(&self, message: &T) -> Result<(), Error> { + let body = serde_json::to_vec(message)?; + let mut frame = Vec::with_capacity(CONTENT_LENGTH_HEADER.len() + 16 + body.len() + 4); + frame.extend_from_slice(CONTENT_LENGTH_HEADER.as_bytes()); + frame.extend_from_slice(body.len().to_string().as_bytes()); + frame.extend_from_slice(b"\r\n\r\n"); + frame.extend_from_slice(&body); + + let (ack_tx, ack_rx) = oneshot::channel(); + self.write_tx + .send(WriteCommand { frame, ack: ack_tx }) + .map_err(|_| { + Error::Io(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "writer actor has shut down", + )) + })?; + + match ack_rx.await { + Ok(Ok(())) => Ok(()), + Ok(Err(e)) => Err(Error::Io(e)), + Err(_) => Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "writer actor dropped ack without responding", + ))), + } + } +} + +/// RAII guard that removes a pending-request entry from the map if the +/// owning future is dropped before the response arrives. Disarmed on the +/// happy path so the read loop's response handling owns the cleanup. +struct PendingGuard<'a> { + map: &'a RwLock>>, + id: u64, + armed: bool, +} + +impl PendingGuard<'_> { + fn disarm(&mut self) { + self.armed = false; + } +} + +impl Drop for PendingGuard<'_> { + fn drop(&mut self) { + if self.armed { + self.map.write().remove(&self.id); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn deserialize_notification() { + let json = r#"{"jsonrpc":"2.0","method":"session.event","params":{"id":"e1"}}"#; + let msg: JsonRpcMessage = serde_json::from_str(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Notification(n) if n.method == "session.event")); + } + + #[test] + fn deserialize_request() { + let json = + r#"{"jsonrpc":"2.0","id":5,"method":"permission.request","params":{"kind":"shell"}}"#; + let msg: JsonRpcMessage = serde_json::from_str(json).unwrap(); + assert!( + matches!(msg, JsonRpcMessage::Request(r) if r.id == 5 && r.method == "permission.request") + ); + } + + #[test] + fn deserialize_response_with_result() { + let json = r#"{"jsonrpc":"2.0","id":3,"result":{"ok":true}}"#; + let msg: JsonRpcMessage = serde_json::from_str(json).unwrap(); + assert!(matches!(msg, JsonRpcMessage::Response(r) if r.id == 3 && !r.is_error())); + } + + #[test] + fn deserialize_error_response() { + let json = + r#"{"jsonrpc":"2.0","id":7,"error":{"code":-32600,"message":"Invalid Request"}}"#; + let msg: JsonRpcMessage = serde_json::from_str(json).unwrap(); + match msg { + JsonRpcMessage::Response(r) => { + assert!(r.is_error()); + let err = r.error.unwrap(); + assert_eq!(err.code, -32600); + assert_eq!(err.message, "Invalid Request"); + } + other => panic!("expected Response, got {other:?}"), + } + } + + #[test] + fn deserialize_rejects_non_object() { + let result = serde_json::from_str::(r#""not an object""#); + assert!(result.is_err()); + } + + #[test] + fn request_new_sets_version() { + let req = JsonRpcRequest::new(42, "test.method", None); + assert_eq!(req.jsonrpc, "2.0"); + assert_eq!(req.id, 42); + assert_eq!(req.method, "test.method"); + assert!(req.params.is_none()); + } + + #[test] + fn request_serializes_camel_case() { + let req = JsonRpcRequest::new(1, "ping", Some(serde_json::json!({}))); + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains(r#""jsonrpc":"2.0""#)); + assert!(json.contains(r#""id":1"#)); + assert!(json.contains(r#""method":"ping""#)); + } + + #[test] + fn notification_without_params_omits_field() { + let n = JsonRpcNotification { + jsonrpc: "2.0".into(), + method: "ping".into(), + params: None, + }; + let json = serde_json::to_string(&n).unwrap(); + assert!(!json.contains("params")); + } + + #[test] + fn response_without_error_omits_field() { + let r = JsonRpcResponse { + jsonrpc: "2.0".into(), + id: 1, + result: Some(serde_json::json!(true)), + error: None, + }; + let json = serde_json::to_string(&r).unwrap(); + assert!(!json.contains("error")); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 000000000..fa03ebe87 --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,2198 @@ +#![doc = include_str!("../README.md")] +#![warn(missing_docs)] +#![deny(rustdoc::broken_intra_doc_links)] +#![cfg_attr(test, allow(clippy::unwrap_used))] + +/// Bundled CLI binary extraction and caching. +pub mod embeddedcli; +/// Event handler traits for session lifecycle. +pub mod handler; +/// Lifecycle hook callbacks (pre/post tool use, prompt submission, session start/end). +pub mod hooks; +mod jsonrpc; +/// Permission-policy helpers that wrap an existing [`handler::SessionHandler`]. +pub mod permission; +/// GitHub Copilot CLI binary resolution (env var, embedded, PATH search). +pub mod resolve; +mod router; +/// Session management — create, resume, send messages, and interact with the agent. +pub mod session; +/// Custom session filesystem provider (virtualizable filesystem layer). +pub mod session_fs; +mod session_fs_dispatch; +/// Event subscription handles returned by `subscribe()` methods. +pub mod subscription; +/// Typed tool definition framework and dispatch router. +pub mod tool; +/// W3C Trace Context propagation for distributed tracing. +pub mod trace_context; +/// System message transform callbacks for customizing agent prompts. +pub mod transforms; +/// Protocol types shared between the SDK and the GitHub Copilot CLI. +pub mod types; + +/// Auto-generated protocol types from Copilot JSON Schemas. +pub mod generated; + +use std::ffi::OsString; +use std::path::{Path, PathBuf}; +use std::process::Stdio; +use std::sync::{Arc, OnceLock}; + +use async_trait::async_trait; +// JSON-RPC wire types are internal transport details (like Go SDK's internal/jsonrpc2/). +// External callers interact via Client/Session methods, not raw RPC. +pub(crate) use jsonrpc::{ + JsonRpcClient, JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, error_codes, +}; + +/// Re-exported JSON-RPC internals for integration tests (requires `test-support` feature). +#[cfg(feature = "test-support")] +pub mod test_support { + pub use crate::jsonrpc::{ + JsonRpcClient, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, + error_codes, + }; +} +use serde::{Deserialize, Serialize}; +use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, BufReader}; +use tokio::net::TcpStream; +use tokio::process::{Child, Command}; +use tokio::sync::{broadcast, mpsc, oneshot}; +use tracing::{Instrument, debug, error, info, warn}; +pub use types::*; + +mod sdk_protocol_version; +pub use sdk_protocol_version::{SDK_PROTOCOL_VERSION, get_sdk_protocol_version}; +pub use subscription::{EventSubscription, Lagged, LifecycleSubscription, RecvError}; + +/// Minimum protocol version this SDK can communicate with. +const MIN_PROTOCOL_VERSION: u32 = 2; + +/// Errors returned by the SDK. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum Error { + /// JSON-RPC transport or protocol violation. + #[error("protocol error: {0}")] + Protocol(ProtocolError), + + /// The CLI returned a JSON-RPC error response. + #[error("RPC error {code}: {message}")] + Rpc { + /// JSON-RPC error code. + code: i32, + /// Human-readable error message. + message: String, + }, + + /// Session-scoped error (not found, agent error, timeout, etc.). + #[error("session error: {0}")] + Session(SessionError), + + /// I/O error on the stdio transport or during process spawn. + #[error(transparent)] + Io(#[from] std::io::Error), + + /// Failed to serialize or deserialize a JSON-RPC message. + #[error(transparent)] + Json(#[from] serde_json::Error), + + /// A required binary was not found on the system. + #[error("binary not found: {name} ({hint})")] + BinaryNotFound { + /// Binary name that was searched for. + name: &'static str, + /// Guidance on how to install or configure the binary. + hint: &'static str, + }, +} + +impl Error { + /// Returns true if this error indicates the transport is broken — the CLI + /// process exited, the connection was lost, or an I/O failure occurred. + /// Callers should discard the client and create a fresh one. + pub fn is_transport_failure(&self) -> bool { + matches!( + self, + Error::Protocol(ProtocolError::RequestCancelled) | Error::Io(_) + ) + } +} + +/// Aggregate of errors collected during [`Client::stop`]. +/// +/// `Client::stop` performs cooperative shutdown across every active +/// session before killing the CLI child process. Errors from any +/// per-session `session.destroy` RPC and from the terminal child-kill +/// step are collected here rather than short-circuiting on the first +/// failure, so callers see the full picture of what went wrong during +/// teardown. +/// +/// Implements [`std::error::Error`] and forwards to `Display` for the +/// first error, with a count suffix when there are more. +#[derive(Debug)] +pub struct StopErrors(Vec); + +impl StopErrors { + /// Borrow the collected errors as a slice, in the order they + /// occurred (per-session destroys first, then child-kill last). + pub fn errors(&self) -> &[Error] { + &self.0 + } + + /// Consume the aggregate and return the underlying error vector. + pub fn into_errors(self) -> Vec { + self.0 + } +} + +impl std::fmt::Display for StopErrors { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0.as_slice() { + [] => write!(f, "stop completed with no errors"), + [only] => write!(f, "stop failed: {only}"), + [first, rest @ ..] => write!( + f, + "stop failed with {n} errors; first: {first}", + n = 1 + rest.len(), + ), + } + } +} + +impl std::error::Error for StopErrors { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.0 + .first() + .map(|e| e as &(dyn std::error::Error + 'static)) + } +} + +/// Specific protocol-level errors in the JSON-RPC transport or CLI lifecycle. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum ProtocolError { + /// Missing `Content-Length` header in a JSON-RPC message. + #[error("missing Content-Length header")] + MissingContentLength, + + /// Invalid `Content-Length` header value. + #[error("invalid Content-Length value: \"{0}\"")] + InvalidContentLength(String), + + /// A pending JSON-RPC request was cancelled (e.g. the response channel was dropped). + #[error("request cancelled")] + RequestCancelled, + + /// The CLI process did not report a listening port within the timeout. + #[error("timed out waiting for CLI to report listening port")] + CliStartupTimeout, + + /// The CLI process exited before reporting a listening port. + #[error("CLI exited before reporting listening port")] + CliStartupFailed, + + /// The CLI server's protocol version is outside the SDK's supported range. + #[error("version mismatch: server={server}, supported={min}–{max}")] + VersionMismatch { + /// Version reported by the server. + server: u32, + /// Minimum version supported by this SDK. + min: u32, + /// Maximum version supported by this SDK. + max: u32, + }, + + /// The CLI server's protocol version changed between calls. + #[error("version changed: was {previous}, now {current}")] + VersionChanged { + /// Previously negotiated version. + previous: u32, + /// Newly reported version. + current: u32, + }, +} + +/// Session-scoped errors. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum SessionError { + /// The CLI could not find the requested session. + #[error("session not found: {0}")] + NotFound(SessionId), + + /// The CLI reported an error during agent execution (via `session.error` event). + #[error("{0}")] + AgentError(String), + + /// A `send_and_wait` call exceeded its timeout. + #[error("timed out after {0:?}")] + Timeout(std::time::Duration), + + /// `send` was called while a `send_and_wait` is in flight. + #[error("cannot send while send_and_wait is in flight")] + SendWhileWaiting, + + /// The session event loop exited before a pending `send_and_wait` completed. + #[error("event loop closed before session reached idle")] + EventLoopClosed, + + /// Elicitation is not supported by the host. + /// Check `session.capabilities().ui.elicitation` before calling UI methods. + #[error( + "elicitation not supported by host — check session.capabilities().ui.elicitation first" + )] + ElicitationNotSupported, + + /// The client was started with [`ClientOptions::session_fs`] but this + /// session was created without a [`SessionFsProvider`]. Set one via + /// [`SessionConfig::with_session_fs_provider`] (or + /// [`ResumeSessionConfig::with_session_fs_provider`]). + #[error( + "session was created on a client with session_fs configured but no SessionFsProvider was supplied" + )] + SessionFsProviderRequired, + + /// [`ClientOptions::session_fs`] was provided with empty or invalid + /// fields. All of `initial_cwd` and `session_state_path` must be + /// non-empty. + #[error("invalid SessionFsConfig: {0}")] + InvalidSessionFsConfig(String), +} + +/// How the SDK communicates with the CLI server. +#[derive(Debug, Default)] +#[non_exhaustive] +pub enum Transport { + /// Communicate over stdin/stdout pipes (default). + #[default] + Stdio, + /// Spawn the CLI with `--port` and connect via TCP. + Tcp { + /// Port to listen on (0 for OS-assigned). + port: u16, + }, + /// Connect to an already-running CLI server (no process spawning). + External { + /// Hostname or IP of the running server. + host: String, + /// Port of the running server. + port: u16, + }, +} + +/// How the SDK locates the GitHub Copilot CLI binary. +#[derive(Debug, Clone, Default)] +pub enum CliProgram { + /// Auto-resolve: `COPILOT_CLI_PATH` → embedded CLI → PATH + common locations. + /// This is the default. + #[default] + Resolve, + /// Use an explicit binary path (skips resolution). + Path(PathBuf), +} + +impl From for CliProgram { + fn from(path: PathBuf) -> Self { + Self::Path(path) + } +} + +/// Options for starting a [`Client`]. +/// +/// When `program` is [`CliProgram::Resolve`] (the default), +/// [`Client::start`] automatically resolves the binary via +/// [`resolve::copilot_binary()`] — checking `COPILOT_CLI_PATH`, the +/// embedded CLI, and then the system PATH and common install locations. +/// +/// Set `program` to [`CliProgram::Path`] to use an explicit binary. +#[non_exhaustive] +pub struct ClientOptions { + /// How to locate the CLI binary. + pub program: CliProgram, + /// Arguments prepended before `--server` (e.g. the script path for node). + pub prefix_args: Vec, + /// Working directory for the CLI process. + pub cwd: PathBuf, + /// Environment variables set on the child process. + pub env: Vec<(OsString, OsString)>, + /// Environment variable names to remove from the child process. + pub env_remove: Vec, + /// Extra CLI flags appended after the transport-specific arguments. + pub extra_args: Vec, + /// Transport mode used to communicate with the CLI server. + pub transport: Transport, + /// GitHub token for authentication. When set, the SDK passes the token + /// to the CLI via `--auth-token-env COPILOT_SDK_AUTH_TOKEN` and exports + /// the token in that env var. When set, the CLI defaults to *not* + /// using the logged-in user (override with [`Self::use_logged_in_user`]). + pub github_token: Option, + /// Whether the CLI should fall back to the logged-in `gh` user when no + /// token is provided. `None` means use the runtime default (true unless + /// [`Self::github_token`] is set, in which case false). + pub use_logged_in_user: Option, + /// Log level passed to the CLI server via `--log-level`. When `None`, + /// the SDK uses [`LogLevel::Info`]. + pub log_level: Option, + /// Server-wide idle timeout for sessions, in seconds. When set to a + /// positive value, the SDK passes `--session-idle-timeout ` to + /// the CLI; sessions without activity for this duration are + /// automatically cleaned up. `None` or `Some(0)` leaves sessions + /// running indefinitely (the CLI default). + pub session_idle_timeout_seconds: Option, + /// Optional override for [`Client::list_models`]. + /// + /// When set, [`Client::list_models`] returns the handler's result + /// without making a `models.list` RPC. This is the BYOK escape hatch + /// for environments where the model catalog is provisioned separately + /// from the GitHub Copilot CLI (e.g. external inference servers selected via + /// [`Transport::External`]). + pub on_list_models: Option>, + /// Custom session filesystem provider configuration. + /// + /// When set, the SDK calls `sessionFs.setProvider` during + /// [`Client::start`] to register a virtualizable filesystem layer with + /// the CLI. Each session created on this client must supply its own + /// [`SessionFsProvider`] via + /// [`SessionConfig::with_session_fs_provider`](crate::SessionConfig::with_session_fs_provider). + pub session_fs: Option, + /// Optional [`TraceContextProvider`] used to inject W3C Trace Context + /// headers (`traceparent` / `tracestate`) on outbound `session.create`, + /// `session.resume`, and `session.send` requests. + /// + /// When [`MessageOptions`] carries a per-turn override (set via + /// [`MessageOptions::with_trace_context`](crate::types::MessageOptions::with_trace_context) + /// or the underlying fields), it takes precedence over this provider. + /// + /// [`MessageOptions`]: crate::types::MessageOptions + pub on_get_trace_context: Option>, + /// OpenTelemetry config forwarded to the spawned CLI process. See + /// [`TelemetryConfig`] for the env-var mapping. The SDK takes no + /// OpenTelemetry dependency — this is pure spawn-time env injection. + pub telemetry: Option, +} + +impl std::fmt::Debug for ClientOptions { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ClientOptions") + .field("program", &self.program) + .field("prefix_args", &self.prefix_args) + .field("cwd", &self.cwd) + .field("env", &self.env) + .field("env_remove", &self.env_remove) + .field("extra_args", &self.extra_args) + .field("transport", &self.transport) + .field( + "github_token", + &self.github_token.as_ref().map(|_| ""), + ) + .field("use_logged_in_user", &self.use_logged_in_user) + .field("log_level", &self.log_level) + .field( + "session_idle_timeout_seconds", + &self.session_idle_timeout_seconds, + ) + .field( + "on_list_models", + &self.on_list_models.as_ref().map(|_| ""), + ) + .field("session_fs", &self.session_fs) + .field( + "on_get_trace_context", + &self.on_get_trace_context.as_ref().map(|_| ""), + ) + .field("telemetry", &self.telemetry) + .finish() + } +} + +/// Custom handler for [`Client::list_models`]. +/// +/// Implementations override the default `models.list` RPC, returning a +/// caller-supplied catalog of models. Set via [`ClientOptions::on_list_models`]. +/// +/// Implementations must be `Send + Sync` because [`Client`] is shared across +/// tasks. Errors returned by [`list_models`](Self::list_models) are propagated +/// from [`Client::list_models`] unchanged. +#[async_trait] +pub trait ListModelsHandler: Send + Sync + 'static { + /// Return the list of available models. + async fn list_models(&self) -> Result, Error>; +} + +/// Log verbosity for the CLI server (passed via `--log-level`). +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + /// Suppress all CLI logs. + None, + /// Errors only. + Error, + /// Warnings and errors. + Warning, + /// Default. Info and above. + Info, + /// Debug, info, warnings, errors. + Debug, + /// Everything, including trace output. + All, +} + +impl LogLevel { + /// CLI argument value (e.g. `"info"`, `"debug"`). + pub fn as_str(self) -> &'static str { + match self { + Self::None => "none", + Self::Error => "error", + Self::Warning => "warning", + Self::Info => "info", + Self::Debug => "debug", + Self::All => "all", + } + } +} + +impl std::fmt::Display for LogLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +/// Backend exporter for the CLI's OpenTelemetry pipeline. +/// +/// Maps to the `COPILOT_OTEL_EXPORTER_TYPE` environment variable on the +/// spawned CLI process. Wire values are `"otlp-http"` and `"file"`. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +#[non_exhaustive] +pub enum OtelExporterType { + /// Export via OTLP HTTP to the endpoint configured by + /// [`TelemetryConfig::otlp_endpoint`]. + OtlpHttp, + /// Export to a JSON-lines file at the path configured by + /// [`TelemetryConfig::file_path`]. + File, +} + +impl OtelExporterType { + /// Environment-variable value (`"otlp-http"` or `"file"`). + pub fn as_str(self) -> &'static str { + match self { + Self::OtlpHttp => "otlp-http", + Self::File => "file", + } + } +} + +/// OpenTelemetry configuration forwarded to the spawned GitHub Copilot CLI +/// process. +/// +/// When [`ClientOptions::telemetry`] is `Some(...)`, the SDK sets +/// `COPILOT_OTEL_ENABLED=true` plus any populated fields below as the +/// corresponding `OTEL_*` / `COPILOT_OTEL_*` environment variables. The +/// CLI's built-in OpenTelemetry exporter consumes these at startup. The +/// SDK itself takes no OpenTelemetry dependency. +/// +/// Environment-variable mapping: +/// +/// | Field | Variable | +/// |----------------------|-------------------------------------------------------| +/// | (any field set) | `COPILOT_OTEL_ENABLED=true` | +/// | [`otlp_endpoint`] | `OTEL_EXPORTER_OTLP_ENDPOINT` | +/// | [`file_path`] | `COPILOT_OTEL_FILE_EXPORTER_PATH` | +/// | [`exporter_type`] | `COPILOT_OTEL_EXPORTER_TYPE` | +/// | [`source_name`] | `COPILOT_OTEL_SOURCE_NAME` | +/// | [`capture_content`] | `OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT` | +/// +/// Caller-supplied entries in [`ClientOptions::env`] override these, so a +/// developer can pin any individual variable to a different value while +/// keeping the rest of the config managed by [`TelemetryConfig`]. +/// +/// Marked `#[non_exhaustive]` so future CLI-side telemetry knobs can be +/// added without breaking callers. +/// +/// [`otlp_endpoint`]: Self::otlp_endpoint +/// [`file_path`]: Self::file_path +/// [`exporter_type`]: Self::exporter_type +/// [`source_name`]: Self::source_name +/// [`capture_content`]: Self::capture_content +#[derive(Debug, Clone, Default)] +#[non_exhaustive] +pub struct TelemetryConfig { + /// OTLP HTTP endpoint URL for trace/metric export. + pub otlp_endpoint: Option, + /// File path for JSON-lines trace output. + pub file_path: Option, + /// Exporter backend type. Typically [`OtelExporterType::OtlpHttp`] or + /// [`OtelExporterType::File`]. + pub exporter_type: Option, + /// Instrumentation scope name. Useful for distinguishing this + /// embedder's traces from other Copilot-CLI consumers exporting to the + /// same backend. + pub source_name: Option, + /// Whether the CLI captures GenAI message content (prompts and + /// responses) on emitted spans. `Some(true)` opts in; `Some(false)` + /// opts out; `None` leaves the CLI default (typically off). + pub capture_content: Option, +} + +impl TelemetryConfig { + /// Construct an empty [`TelemetryConfig`]; all fields default to + /// unset (`is_empty()` returns `true`). + pub fn new() -> Self { + Self::default() + } + + /// Set the OTLP HTTP endpoint URL for trace/metric export. + pub fn with_otlp_endpoint(mut self, endpoint: impl Into) -> Self { + self.otlp_endpoint = Some(endpoint.into()); + self + } + + /// Set the file path for JSON-lines trace output. + pub fn with_file_path(mut self, path: impl Into) -> Self { + self.file_path = Some(path.into()); + self + } + + /// Set the exporter backend type. + pub fn with_exporter_type(mut self, exporter_type: OtelExporterType) -> Self { + self.exporter_type = Some(exporter_type); + self + } + + /// Set the instrumentation scope name. Useful for distinguishing + /// this embedder's traces from other Copilot-CLI consumers + /// exporting to the same backend. + pub fn with_source_name(mut self, source_name: impl Into) -> Self { + self.source_name = Some(source_name.into()); + self + } + + /// Opt in or out of GenAI message content capture on emitted spans. + /// `true` opts in; `false` opts out. Leaving this unset preserves + /// the CLI default (typically off). + pub fn with_capture_content(mut self, capture: bool) -> Self { + self.capture_content = Some(capture); + self + } + + /// Returns `true` if all fields are unset. Used by [`Client::start`] + /// to decide whether to set `COPILOT_OTEL_ENABLED`. + pub fn is_empty(&self) -> bool { + self.otlp_endpoint.is_none() + && self.file_path.is_none() + && self.exporter_type.is_none() + && self.source_name.is_none() + && self.capture_content.is_none() + } +} + +impl Default for ClientOptions { + fn default() -> Self { + Self { + program: CliProgram::Resolve, + prefix_args: Vec::new(), + cwd: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")), + env: Vec::new(), + env_remove: Vec::new(), + extra_args: Vec::new(), + transport: Transport::default(), + github_token: None, + use_logged_in_user: None, + log_level: None, + session_idle_timeout_seconds: None, + on_list_models: None, + session_fs: None, + on_get_trace_context: None, + telemetry: None, + } + } +} + +impl ClientOptions { + /// Construct a new [`ClientOptions`] with default values. + /// + /// Equivalent to [`ClientOptions::default`]; provided as a documented + /// construction entry point for the builder chain. The struct is + /// `#[non_exhaustive]`, so external callers cannot use struct-literal + /// syntax — use this builder or [`Default::default`] plus mut-let. + /// + /// # Example + /// + /// ``` + /// # use github_copilot_sdk::{ClientOptions, LogLevel}; + /// let opts = ClientOptions::new() + /// .with_log_level(LogLevel::Debug) + /// .with_github_token("ghp_…"); + /// ``` + pub fn new() -> Self { + Self::default() + } + + /// How to locate the CLI binary. See [`CliProgram`]. + pub fn with_program(mut self, program: impl Into) -> Self { + self.program = program.into(); + self + } + + /// Arguments prepended before `--server` (e.g. the script path for node). + pub fn with_prefix_args(mut self, args: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.prefix_args = args.into_iter().map(Into::into).collect(); + self + } + + /// Working directory for the CLI process. + pub fn with_cwd(mut self, cwd: impl Into) -> Self { + self.cwd = cwd.into(); + self + } + + /// Environment variables to set on the child process. + pub fn with_env(mut self, env: I) -> Self + where + I: IntoIterator, + K: Into, + V: Into, + { + self.env = env.into_iter().map(|(k, v)| (k.into(), v.into())).collect(); + self + } + + /// Environment variable names to remove from the child process. + pub fn with_env_remove(mut self, names: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.env_remove = names.into_iter().map(Into::into).collect(); + self + } + + /// Extra CLI flags appended after the transport-specific arguments. + pub fn with_extra_args(mut self, args: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.extra_args = args.into_iter().map(Into::into).collect(); + self + } + + /// Transport mode used to communicate with the CLI server. See [`Transport`]. + pub fn with_transport(mut self, transport: Transport) -> Self { + self.transport = transport; + self + } + + /// GitHub token for authentication. The SDK passes the token to the + /// CLI via `--auth-token-env COPILOT_SDK_AUTH_TOKEN`. + pub fn with_github_token(mut self, token: impl Into) -> Self { + self.github_token = Some(token.into()); + self + } + + /// Whether the CLI should fall back to the logged-in `gh` user when + /// no token is provided. See the field docs for default semantics. + pub fn with_use_logged_in_user(mut self, use_logged_in: bool) -> Self { + self.use_logged_in_user = Some(use_logged_in); + self + } + + /// Log level passed to the CLI server via `--log-level`. + pub fn with_log_level(mut self, level: LogLevel) -> Self { + self.log_level = Some(level); + self + } + + /// Server-wide idle timeout for sessions (seconds). Pass `0` to leave + /// sessions running indefinitely (the CLI default). + pub fn with_session_idle_timeout_seconds(mut self, seconds: u64) -> Self { + self.session_idle_timeout_seconds = Some(seconds); + self + } + + /// Override [`Client::list_models`] with a caller-supplied handler. + /// The handler is wrapped in `Arc` internally. + pub fn with_list_models_handler(mut self, handler: H) -> Self + where + H: ListModelsHandler + 'static, + { + self.on_list_models = Some(Arc::new(handler)); + self + } + + /// Custom session filesystem provider configuration. + pub fn with_session_fs(mut self, config: SessionFsConfig) -> Self { + self.session_fs = Some(config); + self + } + + /// Set the [`TraceContextProvider`] used to inject W3C Trace Context + /// headers on outbound `session.create` / `session.resume` / + /// `session.send` requests. The provider is wrapped in `Arc` internally. + pub fn with_trace_context_provider

(mut self, provider: P) -> Self + where + P: TraceContextProvider + 'static, + { + self.on_get_trace_context = Some(Arc::new(provider)); + self + } + + /// OpenTelemetry config forwarded to the spawned CLI process. + pub fn with_telemetry(mut self, config: TelemetryConfig) -> Self { + self.telemetry = Some(config); + self + } +} + +/// Validate a [`SessionFsConfig`] before sending `sessionFs.setProvider`. +fn validate_session_fs_config(cfg: &SessionFsConfig) -> Result<(), Error> { + if cfg.initial_cwd.trim().is_empty() { + return Err(Error::Session(SessionError::InvalidSessionFsConfig( + "initial_cwd must not be empty".to_string(), + ))); + } + if cfg.session_state_path.trim().is_empty() { + return Err(Error::Session(SessionError::InvalidSessionFsConfig( + "session_state_path must not be empty".to_string(), + ))); + } + Ok(()) +} + +/// Connection to a GitHub Copilot CLI server (stdio, TCP, or external). +/// +/// Cheaply cloneable — cloning shares the underlying connection. +/// The child process (if any) is killed when the last clone drops. +#[derive(Clone)] +pub struct Client { + inner: Arc, +} + +impl std::fmt::Debug for Client { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Client") + .field("cwd", &self.inner.cwd) + .field("pid", &self.pid()) + .finish() + } +} + +struct ClientInner { + child: parking_lot::Mutex>, + rpc: JsonRpcClient, + cwd: PathBuf, + request_rx: parking_lot::Mutex>>, + notification_tx: broadcast::Sender, + router: router::SessionRouter, + negotiated_protocol_version: OnceLock, + server_telemetry_method: parking_lot::Mutex>, + state: parking_lot::Mutex, + lifecycle_tx: broadcast::Sender, + on_list_models: Option>, + session_fs_configured: bool, + on_get_trace_context: Option>, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum ServerTelemetryRpcMethod { + SendTelemetry, + NamespacedSendTelemetry, +} + +impl ServerTelemetryRpcMethod { + fn as_str(self) -> &'static str { + match self { + Self::SendTelemetry => "sendTelemetry", + Self::NamespacedSendTelemetry => "server.sendTelemetry", + } + } +} + +impl Client { + /// Start a CLI server process with the given options. + /// + /// For [`Transport::Stdio`], spawns the CLI with `--stdio` and communicates + /// over stdin/stdout pipes. For [`Transport::Tcp`], spawns with `--port` + /// and connects via TCP once the server reports it is listening. For + /// [`Transport::External`], connects to an already-running server. + /// + /// After establishing the connection, calls [`verify_protocol_version`](Self::verify_protocol_version) + /// to ensure the CLI server speaks a compatible protocol version. + /// When [`ClientOptions::session_fs`] is set, also calls + /// `sessionFs.setProvider` to register the SDK as the filesystem + /// backend. + pub async fn start(options: ClientOptions) -> Result { + if let Some(cfg) = &options.session_fs { + validate_session_fs_config(cfg)?; + } + let session_fs_config = options.session_fs.clone(); + let program = match &options.program { + CliProgram::Path(path) => { + info!(path = %path.display(), "using explicit copilot CLI path"); + path.clone() + } + CliProgram::Resolve => { + let resolved = resolve::copilot_binary()?; + info!(path = %resolved.display(), "resolved copilot CLI"); + #[cfg(windows)] + { + if let Some(ext) = resolved.extension().and_then(|e| e.to_str()) { + if ext.eq_ignore_ascii_case("cmd") || ext.eq_ignore_ascii_case("bat") { + warn!( + path = %resolved.display(), + ext = %ext, + "resolved copilot CLI is a .cmd/.bat wrapper; \ + this may cause console window flashes on Windows" + ); + } + } + } + resolved + } + }; + + let client = match options.transport { + Transport::External { ref host, port } => { + info!(host = %host, port = %port, "connecting to external CLI server"); + let stream = TcpStream::connect((host.as_str(), port)).await?; + let (reader, writer) = tokio::io::split(stream); + Self::from_transport( + reader, + writer, + None, + options.cwd, + options.on_list_models, + session_fs_config.is_some(), + options.on_get_trace_context, + )? + } + Transport::Tcp { port } => { + let (mut child, actual_port) = Self::spawn_tcp(&program, &options, port).await?; + let stream = TcpStream::connect(("127.0.0.1", actual_port)).await?; + let (reader, writer) = tokio::io::split(stream); + Self::drain_stderr(&mut child); + Self::from_transport( + reader, + writer, + Some(child), + options.cwd, + options.on_list_models, + session_fs_config.is_some(), + options.on_get_trace_context, + )? + } + Transport::Stdio => { + let mut child = Self::spawn_stdio(&program, &options)?; + let stdin = child.stdin.take().expect("stdin is piped"); + let stdout = child.stdout.take().expect("stdout is piped"); + Self::drain_stderr(&mut child); + Self::from_transport( + stdout, + stdin, + Some(child), + options.cwd, + options.on_list_models, + session_fs_config.is_some(), + options.on_get_trace_context, + )? + } + }; + + client.verify_protocol_version().await?; + if let Some(cfg) = session_fs_config { + let request = crate::generated::api_types::SessionFsSetProviderRequest { + conventions: cfg.conventions.into_wire(), + initial_cwd: cfg.initial_cwd, + session_state_path: cfg.session_state_path, + }; + client.rpc().session_fs().set_provider(request).await?; + } + Ok(client) + } + + /// Create a Client from raw async streams (no child process). + /// + /// Useful for testing or connecting to a server over a custom transport. + pub fn from_streams( + reader: impl AsyncRead + Unpin + Send + 'static, + writer: impl AsyncWrite + Unpin + Send + 'static, + cwd: PathBuf, + ) -> Result { + Self::from_transport(reader, writer, None, cwd, None, false, None) + } + + /// Construct a [`Client`] from raw streams with a + /// [`TraceContextProvider`] preset, for integration testing. + /// + /// Mirrors [`from_streams`](Self::from_streams) but exposes the + /// `on_get_trace_context` plumbing so tests can verify outbound + /// `traceparent` / `tracestate` injection on `session.create`, + /// `session.resume`, and `session.send`. + #[cfg(any(test, feature = "test-support"))] + pub fn from_streams_with_trace_provider( + reader: impl AsyncRead + Unpin + Send + 'static, + writer: impl AsyncWrite + Unpin + Send + 'static, + cwd: PathBuf, + provider: Arc, + ) -> Result { + Self::from_transport(reader, writer, None, cwd, None, false, Some(provider)) + } + + fn from_transport( + reader: impl AsyncRead + Unpin + Send + 'static, + writer: impl AsyncWrite + Unpin + Send + 'static, + child: Option, + cwd: PathBuf, + on_list_models: Option>, + session_fs_configured: bool, + on_get_trace_context: Option>, + ) -> Result { + let (request_tx, request_rx) = mpsc::unbounded_channel::(); + let (notification_broadcast_tx, _) = broadcast::channel::(1024); + let rpc = JsonRpcClient::new( + writer, + reader, + notification_broadcast_tx.clone(), + request_tx, + ); + + let pid = child.as_ref().and_then(|c| c.id()); + info!(pid = ?pid, "copilot CLI client ready"); + + let client = Self { + inner: Arc::new(ClientInner { + child: parking_lot::Mutex::new(child), + rpc, + cwd, + request_rx: parking_lot::Mutex::new(Some(request_rx)), + notification_tx: notification_broadcast_tx, + router: router::SessionRouter::new(), + negotiated_protocol_version: OnceLock::new(), + server_telemetry_method: parking_lot::Mutex::new(None), + state: parking_lot::Mutex::new(ConnectionState::Connected), + lifecycle_tx: broadcast::channel(256).0, + on_list_models, + session_fs_configured, + on_get_trace_context, + }), + }; + client.spawn_lifecycle_dispatcher(); + Ok(client) + } + + /// Spawn the background task that re-broadcasts `session.lifecycle` + /// notifications via [`ClientInner::lifecycle_tx`] to subscribers + /// returned by [`Self::subscribe_lifecycle`]. + fn spawn_lifecycle_dispatcher(&self) { + let inner = Arc::clone(&self.inner); + let mut notif_rx = inner.notification_tx.subscribe(); + tokio::spawn(async move { + loop { + match notif_rx.recv().await { + Ok(notification) => { + if notification.method != "session.lifecycle" { + continue; + } + let Some(params) = notification.params.as_ref() else { + continue; + }; + let event: SessionLifecycleEvent = + match serde_json::from_value(params.clone()) { + Ok(e) => e, + Err(e) => { + warn!( + error = %e, + "failed to deserialize session.lifecycle notification" + ); + continue; + } + }; + // `send` only errors when there are no subscribers — that's + // the normal case before any consumer calls subscribe_lifecycle. + let _ = inner.lifecycle_tx.send(event); + } + Err(tokio::sync::broadcast::error::RecvError::Lagged(n)) => { + warn!(missed = n, "lifecycle dispatcher lagged"); + } + Err(tokio::sync::broadcast::error::RecvError::Closed) => break, + } + } + }); + } + + fn build_command(program: &Path, options: &ClientOptions) -> Command { + let mut command = Command::new(program); + for arg in &options.prefix_args { + command.arg(arg); + } + // Inject the SDK auth token first so explicit `env` / `env_remove` + // entries can override or strip it. + if let Some(token) = &options.github_token { + command.env("COPILOT_SDK_AUTH_TOKEN", token); + } + // Inject telemetry env vars before user env so callers can still + // override individual variables via `options.env`. + if let Some(telemetry) = &options.telemetry { + command.env("COPILOT_OTEL_ENABLED", "true"); + if let Some(endpoint) = &telemetry.otlp_endpoint { + command.env("OTEL_EXPORTER_OTLP_ENDPOINT", endpoint); + } + if let Some(path) = &telemetry.file_path { + command.env("COPILOT_OTEL_FILE_EXPORTER_PATH", path); + } + if let Some(exporter) = telemetry.exporter_type { + command.env("COPILOT_OTEL_EXPORTER_TYPE", exporter.as_str()); + } + if let Some(source) = &telemetry.source_name { + command.env("COPILOT_OTEL_SOURCE_NAME", source); + } + if let Some(capture) = telemetry.capture_content { + command.env( + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", + if capture { "true" } else { "false" }, + ); + } + } + for (key, value) in &options.env { + command.env(key, value); + } + for key in &options.env_remove { + command.env_remove(key); + } + command + .current_dir(&options.cwd) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + #[cfg(windows)] + { + use std::os::windows::process::CommandExt; + const CREATE_NO_WINDOW: u32 = 0x08000000; + command.as_std_mut().creation_flags(CREATE_NO_WINDOW); + } + + command + } + + /// Returns the CLI auth flags derived from [`ClientOptions::github_token`] + /// and [`ClientOptions::use_logged_in_user`]. + /// + /// When a token is set, adds `--auth-token-env COPILOT_SDK_AUTH_TOKEN`. + /// When the effective `use_logged_in_user` is `false` (either explicitly + /// or because a token was provided without an override), adds + /// `--no-auto-login`. + fn auth_args(options: &ClientOptions) -> Vec<&'static str> { + let mut args: Vec<&'static str> = Vec::new(); + if options.github_token.is_some() { + args.push("--auth-token-env"); + args.push("COPILOT_SDK_AUTH_TOKEN"); + } + let use_logged_in = options + .use_logged_in_user + .unwrap_or(options.github_token.is_none()); + if !use_logged_in { + args.push("--no-auto-login"); + } + args + } + + /// Returns `--session-idle-timeout ` when + /// [`ClientOptions::session_idle_timeout_seconds`] is `Some(n)` with + /// `n > 0`. Otherwise returns an empty vector. + fn session_idle_timeout_args(options: &ClientOptions) -> Vec { + match options.session_idle_timeout_seconds { + Some(secs) if secs > 0 => { + vec!["--session-idle-timeout".to_string(), secs.to_string()] + } + _ => Vec::new(), + } + } + + fn spawn_stdio(program: &Path, options: &ClientOptions) -> Result { + info!(cwd = ?options.cwd, program = %program.display(), "spawning copilot CLI (stdio)"); + let mut command = Self::build_command(program, options); + let log_level = options.log_level.unwrap_or(LogLevel::Info); + command + .args([ + "--server", + "--stdio", + "--no-auto-update", + "--log-level", + log_level.as_str(), + ]) + .args(Self::auth_args(options)) + .args(Self::session_idle_timeout_args(options)) + .args(&options.extra_args) + .stdin(Stdio::piped()); + Ok(command.spawn()?) + } + + async fn spawn_tcp( + program: &Path, + options: &ClientOptions, + port: u16, + ) -> Result<(Child, u16), Error> { + info!(cwd = ?options.cwd, program = %program.display(), port = %port, "spawning copilot CLI (tcp)"); + let mut command = Self::build_command(program, options); + let log_level = options.log_level.unwrap_or(LogLevel::Info); + command + .args([ + "--server", + "--port", + &port.to_string(), + "--no-auto-update", + "--log-level", + log_level.as_str(), + ]) + .args(Self::auth_args(options)) + .args(Self::session_idle_timeout_args(options)) + .args(&options.extra_args) + .stdin(Stdio::null()); + let mut child = command.spawn()?; + let stdout = child.stdout.take().expect("stdout is piped"); + + let (port_tx, port_rx) = oneshot::channel::(); + let span = tracing::error_span!("copilot_cli_port_scan"); + tokio::spawn( + async move { + // Scan stdout for the port announcement. + let port_re = regex::Regex::new(r"listening on port (\d+)").expect("valid regex"); + let mut lines = BufReader::new(stdout).lines(); + let mut port_tx = Some(port_tx); + while let Ok(Some(line)) = lines.next_line().await { + debug!(line = %line, "CLI stdout"); + if let Some(tx) = port_tx.take() { + if let Some(caps) = port_re.captures(&line) + && let Some(p) = + caps.get(1).and_then(|m| m.as_str().parse::().ok()) + { + let _ = tx.send(p); + continue; + } + // Not the port line — put tx back + port_tx = Some(tx); + } + } + } + .instrument(span), + ); + + let actual_port = tokio::time::timeout(std::time::Duration::from_secs(10), port_rx) + .await + .map_err(|_| Error::Protocol(ProtocolError::CliStartupTimeout))? + .map_err(|_| Error::Protocol(ProtocolError::CliStartupFailed))?; + + info!(port = %actual_port, "CLI server listening"); + Ok((child, actual_port)) + } + + fn drain_stderr(child: &mut Child) { + if let Some(stderr) = child.stderr.take() { + let span = tracing::error_span!("copilot_cli"); + tokio::spawn( + async move { + let mut reader = BufReader::new(stderr).lines(); + while let Ok(Some(line)) = reader.next_line().await { + warn!(line = %line, "CLI stderr"); + } + } + .instrument(span), + ); + } + } + + /// Returns the working directory of the CLI process. + pub fn cwd(&self) -> &PathBuf { + &self.inner.cwd + } + + /// Typed RPC namespace for server-level methods. + /// + /// Every protocol method lives here under its schema-aligned path — + /// e.g. `client.rpc().models().list()`. Wire method names and request/ + /// response types are generated from the protocol schema, so the typed + /// namespace can't drift from the wire contract. + /// + /// The hand-authored helpers on [`Client`] delegate to this namespace + /// and remain the recommended entry point for everyday use; reach for + /// `rpc()` when you want a method without a hand-written wrapper. + pub fn rpc(&self) -> crate::generated::rpc::ClientRpc<'_> { + crate::generated::rpc::ClientRpc { client: self } + } + + /// Send a JSON-RPC request and wait for the response. + pub(crate) async fn send_request( + &self, + method: &str, + params: Option, + ) -> Result { + self.inner.rpc.send_request(method, params).await + } + + /// Send a JSON-RPC request, check for errors, and return the result value. + /// + /// This is the primary method for session-level RPC calls. It wraps + /// the internal send/receive cycle with error checking so callers + /// don't need to inspect the response manually. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** The frame is committed to the wire via the + /// writer-actor task before the future yields; cancelling the await + /// (via `tokio::time::timeout`, `select!`, or dropped JoinHandle) + /// drops the response oneshot but does not desync the transport. + /// The pending-requests entry is cleaned up by an RAII guard. + /// However, the call's *side effect* on the CLI may still occur — + /// the CLI receives the request and processes it; the caller just + /// won't see the response. For idempotent methods this is fine; for + /// non-idempotent methods (e.g. `session.create`) the caller should + /// avoid wrapping the call in a timeout shorter than the expected + /// CLI processing window. + pub async fn call( + &self, + method: &str, + params: Option, + ) -> Result { + let session_id: Option = params + .as_ref() + .and_then(|p| p.get("sessionId")) + .and_then(|v| v.as_str()) + .map(SessionId::from); + let response = self.send_request(method, params).await?; + if let Some(err) = response.error { + if err.message.contains("Session not found") { + return Err(Error::Session(SessionError::NotFound( + session_id.unwrap_or_else(|| "unknown".into()), + ))); + } + return Err(Error::Rpc { + code: err.code, + message: err.message, + }); + } + Ok(response.result.unwrap_or(serde_json::Value::Null)) + } + + /// Send a JSON-RPC response back to the CLI (e.g. for permission or tool call requests). + pub(crate) async fn send_response(&self, response: &JsonRpcResponse) -> Result<(), Error> { + self.inner.rpc.write(response).await + } + + /// Take the receiver for incoming JSON-RPC requests from the CLI. + /// + /// Can only be called once — subsequent calls return `None`. + #[expect(dead_code, reason = "reserved for future pub(crate) use")] + pub(crate) fn take_request_rx(&self) -> Option> { + self.inner.request_rx.lock().take() + } + + /// Register a session to receive filtered events and requests. + /// + /// Returns per-session channels for notifications and requests, routed + /// by `sessionId`. Starts the internal router on first call. + /// + /// When done, call [`unregister_session`](Self::unregister_session) to + /// clean up (typically on session destroy). + pub(crate) fn register_session( + &self, + session_id: &SessionId, + ) -> crate::router::SessionChannels { + self.inner + .router + .ensure_started(&self.inner.notification_tx, &self.inner.request_rx); + self.inner.router.register(session_id) + } + + /// Unregister a session, dropping its per-session channels. + pub(crate) fn unregister_session(&self, session_id: &SessionId) { + self.inner.router.unregister(session_id); + } + + /// Returns the protocol version negotiated with the CLI server, if any. + /// + /// Set during [`start`](Self::start). Returns `None` if the server didn't + /// report a version, or if the client was created via + /// [`from_streams`](Self::from_streams) without calling + /// [`verify_protocol_version`](Self::verify_protocol_version). + pub fn protocol_version(&self) -> Option { + self.inner.negotiated_protocol_version.get().copied() + } + + /// Verify the CLI server's protocol version is within the supported range. + /// + /// Called automatically by [`start`](Self::start). Call manually after + /// [`from_streams`](Self::from_streams) if you need version verification + /// on a custom transport. + /// + /// Sends a `ping` RPC and checks the `protocolVersion` field in the + /// response. Returns an error if the version is outside + /// `MIN_PROTOCOL_VERSION`..=[`SDK_PROTOCOL_VERSION`]. If the server + /// doesn't report a version, logs a warning and succeeds (backward + /// compatibility with older CLI versions). + pub async fn verify_protocol_version(&self) -> Result<(), Error> { + let response = self.ping(None).await?; + let server_version = response.protocol_version; + + match server_version { + None => { + warn!("CLI server did not report protocolVersion; skipping version check"); + } + Some(v) if !(MIN_PROTOCOL_VERSION..=SDK_PROTOCOL_VERSION).contains(&v) => { + return Err(Error::Protocol(ProtocolError::VersionMismatch { + server: v, + min: MIN_PROTOCOL_VERSION, + max: SDK_PROTOCOL_VERSION, + })); + } + Some(v) => { + if let Some(&existing) = self.inner.negotiated_protocol_version.get() { + if existing != v { + return Err(Error::Protocol(ProtocolError::VersionChanged { + previous: existing, + current: v, + })); + } + } else { + let _ = self.inner.negotiated_protocol_version.set(v); + } + } + } + + Ok(()) + } + + /// Send a `ping` RPC and return the typed [`PingResponse`]. + /// + /// Pass `Some(message)` to have the server echo it back; pass `None` for + /// a bare health check. The response includes a `protocolVersion` when + /// the CLI reports one. + /// + /// [`PingResponse`]: crate::types::PingResponse + pub async fn ping(&self, message: Option<&str>) -> Result { + let params = match message { + Some(m) => serde_json::json!({ "message": m }), + None => serde_json::json!({}), + }; + let value = self + .call(generated::api_types::rpc_methods::PING, Some(params)) + .await?; + Ok(serde_json::from_value(value)?) + } + + /// List persisted sessions, optionally filtered by working directory, + /// repository, or git context. + pub async fn list_sessions( + &self, + filter: Option, + ) -> Result, Error> { + let params = match filter { + Some(f) => serde_json::json!({ "filter": f }), + None => serde_json::json!({}), + }; + let result = self.call("session.list", Some(params)).await?; + let response: ListSessionsResponse = serde_json::from_value(result)?; + Ok(response.sessions) + } + + /// Fetch metadata for a specific persisted session by ID. + /// + /// Returns `Ok(None)` if no session with the given ID exists. More + /// efficient than calling [`list_sessions`](Self::list_sessions) and + /// filtering when you only need data for a single session. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(client: &github_copilot_sdk::Client) -> Result<(), github_copilot_sdk::Error> { + /// use github_copilot_sdk::types::SessionId; + /// if let Some(metadata) = client.get_session_metadata(&SessionId::new("session-123")).await? { + /// println!("Session started at: {}", metadata.start_time); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn get_session_metadata( + &self, + session_id: &SessionId, + ) -> Result, Error> { + let result = self + .call( + "session.getMetadata", + Some(serde_json::json!({ "sessionId": session_id })), + ) + .await?; + let response: GetSessionMetadataResponse = serde_json::from_value(result)?; + Ok(response.session) + } + + /// Delete a persisted session by ID. + pub async fn delete_session(&self, session_id: &SessionId) -> Result<(), Error> { + self.call( + "session.delete", + Some(serde_json::json!({ "sessionId": session_id })), + ) + .await?; + Ok(()) + } + + /// Return the ID of the most recently updated session, if any. + /// + /// Useful for resuming the last conversation when the session ID was + /// not stored. Returns `Ok(None)` if no sessions exist. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(client: &github_copilot_sdk::Client) -> Result<(), github_copilot_sdk::Error> { + /// if let Some(last_id) = client.get_last_session_id().await? { + /// println!("Last session: {last_id}"); + /// } + /// # Ok(()) + /// # } + /// ``` + pub async fn get_last_session_id(&self) -> Result, Error> { + let result = self + .call("session.getLastId", Some(serde_json::json!({}))) + .await?; + let response: GetLastSessionIdResponse = serde_json::from_value(result)?; + Ok(response.session_id) + } + + /// Return the ID of the session currently displayed in the TUI, if any. + /// + /// Only meaningful when connected to a server running in TUI+server mode + /// (`--ui-server`). Returns `Ok(None)` if no foreground session is set. + pub async fn get_foreground_session_id(&self) -> Result, Error> { + let result = self + .call("session.getForeground", Some(serde_json::json!({}))) + .await?; + let response: GetForegroundSessionResponse = serde_json::from_value(result)?; + Ok(response.session_id) + } + + /// Request that the TUI switch to displaying the specified session. + /// + /// Only meaningful when connected to a server running in TUI+server mode + /// (`--ui-server`). + pub async fn set_foreground_session_id(&self, session_id: &SessionId) -> Result<(), Error> { + self.call( + "session.setForeground", + Some(serde_json::json!({ "sessionId": session_id })), + ) + .await?; + Ok(()) + } + + /// Get the CLI server status. + pub async fn get_status(&self) -> Result { + let result = self.call("status.get", Some(serde_json::json!({}))).await?; + Ok(serde_json::from_value(result)?) + } + + /// Get authentication status. + pub async fn get_auth_status(&self) -> Result { + let result = self + .call("auth.getStatus", Some(serde_json::json!({}))) + .await?; + Ok(serde_json::from_value(result)?) + } + + /// List available models. + /// + /// When [`ClientOptions::on_list_models`] is set, returns the handler's + /// result without making a `models.list` RPC. Otherwise queries the CLI. + pub async fn list_models(&self) -> Result, Error> { + if let Some(handler) = &self.inner.on_list_models { + return handler.list_models().await; + } + Ok(self.rpc().models().list().await?.models) + } + + /// Invoke [`ClientOptions::on_get_trace_context`] when configured, + /// otherwise return [`TraceContext::default()`]. + pub(crate) async fn resolve_trace_context(&self) -> TraceContext { + if let Some(provider) = &self.inner.on_get_trace_context { + provider.get_trace_context().await + } else { + TraceContext::default() + } + } + + /// Send a top-level telemetry event via `sendTelemetry`. + pub async fn send_telemetry(&self, event: ServerTelemetryEvent) -> Result<(), Error> { + let params = serde_json::to_value(event)?; + let cached_method = { *self.inner.server_telemetry_method.lock() }; + if let Some(method) = cached_method { + match self.call(method.as_str(), Some(params.clone())).await { + Ok(_) => return Ok(()), + Err(Error::Rpc { code, .. }) + if code == error_codes::METHOD_NOT_FOUND + && method == ServerTelemetryRpcMethod::SendTelemetry => + { + self.call( + ServerTelemetryRpcMethod::NamespacedSendTelemetry.as_str(), + Some(params), + ) + .await?; + *self.inner.server_telemetry_method.lock() = + Some(ServerTelemetryRpcMethod::NamespacedSendTelemetry); + return Ok(()); + } + Err(error) => return Err(error), + } + } + + match self + .call( + ServerTelemetryRpcMethod::SendTelemetry.as_str(), + Some(params.clone()), + ) + .await + { + Ok(_) => { + *self.inner.server_telemetry_method.lock() = + Some(ServerTelemetryRpcMethod::SendTelemetry); + Ok(()) + } + Err(Error::Rpc { code, .. }) if code == error_codes::METHOD_NOT_FOUND => { + self.call( + ServerTelemetryRpcMethod::NamespacedSendTelemetry.as_str(), + Some(params), + ) + .await?; + *self.inner.server_telemetry_method.lock() = + Some(ServerTelemetryRpcMethod::NamespacedSendTelemetry); + Ok(()) + } + Err(error) => Err(error), + } + } + + /// Fetch account-level quota snapshots (request-based usage). + /// + /// This top-level convenience wrapper is Rust-only as of 0.1.0; the Node, + /// Python, Go, and .NET SDKs do not expose a client-level shortcut for + /// quota lookup. The underlying `account.getQuota` JSON-RPC endpoint is + /// itself available cross-SDK via each SDK's typed `rpc()` namespace + /// (Node `client.rpc().account().getQuota()`, Python + /// `client.rpc().account.get_quota()`, Go `client.Rpc().Account().GetQuota()`, + /// .NET `client.Rpc().Account().GetQuotaAsync()`), and in Rust at + /// `client.rpc().account().get_quota()`. This wrapper is a thin shortcut + /// for that same call. + pub async fn get_quota(&self) -> Result { + self.rpc().account().get_quota().await + } + + /// Return the OS process ID of the CLI child process, if one was spawned. + pub fn pid(&self) -> Option { + self.inner.child.lock().as_ref().and_then(|c| c.id()) + } + + /// Cooperatively shut down the client and the CLI child process. + /// + /// Walks every still-registered session and sends `session.destroy` + /// for each one, then kills the CLI child. Errors from per-session + /// destroys and the final child-kill are collected into + /// [`StopErrors`] rather than short-circuiting on the first failure + /// — so callers see the full picture of teardown. + /// + /// If you have already called [`Session::disconnect`] on every + /// session this client created, the per-session destroy step is a + /// no-op (the router map is empty); only the child-kill remains. + /// + /// [`Session::disconnect`]: crate::session::Session::disconnect + /// + /// # Cancel safety + /// + /// **Cancel-unsafe but recoverable.** The body sequentially destroys + /// every registered session (each via [`Client::call`](Self::call), + /// individually cancel-safe) before killing the child. Cancelling + /// `stop()` mid-loop leaves some sessions still in the router map + /// and the child still running. Recovery: call [`force_stop`](Self::force_stop) + /// (sync, kills the child unconditionally and clears router state) + /// or call `stop()` again with a fresh future. The documented + /// `tokio::time::timeout(..., client.stop())` pattern in the example + /// below uses `force_stop` as the fallback for exactly this case. + pub async fn stop(&self) -> Result<(), StopErrors> { + let pid = self.pid(); + info!(pid = ?pid, "stopping CLI process"); + let mut errors: Vec = Vec::new(); + + // Snapshot the registered session IDs without holding the router + // lock across the destroy RPCs. + for session_id in self.inner.router.session_ids() { + match self + .call( + "session.destroy", + Some(serde_json::json!({ "sessionId": session_id })), + ) + .await + { + Ok(_) => {} + Err(e) => { + warn!( + session_id = %session_id, + error = %e, + "session.destroy failed during Client::stop", + ); + errors.push(e); + } + } + self.inner.router.unregister(&session_id); + } + + let child = self.inner.child.lock().take(); + *self.inner.state.lock() = ConnectionState::Disconnected; + if let Some(mut child) = child + && let Err(e) = child.kill().await + { + errors.push(Error::Io(e)); + } + + info!(pid = ?pid, errors = errors.len(), "CLI process stopped"); + if errors.is_empty() { + Ok(()) + } else { + Err(StopErrors(errors)) + } + } + + /// Forcibly stop the CLI process without waiting for it to exit. + /// + /// Synchronous fallback when [`stop`](Self::stop) is unsuitable — for + /// example when the awaiting tokio runtime is shutting down or the + /// process is wedged on I/O. Sends a kill signal without awaiting + /// reaper completion and immediately drops all per-session router + /// state so dependent tasks observe a closed channel rather than a + /// hang. + /// + /// # Cancel safety + /// + /// **Synchronous and infallible by construction.** Not async; cannot + /// be cancelled. Designed as the recovery path when [`stop`](Self::stop) + /// is wrapped in a timeout that elapses. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(client: github_copilot_sdk::Client) { + /// // Try graceful shutdown first; fall back to force_stop if hung. + /// match tokio::time::timeout( + /// std::time::Duration::from_secs(5), + /// client.stop(), + /// ).await { + /// Ok(_) => {} + /// Err(_) => client.force_stop(), + /// } + /// # } + /// ``` + pub fn force_stop(&self) { + let pid = self.pid(); + info!(pid = ?pid, "force-stopping CLI process"); + if let Some(mut child) = self.inner.child.lock().take() + && let Err(e) = child.start_kill() + { + error!(pid = ?pid, error = %e, "failed to send kill signal"); + } + // Drop all session channels so any awaiters see a closed channel + // instead of waiting for responses that will never arrive. + self.inner.router.clear(); + *self.inner.state.lock() = ConnectionState::Disconnected; + } + + /// Subscribe to lifecycle events. + /// + /// Returns a [`LifecycleSubscription`] that yields every + /// [`SessionLifecycleEvent`] sent by the CLI. Drop the value to + /// unsubscribe; there is no separate cancel handle. + /// + /// The returned handle implements both an inherent + /// [`recv`](LifecycleSubscription::recv) method and [`Stream`](tokio_stream::Stream), + /// so callers can use a `while let` loop or any combinator from + /// `tokio_stream::StreamExt` / `futures::StreamExt`. + /// + /// Each subscriber maintains its own queue. If a consumer cannot keep + /// up, the oldest events are dropped and `recv` returns + /// [`RecvError::Lagged`] with the count of skipped events; consumers + /// should match on it and continue. Slow consumers do not block the + /// producer. + /// + /// To filter by event type, match on `event.event_type` in the + /// consumer task. There is no built-in typed filter — `match` is more + /// flexible and keeps the API surface small. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(client: github_copilot_sdk::Client) { + /// let mut events = client.subscribe_lifecycle(); + /// tokio::spawn(async move { + /// while let Ok(event) = events.recv().await { + /// println!("session {} -> {:?}", event.session_id, event.event_type); + /// } + /// }); + /// # } + /// ``` + pub fn subscribe_lifecycle(&self) -> LifecycleSubscription { + LifecycleSubscription::new(self.inner.lifecycle_tx.subscribe()) + } + + /// Return the current [`ConnectionState`]. + /// + /// The state advances to [`Connected`](ConnectionState::Connected) once + /// [`Client::start`] / [`Client::from_streams`] returns successfully and + /// drops to [`Disconnected`](ConnectionState::Disconnected) after + /// [`stop`](Self::stop) or [`force_stop`](Self::force_stop). + pub fn state(&self) -> ConnectionState { + *self.inner.state.lock() + } +} + +impl Drop for ClientInner { + fn drop(&mut self) { + if let Some(ref mut child) = *self.child.lock() { + let pid = child.id(); + if let Err(e) = child.start_kill() { + error!(pid = ?pid, error = %e, "failed to kill CLI process on drop"); + } else { + info!(pid = ?pid, "kill signal sent for CLI process on drop"); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn is_transport_failure_matches_request_cancelled() { + let err = Error::Protocol(ProtocolError::RequestCancelled); + assert!(err.is_transport_failure()); + } + + #[test] + fn is_transport_failure_matches_io_error() { + let err = Error::Io(std::io::Error::new(std::io::ErrorKind::BrokenPipe, "gone")); + assert!(err.is_transport_failure()); + } + + #[test] + fn is_transport_failure_rejects_rpc_error() { + let err = Error::Rpc { + code: -1, + message: "bad".into(), + }; + assert!(!err.is_transport_failure()); + } + + #[test] + fn is_transport_failure_rejects_session_error() { + let err = Error::Session(SessionError::NotFound("s1".into())); + assert!(!err.is_transport_failure()); + } + + #[test] + fn client_options_builder_composes() { + let opts = ClientOptions::new() + .with_program(CliProgram::Path(PathBuf::from("/usr/local/bin/copilot"))) + .with_prefix_args(["node"]) + .with_cwd(PathBuf::from("/tmp")) + .with_env([("KEY", "value")]) + .with_env_remove(["UNWANTED"]) + .with_extra_args(["--quiet"]) + .with_github_token("ghp_test") + .with_use_logged_in_user(false) + .with_log_level(LogLevel::Debug) + .with_session_idle_timeout_seconds(120); + assert!(matches!(opts.program, CliProgram::Path(_))); + assert_eq!(opts.prefix_args, vec![std::ffi::OsString::from("node")]); + assert_eq!(opts.cwd, PathBuf::from("/tmp")); + assert_eq!( + opts.env, + vec![( + std::ffi::OsString::from("KEY"), + std::ffi::OsString::from("value") + )] + ); + assert_eq!(opts.env_remove, vec![std::ffi::OsString::from("UNWANTED")]); + assert_eq!(opts.extra_args, vec!["--quiet".to_string()]); + assert_eq!(opts.github_token.as_deref(), Some("ghp_test")); + assert_eq!(opts.use_logged_in_user, Some(false)); + assert!(matches!(opts.log_level, Some(LogLevel::Debug))); + assert_eq!(opts.session_idle_timeout_seconds, Some(120)); + } + + #[test] + fn is_transport_failure_rejects_other_protocol_errors() { + let err = Error::Protocol(ProtocolError::CliStartupTimeout); + assert!(!err.is_transport_failure()); + } + + #[test] + fn build_command_lets_env_remove_strip_injected_token() { + let opts = ClientOptions { + github_token: Some("secret".to_string()), + env_remove: vec![std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN")], + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + // get_envs() iter yields the latest action per key — None means removed. + let action = cmd + .as_std() + .get_envs() + .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN")) + .map(|(_, v)| v); + assert_eq!( + action, + Some(None), + "env_remove should win over github_token" + ); + } + + #[test] + fn build_command_lets_env_override_injected_token() { + let opts = ClientOptions { + github_token: Some("from-options".to_string()), + env: vec![( + std::ffi::OsString::from("COPILOT_SDK_AUTH_TOKEN"), + std::ffi::OsString::from("from-env"), + )], + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + let value = cmd + .as_std() + .get_envs() + .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN")) + .and_then(|(_, v)| v); + assert_eq!(value, Some(std::ffi::OsStr::new("from-env"))); + } + + #[test] + fn build_command_injects_github_token_by_default() { + let opts = ClientOptions { + github_token: Some("just-the-token".to_string()), + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + let value = cmd + .as_std() + .get_envs() + .find(|(k, _)| *k == std::ffi::OsStr::new("COPILOT_SDK_AUTH_TOKEN")) + .and_then(|(_, v)| v); + assert_eq!(value, Some(std::ffi::OsStr::new("just-the-token"))); + } + + fn env_value<'a>(cmd: &'a tokio::process::Command, key: &str) -> Option<&'a std::ffi::OsStr> { + cmd.as_std() + .get_envs() + .find(|(k, _)| *k == std::ffi::OsStr::new(key)) + .and_then(|(_, v)| v) + } + + #[test] + fn telemetry_config_builder_composes() { + let cfg = TelemetryConfig::new() + .with_otlp_endpoint("http://collector:4318") + .with_file_path(PathBuf::from("/var/log/copilot.jsonl")) + .with_exporter_type(OtelExporterType::OtlpHttp) + .with_source_name("my-app") + .with_capture_content(true); + + assert_eq!(cfg.otlp_endpoint.as_deref(), Some("http://collector:4318")); + assert_eq!( + cfg.file_path.as_deref(), + Some(Path::new("/var/log/copilot.jsonl")), + ); + assert_eq!(cfg.exporter_type, Some(OtelExporterType::OtlpHttp)); + assert_eq!(cfg.source_name.as_deref(), Some("my-app")); + assert_eq!(cfg.capture_content, Some(true)); + assert!(!cfg.is_empty()); + assert!(TelemetryConfig::new().is_empty()); + } + + #[test] + fn build_command_sets_otel_env_when_telemetry_enabled() { + let opts = ClientOptions { + telemetry: Some(TelemetryConfig { + otlp_endpoint: Some("http://collector:4318".to_string()), + file_path: Some(PathBuf::from("/var/log/copilot.jsonl")), + exporter_type: Some(OtelExporterType::OtlpHttp), + source_name: Some("my-app".to_string()), + capture_content: Some(true), + }), + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + assert_eq!( + env_value(&cmd, "COPILOT_OTEL_ENABLED"), + Some(std::ffi::OsStr::new("true")), + ); + assert_eq!( + env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"), + Some(std::ffi::OsStr::new("http://collector:4318")), + ); + assert_eq!( + env_value(&cmd, "COPILOT_OTEL_FILE_EXPORTER_PATH"), + Some(std::ffi::OsStr::new("/var/log/copilot.jsonl")), + ); + assert_eq!( + env_value(&cmd, "COPILOT_OTEL_EXPORTER_TYPE"), + Some(std::ffi::OsStr::new("otlp-http")), + ); + assert_eq!( + env_value(&cmd, "COPILOT_OTEL_SOURCE_NAME"), + Some(std::ffi::OsStr::new("my-app")), + ); + assert_eq!( + env_value(&cmd, "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"), + Some(std::ffi::OsStr::new("true")), + ); + } + + #[test] + fn build_command_omits_otel_env_when_telemetry_none() { + let opts = ClientOptions::default(); + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + for key in [ + "COPILOT_OTEL_ENABLED", + "OTEL_EXPORTER_OTLP_ENDPOINT", + "COPILOT_OTEL_FILE_EXPORTER_PATH", + "COPILOT_OTEL_EXPORTER_TYPE", + "COPILOT_OTEL_SOURCE_NAME", + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", + ] { + assert!( + env_value(&cmd, key).is_none(), + "expected {key} to be unset when telemetry is None", + ); + } + } + + #[test] + fn build_command_omits_unset_telemetry_fields() { + let opts = ClientOptions { + telemetry: Some(TelemetryConfig { + otlp_endpoint: Some("http://collector:4318".to_string()), + ..Default::default() + }), + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + // The one set field plus the implicit enabled flag should propagate. + assert_eq!( + env_value(&cmd, "COPILOT_OTEL_ENABLED"), + Some(std::ffi::OsStr::new("true")), + ); + assert_eq!( + env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"), + Some(std::ffi::OsStr::new("http://collector:4318")), + ); + // None of the other fields should leak as env vars. + for key in [ + "COPILOT_OTEL_FILE_EXPORTER_PATH", + "COPILOT_OTEL_EXPORTER_TYPE", + "COPILOT_OTEL_SOURCE_NAME", + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT", + ] { + assert!(env_value(&cmd, key).is_none(), "{key} should be unset"); + } + } + + #[test] + fn build_command_lets_user_env_override_telemetry() { + let opts = ClientOptions { + telemetry: Some(TelemetryConfig { + otlp_endpoint: Some("http://from-config:4318".to_string()), + ..Default::default() + }), + env: vec![( + std::ffi::OsString::from("OTEL_EXPORTER_OTLP_ENDPOINT"), + std::ffi::OsString::from("http://from-user-env:4318"), + )], + ..Default::default() + }; + let cmd = Client::build_command(Path::new("/bin/echo"), &opts); + assert_eq!( + env_value(&cmd, "OTEL_EXPORTER_OTLP_ENDPOINT"), + Some(std::ffi::OsStr::new("http://from-user-env:4318")), + "user-supplied options.env should override telemetry config", + ); + } + + #[test] + fn telemetry_config_capture_content_serializes_as_lowercase_bool() { + let opts_true = ClientOptions { + telemetry: Some(TelemetryConfig { + capture_content: Some(true), + ..Default::default() + }), + ..Default::default() + }; + let opts_false = ClientOptions { + telemetry: Some(TelemetryConfig { + capture_content: Some(false), + ..Default::default() + }), + ..Default::default() + }; + let cmd_true = Client::build_command(Path::new("/bin/echo"), &opts_true); + let cmd_false = Client::build_command(Path::new("/bin/echo"), &opts_false); + assert_eq!( + env_value( + &cmd_true, + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT" + ), + Some(std::ffi::OsStr::new("true")), + ); + assert_eq!( + env_value( + &cmd_false, + "OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT" + ), + Some(std::ffi::OsStr::new("false")), + ); + } + + #[test] + fn session_idle_timeout_args_are_omitted_by_default() { + let opts = ClientOptions::default(); + assert!(Client::session_idle_timeout_args(&opts).is_empty()); + } + + #[test] + fn session_idle_timeout_args_omitted_for_zero() { + let opts = ClientOptions { + session_idle_timeout_seconds: Some(0), + ..Default::default() + }; + assert!(Client::session_idle_timeout_args(&opts).is_empty()); + } + + #[test] + fn session_idle_timeout_args_emit_flag_for_positive_value() { + let opts = ClientOptions { + session_idle_timeout_seconds: Some(300), + ..Default::default() + }; + assert_eq!( + Client::session_idle_timeout_args(&opts), + vec!["--session-idle-timeout".to_string(), "300".to_string()] + ); + } + + #[test] + fn log_level_str_round_trips() { + for level in [ + LogLevel::None, + LogLevel::Error, + LogLevel::Warning, + LogLevel::Info, + LogLevel::Debug, + LogLevel::All, + ] { + let s = level.as_str(); + let json = serde_json::to_string(&level).unwrap(); + assert_eq!(json, format!("\"{s}\"")); + let parsed: LogLevel = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed, level); + } + } + + #[test] + fn client_options_debug_redacts_handler() { + struct StubHandler; + #[async_trait] + impl ListModelsHandler for StubHandler { + async fn list_models(&self) -> Result, Error> { + Ok(vec![]) + } + } + let opts = ClientOptions { + on_list_models: Some(Arc::new(StubHandler)), + github_token: Some("secret-token".into()), + ..Default::default() + }; + let debug = format!("{opts:?}"); + assert!(debug.contains("on_list_models: Some(\"\")")); + assert!(debug.contains("github_token: Some(\"\")")); + assert!(!debug.contains("secret-token")); + } + + #[tokio::test] + async fn list_models_uses_on_list_models_handler_when_set() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + struct CountingHandler { + calls: Arc, + models: Vec, + } + #[async_trait] + impl ListModelsHandler for CountingHandler { + async fn list_models(&self) -> Result, Error> { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok(self.models.clone()) + } + } + + let calls = Arc::new(AtomicUsize::new(0)); + let model = Model { + billing: None, + capabilities: ModelCapabilities { + limits: None, + supports: None, + }, + default_reasoning_effort: None, + id: "byok-gpt-4".into(), + name: "BYOK GPT-4".into(), + policy: None, + supported_reasoning_efforts: Vec::new(), + }; + let handler = Arc::new(CountingHandler { + calls: Arc::clone(&calls), + models: vec![model.clone()], + }); + + // We can't call list_models() through Client::start without a CLI, but we + // can exercise the override path by directly constructing a Client whose + // inner has the handler set. This is the same dispatch path as the real + // call; from_streams's None default is replaced via inner construction. + let inner = ClientInner { + child: parking_lot::Mutex::new(None), + rpc: { + let (req_tx, _req_rx) = mpsc::unbounded_channel(); + let (notif_tx, _notif_rx) = broadcast::channel(16); + let (read_pipe, _write_pipe) = tokio::io::duplex(64); + let (_unused_read, write_pipe) = tokio::io::duplex(64); + JsonRpcClient::new(write_pipe, read_pipe, notif_tx, req_tx) + }, + cwd: PathBuf::from("."), + request_rx: parking_lot::Mutex::new(None), + notification_tx: broadcast::channel(16).0, + router: router::SessionRouter::new(), + negotiated_protocol_version: OnceLock::new(), + server_telemetry_method: parking_lot::Mutex::new(None), + state: parking_lot::Mutex::new(ConnectionState::Connected), + lifecycle_tx: broadcast::channel(16).0, + on_list_models: Some(handler), + session_fs_configured: false, + on_get_trace_context: None, + }; + let client = Client { + inner: Arc::new(inner), + }; + + let result = client.list_models().await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].id, "byok-gpt-4"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } +} diff --git a/rust/src/permission.rs b/rust/src/permission.rs new file mode 100644 index 000000000..02db23e06 --- /dev/null +++ b/rust/src/permission.rs @@ -0,0 +1,166 @@ +//! Permission-policy helpers that compose with an existing +//! [`SessionHandler`](crate::handler::SessionHandler). +//! +//! These wrap an inner handler and override **only** permission requests, +//! forwarding every other event (tool calls, user input, elicitation, +//! session events) to the inner handler. Use them when you have a custom +//! tool handler — typically a [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) — +//! but want a one-line policy for permission prompts. +//! +//! For a full handler that approves or denies everything, see +//! [`ApproveAllHandler`](crate::handler::ApproveAllHandler) and +//! [`DenyAllHandler`](crate::handler::DenyAllHandler). +//! +//! # Example +//! +//! ```rust,no_run +//! # use std::sync::Arc; +//! # use github_copilot_sdk::handler::ApproveAllHandler; +//! # use github_copilot_sdk::permission; +//! # use github_copilot_sdk::tool::ToolHandlerRouter; +//! let router = ToolHandlerRouter::new(vec![], Arc::new(ApproveAllHandler)); +//! // Inherit the router's tool dispatch but auto-approve all permission prompts: +//! let handler = permission::approve_all(Arc::new(router)); +//! ``` + +use std::sync::Arc; + +use async_trait::async_trait; + +use crate::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler}; +use crate::types::PermissionRequestData; + +/// Wrap `inner` so that every [`HandlerEvent::PermissionRequest`] is +/// auto-approved. All other events are forwarded to `inner`. +pub fn approve_all(inner: Arc) -> Arc { + Arc::new(PermissionOverrideHandler { + inner, + policy: Policy::ApproveAll, + }) +} + +/// Wrap `inner` so that every [`HandlerEvent::PermissionRequest`] is +/// auto-denied. All other events are forwarded to `inner`. +pub fn deny_all(inner: Arc) -> Arc { + Arc::new(PermissionOverrideHandler { + inner, + policy: Policy::DenyAll, + }) +} + +/// Wrap `inner` with a closure-based policy: `predicate` is called for each +/// permission request; `true` approves, `false` denies. All other events +/// are forwarded to `inner`. +/// +/// ```rust,no_run +/// # use std::sync::Arc; +/// # use github_copilot_sdk::handler::ApproveAllHandler; +/// # use github_copilot_sdk::permission; +/// let inner = Arc::new(ApproveAllHandler); +/// let handler = permission::approve_if(inner, |data| { +/// // Inspect data.extra (the raw JSON payload) for custom policy. +/// data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") +/// }); +/// # let _ = handler; +/// ``` +pub fn approve_if(inner: Arc, predicate: F) -> Arc +where + F: Fn(&PermissionRequestData) -> bool + Send + Sync + 'static, +{ + Arc::new(PermissionOverrideHandler { + inner, + policy: Policy::Predicate(Arc::new(predicate)), + }) +} + +enum Policy { + ApproveAll, + DenyAll, + Predicate(Arc bool + Send + Sync>), +} + +struct PermissionOverrideHandler { + inner: Arc, + policy: Policy, +} + +#[async_trait] +impl SessionHandler for PermissionOverrideHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::PermissionRequest { ref data, .. } => { + let approved = match &self.policy { + Policy::ApproveAll => true, + Policy::DenyAll => false, + Policy::Predicate(f) => f(data), + }; + HandlerResponse::Permission(if approved { + PermissionResult::Approved + } else { + PermissionResult::Denied + }) + } + other => self.inner.on_event(other).await, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::handler::ApproveAllHandler; + use crate::types::{RequestId, SessionId}; + + fn request() -> HandlerEvent { + HandlerEvent::PermissionRequest { + session_id: SessionId::from("s1"), + request_id: RequestId::new("1"), + data: PermissionRequestData { + extra: serde_json::json!({"tool": "shell"}), + ..Default::default() + }, + } + } + + #[tokio::test] + async fn approve_all_approves_permission_requests() { + let h = approve_all(Arc::new(ApproveAllHandler)); + match h.on_event(request()).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("expected Approved, got {other:?}"), + } + } + + #[tokio::test] + async fn deny_all_denies_permission_requests() { + let h = deny_all(Arc::new(ApproveAllHandler)); + match h.on_event(request()).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("expected Denied, got {other:?}"), + } + } + + #[tokio::test] + async fn approve_if_consults_predicate() { + let h = approve_if(Arc::new(ApproveAllHandler), |data| { + data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + }); + match h.on_event(request()).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("expected Denied for shell, got {other:?}"), + } + } + + #[tokio::test] + async fn non_permission_events_forward_to_inner() { + let h = deny_all(Arc::new(ApproveAllHandler)); + let event = HandlerEvent::ExitPlanMode { + session_id: SessionId::from("s1"), + data: crate::types::ExitPlanModeData::default(), + }; + match h.on_event(event).await { + HandlerResponse::ExitPlanMode(_) => {} + other => panic!("expected ExitPlanMode forwarded, got {other:?}"), + } + } +} diff --git a/rust/src/resolve.rs b/rust/src/resolve.rs new file mode 100644 index 000000000..8521a4b55 --- /dev/null +++ b/rust/src/resolve.rs @@ -0,0 +1,677 @@ +use std::collections::HashSet; +use std::env; +use std::ffi::OsStr; +use std::path::{Path, PathBuf}; + +use serde::Serialize; +use tracing::warn; + +use crate::Error; + +/// How the copilot binary was resolved. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)] +#[serde(rename_all = "snake_case")] +pub enum BinarySource { + /// Extracted from the build-time embedded binary. + Bundled, + /// Set via `COPILOT_CLI_PATH` environment variable. + EnvOverride, + /// Found on PATH or standard search locations. + Local, +} + +/// Find the `copilot` CLI binary on the system. +/// +/// Checks `COPILOT_CLI_PATH` env var first, then searches PATH and common +/// install locations (homebrew, nvm, nodenv, fnm, volta, cargo, etc.). +/// Use `COPILOT_CLI_NAME` to override the binary name (default: `copilot`). +pub fn copilot_binary() -> Result { + copilot_binary_with_source().map(|(path, _)| path) +} + +/// Like [`copilot_binary`] but also reports how the binary was resolved. +pub fn copilot_binary_with_source() -> Result<(PathBuf, BinarySource), Error> { + if let Ok(value) = env::var("COPILOT_CLI_PATH") { + let candidate = PathBuf::from(value); + if candidate.is_file() { + return Ok((candidate, BinarySource::EnvOverride)); + } + if candidate.is_dir() + && let Some(found) = find_copilot_in_dir(&candidate) + { + return Ok((found, BinarySource::EnvOverride)); + } + warn!(path = %candidate.display(), "COPILOT_CLI_PATH set but not usable"); + } + + if let Some(path) = crate::embeddedcli::path() { + return Ok((path, BinarySource::Bundled)); + } + + for dir in standard_search_paths() { + if let Some(found) = find_copilot_in_dir(&dir) { + return Ok((found, BinarySource::Local)); + } + } + + Err(Error::BinaryNotFound { + name: "copilot", + hint: "ensure the GitHub Copilot CLI is installed and on PATH, or set COPILOT_CLI_PATH. use COPILOT_CLI_NAME to override the binary name (default: copilot)", + }) +} + +/// Find the `copilot` CLI binary using only the current PATH entries. +/// +/// This is intentionally narrower than [`copilot_binary`]: it does not honor +/// override env vars and does not search inferred install locations. +pub fn copilot_binary_on_path() -> Result { + if let Some(found) = find_executable_in_path( + env::var_os("PATH").as_deref(), + &literal_copilot_executable_names(), + ) { + return Ok(found); + } + + Err(Error::BinaryNotFound { + name: "copilot", + hint: "ensure the `copilot` command is installed and available on PATH", + }) +} + +/// Build an extended `PATH` by prepending `extra` dirs to the standard +/// search paths (current PATH + common install locations). +pub fn extended_path(extra: &[PathBuf]) -> Option { + let mut paths = SearchPaths::new(); + for p in extra { + paths.push(p.clone()); + } + paths.append_standard(); + if paths.is_empty() { + return None; + } + env::join_paths(paths).ok() +} + +fn copilot_executable_names() -> Vec { + let base = env::var("COPILOT_CLI_NAME").unwrap_or_else(|_| "copilot".to_string()); + executable_names_for_base(&base) +} + +fn literal_copilot_executable_names() -> Vec { + executable_names_for_base("copilot") +} + +fn executable_names_for_base(base: &str) -> Vec { + #[cfg(target_os = "windows")] + { + vec![ + format!("{}.exe", base), + format!("{}.cmd", base), + format!("{}.bat", base), + ] + } + #[cfg(not(target_os = "windows"))] + { + vec![base.to_string()] + } +} + +fn find_executable(dir: &Path, names: &[impl AsRef]) -> Option { + if dir.as_os_str().is_empty() { + return None; + } + names + .iter() + .map(|n| dir.join(n.as_ref())) + .find(|c| c.is_file()) +} + +fn find_copilot_in_dir(dir: &Path) -> Option { + find_executable(dir, &copilot_executable_names()) +} + +fn find_executable_in_path( + path_env: Option<&OsStr>, + names: &[impl AsRef], +) -> Option { + let path_env = path_env?; + for dir in env::split_paths(path_env) { + if let Some(found) = find_executable(&dir, names) { + return Some(found); + } + } + None +} + +/// Ordered, deduplicated collection of directory paths to search for binaries. +/// +/// Paths are stored in insertion order. Duplicates and empty paths are +/// silently dropped on `push`. Implements `Iterator` so it can be passed +/// directly to `env::join_paths` or used in a `for` loop. +struct SearchPaths { + seen: HashSet, + paths: Vec, +} + +impl SearchPaths { + fn new() -> Self { + Self { + seen: HashSet::new(), + paths: Vec::new(), + } + } + + /// Add a path if it hasn't been seen before. Empty paths are ignored. + fn push(&mut self, path: PathBuf) { + if !path.as_os_str().is_empty() && self.seen.insert(path.clone()) { + self.paths.push(path); + } + } + + fn is_empty(&self) -> bool { + self.paths.is_empty() + } + + /// Append the standard search paths: current PATH, home-relative dirs, + /// version manager paths (nvm, nodenv, fnm), and platform-specific dirs. + fn append_standard(&mut self) { + if let Some(existing) = env::var_os("PATH") { + for p in env::split_paths(&existing) { + self.push(p); + } + } + + if let Some(home) = dirs::home_dir() { + self.push(home.join(".local/bin")); + self.push(home.join(".cargo/bin")); + self.push(home.join(".bun/bin")); + self.push(home.join(".npm-global/bin")); + self.push(home.join(".yarn/bin")); + self.push(home.join(".volta/bin")); + self.push(home.join(".asdf/shims")); + self.push(home.join("bin")); + } + + // Platform-specific standard dirs come before version-manager paths + // so that the system-installed node (e.g. /opt/homebrew/bin/node) + // takes precedence over arbitrary old versions found under + // ~/.nvm/versions, ~/.nodenv/versions, etc. + #[cfg(target_os = "macos")] + { + self.push(PathBuf::from("/opt/homebrew/bin")); + self.push(PathBuf::from("/usr/local/bin")); + self.push(PathBuf::from("/usr/bin")); + self.push(PathBuf::from("/bin")); + self.push(PathBuf::from("/usr/sbin")); + self.push(PathBuf::from("/sbin")); + } + + #[cfg(target_os = "linux")] + { + self.push(PathBuf::from("/usr/local/bin")); + self.push(PathBuf::from("/usr/bin")); + self.push(PathBuf::from("/bin")); + self.push(PathBuf::from("/snap/bin")); + } + + #[cfg(target_os = "windows")] + { + if let Some(appdata) = env::var_os("APPDATA") { + self.push(PathBuf::from(appdata).join("npm")); + } + if let Some(local) = env::var_os("LOCALAPPDATA") { + let local = PathBuf::from(local); + self.push(local.join("Programs")); + // User-scope winget install of Git for Windows. + self.push(local.join("Programs").join("Git").join("cmd")); + self.push(local.join("Programs").join("Git").join("bin")); + } + // Git for Windows standard machine-scope install locations. + for env_var in ["ProgramFiles", "ProgramW6432", "ProgramFiles(x86)"] { + if let Some(program_files) = env::var_os(env_var) { + let program_files = PathBuf::from(program_files); + self.push(program_files.join("Git").join("cmd")); + self.push(program_files.join("Git").join("bin")); + } + } + } + + // Version manager paths are a fallback for binary discovery — + // they enumerate every installed version, so an arbitrary old + // node/copilot can appear first if filesystem ordering is unlucky. + for p in collect_nvm_paths() { + self.push(p); + } + for p in collect_nodenv_paths() { + self.push(p); + } + for p in collect_fnm_paths() { + self.push(p); + } + } +} + +impl IntoIterator for SearchPaths { + type IntoIter = std::vec::IntoIter; + type Item = PathBuf; + + fn into_iter(self) -> Self::IntoIter { + self.paths.into_iter() + } +} + +/// Collect standard search paths for binary resolution. +fn standard_search_paths() -> SearchPaths { + let mut paths = SearchPaths::new(); + paths.append_standard(); + paths +} + +fn collect_nvm_paths() -> Vec { + let mut paths = Vec::new(); + let nvm_dir = env::var_os("NVM_DIR") + .map(PathBuf::from) + .or_else(|| dirs::home_dir().map(|home| home.join(".nvm"))); + let Some(nvm_dir) = nvm_dir else { + return paths; + }; + let versions_dir = nvm_dir.join("versions").join("node"); + let entries = match std::fs::read_dir(&versions_dir) { + Ok(entries) => entries, + Err(_) => return paths, + }; + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + paths.push(path.join("bin")); + } + } + paths +} + +fn collect_nodenv_paths() -> Vec { + let mut paths = Vec::new(); + let root = env::var_os("NODENV_ROOT") + .map(PathBuf::from) + .or_else(|| dirs::home_dir().map(|home| home.join(".nodenv"))); + let Some(root) = root else { + return paths; + }; + let versions_dir = root.join("versions"); + let entries = match std::fs::read_dir(&versions_dir) { + Ok(entries) => entries, + Err(_) => return paths, + }; + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + paths.push(path.join("bin")); + } + } + paths +} + +fn fnm_root_candidates_from( + fnm_dir: Option, + xdg_data_home: Option, + home: Option, +) -> Vec { + let mut roots = SearchPaths::new(); + + if let Some(fnm_dir) = fnm_dir.filter(|path| !path.as_os_str().is_empty()) { + roots.push(fnm_dir); + } + + if let Some(xdg_data_home) = xdg_data_home.filter(|path| !path.as_os_str().is_empty()) { + roots.push(xdg_data_home.join("fnm")); + } + + if let Some(home) = home { + roots.push(home.join(".local").join("share").join("fnm")); + roots.push(home.join(".fnm")); + } + + roots.paths +} + +fn collect_fnm_paths() -> Vec { + let roots = fnm_root_candidates_from( + env::var_os("FNM_DIR").map(PathBuf::from), + env::var_os("XDG_DATA_HOME").map(PathBuf::from), + dirs::home_dir(), + ); + + let mut paths = SearchPaths::new(); + for root in &roots { + paths.push(root.join("aliases").join("default").join("bin")); + + let versions_dir = root.join("node-versions"); + let entries = match std::fs::read_dir(&versions_dir) { + Ok(entries) => entries, + Err(_) => continue, + }; + for entry in entries.flatten() { + let path = entry.path(); + if path.is_dir() { + paths.push(path.join("installation").join("bin")); + } + } + } + + paths.paths +} + +#[cfg(test)] +mod tests { + use std::path::{Path, PathBuf}; + use std::{env, fs}; + + use serial_test::serial; + use tempfile::tempdir; + + use super::{ + copilot_binary_on_path, find_executable_in_path, fnm_root_candidates_from, + literal_copilot_executable_names, + }; + + #[test] + fn fnm_root_candidates_include_xdg_and_legacy_locations() { + let home = PathBuf::from("/tmp/copilot-home"); + + let roots = fnm_root_candidates_from(None, None, Some(home.clone())); + + assert_eq!( + roots, + vec![ + home.join(".local").join("share").join("fnm"), + home.join(".fnm"), + ] + ); + } + + #[test] + fn fnm_root_candidates_prefer_explicit_locations_first() { + let home = PathBuf::from("/tmp/copilot-home"); + let explicit_fnm_dir = PathBuf::from("/tmp/custom-fnm"); + let xdg_data_home = PathBuf::from("/tmp/xdg-data"); + + let roots = fnm_root_candidates_from( + Some(explicit_fnm_dir.clone()), + Some(xdg_data_home.clone()), + Some(home.clone()), + ); + + assert_eq!( + roots, + vec![ + explicit_fnm_dir, + xdg_data_home.join("fnm"), + home.join(".local").join("share").join("fnm"), + home.join(".fnm"), + ] + ); + } + + #[test] + fn fnm_root_candidates_ignore_empty_xdg_data_home() { + let home = PathBuf::from("/tmp/copilot-home"); + + let roots = fnm_root_candidates_from(None, Some(PathBuf::new()), Some(home.clone())); + + assert_eq!( + roots, + vec![ + home.join(".local").join("share").join("fnm"), + home.join(".fnm"), + ] + ); + assert!(!roots.iter().any(|path| path == &PathBuf::from("fnm"))); + } + + #[test] + fn fnm_root_produces_expected_bin_paths() { + let temp_dir = tempdir().expect("should create temp dir"); + let root = temp_dir.path().join("fnm-root"); + let alias_bin = root.join("aliases").join("default").join("bin"); + let version_bin = root + .join("node-versions") + .join("v22.18.0") + .join("installation") + .join("bin"); + + fs::create_dir_all(&alias_bin).expect("should create fnm alias bin"); + fs::create_dir_all(&version_bin).expect("should create fnm version bin"); + + let roots = fnm_root_candidates_from(Some(root.clone()), None, None); + assert_eq!(roots, vec![root.clone()]); + + // Verify the expected bin paths exist under the root structure + assert!(alias_bin.is_dir()); + assert!(version_bin.is_dir()); + } + + #[test] + fn find_copilot_in_path_finds_binary_in_path_entries() { + let temp_dir = tempdir().expect("should create temp dir"); + let bin_dir = temp_dir.path().join("bin"); + fs::create_dir_all(&bin_dir).expect("should create bin dir"); + + let executable_name = literal_copilot_executable_names() + .into_iter() + .next() + .expect("should provide a copilot executable name"); + let executable_path = bin_dir.join(&executable_name); + fs::write(&executable_path, "#!/bin/sh\n").expect("should create fake binary"); + + let path_env = + env::join_paths([Path::new("/missing"), bin_dir.as_path()]).expect("should build PATH"); + + assert_eq!( + find_executable_in_path( + Some(path_env.as_os_str()), + &literal_copilot_executable_names() + ), + Some(executable_path) + ); + } + + #[test] + fn find_copilot_in_path_ignores_missing_entries() { + let path_env = env::join_paths([Path::new("/missing-one"), Path::new("/missing-two")]) + .expect("should build PATH"); + + assert_eq!( + find_executable_in_path( + Some(path_env.as_os_str()), + &literal_copilot_executable_names() + ), + None + ); + } + + #[test] + #[serial] + #[cfg(target_os = "macos")] + fn platform_dirs_precede_version_manager_dirs() { + let temp = tempdir().expect("should create temp dir"); + let fake_home = temp.path().join("home"); + + // Create fake nvm version dirs so collect_nvm_paths() returns entries. + let nvm_dir = fake_home.join(".nvm"); + let nvm_version_bin = nvm_dir + .join("versions") + .join("node") + .join("v18.0.0") + .join("bin"); + fs::create_dir_all(&nvm_version_bin).expect("should create nvm version bin"); + + // Create fake nodenv version dirs. + let nodenv_root = fake_home.join(".nodenv"); + let nodenv_version_bin = nodenv_root.join("versions").join("20.0.0").join("bin"); + fs::create_dir_all(&nodenv_version_bin).expect("should create nodenv version bin"); + + // Create fake fnm version dirs. + let fnm_root = fake_home.join(".local").join("share").join("fnm"); + let fnm_version_bin = fnm_root + .join("node-versions") + .join("v22.0.0") + .join("installation") + .join("bin"); + fs::create_dir_all(&fnm_version_bin).expect("should create fnm version bin"); + + // Save env vars. + let prev_path = env::var_os("PATH"); + let prev_home = env::var_os("HOME"); + let prev_nvm_dir = env::var_os("NVM_DIR"); + let prev_nodenv_root = env::var_os("NODENV_ROOT"); + let prev_fnm_dir = env::var_os("FNM_DIR"); + let prev_xdg_data_home = env::var_os("XDG_DATA_HOME"); + + // Set env: empty PATH so only append_standard() dirs appear, + // HOME to our fake home, and explicit version-manager roots. + // Safety: test-only, single-threaded via #[serial]. + unsafe { + env::set_var("PATH", ""); + env::set_var("HOME", &fake_home); + env::set_var("NVM_DIR", &nvm_dir); + env::set_var("NODENV_ROOT", &nodenv_root); + env::remove_var("FNM_DIR"); + env::remove_var("XDG_DATA_HOME"); + } + + let paths: Vec = super::standard_search_paths().into_iter().collect(); + + // Restore env vars. + // Safety: test-only, single-threaded via #[serial]. + unsafe { + match prev_path { + Some(v) => env::set_var("PATH", v), + None => env::remove_var("PATH"), + } + match prev_home { + Some(v) => env::set_var("HOME", v), + None => env::remove_var("HOME"), + } + match prev_nvm_dir { + Some(v) => env::set_var("NVM_DIR", v), + None => env::remove_var("NVM_DIR"), + } + match prev_nodenv_root { + Some(v) => env::set_var("NODENV_ROOT", v), + None => env::remove_var("NODENV_ROOT"), + } + match prev_fnm_dir { + Some(v) => env::set_var("FNM_DIR", v), + None => env::remove_var("FNM_DIR"), + } + match prev_xdg_data_home { + Some(v) => env::set_var("XDG_DATA_HOME", v), + None => env::remove_var("XDG_DATA_HOME"), + } + } + + let platform_dirs: Vec = vec![ + PathBuf::from("/opt/homebrew/bin"), + PathBuf::from("/usr/local/bin"), + PathBuf::from("/usr/bin"), + PathBuf::from("/bin"), + PathBuf::from("/usr/sbin"), + PathBuf::from("/sbin"), + ]; + + // Find the last platform dir index and the first version-manager dir index. + let last_platform_idx = platform_dirs + .iter() + .filter_map(|d| paths.iter().position(|p| p == d)) + .max() + .expect("at least one platform dir should be present"); + + let version_manager_prefixes = [ + nvm_version_bin.parent().unwrap().parent().unwrap(), // .nvm/versions/node + nodenv_version_bin.parent().unwrap().parent().unwrap(), // .nodenv/versions + fnm_version_bin + .parent() + .unwrap() + .parent() + .unwrap() + .parent() + .unwrap() + .parent() + .unwrap(), // .local/share/fnm + ]; + + let first_version_mgr_idx = paths + .iter() + .position(|p| { + version_manager_prefixes + .iter() + .any(|prefix| p.starts_with(prefix)) + }) + .expect("at least one version-manager dir should be present"); + + assert!( + last_platform_idx < first_version_mgr_idx, + "Platform dirs (last at index {last_platform_idx}) must precede \ + version-manager dirs (first at index {first_version_mgr_idx}).\n\ + Full path list: {paths:#?}" + ); + } + + #[test] + #[serial] + fn find_executable_in_path_can_ignore_copilot_name_override() { + let temp_dir = tempdir().expect("should create temp dir"); + let bin_dir = temp_dir.path().join("bin"); + fs::create_dir_all(&bin_dir).expect("should create bin dir"); + + let path_executable_name = literal_copilot_executable_names() + .into_iter() + .next() + .expect("should provide a literal copilot executable name"); + #[cfg(target_os = "windows")] + let overridden_executable_name = "my-copilot.exe"; + + #[cfg(not(target_os = "windows"))] + let overridden_executable_name = "my-copilot"; + + let path_executable_path = bin_dir.join(&path_executable_name); + let overridden_executable_path = bin_dir.join(overridden_executable_name); + + fs::write(&path_executable_path, "#!/bin/sh\n").expect("should create literal fake binary"); + fs::write(&overridden_executable_path, "#!/bin/sh\n") + .expect("should create overridden fake binary"); + + let path_env = + env::join_paths([Path::new("/missing"), bin_dir.as_path()]).expect("should build PATH"); + + let previous_path = env::var_os("PATH"); + let previous_copilot_cli_name = env::var_os("COPILOT_CLI_NAME"); + // Safety: test-only, single-threaded via #[serial]. + unsafe { + env::set_var("PATH", &path_env); + env::set_var("COPILOT_CLI_NAME", "my-copilot"); + } + + let resolved_path = copilot_binary_on_path(); + + // Safety: test-only, single-threaded via #[serial]. + unsafe { + if let Some(previous_path) = previous_path { + env::set_var("PATH", previous_path); + } else { + env::remove_var("PATH"); + } + + if let Some(previous_copilot_cli_name) = previous_copilot_cli_name { + env::set_var("COPILOT_CLI_NAME", previous_copilot_cli_name); + } else { + env::remove_var("COPILOT_CLI_NAME"); + } + } + + assert_eq!( + resolved_path.expect("should find the literal copilot binary on PATH"), + path_executable_path + ); + } +} diff --git a/rust/src/router.rs b/rust/src/router.rs new file mode 100644 index 000000000..e14630e03 --- /dev/null +++ b/rust/src/router.rs @@ -0,0 +1,178 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use parking_lot::Mutex; +use tokio::sync::{broadcast, mpsc}; +use tracing::warn; + +use crate::jsonrpc::{JsonRpcNotification, JsonRpcRequest}; +use crate::types::{SessionEventNotification, SessionId}; + +/// Per-session channels created by the router during session registration. +pub(crate) struct SessionChannels { + /// Filtered `session.event` notifications for this session. + pub(crate) notifications: mpsc::UnboundedReceiver, + /// Filtered JSON-RPC requests (tool.call, userInput.request, etc.) for this session. + pub(crate) requests: mpsc::UnboundedReceiver, +} + +struct SessionSenders { + notifications: mpsc::UnboundedSender, + requests: mpsc::UnboundedSender, +} + +/// Routes notifications and requests by sessionId to per-session channels. +/// +/// Internal to the SDK — consumers interact via `Client::register_session()`. +pub(crate) struct SessionRouter { + sessions: Arc>>, + started: Mutex, +} + +impl SessionRouter { + pub(crate) fn new() -> Self { + Self { + sessions: Arc::new(Mutex::new(HashMap::new())), + started: Mutex::new(false), + } + } + + /// Register a session to receive filtered events and requests. + pub(crate) fn register(&self, session_id: &SessionId) -> SessionChannels { + let (notif_tx, notif_rx) = mpsc::unbounded_channel(); + let (req_tx, req_rx) = mpsc::unbounded_channel(); + self.sessions.lock().insert( + session_id.clone(), + SessionSenders { + notifications: notif_tx, + requests: req_tx, + }, + ); + SessionChannels { + notifications: notif_rx, + requests: req_rx, + } + } + + /// Unregister a session, dropping its channels. + pub(crate) fn unregister(&self, session_id: &SessionId) { + self.sessions.lock().remove(session_id.as_str()); + } + + /// Snapshot every currently-registered session ID. + /// + /// Used by [`Client::stop`](crate::Client::stop) to iterate active + /// sessions for cooperative shutdown without holding the router lock + /// across `.await`. + pub(crate) fn session_ids(&self) -> Vec { + self.sessions.lock().keys().cloned().collect() + } + + /// Drop all registered session channels. + /// + /// Used by [`Client::force_stop`](crate::Client::force_stop) to release + /// per-session state without waiting for graceful unregistration. + pub(crate) fn clear(&self) { + self.sessions.lock().clear(); + } + + /// Start the router tasks if not already running. + /// + /// Takes the notification broadcast and request channel from the Client. + /// If `request_rx` is `None` (already taken by `take_request_rx()`), + /// only notification routing is available. + pub(crate) fn ensure_started( + &self, + notification_tx: &broadcast::Sender, + request_rx: &Mutex>>, + ) { + let mut started = self.started.lock(); + if *started { + return; + } + *started = true; + + // Notification routing task + let sessions = self.sessions.clone(); + let mut notif_rx = notification_tx.subscribe(); + tokio::spawn(async move { + loop { + match notif_rx.recv().await { + Ok(notification) => { + if notification.method != "session.event" { + continue; + } + let Some(ref params) = notification.params else { + continue; + }; + let Some(session_id) = params.get("sessionId").and_then(|v| v.as_str()) + else { + continue; + }; + + let sender = { + let guard = sessions.lock(); + guard.get(session_id).map(|s| s.notifications.clone()) + }; + if let Some(sender) = sender { + match serde_json::from_value::(params.clone()) + { + Ok(event_notification) => { + let _ = sender.send(event_notification); + } + Err(e) => { + warn!( + error = %e, + session_id = session_id, + "failed to deserialize session event notification" + ); + } + } + } + // Unknown session IDs are silently dropped — the session + // may have been unregistered between dispatch and delivery. + } + Err(broadcast::error::RecvError::Lagged(n)) => { + warn!(missed = n, "notification router lagged"); + } + Err(broadcast::error::RecvError::Closed) => break, + } + } + }); + + // Request routing task (if request_rx is available) + if let Some(mut rx) = request_rx.lock().take() { + let sessions = self.sessions.clone(); + tokio::spawn(async move { + while let Some(request) = rx.recv().await { + let session_id = request + .params + .as_ref() + .and_then(|p| p.get("sessionId")) + .and_then(|v| v.as_str()); + + if let Some(sid) = session_id { + let sender = { + let guard = sessions.lock(); + guard.get(sid).map(|s| s.requests.clone()) + }; + if let Some(sender) = sender { + let _ = sender.send(request); + } else { + warn!( + session_id = sid, + method = %request.method, + "request for unregistered session" + ); + } + } else { + warn!( + method = %request.method, + "request missing sessionId" + ); + } + } + }); + } + } +} diff --git a/rust/src/sdk_protocol_version.rs b/rust/src/sdk_protocol_version.rs new file mode 100644 index 000000000..21089f99e --- /dev/null +++ b/rust/src/sdk_protocol_version.rs @@ -0,0 +1,13 @@ +// Code generated by update-protocol-version.ts. DO NOT EDIT. + +//! The SDK protocol version. Must match the version expected by the +//! copilot-agent-runtime server. + +/// The SDK protocol version. +pub const SDK_PROTOCOL_VERSION: u32 = 3; + +/// Returns the SDK protocol version. +#[must_use] +pub const fn get_sdk_protocol_version() -> u32 { + SDK_PROTOCOL_VERSION +} diff --git a/rust/src/session.rs b/rust/src/session.rs new file mode 100644 index 000000000..0d8ebc5da --- /dev/null +++ b/rust/src/session.rs @@ -0,0 +1,1954 @@ +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; + +use parking_lot::Mutex as ParkingLotMutex; +use serde_json::Value; +use tokio::sync::{Notify, oneshot}; +use tokio::task::JoinHandle; +use tracing::{Instrument, warn}; + +use crate::generated::api_types::{ + LogRequest, ModeSetRequest, ModelSwitchToRequest, NameSetRequest, PermissionDecision, + PermissionDecisionApproveOnce, PermissionDecisionApproveOnceKind, PermissionDecisionReject, + PermissionDecisionRejectKind, PlanUpdateRequest, SessionMode, WorkspacesCreateFileRequest, + WorkspacesReadFileRequest, +}; +use crate::generated::session_events::{ + CommandExecuteData, ElicitationRequestedData, ExternalToolRequestedData, SessionErrorData, + SessionEventType, +}; +use crate::handler::{ + AutoModeSwitchResponse, ExitPlanModeResult, HandlerEvent, HandlerResponse, PermissionResult, + SessionHandler, UserInputResponse, +}; +use crate::hooks::SessionHooks; +use crate::session_fs::SessionFsProvider; +use crate::trace_context::inject_trace_context; +use crate::transforms::SystemMessageTransform; +use crate::types::{ + CommandContext, CommandDefinition, CommandHandler, CreateSessionResult, ElicitationRequest, + ElicitationResult, ExitPlanModeData, GetMessagesResponse, InputOptions, MessageOptions, + PermissionRequestData, RequestId, ResumeSessionConfig, SectionOverride, SessionCapabilities, + SessionConfig, SessionEvent, SessionId, SessionTelemetryEvent, SetModelOptions, + SystemMessageConfig, ToolInvocation, ToolResult, ToolResultResponse, TraceContext, + ensure_attachment_display_names, +}; +use crate::{Client, Error, JsonRpcResponse, SessionError, SessionEventNotification, error_codes}; + +/// Shared state between a [`Session`] and its event loop, used by [`Session::send_and_wait`]. +struct IdleWaiter { + tx: oneshot::Sender, Error>>, + last_assistant_message: Option, +} + +/// RAII guard that clears the [`Session::idle_waiter`] slot on drop. Used +/// by [`Session::send_and_wait`] to ensure the slot doesn't leak if the +/// caller's future is cancelled (outer `tokio::time::timeout` / `select!` +/// / dropped JoinHandle). Synchronous clear via `parking_lot::Mutex` — +/// no async drop needed. +/// +/// Without this, an outer cancellation between "install waiter" and +/// "drain channel" would leave the slot occupied, causing all subsequent +/// `send` and `send_and_wait` calls on the session to return +/// [`SendWhileWaiting`](SessionError::SendWhileWaiting). Closes RFD-400 +/// review finding #2. +struct WaiterGuard { + slot: Arc>>, +} + +impl Drop for WaiterGuard { + fn drop(&mut self) { + self.slot.lock().take(); + } +} + +/// A session on a GitHub Copilot CLI server. +/// +/// Created via [`Client::create_session`] or [`Client::resume_session`]. +/// Owns an internal event loop that dispatches events to the [`SessionHandler`]. +/// +/// Protocol methods (`send`, `get_messages`, `abort`, etc.) automatically +/// inject the session ID into RPC params. +/// +/// Call [`destroy`](Self::destroy) for graceful cleanup (RPC + local). If dropped +/// without calling `destroy`, the `Drop` impl aborts the event loop and +/// unregisters from the router as a best-effort safety net. +pub struct Session { + id: SessionId, + cwd: PathBuf, + workspace_path: Option, + remote_url: Option, + client: Client, + /// Handle to the spawned event-loop task. Sync `parking_lot::Mutex` + /// because the lock is never held across an `.await` and the `Drop` + /// impl needs to take the handle synchronously without `try_lock` + /// fallibility. + event_loop: ParkingLotMutex>>, + /// Cooperative shutdown signal for the event loop. The loop selects + /// on `shutdown.notified()` alongside its inbound channels; + /// [`Session::stop_event_loop`] and [`Drop`] both call `notify_one()` + /// (which buffers a single signal so it is not lost while the loop + /// is inside an awaiting branch) to ask the loop to exit between + /// iterations rather than `JoinHandle::abort` (which can land at any + /// await point and leave the session in mid-protocol state). See + /// RFD-400 review finding #3. + shutdown: Arc, + /// Only populated while a `send_and_wait` call is in flight. + /// + /// Sync `parking_lot::Mutex` because the lock is never held across an + /// `.await`, and synchronous access lets the `WaiterGuard` RAII helper + /// in `send_and_wait` clear the slot from a `Drop` impl on caller-side + /// cancellation. See RFD-400 review (cancel-safety hardening). + idle_waiter: Arc>>, + /// Capabilities negotiated with the CLI, updated on `capabilities.changed` events. + capabilities: Arc>, + /// Broadcast channel for runtime event subscribers — see [`Session::subscribe`]. + event_tx: tokio::sync::broadcast::Sender, +} + +impl Session { + /// Session ID assigned by the CLI. + pub fn id(&self) -> &SessionId { + &self.id + } + + /// Working directory of the CLI process. + pub fn cwd(&self) -> &PathBuf { + &self.cwd + } + + /// Workspace directory for the session (if using infinite sessions). + pub fn workspace_path(&self) -> Option<&Path> { + self.workspace_path.as_deref() + } + + /// Remote session URL, if the session is running remotely. + pub fn remote_url(&self) -> Option<&str> { + self.remote_url.as_deref() + } + + /// Session capabilities negotiated with the CLI. + /// + /// Capabilities are set during session creation and updated at runtime + /// via `capabilities.changed` events. + pub fn capabilities(&self) -> SessionCapabilities { + self.capabilities.read().clone() + } + + /// Subscribe to events for this session. + /// + /// Returns an [`EventSubscription`](crate::subscription::EventSubscription) + /// that yields every [`SessionEvent`] dispatched on this session's + /// event loop. Drop the value to unsubscribe; there is no separate + /// cancel handle. + /// + /// **Observe-only.** Subscribers receive a clone of every + /// [`SessionEvent`] but cannot influence permission decisions, tool + /// results, or anything else that requires returning a + /// [`HandlerResponse`]. Those remain + /// the responsibility of the [`SessionHandler`] passed via + /// [`SessionConfig::handler`](crate::types::SessionConfig::handler). + /// + /// The returned handle implements both an inherent + /// [`recv`](crate::subscription::EventSubscription::recv) method and + /// [`Stream`](tokio_stream::Stream), so callers can use a `while let` + /// loop or any combinator from `tokio_stream::StreamExt` / + /// `futures::StreamExt`. + /// + /// Each subscriber maintains its own queue. If a consumer cannot keep + /// up, the oldest events are dropped and `recv` returns + /// [`RecvError::Lagged`](crate::subscription::RecvError::Lagged) + /// reporting the count of skipped events. Slow consumers do not block + /// the session's event loop. + /// + /// # Example + /// + /// ```no_run + /// # async fn example(session: github_copilot_sdk::session::Session) { + /// let mut events = session.subscribe(); + /// tokio::spawn(async move { + /// while let Ok(event) = events.recv().await { + /// println!("[{}] event {}", event.id, event.event_type); + /// } + /// }); + /// # } + /// ``` + pub fn subscribe(&self) -> crate::subscription::EventSubscription { + crate::subscription::EventSubscription::new(self.event_tx.subscribe()) + } + + /// The underlying Client (for advanced use cases). + pub fn client(&self) -> &Client { + &self.client + } + + /// Typed RPC namespace for this session. + /// + /// Every protocol method lives here under its schema-aligned path — + /// e.g. `session.rpc().workspaces().list_files()`. Wire method names + /// and request/response types are generated from the protocol schema, + /// so the typed namespace can't drift from the wire contract. + /// + /// The hand-authored helpers on [`Session`] delegate to this namespace + /// and remain the recommended entry point for everyday use; reach for + /// `rpc()` when you want a method without a hand-written wrapper. + pub fn rpc(&self) -> crate::generated::rpc::SessionRpc<'_> { + crate::generated::rpc::SessionRpc { session: self } + } + + /// Stop the internal event loop. Called automatically on [`destroy`](Self::destroy). + /// + /// Cooperative: signals shutdown via the session's `Notify` and awaits + /// the loop's natural exit rather than aborting the task. Any in-flight + /// handler (permission callback, tool call, elicitation response) + /// completes before the loop exits, so the CLI never sees a + /// half-handled request. See RFD-400 review finding #3. + pub async fn stop_event_loop(&self) { + self.shutdown.notify_one(); + let handle = self.event_loop.lock().take(); + if let Some(handle) = handle { + let _ = handle.await; + } + // Fail any pending send_and_wait so it returns immediately. + if let Some(waiter) = self.idle_waiter.lock().take() { + let _ = waiter + .tx + .send(Err(Error::Session(SessionError::EventLoopClosed))); + } + } + + /// Send a user message to the agent. + /// + /// Accepts anything convertible to [`MessageOptions`] — pass a `&str` for the + /// trivial case, or build a `MessageOptions` for mode/attachments. The + /// `wait_timeout` field on `MessageOptions` is ignored here (use + /// [`send_and_wait`](Self::send_and_wait) if you need to wait). + /// + /// Returns the assigned message ID, which can be used to correlate the + /// send with later [`SessionEvent`]s emitted in + /// response (assistant messages, tool requests, etc.). + /// + /// Returns an error if a [`send_and_wait`](Self::send_and_wait) call is + /// currently in flight, since the plain send would race with the waiter. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** The underlying `session.send` RPC is dispatched + /// through the writer-actor (see [`Client::call`](crate::Client::call)), + /// so dropping this future after the actor has committed to writing + /// will not produce a partial frame on the wire. If the caller's + /// future is dropped between "frame enqueued" and "response received", + /// the message has already landed on the wire — the agent will process + /// it and emit events normally; the caller just won't see the returned + /// message ID. + pub async fn send(&self, opts: impl Into) -> Result { + if self.idle_waiter.lock().is_some() { + return Err(Error::Session(SessionError::SendWhileWaiting)); + } + self.send_inner(opts.into()).await + } + + async fn send_inner(&self, opts: MessageOptions) -> Result { + let mut params = serde_json::json!({ + "sessionId": self.id, + "prompt": opts.prompt, + }); + if let Some(m) = opts.mode { + params["mode"] = serde_json::to_value(m)?; + } + if let Some(mut a) = opts.attachments { + ensure_attachment_display_names(&mut a); + params["attachments"] = serde_json::to_value(a)?; + } + if let Some(headers) = opts.request_headers + && !headers.is_empty() + { + params["requestHeaders"] = serde_json::to_value(headers)?; + } + let trace_ctx = if opts.traceparent.is_some() || opts.tracestate.is_some() { + TraceContext { + traceparent: opts.traceparent, + tracestate: opts.tracestate, + } + } else { + self.client.resolve_trace_context().await + }; + inject_trace_context(&mut params, &trace_ctx); + let result = self.client.call("session.send", Some(params)).await?; + let message_id = result + .get("messageId") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + .unwrap_or_default(); + Ok(message_id) + } + + /// Enable or disable session-wide auto-approval for tool permission requests. + pub async fn set_approve_all_permissions(&self, enabled: bool) -> Result<(), Error> { + self.rpc() + .permissions() + .set_approve_all( + crate::generated::api_types::PermissionsSetApproveAllRequest { enabled }, + ) + .await?; + Ok(()) + } + + /// Send a user message and wait for the agent to finish processing. + /// + /// Accepts anything convertible to [`MessageOptions`] — pass a `&str` for the + /// trivial case, or build a `MessageOptions` for mode/attachments/timeout. + /// Blocks until `session.idle` (success) or `session.error` (failure), + /// returning the last `assistant.message` event captured during streaming. + /// Times out after `MessageOptions::wait_timeout` (default 60 seconds). + /// + /// Only one `send_and_wait` call may be active per session at a time. + /// Calling [`send`](Self::send) while a `send_and_wait` + /// is in flight will also return an error. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** A `WaiterGuard` clears the in-flight slot on every + /// exit path (success, internal failure, internal timeout, *and* + /// external cancellation via `tokio::time::timeout` / `select!` / + /// dropped JoinHandle). Subsequent `send` and `send_and_wait` calls on + /// this session will succeed normally — the slot is never leaked. + pub async fn send_and_wait( + &self, + opts: impl Into, + ) -> Result, Error> { + let opts = opts.into(); + let timeout_duration = opts.wait_timeout.unwrap_or(Duration::from_secs(60)); + let (tx, rx) = oneshot::channel(); + + { + let mut guard = self.idle_waiter.lock(); + if guard.is_some() { + return Err(Error::Session(SessionError::SendWhileWaiting)); + } + *guard = Some(IdleWaiter { + tx, + last_assistant_message: None, + }); + } + + // RAII: clears the idle_waiter slot on every exit path, including + // external cancellation (caller's outer `select!` / `timeout` / + // dropped future). Without this, an outer cancellation would leak + // the slot and brick subsequent `send`/`send_and_wait` calls. + let _waiter_guard = WaiterGuard { + slot: self.idle_waiter.clone(), + }; + + let result = tokio::time::timeout(timeout_duration, async { + self.send_inner(opts).await?; + match rx.await { + Ok(result) => result, + Err(_) => Err(Error::Session(SessionError::EventLoopClosed)), + } + }) + .await; + + match result { + Ok(inner) => inner, + Err(_) => Err(Error::Session(SessionError::Timeout(timeout_duration))), + } + } + + /// Retrieve the session's message history. + pub async fn get_messages(&self) -> Result, Error> { + let result = self + .client + .call( + "session.getMessages", + Some(serde_json::json!({ "sessionId": self.id })), + ) + .await?; + let response: GetMessagesResponse = serde_json::from_value(result)?; + Ok(response.events) + } + + /// Abort the current agent turn. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** Single `session.abort` RPC; the underlying + /// [`Client::call`](crate::Client::call) is cancel-safe via the + /// writer-actor. + pub async fn abort(&self) -> Result<(), Error> { + self.client + .call( + "session.abort", + Some(serde_json::json!({ "sessionId": self.id })), + ) + .await?; + Ok(()) + } + + /// Switch to a different model. + /// + /// Pass `None` for `opts` if no extra configuration is needed. + pub async fn set_model(&self, model: &str, opts: Option) -> Result<(), Error> { + let opts = opts.unwrap_or_default(); + let request = ModelSwitchToRequest { + model_id: model.to_string(), + reasoning_effort: opts.reasoning_effort, + model_capabilities: opts.model_capabilities, + }; + self.rpc().model().switch_to(request).await?; + Ok(()) + } + + /// Get the current model. + pub async fn get_model(&self) -> Result, Error> { + Ok(self.rpc().model().get_current().await?.model_id) + } + + /// Set the session mode (e.g. "interactive", "plan", "autopilot"). + pub async fn set_mode(&self, mode: &str) -> Result { + let parsed: SessionMode = serde_json::from_value(Value::String(mode.to_string()))?; + self.rpc() + .mode() + .set(ModeSetRequest { mode: parsed }) + .await?; + Ok(mode.to_string()) + } + + /// Get the current session mode. + pub async fn get_mode(&self) -> Result { + let mode = self.rpc().mode().get().await?; + Ok(serde_json::to_value(mode)? + .as_str() + .unwrap_or("interactive") + .to_string()) + } + + /// Get the current session name. + pub async fn get_name(&self) -> Result, Error> { + Ok(self.rpc().name().get().await?.name) + } + + /// Set the current session name. + pub async fn set_name(&self, name: &str) -> Result<(), Error> { + self.rpc() + .name() + .set(NameSetRequest { + name: name.to_string(), + }) + .await + } + + /// Disconnect this session from the CLI. + /// + /// Sends the `session.destroy` RPC, stops the event loop, and unregisters + /// the session from the client. **Session state on disk** (conversation + /// history, planning state, artifacts) is **preserved**, so the + /// conversation can be resumed later via [`Client::resume_session`] + /// using this session's ID. To permanently remove all on-disk session + /// data, use [`Client::delete_session`] instead. + /// + /// The caller should ensure the session is idle (e.g. [`send_and_wait`] + /// has returned) before disconnecting; in-flight tool or event handlers + /// may otherwise observe failures. + /// + /// [`Client::resume_session`]: crate::Client::resume_session + /// [`Client::delete_session`]: crate::Client::delete_session + /// [`send_and_wait`]: Self::send_and_wait + pub async fn disconnect(&self) -> Result<(), Error> { + self.client + .call( + "session.destroy", + Some(serde_json::json!({ "sessionId": self.id })), + ) + .await?; + self.stop_event_loop().await; + self.client.unregister_session(&self.id); + Ok(()) + } + + /// Alias for [`disconnect`](Self::disconnect). + /// + /// Named after the `session.destroy` wire RPC. Prefer `disconnect` in + /// new code — the wire-level "destroy" is misleading because on-disk + /// state is preserved. + pub async fn destroy(&self) -> Result<(), Error> { + self.disconnect().await + } + + /// List files in the session workspace. + pub async fn list_workspace_files(&self) -> Result, Error> { + Ok(self.rpc().workspaces().list_files().await?.files) + } + + /// Read a file from the session workspace. + pub async fn read_workspace_file(&self, path: &Path) -> Result { + Ok(self + .rpc() + .workspaces() + .read_file(WorkspacesReadFileRequest { + path: path.to_string_lossy().into_owned(), + }) + .await? + .content) + } + + /// Create a file in the session workspace. + pub async fn create_workspace_file(&self, path: &Path, content: &str) -> Result<(), Error> { + self.rpc() + .workspaces() + .create_file(WorkspacesCreateFileRequest { + path: path.to_string_lossy().into_owned(), + content: content.to_string(), + }) + .await + } + + /// Read the session plan. + pub async fn read_plan(&self) -> Result<(bool, Option), Error> { + let r = self.rpc().plan().read().await?; + Ok((r.exists, r.content)) + } + + /// Update the session plan. + pub async fn update_plan(&self, content: &str) -> Result<(), Error> { + self.rpc() + .plan() + .update(PlanUpdateRequest { + content: content.to_string(), + }) + .await + } + + /// Delete the session plan. + pub async fn delete_plan(&self) -> Result<(), Error> { + self.rpc().plan().delete().await + } + + /// Write a log message to the session. + /// + /// Pass `None` for `opts` to use defaults (info level, persisted). + pub async fn log( + &self, + message: &str, + opts: Option, + ) -> Result<(), Error> { + let opts = opts.unwrap_or_default(); + let level = match opts.level { + Some(level) => Some(serde_json::from_value(serde_json::to_value(level)?)?), + None => None, + }; + let request = LogRequest { + message: message.to_string(), + level, + ephemeral: opts.ephemeral, + url: None, + }; + self.rpc().log(request).await?; + Ok(()) + } + + /// Send a telemetry event through the session's internal shared API. + pub async fn send_telemetry(&self, event: SessionTelemetryEvent) -> Result<(), Error> { + let mut params = serde_json::to_value(event)?; + let params_object = params + .as_object_mut() + .expect("SessionTelemetryEvent always serializes to an object"); + params_object.insert("sessionId".to_string(), serde_json::to_value(&self.id)?); + + self.client + .call("session.sendTelemetry", Some(params)) + .await?; + Ok(()) + } + + /// Returns the UI sub-API for elicitation, confirmation, selection, and + /// free-form input. + /// + /// All UI methods route through `session.ui.*` RPCs and require host + /// support — check `session.capabilities().ui.elicitation` before use. + pub fn ui(&self) -> SessionUi<'_> { + SessionUi { session: self } + } + + /// Returns an error if the host doesn't support elicitation. + fn assert_elicitation(&self) -> Result<(), Error> { + if self + .capabilities + .read() + .ui + .as_ref() + .and_then(|u| u.elicitation) + != Some(true) + { + return Err(Error::Session(SessionError::ElicitationNotSupported)); + } + Ok(()) + } + + /// Start a fleet of sub-agents. + pub async fn start_fleet(&self, prompt: Option<&str>) -> Result { + Ok(self + .rpc() + .fleet() + .start(crate::generated::api_types::FleetStartRequest { + prompt: prompt.map(|s| s.to_string()), + }) + .await? + .started) + } + + /// Generic RPC forwarder — auto-injects sessionId into params. + pub async fn call_rpc( + &self, + method: &str, + extra_params: Option, + ) -> Result { + let mut params = serde_json::json!({ "sessionId": self.id }); + let extra_obj = extra_params.as_ref().and_then(Value::as_object); + if let (Some(base), Some(extra_obj)) = (params.as_object_mut(), extra_obj) { + for (k, v) in extra_obj { + base.insert(k.clone(), v.clone()); + } + } + self.client.call(method, Some(params)).await + } +} + +impl Drop for Session { + fn drop(&mut self) { + // Cooperative shutdown: notify the event loop to exit between + // iterations. The loop will see the signal on its next select + // poll and break cleanly without interrupting an in-flight + // handler. We do NOT abort the JoinHandle — that would land at + // any await point in the loop body, potentially leaving the CLI + // with an unanswered request id. RFD-400 review finding #3. + // + // The handle itself is left in `event_loop` to be reaped by the + // tokio runtime when it next polls; we intentionally don't await + // it here because Drop is sync. + self.shutdown.notify_one(); + self.client.unregister_session(&self.id); + } +} + +/// UI sub-API for a [`Session`] — elicitation, confirmation, selection, +/// and free-form input. +/// +/// Acquired via [`Session::ui`]. Methods route to `session.ui.*` RPCs and +/// require host elicitation support — check +/// `session.capabilities().ui.elicitation` before use. +pub struct SessionUi<'a> { + session: &'a Session, +} + +impl<'a> SessionUi<'a> { + /// Request user input via an interactive UI form (elicitation). + /// + /// Sends a JSON Schema describing form fields to the CLI host. The host + /// renders a form dialog and returns the user's response. + /// + /// Prefer the typed convenience methods [`confirm`](Self::confirm), + /// [`select`](Self::select), and [`input`](Self::input) for common cases. + pub async fn elicitation( + &self, + message: &str, + schema: Value, + ) -> Result { + self.session.assert_elicitation()?; + let result = self + .session + .client + .call( + "session.ui.elicitation", + Some(serde_json::json!({ + "sessionId": self.session.id, + "message": message, + "requestedSchema": schema, + })), + ) + .await?; + let elicitation: ElicitationResult = serde_json::from_value(result)?; + Ok(elicitation) + } + + /// Ask the user a yes/no confirmation question. + /// + /// Returns `true` if the user accepted and confirmed, `false` otherwise. + pub async fn confirm(&self, message: &str) -> Result { + self.session.assert_elicitation()?; + let schema = serde_json::json!({ + "type": "object", + "properties": { + "confirmed": { + "type": "boolean", + "default": true, + } + }, + "required": ["confirmed"] + }); + let result = self.elicitation(message, schema).await?; + Ok(result.action == "accept" + && result + .content + .and_then(|c| c.get("confirmed").and_then(|v| v.as_bool())) + == Some(true)) + } + + /// Ask the user to select from a list of options. + /// + /// Returns the selected option string on accept, or `None` on decline/cancel. + pub async fn select(&self, message: &str, options: &[&str]) -> Result, Error> { + self.session.assert_elicitation()?; + let schema = serde_json::json!({ + "type": "object", + "properties": { + "selection": { + "type": "string", + "enum": options, + } + }, + "required": ["selection"] + }); + let result = self.elicitation(message, schema).await?; + if result.action != "accept" { + return Ok(None); + } + let selection = result.content.and_then(|c| { + c.get("selection") + .and_then(|v| v.as_str()) + .map(String::from) + }); + Ok(selection) + } + + /// Ask the user for free-form text input. + /// + /// Returns the input string on accept, or `None` on decline/cancel. + /// Use [`InputOptions`] to set validation constraints and field metadata. + pub async fn input( + &self, + message: &str, + options: Option<&InputOptions<'_>>, + ) -> Result, Error> { + self.session.assert_elicitation()?; + let mut field = serde_json::json!({ "type": "string" }); + if let Some(opts) = options { + if let Some(title) = opts.title { + field["title"] = Value::String(title.to_string()); + } + if let Some(desc) = opts.description { + field["description"] = Value::String(desc.to_string()); + } + if let Some(min) = opts.min_length { + field["minLength"] = Value::Number(min.into()); + } + if let Some(max) = opts.max_length { + field["maxLength"] = Value::Number(max.into()); + } + if let Some(fmt) = &opts.format { + field["format"] = Value::String(fmt.as_str().to_string()); + } + if let Some(default) = opts.default { + field["default"] = Value::String(default.to_string()); + } + } + let schema = serde_json::json!({ + "type": "object", + "properties": { "value": field }, + "required": ["value"] + }); + let result = self.elicitation(message, schema).await?; + if result.action != "accept" { + return Ok(None); + } + let value = result + .content + .and_then(|c| c.get("value").and_then(|v| v.as_str()).map(String::from)); + Ok(value) + } +} + +impl Client { + /// Create a new session on the CLI. + /// + /// Sends `session.create`, registers the session on the router, + /// and spawns an internal event loop that dispatches to the handler. + /// + /// All callbacks (event handler, hooks, transform) are configured + /// via [`SessionConfig`] using [`with_handler`](SessionConfig::with_handler), + /// [`with_hooks`](SessionConfig::with_hooks), and + /// [`with_transform`](SessionConfig::with_transform). + /// + /// If [`hooks_handler`](SessionConfig::hooks_handler) is set, the + /// wire-level `hooks` flag is automatically enabled. + /// + /// If [`transform`](SessionConfig::transform) is set, the SDK injects + /// `action: "transform"` sections into the [`SystemMessageConfig`] wire + /// format and handles `systemMessage.transform` RPC callbacks during + /// the session. + /// + /// If [`handler`](SessionConfig::handler) is `None`, the session uses + /// [`DenyAllHandler`](crate::handler::DenyAllHandler) — permission + /// requests are denied; other events are no-ops. + pub async fn create_session(&self, mut config: SessionConfig) -> Result { + let handler = config + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + let hooks = config.hooks_handler.take(); + let transforms = config.transform.take(); + let command_handlers = build_command_handler_map(config.commands.as_deref()); + let session_fs_provider = config.session_fs_provider.take(); + if self.inner.session_fs_configured && session_fs_provider.is_none() { + return Err(Error::Session(SessionError::SessionFsProviderRequired)); + } + + if hooks.is_some() && config.hooks.is_none() { + config.hooks = Some(true); + } + if let Some(ref transforms) = transforms { + inject_transform_sections(&mut config, transforms.as_ref()); + } + let mut params = serde_json::to_value(&config)?; + let trace_ctx = self.resolve_trace_context().await; + inject_trace_context(&mut params, &trace_ctx); + let result = self.call("session.create", Some(params)).await?; + let create_result: CreateSessionResult = serde_json::from_value(result)?; + + let session_id = create_result.session_id; + let capabilities = Arc::new(parking_lot::RwLock::new( + create_result.capabilities.unwrap_or_default(), + )); + let channels = self.register_session(&session_id); + + let idle_waiter = Arc::new(ParkingLotMutex::new(None)); + let shutdown = Arc::new(Notify::new()); + let (event_tx, _) = tokio::sync::broadcast::channel(512); + let event_loop = spawn_event_loop( + session_id.clone(), + self.clone(), + handler, + hooks, + transforms, + command_handlers, + session_fs_provider, + channels, + idle_waiter.clone(), + capabilities.clone(), + event_tx.clone(), + shutdown.clone(), + ); + + Ok(Session { + id: session_id, + cwd: self.cwd().clone(), + workspace_path: create_result.workspace_path, + remote_url: create_result.remote_url, + client: self.clone(), + event_loop: ParkingLotMutex::new(Some(event_loop)), + shutdown, + idle_waiter, + capabilities, + event_tx, + }) + } + + /// Resume an existing session on the CLI. + /// + /// Sends `session.resume` and `session.skills.reload`, registers the + /// session on the router, and spawns the event loop. + /// + /// All callbacks (event handler, hooks, transform) are configured + /// via [`ResumeSessionConfig`] using its `with_*` builder methods. + /// + /// See [`Self::create_session`] for the defaults applied when callback + /// fields are unset. + pub async fn resume_session(&self, mut config: ResumeSessionConfig) -> Result { + let handler = config + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + let hooks = config.hooks_handler.take(); + let transforms = config.transform.take(); + let command_handlers = build_command_handler_map(config.commands.as_deref()); + let session_fs_provider = config.session_fs_provider.take(); + if self.inner.session_fs_configured && session_fs_provider.is_none() { + return Err(Error::Session(SessionError::SessionFsProviderRequired)); + } + + if hooks.is_some() && config.hooks.is_none() { + config.hooks = Some(true); + } + if let Some(ref transforms) = transforms { + inject_transform_sections_resume(&mut config, transforms.as_ref()); + } + let session_id = config.session_id.clone(); + let mut params = serde_json::to_value(&config)?; + let trace_ctx = self.resolve_trace_context().await; + inject_trace_context(&mut params, &trace_ctx); + let result = self.call("session.resume", Some(params)).await?; + + // The CLI may reassign the session ID on resume. + let cli_session_id: SessionId = result + .get("sessionId") + .and_then(|v| v.as_str()) + .unwrap_or(&session_id) + .into(); + + let resume_capabilities: Option = result + .get("capabilities") + .and_then(|v| { + serde_json::from_value(v.clone()) + .map_err(|e| warn!(error = %e, "failed to deserialize capabilities from resume response")) + .ok() + }); + let remote_url = result + .get("remoteUrl") + .or_else(|| result.get("remote_url")) + .and_then(|value| value.as_str()) + .map(ToString::to_string); + + // Reload skills after resume (best-effort). + if let Err(e) = self + .call( + "session.skills.reload", + Some(serde_json::json!({ "sessionId": cli_session_id })), + ) + .await + { + warn!(error = %e, "failed to reload skills after resume"); + } + + let capabilities = Arc::new(parking_lot::RwLock::new( + resume_capabilities.unwrap_or_default(), + )); + let channels = self.register_session(&cli_session_id); + + let idle_waiter = Arc::new(ParkingLotMutex::new(None)); + let shutdown = Arc::new(Notify::new()); + let (event_tx, _) = tokio::sync::broadcast::channel(512); + let event_loop = spawn_event_loop( + cli_session_id.clone(), + self.clone(), + handler, + hooks, + transforms, + command_handlers, + session_fs_provider, + channels, + idle_waiter.clone(), + capabilities.clone(), + event_tx.clone(), + shutdown.clone(), + ); + + Ok(Session { + id: cli_session_id, + cwd: self.cwd().clone(), + workspace_path: None, + remote_url, + client: self.clone(), + event_loop: ParkingLotMutex::new(Some(event_loop)), + shutdown, + idle_waiter, + capabilities, + event_tx, + }) + } +} + +type CommandHandlerMap = HashMap>; + +fn build_command_handler_map(commands: Option<&[CommandDefinition]>) -> Arc { + let map = match commands { + Some(commands) => commands + .iter() + .filter(|cmd| !cmd.name.is_empty()) + .map(|cmd| (cmd.name.clone(), cmd.handler.clone())) + .collect(), + None => HashMap::new(), + }; + Arc::new(map) +} + +#[allow(clippy::too_many_arguments)] +fn spawn_event_loop( + session_id: SessionId, + client: Client, + handler: Arc, + hooks: Option>, + transforms: Option>, + command_handlers: Arc, + session_fs_provider: Option>, + channels: crate::router::SessionChannels, + idle_waiter: Arc>>, + capabilities: Arc>, + event_tx: tokio::sync::broadcast::Sender, + shutdown: Arc, +) -> JoinHandle<()> { + let crate::router::SessionChannels { + mut notifications, + mut requests, + } = channels; + + let span = tracing::error_span!("session_event_loop", session_id = %session_id); + tokio::spawn( + async move { + loop { + // `mpsc::UnboundedReceiver::recv` and `Notify::notified()` + // are both cancel-safe per RFD 400. The selected branch's + // `await`'d handler is *not* mid-cancelled by the select + // — once a branch fires it runs to completion within the + // loop's iteration. Spawned child tasks inside + // `handle_notification` (permission/tool/elicitation + // callbacks) intentionally outlive the parent loop and + // own their own cleanup; this is RFD 400's "spawn + // background tasks to perform cancel-unsafe operations" + // pattern and is correct as-is. + tokio::select! { + _ = shutdown.notified() => break, + Some(notification) = notifications.recv() => { + handle_notification( + &session_id, &client, &handler, &command_handlers, notification, &idle_waiter, &capabilities, &event_tx, + ).await; + } + Some(request) = requests.recv() => { + handle_request( + &session_id, &client, &handler, hooks.as_deref(), transforms.as_deref(), session_fs_provider.as_ref(), request, + ).await; + } + else => break, + } + } + // Channels closed or shutdown signaled — fail any pending + // send_and_wait so the caller observes a clean error. + if let Some(waiter) = idle_waiter.lock().take() { + let _ = waiter + .tx + .send(Err(Error::Session(SessionError::EventLoopClosed))); + } + } + .instrument(span), + ) +} + +fn extract_request_id(data: &Value) -> Option { + data.get("requestId") + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + .map(RequestId::new) +} + +fn pending_permission_result_kind(response: &HandlerResponse) -> &'static str { + match response { + HandlerResponse::Permission(PermissionResult::Approved) => "approve-once", + HandlerResponse::Permission(PermissionResult::Denied) => "reject", + HandlerResponse::Permission(PermissionResult::NoResult) => "no-result", + // Fallback to "user-not-available" for UserNotAvailable, Deferred (when + // forced through this path), Custom (handled separately upstream), and + // any non-permission HandlerResponse that gets here defensively. + _ => "user-not-available", + } +} + +fn permission_request_response(response: &HandlerResponse) -> PermissionDecision { + match response { + HandlerResponse::Permission(PermissionResult::Approved) => { + PermissionDecision::ApproveOnce(PermissionDecisionApproveOnce { + kind: PermissionDecisionApproveOnceKind::ApproveOnce, + }) + } + _ => PermissionDecision::Reject(PermissionDecisionReject { + kind: PermissionDecisionRejectKind::Reject, + feedback: None, + }), + } +} + +/// Map a handler response into the `result` payload for the notification +/// path (`session.permissions.handlePendingPermissionRequest`). +/// +/// Returns `None` when the SDK must not respond — currently only the +/// [`PermissionResult::Deferred`] case, where the handler takes over +/// responsibility for the round-trip itself. +fn notification_permission_payload(response: &HandlerResponse) -> Option { + match response { + HandlerResponse::Permission(PermissionResult::Deferred) => None, + HandlerResponse::Permission(PermissionResult::Custom(value)) => Some(value.clone()), + _ => Some(serde_json::json!({ + "kind": pending_permission_result_kind(response), + })), + } +} + +/// Map a handler response into the JSON-RPC `result` payload for the +/// direct-RPC path (`permission.request`). +/// +/// Always returns a value. [`PermissionResult::Deferred`] is treated as +/// [`PermissionResult::Approved`] here because the JSON-RPC contract +/// requires a reply — see the variant's doc comment. +fn direct_permission_payload(response: &HandlerResponse) -> Value { + match response { + HandlerResponse::Permission(PermissionResult::Custom(value)) => value.clone(), + HandlerResponse::Permission(PermissionResult::Deferred) => serde_json::to_value( + permission_request_response(&HandlerResponse::Permission(PermissionResult::Approved)), + ) + .expect("serializing direct permission response should succeed"), + HandlerResponse::Permission(PermissionResult::NoResult) + | HandlerResponse::Permission(PermissionResult::UserNotAvailable) => serde_json::json!({ + "kind": pending_permission_result_kind(response), + }), + _ => serde_json::to_value(permission_request_response(response)) + .expect("serializing direct permission response should succeed"), + } +} + +/// Process a notification from the CLI's broadcast channel. +#[allow(clippy::too_many_arguments)] +async fn handle_notification( + session_id: &SessionId, + client: &Client, + handler: &Arc, + command_handlers: &Arc, + notification: SessionEventNotification, + idle_waiter: &Arc>>, + capabilities: &Arc>, + event_tx: &tokio::sync::broadcast::Sender, +) { + let event = notification.event.clone(); + let event_type = event.parsed_type(); + + // Signal send_and_wait if active. The lock is only contended when + // a send_and_wait call is in flight (idle_waiter is Some). + match event_type { + SessionEventType::AssistantMessage + | SessionEventType::SessionIdle + | SessionEventType::SessionError => { + let mut guard = idle_waiter.lock(); + if let Some(waiter) = guard.as_mut() { + match event_type { + SessionEventType::AssistantMessage => { + waiter.last_assistant_message = Some(event.clone()); + } + SessionEventType::SessionIdle | SessionEventType::SessionError => { + if let Some(waiter) = guard.take() { + if event_type == SessionEventType::SessionIdle { + let _ = waiter.tx.send(Ok(waiter.last_assistant_message)); + } else { + let error_msg = event + .typed_data::() + .map(|d| d.message) + .or_else(|| { + event + .data + .get("message") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) + }) + .unwrap_or_else(|| "session error".to_string()); + let _ = waiter + .tx + .send(Err(Error::Session(SessionError::AgentError(error_msg)))); + } + } + } + _ => {} + } + } + } + _ => {} + } + + // Fan out the event to runtime subscribers (`Session::subscribe`). `send` + // only errors when there are no receivers, which is the normal case + // before any consumer subscribes. + let _ = event_tx.send(event.clone()); + + // Fire-and-forget dispatch for the general event. + handler + .on_event(HandlerEvent::SessionEvent { + session_id: session_id.clone(), + event, + }) + .await; + + // Update capabilities when the CLI reports changes. The CLI sends + // the full updated capabilities object — replace wholesale so removals + // and new subfields are handled correctly. + if event_type == SessionEventType::CapabilitiesChanged { + match serde_json::from_value::(notification.event.data.clone()) { + Ok(changed) => *capabilities.write() = changed, + Err(e) => warn!(error = %e, "failed to deserialize capabilities.changed payload"), + } + } + + // Notification-based permission/tool/elicitation requests require a + // separate RPC callback. Spawn concurrently since the CLI doesn't block. + match event_type { + SessionEventType::PermissionRequested => { + let Some(request_id) = extract_request_id(¬ification.event.data) else { + return; + }; + let client = client.clone(); + let handler = handler.clone(); + let sid = session_id.clone(); + let data: PermissionRequestData = + serde_json::from_value(notification.event.data.clone()).unwrap_or_else(|_| { + PermissionRequestData { + kind: None, + tool_call_id: None, + extra: notification.event.data.clone(), + } + }); + tokio::spawn(async move { + let response = handler + .on_event(HandlerEvent::PermissionRequest { + session_id: sid.clone(), + request_id: request_id.clone(), + data, + }) + .await; + let Some(result_value) = notification_permission_payload(&response) else { + // Handler returned Deferred — it will call + // handlePendingPermissionRequest itself. + return; + }; + let _ = client + .call( + "session.permissions.handlePendingPermissionRequest", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": result_value, + })), + ) + .await; + }); + } + SessionEventType::ExternalToolRequested => { + let Some(request_id) = extract_request_id(¬ification.event.data) else { + return; + }; + let data: ExternalToolRequestedData = + match serde_json::from_value(notification.event.data.clone()) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize external_tool.requested"); + let client = client.clone(); + let sid = session_id.clone(); + tokio::spawn(async move { + let _ = client + .call( + "session.tools.handlePendingToolCall", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "error": format!("Failed to deserialize tool request: {e}"), + })), + ) + .await; + }); + return; + } + }; + let client = client.clone(); + let handler = handler.clone(); + let sid = session_id.clone(); + tokio::spawn(async move { + if data.tool_call_id.is_empty() || data.tool_name.is_empty() { + let error_msg = if data.tool_call_id.is_empty() { + "Missing toolCallId" + } else { + "Missing toolName" + }; + let _ = client + .call( + "session.tools.handlePendingToolCall", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "error": error_msg, + })), + ) + .await; + return; + } + let invocation = ToolInvocation { + session_id: sid.clone(), + tool_call_id: data.tool_call_id, + tool_name: data.tool_name, + arguments: data + .arguments + .unwrap_or(Value::Object(serde_json::Map::new())), + traceparent: data.traceparent, + tracestate: data.tracestate, + }; + let response = handler + .on_event(HandlerEvent::ExternalTool { invocation }) + .await; + let tool_result = match response { + HandlerResponse::ToolResult(r) => r, + _ => ToolResult::Text("Unexpected handler response".to_string()), + }; + let result_value = serde_json::to_value(&tool_result).unwrap_or(Value::Null); + let _ = client + .call( + "session.tools.handlePendingToolCall", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": result_value, + })), + ) + .await; + }); + } + SessionEventType::UserInputRequested => { + // Notification-only signal for observers (UI, telemetry). + // The CLI follows up with a `userInput.request` JSON-RPC call + // that drives `HandlerEvent::UserInput` dispatch — handling + // the notification here too would double-fire the handler + // and produce duplicate prompts on the consumer side. See + // github/github-app#4249. + } + SessionEventType::ElicitationRequested => { + let Some(request_id) = extract_request_id(¬ification.event.data) else { + return; + }; + let elicitation_data: ElicitationRequestedData = + match serde_json::from_value(notification.event.data.clone()) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize elicitation request"); + return; + } + }; + let request = ElicitationRequest { + message: elicitation_data.message, + requested_schema: elicitation_data + .requested_schema + .map(|s| serde_json::to_value(s).unwrap_or(Value::Null)), + mode: elicitation_data.mode.map(|m| match m { + crate::generated::session_events::ElicitationRequestedMode::Form => { + crate::types::ElicitationMode::Form + } + crate::generated::session_events::ElicitationRequestedMode::Url => { + crate::types::ElicitationMode::Url + } + _ => crate::types::ElicitationMode::Unknown, + }), + elicitation_source: elicitation_data.elicitation_source, + url: elicitation_data.url, + }; + let client = client.clone(); + let handler = handler.clone(); + let sid = session_id.clone(); + tokio::spawn(async move { + let cancel = ElicitationResult { + action: "cancel".to_string(), + content: None, + }; + // Dispatch to handler inside a nested task so panics are + // caught as JoinErrors (matches Node SDK's try/catch pattern). + let handler_task = tokio::spawn({ + let sid = sid.clone(); + let request_id = request_id.clone(); + async move { + handler + .on_event(HandlerEvent::ElicitationRequest { + session_id: sid, + request_id, + request, + }) + .await + } + }); + let result = match handler_task.await { + Ok(HandlerResponse::Elicitation(r)) => r, + _ => cancel.clone(), + }; + if let Err(e) = client + .call( + "session.ui.handlePendingElicitation", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": result, + })), + ) + .await + { + // RPC failed — attempt cancel as last resort + warn!(error = %e, "handlePendingElicitation failed, sending cancel"); + let _ = client + .call( + "session.ui.handlePendingElicitation", + Some(serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + "result": cancel, + })), + ) + .await; + } + }); + } + SessionEventType::CommandExecute => { + let data: CommandExecuteData = + match serde_json::from_value(notification.event.data.clone()) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize command.execute"); + return; + } + }; + let client = client.clone(); + let command_handlers = command_handlers.clone(); + let sid = session_id.clone(); + tokio::spawn(async move { + let request_id = data.request_id; + let ack_error = match command_handlers.get(&data.command_name).cloned() { + None => Some(format!("Unknown command: {}", data.command_name)), + Some(handler) => { + let ctx = CommandContext { + session_id: sid.clone(), + command: data.command, + command_name: data.command_name, + args: data.args, + }; + match handler.on_command(ctx).await { + Ok(()) => None, + Err(e) => Some(e.to_string()), + } + } + }; + let mut params = serde_json::json!({ + "sessionId": sid, + "requestId": request_id, + }); + if let Some(error_msg) = ack_error { + params["error"] = serde_json::Value::String(error_msg); + } + let _ = client + .call("session.commands.handlePendingCommand", Some(params)) + .await; + }); + } + _ => {} + } +} + +/// Process a JSON-RPC request from the CLI. +async fn handle_request( + session_id: &SessionId, + client: &Client, + handler: &Arc, + hooks: Option<&dyn SessionHooks>, + transforms: Option<&dyn SystemMessageTransform>, + session_fs_provider: Option<&Arc>, + request: crate::JsonRpcRequest, +) { + let sid = session_id.clone(); + + if request.method.starts_with("sessionFs.") { + crate::session_fs_dispatch::dispatch(client, session_fs_provider, request).await; + return; + } + + match request.method.as_str() { + "hooks.invoke" => { + let params = request.params.as_ref(); + let hook_type = params + .and_then(|p| p.get("hookType")) + .and_then(|v| v.as_str()) + .unwrap_or(""); + let input = params + .and_then(|p| p.get("input")) + .cloned() + .unwrap_or(Value::Object(Default::default())); + + let rpc_result = if let Some(hooks) = hooks { + match crate::hooks::dispatch_hook(hooks, &sid, hook_type, input).await { + Ok(output) => output, + Err(e) => { + warn!(error = %e, hook_type = hook_type, "hook dispatch failed"); + serde_json::json!({ "output": {} }) + } + } + } else { + serde_json::json!({ "output": {} }) + }; + + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(rpc_result), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "tool.call" => { + let invocation: ToolInvocation = match request + .params + .as_ref() + .and_then(|p| serde_json::from_value::(p.clone()).ok()) + { + Some(inv) => inv, + None => { + let _ = send_error_response( + client, + request.id, + error_codes::INVALID_PARAMS, + "invalid tool.call params", + ) + .await; + return; + } + }; + let response = handler + .on_event(HandlerEvent::ExternalTool { invocation }) + .await; + let tool_result = match response { + HandlerResponse::ToolResult(r) => r, + _ => ToolResult::Text("Unexpected handler response".to_string()), + }; + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(serde_json::json!(ToolResultResponse { + result: tool_result + })), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "userInput.request" => { + let params = request.params.as_ref(); + let Some(question) = params + .and_then(|p| p.get("question")) + .and_then(|v| v.as_str()) + else { + warn!("userInput.request missing 'question' field"); + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: None, + error: Some(crate::JsonRpcError { + code: error_codes::INVALID_PARAMS, + message: "missing required field: question".to_string(), + data: None, + }), + }; + let _ = client.send_response(&rpc_response).await; + return; + }; + let question = question.to_string(); + let choices = params + .and_then(|p| p.get("choices")) + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| v.as_str().map(|s| s.to_string())) + .collect() + }); + let allow_freeform = params + .and_then(|p| p.get("allowFreeform")) + .and_then(|v| v.as_bool()); + + let response = handler + .on_event(HandlerEvent::UserInput { + session_id: sid, + question, + choices, + allow_freeform, + }) + .await; + + let rpc_result = match response { + HandlerResponse::UserInput(Some(UserInputResponse { + answer, + was_freeform, + })) => serde_json::json!({ + "answer": answer, + "wasFreeform": was_freeform, + }), + _ => serde_json::json!({ "noResponse": true }), + }; + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(rpc_result), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "exitPlanMode.request" => { + let params = request + .params + .as_ref() + .cloned() + .unwrap_or(Value::Object(serde_json::Map::new())); + let data: ExitPlanModeData = match serde_json::from_value(params) { + Ok(d) => d, + Err(e) => { + warn!(error = %e, "failed to deserialize exitPlanMode.request params, using defaults"); + ExitPlanModeData::default() + } + }; + + let response = handler + .on_event(HandlerEvent::ExitPlanMode { + session_id: sid, + data, + }) + .await; + + let rpc_result = match response { + HandlerResponse::ExitPlanMode(ExitPlanModeResult { + approved, + selected_action, + feedback, + }) => serde_json::json!({ + "approved": approved, + "selectedAction": selected_action, + "feedback": feedback, + }), + _ => serde_json::json!({ "approved": true }), + }; + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(rpc_result), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "autoModeSwitch.request" => { + let error_code = request + .params + .as_ref() + .and_then(|p| p.get("errorCode")) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()); + let retry_after_seconds = request + .params + .as_ref() + .and_then(|p| p.get("retryAfterSeconds")) + .and_then(|v| v.as_u64()); + + let response = handler + .on_event(HandlerEvent::AutoModeSwitch { + session_id: sid, + error_code, + retry_after_seconds, + }) + .await; + + let answer = match response { + HandlerResponse::AutoModeSwitch(answer) => answer, + _ => AutoModeSwitchResponse::No, + }; + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(serde_json::json!({ "response": answer })), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "permission.request" => { + let Some(request_id) = request + .params + .as_ref() + .and_then(|p| p.get("requestId")) + .and_then(|v| v.as_str()) + .filter(|s| !s.is_empty()) + else { + warn!("permission.request missing 'requestId' field"); + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: None, + error: Some(crate::JsonRpcError { + code: error_codes::INVALID_PARAMS, + message: "missing required field: requestId".to_string(), + data: None, + }), + }; + let _ = client.send_response(&rpc_response).await; + return; + }; + let request_id = RequestId::new(request_id); + let raw_params = request + .params + .as_ref() + .cloned() + .unwrap_or(Value::Object(serde_json::Map::new())); + let data: PermissionRequestData = + serde_json::from_value(raw_params.clone()).unwrap_or(PermissionRequestData { + kind: None, + tool_call_id: None, + extra: raw_params, + }); + + let response = handler + .on_event(HandlerEvent::PermissionRequest { + session_id: sid, + request_id, + data, + }) + .await; + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(direct_permission_payload(&response)), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + "systemMessage.transform" => { + let params = request.params.as_ref(); + let sections: HashMap = + match params.and_then(|p| p.get("sections")) { + Some(v) => match serde_json::from_value(v.clone()) { + Ok(s) => s, + Err(e) => { + let _ = send_error_response( + client, + request.id, + error_codes::INVALID_PARAMS, + &format!("invalid sections: {e}"), + ) + .await; + return; + } + }, + None => { + let _ = send_error_response( + client, + request.id, + error_codes::INVALID_PARAMS, + "missing sections parameter", + ) + .await; + return; + } + }; + + let rpc_result = if let Some(transforms) = transforms { + let response = + crate::transforms::dispatch_transform(transforms, &sid, sections).await; + match serde_json::to_value(response) { + Ok(v) => v, + Err(e) => { + warn!(error = %e, "failed to serialize transform response"); + serde_json::json!({ "sections": {} }) + } + } + } else { + // No transforms registered — pass through all sections unchanged. + let passthrough: HashMap = sections; + serde_json::json!({ "sections": passthrough }) + }; + + let rpc_response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request.id, + result: Some(rpc_result), + error: None, + }; + let _ = client.send_response(&rpc_response).await; + } + + method => { + warn!( + method = method, + "unhandled request method in session event loop" + ); + let _ = send_error_response( + client, + request.id, + error_codes::METHOD_NOT_FOUND, + &format!("unknown method: {method}"), + ) + .await; + } + } +} + +async fn send_error_response( + client: &Client, + id: u64, + code: i32, + message: &str, +) -> Result<(), Error> { + let response = JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id, + result: None, + error: Some(crate::JsonRpcError { + code, + message: message.to_string(), + data: None, + }), + }; + client.send_response(&response).await +} + +/// Inject `action: "transform"` sections into a `SystemMessageConfig`, +/// forcing `mode: "customize"` (required by the CLI for transforms to fire). +/// Preserves any existing caller-provided section overrides. +fn apply_transform_sections( + sys_msg: &mut SystemMessageConfig, + transforms: &dyn SystemMessageTransform, +) { + sys_msg.mode = Some("customize".to_string()); + let sections = sys_msg.sections.get_or_insert_with(HashMap::new); + for id in transforms.section_ids() { + sections.entry(id).or_insert_with(|| SectionOverride { + action: Some("transform".to_string()), + content: None, + }); + } +} + +fn inject_transform_sections(config: &mut SessionConfig, transforms: &dyn SystemMessageTransform) { + let sys_msg = config.system_message.get_or_insert_with(Default::default); + apply_transform_sections(sys_msg, transforms); +} + +fn inject_transform_sections_resume( + config: &mut ResumeSessionConfig, + transforms: &dyn SystemMessageTransform, +) { + let sys_msg = config.system_message.get_or_insert_with(Default::default); + apply_transform_sections(sys_msg, transforms); +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use super::{ + direct_permission_payload, notification_permission_payload, pending_permission_result_kind, + permission_request_response, + }; + use crate::handler::{HandlerResponse, PermissionResult}; + + #[test] + fn pending_permission_requests_use_decision_kinds() { + assert_eq!( + pending_permission_result_kind(&HandlerResponse::Permission( + PermissionResult::Approved, + )), + "approve-once" + ); + assert_eq!( + pending_permission_result_kind(&HandlerResponse::Permission(PermissionResult::Denied)), + "reject" + ); + assert_eq!( + pending_permission_result_kind(&HandlerResponse::Ok), + "user-not-available" + ); + } + + #[test] + fn direct_permission_requests_use_decision_response_kinds() { + assert_eq!( + serde_json::to_value(permission_request_response(&HandlerResponse::Permission( + PermissionResult::Approved + ),)) + .expect("serializing approved permission response should succeed"), + json!({ "kind": "approve-once" }) + ); + assert_eq!( + serde_json::to_value(permission_request_response(&HandlerResponse::Permission( + PermissionResult::Denied + ),)) + .expect("serializing denied permission response should succeed"), + json!({ "kind": "reject" }) + ); + assert_eq!( + serde_json::to_value(permission_request_response(&HandlerResponse::Ok)) + .expect("serializing fallback permission response should succeed"), + json!({ "kind": "reject" }) + ); + } + + #[test] + fn notification_payload_handles_deferred_and_custom() { + // Deferred → no payload, SDK must not respond. + assert!( + notification_permission_payload(&HandlerResponse::Permission( + PermissionResult::Deferred, + )) + .is_none() + ); + + // Custom → handler-supplied value passed through verbatim. + let custom = json!({ + "kind": "approve-and-remember", + "allowlist": ["ls", "grep"], + }); + assert_eq!( + notification_permission_payload(&HandlerResponse::Permission( + PermissionResult::Custom(custom.clone()), + )), + Some(custom) + ); + + // Approved/Denied → existing kind-only shape. + assert_eq!( + notification_permission_payload(&HandlerResponse::Permission( + PermissionResult::Approved, + )), + Some(json!({ "kind": "approve-once" })) + ); + assert_eq!( + notification_permission_payload( + &HandlerResponse::Permission(PermissionResult::Denied,) + ), + Some(json!({ "kind": "reject" })) + ); + } + + #[test] + fn direct_payload_handles_deferred_and_custom() { + // Custom → handler-supplied value passed through verbatim. + let custom = json!({ + "kind": "approve-and-remember", + "allowlist": ["ls", "grep"], + }); + assert_eq!( + direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Custom( + custom.clone(), + ))), + custom + ); + + // Deferred → falls back to Approved because the direct RPC must reply. + assert_eq!( + direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Deferred)), + json!({ "kind": "approve-once" }) + ); + + // Approved/Denied → existing kind-only shape. + assert_eq!( + direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Approved)), + json!({ "kind": "approve-once" }) + ); + assert_eq!( + direct_permission_payload(&HandlerResponse::Permission(PermissionResult::Denied)), + json!({ "kind": "reject" }) + ); + } +} diff --git a/rust/src/session_fs.rs b/rust/src/session_fs.rs new file mode 100644 index 000000000..e675760a1 --- /dev/null +++ b/rust/src/session_fs.rs @@ -0,0 +1,394 @@ +//! Session filesystem provider — virtualizable filesystem layer over JSON-RPC. +//! +//! When [`ClientOptions::session_fs`] is set, the SDK tells the CLI to delegate +//! all per-session filesystem operations (`readFile`, `writeFile`, `stat`, ...) +//! to a [`SessionFsProvider`] registered on each session. This lets host +//! applications sandbox sessions, project files into in-memory or remote +//! storage, and apply permission policies before bytes move. +//! +//! # Concurrency +//! +//! Each inbound `sessionFs.*` request is dispatched on its own spawned task, +//! matching Node's behavior. Provider implementations MUST be safe for +//! concurrent invocation across distinct paths. Use internal synchronization +//! (e.g. [`tokio::sync::Mutex`] keyed by path) if your backing store needs +//! ordering. +//! +//! # Errors +//! +//! Provider methods return [`Result`]. The SDK adapts these into +//! the schema's `{ ..., error: Option }` payload, mapping +//! [`FsError::NotFound`] to the wire's `ENOENT` and everything else to +//! `UNKNOWN`. A [`From`] conversion is provided so handlers +//! backed by [`tokio::fs`](https://docs.rs/tokio/latest/tokio/fs/index.html) +//! can propagate `io::Error` with `?`. +//! +//! # Example +//! +//! ```no_run +//! use std::sync::Arc; +//! use async_trait::async_trait; +//! use github_copilot_sdk::types::{SessionFsProvider, FsError, FileInfo, DirEntry}; +//! +//! struct MyProvider; +//! +//! #[async_trait] +//! impl SessionFsProvider for MyProvider { +//! async fn read_file(&self, path: &str) -> Result { +//! std::fs::read_to_string(path) +//! .map_err(FsError::from) +//! } +//! } +//! ``` + +use async_trait::async_trait; + +use crate::generated::api_types::{ + SessionFsError, SessionFsErrorCode, SessionFsReaddirWithTypesEntry, + SessionFsReaddirWithTypesEntryType, SessionFsSetProviderConventions, SessionFsStatResult, +}; + +/// Configuration for a custom session filesystem provider. +/// +/// When set on [`ClientOptions::session_fs`](crate::ClientOptions::session_fs), +/// the SDK calls `sessionFs.setProvider` during [`Client::start`](crate::Client::start) +/// to tell the CLI to route per-session filesystem operations to the SDK. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct SessionFsConfig { + /// Initial working directory for sessions (the user's project directory). + pub initial_cwd: String, + /// Path within each session's SessionFs where the runtime stores + /// session-scoped files (events, workspace, checkpoints, etc.). + pub session_state_path: String, + /// Path conventions used by this filesystem provider. + pub conventions: SessionFsConventions, +} + +impl SessionFsConfig { + /// Build a new config with the required fields. + pub fn new( + initial_cwd: impl Into, + session_state_path: impl Into, + conventions: SessionFsConventions, + ) -> Self { + Self { + initial_cwd: initial_cwd.into(), + session_state_path: session_state_path.into(), + conventions, + } + } +} + +/// Path conventions used by a session filesystem provider. +/// +/// Hand-authored consumer-facing enum (rather than reusing +/// [`SessionFsSetProviderConventions`]) to avoid exposing the generated +/// catch-all `Unknown` variant on the input side. The SDK rejects unknown +/// conventions at validation time with a typed error. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SessionFsConventions { + /// POSIX-style paths (`/foo/bar`). + Posix, + /// Windows-style paths (`C:\foo\bar`). + Windows, +} + +impl SessionFsConventions { + pub(crate) fn into_wire(self) -> SessionFsSetProviderConventions { + match self { + Self::Posix => SessionFsSetProviderConventions::Posix, + Self::Windows => SessionFsSetProviderConventions::Windows, + } + } +} + +/// Error returned by a [`SessionFsProvider`] method. +/// +/// The SDK maps this onto the wire schema's [`SessionFsError`]: +/// [`FsError::NotFound`] becomes `ENOENT`, everything else becomes `UNKNOWN`. +#[non_exhaustive] +#[derive(Debug, Clone, thiserror::Error)] +pub enum FsError { + /// File or directory does not exist. + #[error("not found: {0}")] + NotFound(String), + + /// Any other filesystem error (permission denied, I/O error, etc.). + /// + /// The wire mapping always uses `UNKNOWN` as the code; the message is + /// preserved for diagnostics. + #[error("{0}")] + Other(String), +} + +impl FsError { + pub(crate) fn into_wire(self) -> SessionFsError { + match self { + Self::NotFound(message) => SessionFsError { + code: SessionFsErrorCode::ENOENT, + message: Some(message), + }, + Self::Other(message) => SessionFsError { + code: SessionFsErrorCode::UNKNOWN, + message: Some(message), + }, + } + } +} + +impl From for FsError { + fn from(err: std::io::Error) -> Self { + match err.kind() { + std::io::ErrorKind::NotFound => Self::NotFound(err.to_string()), + _ => Self::Other(err.to_string()), + } + } +} + +/// File or directory metadata returned by [`SessionFsProvider::stat`]. +/// +/// The SDK adapts this into the wire's [`SessionFsStatResult`]. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct FileInfo { + /// Whether the path is a regular file. + pub is_file: bool, + /// Whether the path is a directory. + pub is_directory: bool, + /// File size in bytes. + pub size: i64, + /// ISO 8601 timestamp of last modification. + pub mtime: String, + /// ISO 8601 timestamp of creation. + pub birthtime: String, +} + +impl FileInfo { + /// Build a metadata record. The mtime/birthtime arguments are caller- + /// supplied ISO 8601 strings — the SDK does not format timestamps for + /// you. + pub fn new( + is_file: bool, + is_directory: bool, + size: i64, + mtime: impl Into, + birthtime: impl Into, + ) -> Self { + Self { + is_file, + is_directory, + size, + mtime: mtime.into(), + birthtime: birthtime.into(), + } + } + + pub(crate) fn into_wire(self) -> SessionFsStatResult { + SessionFsStatResult { + is_file: self.is_file, + is_directory: self.is_directory, + size: self.size, + mtime: self.mtime, + birthtime: self.birthtime, + error: None, + } + } +} + +/// Kind of entry returned by [`SessionFsProvider::readdir_with_types`]. +/// +/// The wire schema's `Unknown` forward-compat variant is intentionally absent +/// from this consumer-facing enum — providers must classify each entry as +/// either a file or a directory. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DirEntryKind { + /// Regular file. + File, + /// Directory. + Directory, +} + +impl DirEntryKind { + fn into_wire(self) -> SessionFsReaddirWithTypesEntryType { + match self { + Self::File => SessionFsReaddirWithTypesEntryType::File, + Self::Directory => SessionFsReaddirWithTypesEntryType::Directory, + } + } +} + +/// Single entry in a directory listing returned by +/// [`SessionFsProvider::readdir_with_types`]. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct DirEntry { + /// Entry name (basename, not full path). + pub name: String, + /// Whether the entry is a file or a directory. + pub kind: DirEntryKind, +} + +impl DirEntry { + /// Build a new directory entry. + pub fn new(name: impl Into, kind: DirEntryKind) -> Self { + Self { + name: name.into(), + kind, + } + } + + pub(crate) fn into_wire(self) -> SessionFsReaddirWithTypesEntry { + SessionFsReaddirWithTypesEntry { + name: self.name, + r#type: self.kind.into_wire(), + } + } +} + +/// Implementor-supplied filesystem backing for a session. +/// +/// Each method takes a path using the conventions declared in +/// [`SessionFsConfig::conventions`] and returns the operation's result. The +/// SDK adapts every `Result<_, FsError>` into the JSON-RPC response shape +/// expected by the GitHub Copilot CLI. +/// +/// # Concurrency +/// +/// Implementations MUST be `Send + Sync` and safe for concurrent invocation +/// across distinct paths. The SDK dispatches each inbound `sessionFs.*` +/// request on its own spawned task. Use internal synchronization (e.g. +/// [`tokio::sync::Mutex`] keyed by path) if your backing store requires +/// ordering. +/// +/// # Forward compatibility +/// +/// Methods on this trait have default implementations that return +/// `Err(FsError::Other("operation not supported".into()))`. When the CLI +/// schema grows new `sessionFs.*` methods, the SDK adds them to this trait +/// with default impls so existing implementations continue to compile. +/// Override only the methods relevant to your backing store. +#[async_trait] +pub trait SessionFsProvider: Send + Sync + 'static { + /// Read the full contents of a file as UTF-8. + async fn read_file(&self, path: &str) -> Result { + let _ = path; + Err(FsError::Other("read_file not supported".to_string())) + } + + /// Write content to a file, creating parent directories if needed. + async fn write_file( + &self, + path: &str, + content: &str, + mode: Option, + ) -> Result<(), FsError> { + let _ = (path, content, mode); + Err(FsError::Other("write_file not supported".to_string())) + } + + /// Append content to a file, creating parent directories if needed. + async fn append_file( + &self, + path: &str, + content: &str, + mode: Option, + ) -> Result<(), FsError> { + let _ = (path, content, mode); + Err(FsError::Other("append_file not supported".to_string())) + } + + /// Check whether a path exists. + /// + /// Returns `Ok(false)` for non-existent paths, not [`FsError::NotFound`]. + async fn exists(&self, path: &str) -> Result { + let _ = path; + Err(FsError::Other("exists not supported".to_string())) + } + + /// Get metadata about a file or directory. + async fn stat(&self, path: &str) -> Result { + let _ = path; + Err(FsError::Other("stat not supported".to_string())) + } + + /// Create a directory. When `recursive`, missing parents are also created. + async fn mkdir(&self, path: &str, recursive: bool, mode: Option) -> Result<(), FsError> { + let _ = (path, recursive, mode); + Err(FsError::Other("mkdir not supported".to_string())) + } + + /// List entry names in a directory. + async fn readdir(&self, path: &str) -> Result, FsError> { + let _ = path; + Err(FsError::Other("readdir not supported".to_string())) + } + + /// List directory entries with type information. + async fn readdir_with_types(&self, path: &str) -> Result, FsError> { + let _ = path; + Err(FsError::Other( + "readdir_with_types not supported".to_string(), + )) + } + + /// Remove a file or directory. When `force`, missing paths are not an + /// error. When `recursive`, directory contents are removed as well. + async fn rm(&self, path: &str, recursive: bool, force: bool) -> Result<(), FsError> { + let _ = (path, recursive, force); + Err(FsError::Other("rm not supported".to_string())) + } + + /// Rename or move a file or directory. + async fn rename(&self, src: &str, dest: &str) -> Result<(), FsError> { + let _ = (src, dest); + Err(FsError::Other("rename not supported".to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn fs_error_maps_io_not_found_to_enoent() { + let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "missing.txt"); + let fs_err: FsError = io_err.into(); + assert!(matches!(fs_err, FsError::NotFound(_))); + let wire = fs_err.into_wire(); + assert_eq!(wire.code, SessionFsErrorCode::ENOENT); + } + + #[test] + fn fs_error_maps_other_io_to_unknown() { + let io_err = std::io::Error::other("disk full"); + let fs_err: FsError = io_err.into(); + assert!(matches!(fs_err, FsError::Other(_))); + let wire = fs_err.into_wire(); + assert_eq!(wire.code, SessionFsErrorCode::UNKNOWN); + assert!(wire.message.unwrap().contains("disk full")); + } + + #[test] + fn conventions_maps_to_wire() { + assert_eq!( + SessionFsConventions::Posix.into_wire(), + SessionFsSetProviderConventions::Posix + ); + assert_eq!( + SessionFsConventions::Windows.into_wire(), + SessionFsSetProviderConventions::Windows + ); + } + + struct DefaultProvider; + #[async_trait] + impl SessionFsProvider for DefaultProvider {} + + #[tokio::test] + async fn default_impls_return_unsupported() { + let p = DefaultProvider; + let err = p.read_file("/x").await.unwrap_err(); + assert!(matches!(err, FsError::Other(ref m) if m.contains("not supported"))); + } +} diff --git a/rust/src/session_fs_dispatch.rs b/rust/src/session_fs_dispatch.rs new file mode 100644 index 000000000..7b2ae49fd --- /dev/null +++ b/rust/src/session_fs_dispatch.rs @@ -0,0 +1,351 @@ +//! Inbound `sessionFs.*` JSON-RPC request dispatch helpers. +//! +//! Internal — public-facing trait lives in `crate::session_fs`. Each helper +//! deserializes the typed request, calls the [`SessionFsProvider`] method, +//! and serializes the schema response with `FsError` mapped onto the wire's +//! `SessionFsError` variant. + +use std::sync::Arc; + +use serde::Serialize; +use serde_json::Value; +use tracing::warn; + +use crate::generated::api_types::{ + SessionFsAppendFileRequest, SessionFsExistsRequest, SessionFsExistsResult, + SessionFsMkdirRequest, SessionFsReadFileRequest, SessionFsReadFileResult, + SessionFsReaddirRequest, SessionFsReaddirResult, SessionFsReaddirWithTypesRequest, + SessionFsReaddirWithTypesResult, SessionFsRenameRequest, SessionFsRmRequest, + SessionFsStatRequest, SessionFsStatResult, SessionFsWriteFileRequest, +}; +use crate::session_fs::{FsError, SessionFsProvider}; +use crate::{Client, JsonRpcRequest, JsonRpcResponse, error_codes}; + +/// Helper: serialize a typed result, send the response. +async fn respond(client: &Client, request_id: u64, result: T) { + let value = match serde_json::to_value(&result) { + Ok(v) => v, + Err(e) => { + warn!(error = %e, "failed to serialize sessionFs response"); + send_error(client, request_id, "serialization failure").await; + return; + } + }; + let _ = client + .send_response(&JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request_id, + result: Some(value), + error: None, + }) + .await; +} + +async fn send_error(client: &Client, request_id: u64, message: &str) { + let _ = client + .send_response(&JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: request_id, + result: None, + error: Some(crate::JsonRpcError { + code: error_codes::INTERNAL_ERROR, + message: message.to_string(), + data: None, + }), + }) + .await; +} + +fn parse_params(request: &JsonRpcRequest) -> Option { + request + .params + .as_ref() + .and_then(|p| serde_json::from_value(p.clone()).ok()) +} + +pub(crate) async fn read_file( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsReadFileRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.readFile params").await; + return; + } + }; + let id = request.id; + let result = match provider.read_file(¶ms.path).await { + Ok(content) => SessionFsReadFileResult { + content, + error: None, + }, + Err(e) => SessionFsReadFileResult { + content: String::new(), + error: Some(e.into_wire()), + }, + }; + respond(client, id, result).await; +} + +pub(crate) async fn write_file( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsWriteFileRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.writeFile params").await; + return; + } + }; + let id = request.id; + match provider + .write_file(¶ms.path, ¶ms.content, params.mode) + .await + { + Ok(()) => respond(client, id, Value::Null).await, + Err(e) => respond(client, id, e.into_wire()).await, + } +} + +pub(crate) async fn append_file( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsAppendFileRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.appendFile params").await; + return; + } + }; + let id = request.id; + match provider + .append_file(¶ms.path, ¶ms.content, params.mode) + .await + { + Ok(()) => respond(client, id, Value::Null).await, + Err(e) => respond(client, id, e.into_wire()).await, + } +} + +pub(crate) async fn exists( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsExistsRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.exists params").await; + return; + } + }; + let id = request.id; + // Match Node's `createSessionFsAdapter`: errors collapse to `exists: false`. + let exists_value = provider.exists(¶ms.path).await.unwrap_or(false); + respond( + client, + id, + SessionFsExistsResult { + exists: exists_value, + }, + ) + .await; +} + +pub(crate) async fn stat( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsStatRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.stat params").await; + return; + } + }; + let id = request.id; + let result = match provider.stat(¶ms.path).await { + Ok(info) => info.into_wire(), + Err(e) => SessionFsStatResult { + is_file: false, + is_directory: false, + size: 0, + mtime: String::new(), + birthtime: String::new(), + error: Some(e.into_wire()), + }, + }; + respond(client, id, result).await; +} + +pub(crate) async fn mkdir( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsMkdirRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.mkdir params").await; + return; + } + }; + let id = request.id; + let recursive = params.recursive.unwrap_or(false); + match provider.mkdir(¶ms.path, recursive, params.mode).await { + Ok(()) => respond(client, id, Value::Null).await, + Err(e) => respond(client, id, e.into_wire()).await, + } +} + +pub(crate) async fn readdir( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsReaddirRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.readdir params").await; + return; + } + }; + let id = request.id; + let result = match provider.readdir(¶ms.path).await { + Ok(entries) => SessionFsReaddirResult { + entries, + error: None, + }, + Err(e) => SessionFsReaddirResult { + entries: Vec::new(), + error: Some(e.into_wire()), + }, + }; + respond(client, id, result).await; +} + +pub(crate) async fn readdir_with_types( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsReaddirWithTypesRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error( + client, + request.id, + "invalid sessionFs.readdirWithTypes params", + ) + .await; + return; + } + }; + let id = request.id; + let result = match provider.readdir_with_types(¶ms.path).await { + Ok(entries) => SessionFsReaddirWithTypesResult { + entries: entries.into_iter().map(|e| e.into_wire()).collect(), + error: None, + }, + Err(e) => SessionFsReaddirWithTypesResult { + entries: Vec::new(), + error: Some(e.into_wire()), + }, + }; + respond(client, id, result).await; +} + +pub(crate) async fn rm( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsRmRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.rm params").await; + return; + } + }; + let id = request.id; + let recursive = params.recursive.unwrap_or(false); + let force = params.force.unwrap_or(false); + match provider.rm(¶ms.path, recursive, force).await { + Ok(()) => respond(client, id, Value::Null).await, + Err(e) => respond(client, id, e.into_wire()).await, + } +} + +pub(crate) async fn rename( + client: &Client, + provider: &Arc, + request: JsonRpcRequest, +) { + let params: SessionFsRenameRequest = match parse_params(&request) { + Some(p) => p, + None => { + send_error(client, request.id, "invalid sessionFs.rename params").await; + return; + } + }; + let id = request.id; + match provider.rename(¶ms.src, ¶ms.dest).await { + Ok(()) => respond(client, id, Value::Null).await, + Err(e) => respond(client, id, e.into_wire()).await, + } +} + +/// Dispatch a `sessionFs.*` request to the appropriate handler. Returns +/// `true` if the request was a session-fs method (whether or not a provider +/// was registered), `false` otherwise (caller should continue matching). +pub(crate) async fn dispatch( + client: &Client, + provider: Option<&Arc>, + request: JsonRpcRequest, +) -> bool { + let method = request.method.as_str(); + if !method.starts_with("sessionFs.") { + return false; + } + let provider = match provider { + Some(p) => p.clone(), + None => { + warn!(method = %method, "sessionFs request without registered provider"); + send_error( + client, + request.id, + "no sessionFs provider registered for this session", + ) + .await; + return true; + } + }; + match method { + "sessionFs.readFile" => read_file(client, &provider, request).await, + "sessionFs.writeFile" => write_file(client, &provider, request).await, + "sessionFs.appendFile" => append_file(client, &provider, request).await, + "sessionFs.exists" => exists(client, &provider, request).await, + "sessionFs.stat" => stat(client, &provider, request).await, + "sessionFs.mkdir" => mkdir(client, &provider, request).await, + "sessionFs.readdir" => readdir(client, &provider, request).await, + "sessionFs.readdirWithTypes" => readdir_with_types(client, &provider, request).await, + "sessionFs.rm" => rm(client, &provider, request).await, + "sessionFs.rename" => rename(client, &provider, request).await, + _ => { + warn!(method = %method, "unknown sessionFs.* method"); + send_error(client, request.id, "unknown sessionFs method").await; + } + } + true +} + +// FsError is used through `into_wire()` calls above. +#[allow(dead_code)] +fn _ensure_fs_error_used(_e: FsError) {} diff --git a/rust/src/subscription.rs b/rust/src/subscription.rs new file mode 100644 index 000000000..ef5f95381 --- /dev/null +++ b/rust/src/subscription.rs @@ -0,0 +1,217 @@ +//! Subscription handles for observing session and lifecycle events. +//! +//! Returned by [`Session::subscribe`](crate::session::Session::subscribe) and +//! [`Client::subscribe_lifecycle`](crate::Client::subscribe_lifecycle). +//! +//! Each subscription is an opt-in **observer** of events that are also +//! delivered to the [`SessionHandler`](crate::handler::SessionHandler). +//! Subscribers receive a clone of every event but cannot influence +//! permission decisions, tool results, or anything else that requires +//! returning a [`HandlerResponse`](crate::handler::HandlerResponse). +//! +//! # Async iteration +//! +//! The subscription types implement [`tokio_stream::Stream`], so consumers +//! can use adapter combinators from [`tokio_stream::StreamExt`] or +//! `futures::StreamExt` (filtering, mapping, batching, racing with +//! `tokio::select!`, etc.) without learning the SDK's internal channel +//! choice. A simple `while let Ok(event) = sub.recv().await { ... }` loop +//! also works for callers who don't need the [`Stream`](tokio_stream::Stream) +//! surface. +//! +//! # Lag policy +//! +//! Each subscriber maintains its own internal queue. If a consumer cannot +//! keep up, the oldest events are dropped and the next call yields +//! [`Lagged`] reporting how many events were skipped. Slow subscribers do +//! not block the producer. + +use std::pin::Pin; +use std::task::{Context, Poll}; + +use tokio::sync::broadcast::Receiver; +use tokio_stream::wrappers::BroadcastStream; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; +use tokio_stream::{Stream, StreamExt as _}; + +use crate::types::{SessionEvent, SessionLifecycleEvent}; + +/// The subscription fell behind the producer. +/// +/// Reports the number of events that were dropped from this subscriber's +/// queue because the consumer didn't keep up. The subscription continues +/// after this error, starting from the next live event — callers who care +/// about lag should match on it and decide whether to resync, re-fetch, or +/// log and continue. +#[derive(Debug, Clone, Copy, PartialEq, Eq, thiserror::Error)] +#[error("subscription lagged behind by {0} events")] +pub struct Lagged(u64); + +impl Lagged { + /// Number of events skipped before this consumer could read them. + pub fn skipped(&self) -> u64 { + self.0 + } +} + +/// Error returned by [`EventSubscription::recv`] and +/// [`LifecycleSubscription::recv`]. +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +pub enum RecvError { + /// The producer is gone — the session has shut down or the client has + /// stopped. No further events will be delivered. + #[error("subscription closed")] + Closed, + + /// The subscriber fell behind. See [`Lagged`]. + #[error(transparent)] + Lagged(#[from] Lagged), +} + +macro_rules! define_subscription { + ( + $(#[$meta:meta])* + $name:ident, $item:ty $(,)? + ) => { + $(#[$meta])* + #[must_use = "subscriptions are inert until polled"] + pub struct $name { + inner: BroadcastStream<$item>, + } + + impl $name { + pub(crate) fn new(rx: Receiver<$item>) -> Self { + Self { + inner: BroadcastStream::new(rx), + } + } + + /// Receive the next event. + /// + /// Returns: + /// + /// - `Ok(event)` for the next delivered event. + /// - `Err(`[`RecvError::Lagged`]`)` if the subscriber fell behind; + /// call `recv` again to continue from the next live event. + /// - `Err(`[`RecvError::Closed`]`)` once the producer is gone. + /// + /// # Cancel safety + /// + /// **Cancel-safe.** Wraps a `tokio::sync::broadcast::Receiver` + /// via `BroadcastStream`; both are cancel-safe by design. + /// Dropping the future before completion is harmless — events + /// already buffered for this subscriber remain available on + /// the next `recv` call. + pub async fn recv(&mut self) -> Result<$item, RecvError> { + match self.inner.next().await { + Some(Ok(event)) => Ok(event), + Some(Err(BroadcastStreamRecvError::Lagged(n))) => { + Err(RecvError::Lagged(Lagged(n))) + } + None => Err(RecvError::Closed), + } + } + } + + impl Stream for $name { + type Item = Result<$item, Lagged>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match Pin::new(&mut self.inner).poll_next(cx) { + Poll::Ready(Some(Ok(event))) => Poll::Ready(Some(Ok(event))), + Poll::Ready(Some(Err(BroadcastStreamRecvError::Lagged(n)))) => { + Poll::Ready(Some(Err(Lagged(n)))) + } + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, + } + } + } + }; +} + +define_subscription! { + /// Subscription to runtime events for a single + /// [`Session`](crate::session::Session). + /// + /// Created by [`Session::subscribe`](crate::session::Session::subscribe). + /// Implements [`Stream`] yielding `Result`. + /// Drop the value to unsubscribe; there is no separate cancel handle. + EventSubscription, SessionEvent +} + +define_subscription! { + /// Subscription to lifecycle events on a [`Client`](crate::Client). + /// + /// Created by + /// [`Client::subscribe_lifecycle`](crate::Client::subscribe_lifecycle). + /// Implements [`Stream`] yielding `Result`. + /// Drop the value to unsubscribe; there is no separate cancel handle. + LifecycleSubscription, SessionLifecycleEvent +} + +#[cfg(test)] +mod tests { + use tokio::sync::broadcast; + + use super::*; + + fn make_event(id: &str) -> SessionEvent { + SessionEvent { + id: id.into(), + timestamp: "2025-01-01T00:00:00Z".into(), + parent_id: None, + ephemeral: None, + debug_cli_received_at_ms: None, + debug_ws_forwarded_at_ms: None, + event_type: "noop".into(), + data: serde_json::json!({}), + } + } + + #[tokio::test] + async fn recv_yields_then_closes_on_drop_sender() { + let (tx, rx) = broadcast::channel(8); + let mut sub = EventSubscription::new(rx); + tx.send(make_event("a")).unwrap(); + tx.send(make_event("b")).unwrap(); + drop(tx); + + assert_eq!(sub.recv().await.unwrap().id, "a"); + assert_eq!(sub.recv().await.unwrap().id, "b"); + assert!(matches!(sub.recv().await, Err(RecvError::Closed))); + } + + #[tokio::test] + async fn recv_surfaces_lag() { + let (tx, rx) = broadcast::channel(2); + let mut sub = EventSubscription::new(rx); + for id in ["a", "b", "c", "d"] { + tx.send(make_event(id)).unwrap(); + } + match sub.recv().await { + Err(RecvError::Lagged(l)) => assert_eq!(l.skipped(), 2), + other => panic!("expected Lagged, got {other:?}"), + } + // Subscription continues with the live tail. + assert_eq!(sub.recv().await.unwrap().id, "c"); + assert_eq!(sub.recv().await.unwrap().id, "d"); + } + + #[tokio::test] + async fn stream_impl_matches_recv_semantics() { + let (tx, rx) = broadcast::channel(8); + let mut sub = EventSubscription::new(rx); + tx.send(make_event("a")).unwrap(); + drop(tx); + + // poll_next path + let next = sub.next().await; + assert_eq!(next.unwrap().unwrap().id, "a"); + assert!(sub.next().await.is_none()); + } +} diff --git a/rust/src/tool.rs b/rust/src/tool.rs new file mode 100644 index 000000000..cccdad486 --- /dev/null +++ b/rust/src/tool.rs @@ -0,0 +1,828 @@ +//! Typed tool definition framework. +//! +//! Provides the [`ToolHandler`](crate::tool::ToolHandler) trait for implementing tools as named types, +//! and [`ToolHandlerRouter`](crate::tool::ToolHandlerRouter) for automatic dispatch of tool calls within a +//! [`SessionHandler`](crate::handler::SessionHandler). +//! +//! Enable the `derive` feature for `schema_for`, which generates JSON +//! Schema from Rust types via `schemars`. + +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +/// Re-export of [`schemars::JsonSchema`] for deriving tool parameter schemas. +#[cfg(feature = "derive")] +pub use schemars::JsonSchema; + +use crate::Error; +use crate::handler::{ExitPlanModeResult, PermissionResult, SessionHandler, UserInputResponse}; +use crate::types::{ + ElicitationRequest, ElicitationResult, ExitPlanModeData, PermissionRequestData, RequestId, + SessionEvent, SessionId, Tool, ToolInvocation, ToolResult, ToolResultExpanded, +}; + +/// Generate a JSON Schema [`Value`](serde_json::Value) from a Rust type. +/// +/// Strips `$schema` and `title` root-level metadata so the output is ready +/// to use as [`Tool::parameters`]. +/// +/// # Example +/// +/// ```rust +/// use github_copilot_sdk::tool::{schema_for, JsonSchema}; +/// +/// #[derive(JsonSchema)] +/// struct Params { +/// /// City name +/// city: String, +/// } +/// +/// let schema = schema_for::(); +/// assert_eq!(schema["type"], "object"); +/// assert!(schema["properties"]["city"].is_object()); +/// ``` +#[cfg(feature = "derive")] +pub fn schema_for() -> serde_json::Value { + let schema = schemars::schema_for!(T); + let mut value = serde_json::to_value(schema).expect("JSON Schema serialization cannot fail"); + if let Some(obj) = value.as_object_mut() { + obj.remove("$schema"); + obj.remove("title"); + } + value +} + +/// Convert a JSON Schema [`Value`](serde_json::Value) into the +/// [`Tool::parameters`] map shape expected by the protocol. +/// +/// Panics if the input is not a JSON object — tool parameter schemas +/// are always top-level objects (`{"type": "object", ...}`). Pair with +/// [`schema_for`] or a `serde_json::json!(...)` literal. +/// +/// Use [`try_tool_parameters`] when the schema comes from dynamic input and +/// should return a recoverable error instead of panicking. +/// +/// # Example +/// +/// ```rust +/// use github_copilot_sdk::tool::tool_parameters; +/// use github_copilot_sdk::Tool; +/// +/// let mut tool = Tool::default(); +/// tool.name = "ping".to_string(); +/// tool.description = "ping the server".to_string(); +/// tool.parameters = tool_parameters(serde_json::json!({"type": "object"})); +/// # let _ = tool; +/// ``` +pub fn tool_parameters(schema: serde_json::Value) -> HashMap { + try_tool_parameters(schema).expect("tool parameter schema must be a JSON object") +} + +/// Fallible variant of [`tool_parameters`] for callers handling dynamic schema input. +pub fn try_tool_parameters( + schema: serde_json::Value, +) -> Result, serde_json::Error> { + serde_json::from_value(schema) +} + +/// A client-defined tool with its handler logic. +/// +/// Implement this trait for each tool you expose to the Copilot agent. +/// The struct is a named type — visible in stack traces and navigable +/// via "go to definition" — unlike closure-based alternatives. +/// +/// # Example +/// +/// ```rust,ignore +/// use github_copilot_sdk::tool::{schema_for, tool_parameters, JsonSchema, ToolHandler}; +/// use github_copilot_sdk::{Error, Tool, ToolInvocation, ToolResult}; +/// use serde::Deserialize; +/// use async_trait::async_trait; +/// +/// #[derive(Deserialize, JsonSchema)] +/// struct GetWeatherParams { +/// /// City name +/// city: String, +/// /// Temperature unit +/// unit: Option, +/// } +/// +/// struct GetWeatherTool; +/// +/// #[async_trait] +/// impl ToolHandler for GetWeatherTool { +/// fn tool(&self) -> Tool { +/// Tool { +/// name: "get_weather".to_string(), +/// namespaced_name: None, +/// description: "Get weather for a city".to_string(), +/// parameters: tool_parameters(schema_for::()), +/// instructions: None, +/// ..Default::default() +/// } +/// } +/// +/// async fn call(&self, inv: ToolInvocation) -> Result { +/// let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; +/// Ok(ToolResult::Text(format!("Weather in {}: sunny", params.city))) +/// } +/// } +/// ``` +#[async_trait] +pub trait ToolHandler: Send + Sync { + /// The tool definition sent to the CLI during session creation. + fn tool(&self) -> Tool; + + /// Handle a tool invocation from the agent. + async fn call(&self, invocation: ToolInvocation) -> Result; +} + +/// Define a tool from an async function (or closure) that takes a typed, +/// `JsonSchema`-derived parameter struct. +/// +/// The returned `Box` plugs directly into +/// [`ToolHandlerRouter::new`]. JSON Schema for the parameter type is generated +/// via [`schema_for`] at construction time. +/// +/// The handler bound (`Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static`) +/// accepts both bare `async fn` items and closures — the same shape as +/// [`tower::service_fn`][tower-service-fn] and +/// [`hyper::service::service_fn`][hyper-service-fn]. Prefer a free `async fn` +/// for non-trivial tools so it shows up in stack traces by name. +/// +/// The closure receives the full [`ToolInvocation`] alongside the deserialized +/// parameters so handlers can use `inv.session_id`, `inv.tool_call_id`, or +/// other invocation metadata. Handlers that don't need that metadata can +/// destructure with `|_inv, params|`. +/// +/// # Example +/// +/// ```rust,no_run +/// use github_copilot_sdk::tool::{define_tool, JsonSchema}; +/// use github_copilot_sdk::types::ToolInvocation; +/// use github_copilot_sdk::{Error, ToolResult}; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize, JsonSchema)] +/// struct GetWeatherParams { +/// /// City name +/// city: String, +/// } +/// +/// async fn get_weather( +/// inv: ToolInvocation, +/// params: GetWeatherParams, +/// ) -> Result { +/// // `inv.session_id` and `inv.tool_call_id` are available for telemetry, +/// // streaming updates, scoping DB lookups, etc. +/// let _ = inv.session_id; +/// Ok(ToolResult::Text(format!("Sunny in {}", params.city))) +/// } +/// +/// // Pass a free async fn — preferred for non-trivial tools. +/// let tool = define_tool("get_weather", "Get weather for a city", get_weather); +/// +/// // ...or an inline closure when the body is trivial. +/// let tool = define_tool( +/// "echo", +/// "Echo the input", +/// |_inv, params: GetWeatherParams| async move { +/// Ok(ToolResult::Text(params.city)) +/// }, +/// ); +/// # let _ = tool; +/// ``` +/// +/// [tower-service-fn]: https://docs.rs/tower/latest/tower/fn.service_fn.html +/// [hyper-service-fn]: https://docs.rs/hyper/latest/hyper/service/fn.service_fn.html +#[cfg(feature = "derive")] +pub fn define_tool( + name: impl Into, + description: impl Into, + handler: F, +) -> Box +where + P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static, + F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, +{ + struct FnTool { + name: String, + description: String, + parameters: HashMap, + handler: F, + _marker: std::marker::PhantomData, + } + + #[async_trait] + impl ToolHandler for FnTool + where + P: schemars::JsonSchema + serde::de::DeserializeOwned + Send + 'static, + F: Fn(ToolInvocation, P) -> Fut + Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, + { + fn tool(&self) -> Tool { + Tool { + name: self.name.clone(), + description: self.description.clone(), + parameters: self.parameters.clone(), + ..Default::default() + } + } + + async fn call(&self, mut invocation: ToolInvocation) -> Result { + let arguments = std::mem::take(&mut invocation.arguments); + let params: P = serde_json::from_value(arguments)?; + (self.handler)(invocation, params).await + } + } + + Box::new(FnTool { + name: name.into(), + description: description.into(), + parameters: tool_parameters(schema_for::

()), + handler, + _marker: std::marker::PhantomData, + }) +} + +/// A [`SessionHandler`] that dispatches tool calls to registered +/// [`ToolHandler`] implementations by name. +/// +/// For tool calls matching a registered handler, the handler is invoked +/// directly. All other events (permissions, user input, unrecognized tools) +/// are forwarded to the inner handler. +/// +/// # Example +/// +/// ```rust,no_run +/// use std::sync::Arc; +/// use github_copilot_sdk::handler::ApproveAllHandler; +/// use github_copilot_sdk::tool::ToolHandlerRouter; +/// +/// let router = ToolHandlerRouter::new( +/// vec![/* Box::new(MyTool), ... */], +/// Arc::new(ApproveAllHandler), +/// ); +/// +/// // Use router.tools() in SessionConfig +/// // Use Arc::new(router) as the session handler +/// ``` +pub struct ToolHandlerRouter { + handlers: HashMap>, + inner: Arc, +} + +impl std::fmt::Debug for ToolHandlerRouter { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut tools: Vec<_> = self.handlers.keys().collect(); + tools.sort(); + f.debug_struct("ToolHandlerRouter") + .field("tool_count", &self.handlers.len()) + .field("tools", &tools) + .finish() + } +} + +impl ToolHandlerRouter { + /// Create a router from tool handler impls and a fallback handler. + /// + /// Call [`tools()`](Self::tools) to get the tool definitions for + /// [`SessionConfig::tools`](crate::SessionConfig::tools). + pub fn new(tools: Vec>, inner: Arc) -> Self { + let mut handlers = HashMap::new(); + for tool in tools { + handlers.insert(tool.tool().name.clone(), tool); + } + Self { handlers, inner } + } + + /// Tool definitions for [`SessionConfig::tools`](crate::SessionConfig::tools). + pub fn tools(&self) -> Vec { + self.handlers.values().map(|h| h.tool()).collect() + } +} + +#[async_trait] +impl SessionHandler for ToolHandlerRouter { + async fn on_external_tool(&self, invocation: ToolInvocation) -> ToolResult { + let Some(handler) = self.handlers.get(&invocation.tool_name) else { + return self.inner.on_external_tool(invocation).await; + }; + match handler.call(invocation).await { + Ok(result) => result, + Err(e) => { + let msg = e.to_string(); + ToolResult::Expanded(ToolResultExpanded { + text_result_for_llm: msg.clone(), + result_type: "failure".to_string(), + session_log: None, + error: Some(msg), + }) + } + } + } + + async fn on_session_event(&self, session_id: SessionId, event: SessionEvent) { + self.inner.on_session_event(session_id, event).await + } + + async fn on_permission_request( + &self, + session_id: SessionId, + request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + self.inner + .on_permission_request(session_id, request_id, data) + .await + } + + async fn on_user_input( + &self, + session_id: SessionId, + question: String, + choices: Option>, + allow_freeform: Option, + ) -> Option { + self.inner + .on_user_input(session_id, question, choices, allow_freeform) + .await + } + + async fn on_elicitation( + &self, + session_id: SessionId, + request_id: RequestId, + request: ElicitationRequest, + ) -> ElicitationResult { + self.inner + .on_elicitation(session_id, request_id, request) + .await + } + + async fn on_exit_plan_mode( + &self, + session_id: SessionId, + data: ExitPlanModeData, + ) -> ExitPlanModeResult { + self.inner.on_exit_plan_mode(session_id, data).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::types::{PermissionRequestData, RequestId, SessionId}; + + struct EchoTool; + + #[async_trait] + impl ToolHandler for EchoTool { + fn tool(&self) -> Tool { + Tool { + name: "echo".to_string(), + namespaced_name: None, + description: "Echo the input".to_string(), + parameters: tool_parameters(serde_json::json!({"type": "object"})), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, inv: ToolInvocation) -> Result { + Ok(ToolResult::Text(inv.arguments.to_string())) + } + } + + #[test] + fn tool_handler_returns_tool_definition() { + let tool = EchoTool; + let def = tool.tool(); + assert_eq!(def.name, "echo"); + assert_eq!(def.description, "Echo the input"); + assert!(def.parameters.contains_key("type")); + } + + #[test] + fn try_tool_parameters_rejects_non_object_schema() { + let err = try_tool_parameters(serde_json::json!(["not", "an", "object"])) + .expect_err("non-object schemas should be rejected"); + + assert!(err.is_data()); + } + + #[tokio::test] + async fn tool_handler_call_returns_result() { + let tool = EchoTool; + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "echo".to_string(), + arguments: serde_json::json!({"msg": "hello"}), + traceparent: None, + tracestate: None, + }; + + let result = tool.call(inv).await.unwrap(); + match result { + ToolResult::Text(s) => assert!(s.contains("hello")), + _ => panic!("expected Text result"), + } + } + + #[cfg(feature = "derive")] + #[tokio::test] + async fn define_tool_builds_schema_and_dispatches() { + use serde::Deserialize; + + #[derive(Deserialize, schemars::JsonSchema)] + struct Params { + city: String, + } + + let tool = define_tool( + "weather", + "Get the weather for a city", + |_inv, params: Params| async move { + Ok(ToolResult::Text(format!("sunny in {}", params.city))) + }, + ); + + let def = tool.tool(); + assert_eq!(def.name, "weather"); + assert_eq!(def.description, "Get the weather for a city"); + assert_eq!(def.parameters["type"], "object"); + assert!(def.parameters["properties"]["city"].is_object()); + + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "weather".to_string(), + arguments: serde_json::json!({"city": "Seattle"}), + traceparent: None, + tracestate: None, + }; + match tool.call(inv).await.unwrap() { + ToolResult::Text(s) => assert_eq!(s, "sunny in Seattle"), + _ => panic!("expected Text result"), + } + } + + #[tokio::test] + async fn router_dispatches_to_correct_handler() { + struct ToolA; + #[async_trait] + impl ToolHandler for ToolA { + fn tool(&self) -> Tool { + Tool { + name: "tool_a".to_string(), + namespaced_name: None, + description: "A".to_string(), + parameters: HashMap::new(), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, _inv: ToolInvocation) -> Result { + Ok(ToolResult::Text("a_result".to_string())) + } + } + + struct ToolB; + #[async_trait] + impl ToolHandler for ToolB { + fn tool(&self) -> Tool { + Tool { + name: "tool_b".to_string(), + namespaced_name: None, + description: "B".to_string(), + parameters: HashMap::new(), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, _inv: ToolInvocation) -> Result { + Ok(ToolResult::Text("b_result".to_string())) + } + } + + let router = ToolHandlerRouter::new( + vec![Box::new(ToolA), Box::new(ToolB)], + Arc::new(crate::handler::ApproveAllHandler), + ); + + let tools = router.tools(); + assert_eq!(tools.len(), 2); + + let response = router + .on_external_tool(ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "tool_b".to_string(), + arguments: serde_json::json!({}), + traceparent: None, + tracestate: None, + }) + .await; + match response { + ToolResult::Text(s) => assert_eq!(s, "b_result"), + _ => panic!("expected ToolResult::Text"), + } + } + + #[tokio::test] + async fn router_falls_through_for_unknown_tool() { + use std::sync::atomic::{AtomicBool, Ordering}; + + struct FallbackHandler { + called: AtomicBool, + } + #[async_trait] + impl SessionHandler for FallbackHandler { + async fn on_external_tool(&self, _inv: ToolInvocation) -> ToolResult { + self.called.store(true, Ordering::Relaxed); + ToolResult::Text("fallback".to_string()) + } + } + + let fallback = Arc::new(FallbackHandler { + called: AtomicBool::new(false), + }); + let router = ToolHandlerRouter::new(vec![], fallback.clone()); + + let response = router + .on_external_tool(ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "unknown".to_string(), + arguments: serde_json::json!({}), + traceparent: None, + tracestate: None, + }) + .await; + assert!(fallback.called.load(Ordering::Relaxed)); + match response { + ToolResult::Text(s) => assert_eq!(s, "fallback"), + _ => panic!("expected fallback result"), + } + } + + #[tokio::test] + async fn router_returns_failure_on_handler_error() { + struct FailTool; + #[async_trait] + impl ToolHandler for FailTool { + fn tool(&self) -> Tool { + Tool { + name: "bad_tool".to_string(), + namespaced_name: None, + description: "Always fails".to_string(), + parameters: HashMap::new(), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, _inv: ToolInvocation) -> Result { + Err(Error::Rpc { + code: -1, + message: "intentional failure".to_string(), + }) + } + } + + let router = ToolHandlerRouter::new( + vec![Box::new(FailTool)], + Arc::new(crate::handler::ApproveAllHandler), + ); + + let response = router + .on_external_tool(ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "bad_tool".to_string(), + arguments: serde_json::json!({}), + traceparent: None, + tracestate: None, + }) + .await; + match response { + ToolResult::Expanded(exp) => { + assert_eq!(exp.result_type, "failure"); + assert!(exp.error.unwrap().contains("intentional failure")); + } + _ => panic!("expected expanded failure result"), + } + } + + #[tokio::test] + async fn router_forwards_non_tool_events() { + struct PermHandler; + #[async_trait] + impl SessionHandler for PermHandler { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Denied + } + } + + let router = ToolHandlerRouter::new(vec![], Arc::new(PermHandler)); + + let response = router + .on_permission_request( + SessionId::from("s1"), + RequestId::new("r1"), + PermissionRequestData { + extra: serde_json::json!({}), + ..Default::default() + }, + ) + .await; + assert!(matches!(response, PermissionResult::Denied)); + } + + #[tokio::test] + async fn router_default_on_event_dispatches_via_per_event_methods() { + // Regression: callers using the legacy on_event entry point should + // still get correct dispatch through the inherited default impl. + use crate::handler::{HandlerEvent, HandlerResponse}; + + struct OkTool; + #[async_trait] + impl ToolHandler for OkTool { + fn tool(&self) -> Tool { + Tool { + name: "ok_tool".to_string(), + namespaced_name: None, + description: "ok".to_string(), + parameters: HashMap::new(), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, _inv: ToolInvocation) -> Result { + Ok(ToolResult::Text("ok".to_string())) + } + } + + let router = ToolHandlerRouter::new( + vec![Box::new(OkTool)], + Arc::new(crate::handler::ApproveAllHandler), + ); + + let response = router + .on_event(HandlerEvent::ExternalTool { + invocation: ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "ok_tool".to_string(), + arguments: serde_json::json!({}), + traceparent: None, + tracestate: None, + }, + }) + .await; + match response { + HandlerResponse::ToolResult(ToolResult::Text(s)) => assert_eq!(s, "ok"), + _ => panic!("expected ToolResult via default on_event"), + } + } + + // Tests requiring `schemars` (the `derive` feature). + #[cfg(feature = "derive")] + mod derive_tests { + use serde::Deserialize; + + use super::super::*; + use crate::SessionId; + + #[derive(Deserialize, schemars::JsonSchema)] + struct GetWeatherParams { + /// City name to get weather for. + city: String, + /// Temperature unit (celsius or fahrenheit). + unit: Option, + } + + #[test] + fn schema_for_generates_clean_schema() { + let schema = schema_for::(); + assert_eq!(schema["type"], "object"); + assert!(schema["properties"]["city"].is_object()); + assert!(schema["properties"]["unit"].is_object()); + // city is required (non-Option), unit is not + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&serde_json::json!("city"))); + assert!(!required.contains(&serde_json::json!("unit"))); + // Root-level metadata stripped + assert!(schema.get("$schema").is_none()); + assert!(schema.get("title").is_none()); + } + + struct GetWeatherTool; + + #[async_trait] + impl ToolHandler for GetWeatherTool { + fn tool(&self) -> Tool { + Tool { + name: "get_weather".to_string(), + namespaced_name: None, + description: "Get weather for a city".to_string(), + parameters: tool_parameters(schema_for::()), + instructions: None, + ..Default::default() + } + } + + async fn call(&self, inv: ToolInvocation) -> Result { + let params: GetWeatherParams = serde_json::from_value(inv.arguments)?; + Ok(ToolResult::Text(format!( + "{} {}", + params.city, + params.unit.unwrap_or_default() + ))) + } + } + + #[test] + fn tool_handler_with_schema_for() { + let tool = GetWeatherTool; + let def = tool.tool(); + assert_eq!(def.name, "get_weather"); + let schema = serde_json::to_value(&def.parameters).expect("serialize tool parameters"); + assert_eq!(schema["type"], "object"); + assert!(schema["properties"]["city"].is_object()); + } + + #[tokio::test] + async fn tool_handler_deserializes_typed_params() { + let tool = GetWeatherTool; + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "get_weather".to_string(), + arguments: serde_json::json!({"city": "Seattle", "unit": "celsius"}), + traceparent: None, + tracestate: None, + }; + + let result = tool.call(inv).await.unwrap(); + match result { + ToolResult::Text(s) => assert_eq!(s, "Seattle celsius"), + _ => panic!("expected Text result"), + } + } + + #[tokio::test] + async fn tool_handler_returns_error_on_bad_params() { + let tool = GetWeatherTool; + let inv = ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "get_weather".to_string(), + arguments: serde_json::json!({"wrong_field": 42}), + traceparent: None, + tracestate: None, + }; + + let err = tool.call(inv).await.unwrap_err(); + assert!(matches!(err, Error::Json(_))); + } + + #[tokio::test] + async fn router_with_schema_for_tools() { + let router = ToolHandlerRouter::new( + vec![Box::new(GetWeatherTool)], + Arc::new(crate::handler::ApproveAllHandler), + ); + + let tools = router.tools(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0].name, "get_weather"); + + let response = router + .on_external_tool(ToolInvocation { + session_id: SessionId::from("s1"), + tool_call_id: "tc1".to_string(), + tool_name: "get_weather".to_string(), + arguments: serde_json::json!({"city": "Portland"}), + traceparent: None, + tracestate: None, + }) + .await; + match response { + ToolResult::Text(s) => assert!(s.contains("Portland")), + _ => panic!("expected ToolResult::Text"), + } + } + } +} diff --git a/rust/src/trace_context.rs b/rust/src/trace_context.rs new file mode 100644 index 000000000..287c87cbd --- /dev/null +++ b/rust/src/trace_context.rs @@ -0,0 +1,132 @@ +//! W3C Trace Context propagation for distributed tracing. +//! +//! The GitHub Copilot CLI propagates [W3C Trace Context] headers (`traceparent` +//! and `tracestate`) so SDK consumers can correlate spans created by the +//! CLI with their own observability pipelines. +//! +//! Two injection paths are supported: +//! +//! - **Per-turn override** via [`MessageOptions::traceparent`] / +//! [`MessageOptions::tracestate`](crate::types::MessageOptions::tracestate), +//! which take precedence when set. +//! - **Ambient callback** via +//! [`ClientOptions::on_get_trace_context`](crate::ClientOptions::on_get_trace_context), +//! which the SDK invokes before `session.create`, `session.resume`, and +//! `session.send` whenever the per-turn override is absent. +//! +//! [W3C Trace Context]: https://www.w3.org/TR/trace-context/ +//! [`MessageOptions::traceparent`]: crate::types::MessageOptions::traceparent + +use async_trait::async_trait; + +/// W3C Trace Context headers propagated to and from the GitHub Copilot CLI. +/// +/// `traceparent` carries the trace and parent-span identifiers; `tracestate` +/// carries vendor-specific extensions. Either field may be `None` when the +/// caller has nothing to propagate; in that case the corresponding wire +/// field is omitted. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[non_exhaustive] +pub struct TraceContext { + /// `traceparent` HTTP header value. + pub traceparent: Option, + /// `tracestate` HTTP header value. + pub tracestate: Option, +} + +impl TraceContext { + /// Construct an empty [`TraceContext`]; both fields default to unset + /// (the SDK skips trace-context injection on the wire). + pub fn new() -> Self { + Self::default() + } + + /// Construct a [`TraceContext`] from a `traceparent` header value, with + /// no `tracestate`. + /// + /// Equivalent to `TraceContext::new().with_traceparent(value)`; kept + /// for ergonomics in the common single-header case. + pub fn from_traceparent(traceparent: impl Into) -> Self { + Self::new().with_traceparent(traceparent) + } + + /// Set or replace the `traceparent` header value, returning `self` for + /// chaining. + pub fn with_traceparent(mut self, traceparent: impl Into) -> Self { + self.traceparent = Some(traceparent.into()); + self + } + + /// Set or replace the `tracestate` header value, returning `self` for + /// chaining. + pub fn with_tracestate(mut self, tracestate: impl Into) -> Self { + self.tracestate = Some(tracestate.into()); + self + } + + /// Returns `true` when neither `traceparent` nor `tracestate` is set. + pub fn is_empty(&self) -> bool { + self.traceparent.is_none() && self.tracestate.is_none() + } +} + +/// Async provider that returns the current [`TraceContext`] for outbound +/// session RPCs. +/// +/// Set via +/// [`ClientOptions::on_get_trace_context`](crate::ClientOptions::on_get_trace_context). +/// The SDK invokes [`get_trace_context`](Self::get_trace_context) before +/// each `session.create`, `session.resume`, and `session.send` whenever +/// the call site does not carry a per-turn override. +/// +/// Implementations should handle errors internally and return +/// [`TraceContext::default()`] to skip injection — no `Result` return type +/// is exposed because trace propagation is a best-effort observability +/// feature, not a correctness-critical RPC parameter. +#[async_trait] +pub trait TraceContextProvider: Send + Sync + 'static { + /// Return the current trace context, or [`TraceContext::default()`] to + /// skip injection. + async fn get_trace_context(&self) -> TraceContext; +} + +/// Inject `traceparent` / `tracestate` from `ctx` into the JSON `params` +/// object if either field is set. No-op when both are `None`. +pub(crate) fn inject_trace_context(params: &mut serde_json::Value, ctx: &TraceContext) { + if let Some(tp) = &ctx.traceparent { + params["traceparent"] = serde_json::Value::String(tp.clone()); + } + if let Some(ts) = &ctx.tracestate { + params["tracestate"] = serde_json::Value::String(ts.clone()); + } +} + +#[cfg(test)] +mod tests { + use super::TraceContext; + + #[test] + fn new_yields_empty_context() { + let ctx = TraceContext::new(); + assert!(ctx.is_empty()); + assert!(ctx.traceparent.is_none()); + assert!(ctx.tracestate.is_none()); + } + + #[test] + fn builder_composes_traceparent_and_tracestate() { + let ctx = TraceContext::new() + .with_traceparent("00-trace-span-01") + .with_tracestate("vendor=key"); + assert_eq!(ctx.traceparent.as_deref(), Some("00-trace-span-01")); + assert_eq!(ctx.tracestate.as_deref(), Some("vendor=key")); + assert!(!ctx.is_empty()); + } + + #[test] + fn from_traceparent_matches_builder() { + let direct = TraceContext::from_traceparent("00-trace-span-01"); + let chained = TraceContext::new().with_traceparent("00-trace-span-01"); + assert_eq!(direct, chained); + } +} diff --git a/rust/src/transforms.rs b/rust/src/transforms.rs new file mode 100644 index 000000000..a090bc649 --- /dev/null +++ b/rust/src/transforms.rs @@ -0,0 +1,223 @@ +//! System message transform callbacks for customizing agent prompts. +//! +//! Implement [`SystemMessageTransform`](crate::transforms::SystemMessageTransform) to intercept and modify system prompt +//! sections during session creation. The CLI sends the current content for +//! each section the transform registered, and the SDK returns the modified +//! content. + +use std::collections::HashMap; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::types::SessionId; + +/// Context provided to every transform invocation. +#[derive(Debug, Clone)] +pub struct TransformContext { + /// The session being created or resumed. + pub session_id: SessionId, +} + +/// Handles `systemMessage.transform` RPC requests from the CLI. +/// +/// The CLI sends these during session creation/resumption when the session's +/// `SystemMessageConfig` contains sections with `action: "transform"`. For each +/// such section, the CLI provides the current content and expects the SDK to +/// return the (possibly modified) content. +/// +/// Implement this trait and pass it to [`Client::create_session`](crate::Client::create_session) / +/// [`Client::resume_session`](crate::Client::resume_session) to participate in system message customization. +/// +/// # Example +/// +/// ```ignore +/// struct MyTransform; +/// +/// #[async_trait::async_trait] +/// impl SystemMessageTransform for MyTransform { +/// fn section_ids(&self) -> Vec { +/// vec!["instructions".to_string()] +/// } +/// +/// async fn transform_section( +/// &self, +/// _section_id: &str, +/// content: &str, +/// _ctx: TransformContext, +/// ) -> Option { +/// Some(format!("{content}\n\nAlways be concise.")) +/// } +/// } +/// ``` +#[async_trait] +pub trait SystemMessageTransform: Send + Sync + 'static { + /// Section IDs this transform handles. + /// + /// The SDK injects `action: "transform"` entries into the + /// [`SystemMessageConfig`](crate::types::SystemMessageConfig) wire format + /// for each returned ID. + fn section_ids(&self) -> Vec; + + /// Transform a section's content. Return `Some(new_content)` to modify the + /// section, or `None` to pass through unchanged. + async fn transform_section( + &self, + section_id: &str, + content: &str, + ctx: TransformContext, + ) -> Option; +} + +/// Wire format for a single section in the transform request/response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct TransformSection { + pub(crate) content: String, +} + +/// Wire format for the `systemMessage.transform` response. +#[derive(Debug, Clone, Serialize)] +pub(crate) struct TransformResponse { + pub(crate) sections: HashMap, +} + +/// Apply transforms to the incoming sections map, returning the response. +/// +/// For each section, calls the matching transform if the implementor returns +/// `Some`; otherwise passes through the original content. +pub(crate) async fn dispatch_transform( + transform: &dyn SystemMessageTransform, + session_id: &SessionId, + sections: HashMap, +) -> TransformResponse { + let ctx = TransformContext { + session_id: session_id.clone(), + }; + + let mut result = HashMap::with_capacity(sections.len()); + for (section_id, data) in sections { + let content = match transform + .transform_section(§ion_id, &data.content, ctx.clone()) + .await + { + Some(transformed) => transformed, + None => data.content, + }; + result.insert(section_id, TransformSection { content }); + } + + TransformResponse { sections: result } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct TestTransform; + + #[async_trait] + impl SystemMessageTransform for TestTransform { + fn section_ids(&self) -> Vec { + vec!["instructions".to_string(), "context".to_string()] + } + + async fn transform_section( + &self, + section_id: &str, + content: &str, + _ctx: TransformContext, + ) -> Option { + match section_id { + "instructions" => Some(format!("[modified] {content}")), + _ => None, + } + } + } + + #[tokio::test] + async fn dispatch_applies_matching_transform() { + let transform = TestTransform; + let mut sections = HashMap::new(); + sections.insert( + "instructions".to_string(), + TransformSection { + content: "be helpful".to_string(), + }, + ); + + let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await; + assert_eq!( + response.sections["instructions"].content, + "[modified] be helpful" + ); + } + + #[tokio::test] + async fn dispatch_passes_through_unhandled_section() { + let transform = TestTransform; + let mut sections = HashMap::new(); + sections.insert( + "context".to_string(), + TransformSection { + content: "original context".to_string(), + }, + ); + + let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await; + assert_eq!(response.sections["context"].content, "original context"); + } + + #[tokio::test] + async fn dispatch_unknown_section_passes_through() { + let transform = TestTransform; + let mut sections = HashMap::new(); + sections.insert( + "unknown".to_string(), + TransformSection { + content: "mystery".to_string(), + }, + ); + + let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await; + assert_eq!(response.sections["unknown"].content, "mystery"); + } + + #[tokio::test] + async fn dispatch_mixed_sections() { + let transform = TestTransform; + let mut sections = HashMap::new(); + sections.insert( + "instructions".to_string(), + TransformSection { + content: "help me".to_string(), + }, + ); + sections.insert( + "context".to_string(), + TransformSection { + content: "some context".to_string(), + }, + ); + sections.insert( + "other".to_string(), + TransformSection { + content: "other stuff".to_string(), + }, + ); + + let response = dispatch_transform(&transform, &SessionId::new("sess-1"), sections).await; + assert_eq!( + response.sections["instructions"].content, + "[modified] help me" + ); + assert_eq!(response.sections["context"].content, "some context"); + assert_eq!(response.sections["other"].content, "other stuff"); + } + + #[tokio::test] + async fn section_ids_returns_registered_sections() { + let transform = TestTransform; + let ids = transform.section_ids(); + assert_eq!(ids, vec!["instructions", "context"]); + } +} diff --git a/rust/src/types.rs b/rust/src/types.rs new file mode 100644 index 000000000..d2e2ec012 --- /dev/null +++ b/rust/src/types.rs @@ -0,0 +1,3499 @@ +//! Protocol types shared between the SDK and the GitHub Copilot CLI. +//! +//! These types map directly to the JSON-RPC request/response payloads +//! defined by the GitHub Copilot CLI protocol. They are used for session +//! configuration, event handling, tool invocations, and model queries. + +use std::collections::HashMap; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use crate::handler::SessionHandler; +use crate::hooks::SessionHooks; +pub use crate::session_fs::{ + DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConfig, SessionFsConventions, + SessionFsProvider, +}; +pub use crate::trace_context::{TraceContext, TraceContextProvider}; +use crate::transforms::SystemMessageTransform; + +/// Lifecycle state of a [`Client`](crate::Client) connection to the CLI. +/// +/// The state advances from `Connecting` → `Connected` during construction, +/// transitions to `Disconnected` after [`Client::stop`](crate::Client::stop) or +/// [`Client::force_stop`](crate::Client::force_stop), and lands in +/// `Error` if startup fails or the underlying transport tears down +/// unexpectedly. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[non_exhaustive] +pub enum ConnectionState { + /// No CLI process is attached or the process has exited cleanly. + Disconnected, + /// The client is starting up (spawning the CLI, negotiating protocol). + Connecting, + /// The client is connected and ready to handle RPC traffic. + Connected, + /// Startup failed or the connection encountered an unrecoverable error. + Error, +} + +/// Type of [`SessionLifecycleEvent`] received via [`Client::subscribe_lifecycle`](crate::Client::subscribe_lifecycle). +/// +/// Values serialize as the dotted JSON strings the CLI sends (e.g. +/// `"session.created"`). +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[non_exhaustive] +pub enum SessionLifecycleEventType { + /// A new session was created. + #[serde(rename = "session.created")] + Created, + /// A session was deleted. + #[serde(rename = "session.deleted")] + Deleted, + /// A session's metadata was updated (e.g. summary regenerated). + #[serde(rename = "session.updated")] + Updated, + /// A session moved into the foreground. + #[serde(rename = "session.foreground")] + Foreground, + /// A session moved into the background. + #[serde(rename = "session.background")] + Background, +} + +/// Optional metadata attached to a [`SessionLifecycleEvent`]. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SessionLifecycleEventMetadata { + /// ISO-8601 timestamp the session was created. + #[serde(rename = "startTime")] + pub start_time: String, + /// ISO-8601 timestamp the session was last modified. + #[serde(rename = "modifiedTime")] + pub modified_time: String, + /// Optional generated summary of the session conversation so far. + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, +} + +/// A `session.lifecycle` notification dispatched to subscribers obtained via +/// [`Client::subscribe_lifecycle`](crate::Client::subscribe_lifecycle). +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct SessionLifecycleEvent { + /// The kind of lifecycle change this event represents. + #[serde(rename = "type")] + pub event_type: SessionLifecycleEventType, + /// Identifier of the session this event refers to. + #[serde(rename = "sessionId")] + pub session_id: SessionId, + /// Optional metadata describing the session at the time of the event. + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +/// Opaque session identifier assigned by the CLI. +/// +/// A newtype wrapper around `String` that provides type safety — prevents +/// accidentally passing a workspace ID or request ID where a session ID +/// is expected. Derefs to `str` for zero-friction borrowing. +#[derive(Debug, Clone, Default, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub struct SessionId(String); + +impl SessionId { + /// Create a new session ID from any string-like value. + pub fn new(id: impl Into) -> Self { + Self(id.into()) + } + + /// Borrow the inner string. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Consume the wrapper, returning the inner string. + pub fn into_inner(self) -> String { + self.0 + } +} + +impl std::ops::Deref for SessionId { + type Target = str; + + fn deref(&self) -> &str { + &self.0 + } +} + +impl std::fmt::Display for SessionId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl From for SessionId { + fn from(s: String) -> Self { + Self(s) + } +} + +impl From<&str> for SessionId { + fn from(s: &str) -> Self { + Self(s.to_owned()) + } +} + +impl AsRef for SessionId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl std::borrow::Borrow for SessionId { + fn borrow(&self) -> &str { + &self.0 + } +} + +impl From for String { + fn from(id: SessionId) -> String { + id.0 + } +} + +impl PartialEq for SessionId { + fn eq(&self, other: &str) -> bool { + self.0 == other + } +} + +impl PartialEq for SessionId { + fn eq(&self, other: &String) -> bool { + &self.0 == other + } +} + +impl PartialEq for String { + fn eq(&self, other: &SessionId) -> bool { + self == &other.0 + } +} + +impl PartialEq<&str> for SessionId { + fn eq(&self, other: &&str) -> bool { + self.0 == *other + } +} + +impl PartialEq<&SessionId> for SessionId { + fn eq(&self, other: &&SessionId) -> bool { + self.0 == other.0 + } +} + +impl PartialEq for &SessionId { + fn eq(&self, other: &SessionId) -> bool { + self.0 == other.0 + } +} + +/// Opaque request identifier for pending CLI requests (permission, user-input, etc.). +/// +/// A newtype wrapper around `String` that provides type safety — prevents +/// accidentally passing a session ID or workspace ID where a request ID +/// is expected. Derefs to `str` for zero-friction borrowing. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub struct RequestId(String); + +impl RequestId { + /// Create a new request ID from any string-like value. + pub fn new(id: impl Into) -> Self { + Self(id.into()) + } + + /// Consume the wrapper, returning the inner string. + pub fn into_inner(self) -> String { + self.0 + } +} + +impl std::ops::Deref for RequestId { + type Target = str; + + fn deref(&self) -> &str { + &self.0 + } +} + +impl std::fmt::Display for RequestId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.0) + } +} + +impl From for RequestId { + fn from(s: String) -> Self { + Self(s) + } +} + +impl From<&str> for RequestId { + fn from(s: &str) -> Self { + Self(s.to_owned()) + } +} + +impl AsRef for RequestId { + fn as_ref(&self) -> &str { + &self.0 + } +} + +impl std::borrow::Borrow for RequestId { + fn borrow(&self) -> &str { + &self.0 + } +} + +impl From for String { + fn from(id: RequestId) -> String { + id.0 + } +} + +impl PartialEq for RequestId { + fn eq(&self, other: &str) -> bool { + self.0 == other + } +} + +impl PartialEq for RequestId { + fn eq(&self, other: &String) -> bool { + &self.0 == other + } +} + +impl PartialEq for String { + fn eq(&self, other: &RequestId) -> bool { + self == &other.0 + } +} + +impl PartialEq<&str> for RequestId { + fn eq(&self, other: &&str) -> bool { + self.0 == *other + } +} + +/// A tool that the client exposes to the Copilot agent. +/// +/// Sent to the CLI as part of [`SessionConfig::tools`] / [`ResumeSessionConfig::tools`] +/// at session creation/resume time. The Rust SDK hand-authors this struct +/// (rather than using the schema-generated form) so it can carry runtime +/// hints — `overrides_built_in_tool`, `skip_permission` — that don't appear +/// in the wire schema but are honored by the CLI. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct Tool { + /// Tool identifier (e.g., `"bash"`, `"grep"`, `"str_replace_editor"`). + pub name: String, + /// Optional namespaced name for declarative filtering (e.g., `"playwright/navigate"` + /// for MCP tools). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub namespaced_name: Option, + /// Description of what the tool does. + #[serde(default)] + pub description: String, + /// Optional instructions for how to use this tool effectively. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub instructions: Option, + /// JSON Schema for the tool's input parameters. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub parameters: HashMap, + /// When `true`, this tool replaces a built-in tool of the same name + /// (e.g. supplying a custom `grep` that the agent uses in place of the + /// CLI's built-in implementation). + #[serde(default, skip_serializing_if = "is_false")] + pub overrides_built_in_tool: bool, + /// When `true`, the CLI does not request permission before invoking + /// this tool. Use with caution — the tool is responsible for any + /// access control. + #[serde(default, skip_serializing_if = "is_false")] + pub skip_permission: bool, +} + +#[inline] +fn is_false(b: &bool) -> bool { + !*b +} + +impl Tool { + /// Construct a new [`Tool`] with the given name and otherwise default + /// values. The struct is `#[non_exhaustive]`, so external callers + /// cannot use struct-literal syntax — use this builder or + /// [`Default::default`] plus mut-let. + /// + /// # Example + /// + /// ``` + /// # use github_copilot_sdk::types::Tool; + /// # use serde_json::json; + /// let tool = Tool::new("greet") + /// .with_description("Say hello to a user") + /// .with_parameters(json!({ + /// "type": "object", + /// "properties": { "name": { "type": "string" } }, + /// "required": ["name"] + /// })); + /// # let _ = tool; + /// ``` + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + ..Default::default() + } + } + + /// Set the namespaced name for declarative filtering (e.g. + /// `"playwright/navigate"` for MCP tools). + pub fn with_namespaced_name(mut self, namespaced_name: impl Into) -> Self { + self.namespaced_name = Some(namespaced_name.into()); + self + } + + /// Set the human-readable description of what the tool does. + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = description.into(); + self + } + + /// Set optional instructions for how to use this tool effectively. + pub fn with_instructions(mut self, instructions: impl Into) -> Self { + self.instructions = Some(instructions.into()); + self + } + + /// Set the JSON Schema for the tool's input parameters. + /// + /// Accepts anything that converts into a JSON object, including a + /// `serde_json::Value` produced by `json!({...})`. Non-object values + /// are stored as an empty parameter map; callers that need direct + /// control over the field can construct a `HashMap` + /// and assign it to [`Tool::parameters`] via [`Default::default`]. + pub fn with_parameters(mut self, parameters: Value) -> Self { + self.parameters = parameters + .as_object() + .map(|m| m.iter().map(|(k, v)| (k.clone(), v.clone())).collect()) + .unwrap_or_default(); + self + } + + /// Mark this tool as overriding a built-in tool of the same name. + /// E.g. supplying a custom `grep` that the agent uses in place of the + /// CLI's built-in implementation. + pub fn with_overrides_built_in_tool(mut self, overrides: bool) -> Self { + self.overrides_built_in_tool = overrides; + self + } + + /// When `true`, the CLI will not request permission before invoking + /// this tool. Use with caution — the tool is responsible for any + /// access control. + pub fn with_skip_permission(mut self, skip: bool) -> Self { + self.skip_permission = skip; + self + } +} + +/// Context passed to a [`CommandHandler`] when a registered slash command +/// is executed by the user. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct CommandContext { + /// Session ID where the command was invoked. + pub session_id: SessionId, + /// The full command text (e.g. `"/deploy production"`). + pub command: String, + /// Command name without the leading `/` (e.g. `"deploy"`). + pub command_name: String, + /// Raw argument string after the command name (e.g. `"production"`). + pub args: String, +} + +/// Handler invoked when a registered slash command is executed. +/// +/// Returning `Err(_)` causes the SDK to forward the error message back to +/// the CLI via `session.commands.handlePendingCommand` so the TUI can +/// surface it. Returning `Ok(())` reports success. +#[async_trait::async_trait] +pub trait CommandHandler: Send + Sync { + /// Called when the user invokes the command this handler is registered for. + async fn on_command(&self, ctx: CommandContext) -> Result<(), crate::Error>; +} + +/// Definition of a slash command registered with the session. +/// +/// When the CLI is running with a TUI, registered commands appear as +/// `/name` for the user to invoke. Only `name` and `description` are sent +/// over the wire — the handler is local to this SDK process. +#[non_exhaustive] +#[derive(Clone)] +pub struct CommandDefinition { + /// Command name (without leading `/`). + pub name: String, + /// Human-readable description shown in command-completion UI. + pub description: Option, + /// Handler invoked when the command is executed. + pub handler: Arc, +} + +impl CommandDefinition { + /// Construct a new command definition. Use [`with_description`](Self::with_description) + /// to add a description. + pub fn new(name: impl Into, handler: Arc) -> Self { + Self { + name: name.into(), + description: None, + handler, + } + } + + /// Set the human-readable description shown in the CLI's command-completion UI. + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } +} + +impl std::fmt::Debug for CommandDefinition { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CommandDefinition") + .field("name", &self.name) + .field("description", &self.description) + .field("handler", &"") + .finish() + } +} + +impl Serialize for CommandDefinition { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeStruct; + let len = if self.description.is_some() { 2 } else { 1 }; + let mut state = serializer.serialize_struct("CommandDefinition", len)?; + state.serialize_field("name", &self.name)?; + if let Some(description) = &self.description { + state.serialize_field("description", description)?; + } + state.end() + } +} + +/// Configures a custom agent (sub-agent) for the session. +/// +/// Custom agents have their own prompt, tool allowlist, and optionally +/// their own MCP servers and skill set. The agent named in +/// [`SessionConfig::agent`] (or the runtime default) is the active one +/// when the session starts. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct CustomAgentConfig { + /// Unique name of the custom agent. + pub name: String, + /// Display name for UI purposes. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub display_name: Option, + /// Description of what the agent does. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub description: Option, + /// List of tool names the agent can use. `None` means all tools. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tools: Option>, + /// Prompt content for the agent. + pub prompt: String, + /// MCP servers specific to this agent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + /// Whether the agent is available for model inference. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub infer: Option, + /// Skill names to preload into this agent's context at startup. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub skills: Option>, +} + +impl CustomAgentConfig { + /// Construct a custom agent configuration with the required `name` + /// and `prompt` fields populated. + /// + /// All other fields default to unset; use the `with_*` chain to + /// customize them. Fields are also `pub` if direct assignment is + /// preferred for `Option` pass-through. + pub fn new(name: impl Into, prompt: impl Into) -> Self { + Self { + name: name.into(), + prompt: prompt.into(), + ..Self::default() + } + } + + /// Set the display name shown in the CLI's agent-selection UI. + pub fn with_display_name(mut self, display_name: impl Into) -> Self { + self.display_name = Some(display_name.into()); + self + } + + /// Set the description of what the agent does. + pub fn with_description(mut self, description: impl Into) -> Self { + self.description = Some(description.into()); + self + } + + /// Restrict the agent to a specific tool allowlist. When unset, the + /// agent inherits the parent session's tool set. + pub fn with_tools(mut self, tools: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.tools = Some(tools.into_iter().map(Into::into).collect()); + self + } + + /// Configure agent-specific MCP servers. + pub fn with_mcp_servers(mut self, mcp_servers: HashMap) -> Self { + self.mcp_servers = Some(mcp_servers); + self + } + + /// Whether the agent participates in model inference. + pub fn with_infer(mut self, infer: bool) -> Self { + self.infer = Some(infer); + self + } + + /// Set the skills preloaded into the agent's context at startup. + pub fn with_skills(mut self, skills: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.skills = Some(skills.into_iter().map(Into::into).collect()); + self + } +} + +/// Configures the default (built-in) agent that handles turns when no +/// custom agent is selected. +/// +/// Use [`Self::excluded_tools`] to hide tools from the default agent +/// while keeping them available to custom sub-agents that list them in +/// their [`CustomAgentConfig::tools`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DefaultAgentConfig { + /// Tool names to exclude from the default agent. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub excluded_tools: Option>, +} + +/// Configures infinite sessions: persistent workspaces with automatic +/// context-window compaction. +/// +/// When enabled (default), sessions automatically manage context limits +/// through background compaction and persist state to a workspace +/// directory. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct InfiniteSessionConfig { + /// Whether infinite sessions are enabled. Defaults to `true` on the CLI. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub enabled: Option, + /// Context utilization (0.0–1.0) at which background compaction starts. + /// Default: 0.80. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub background_compaction_threshold: Option, + /// Context utilization (0.0–1.0) at which the session blocks until + /// compaction completes. Default: 0.95. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub buffer_exhaustion_threshold: Option, +} + +impl InfiniteSessionConfig { + /// Construct an empty [`InfiniteSessionConfig`]; all fields default to + /// unset (the CLI applies its own defaults). + pub fn new() -> Self { + Self::default() + } + + /// Toggle infinite sessions on or off. Defaults to `true` on the CLI + /// when unset. + pub fn with_enabled(mut self, enabled: bool) -> Self { + self.enabled = Some(enabled); + self + } + + /// Set the context utilization (0.0–1.0) at which background + /// compaction starts. + pub fn with_background_compaction_threshold(mut self, threshold: f64) -> Self { + self.background_compaction_threshold = Some(threshold); + self + } + + /// Set the context utilization (0.0–1.0) at which the session blocks + /// until compaction completes. + pub fn with_buffer_exhaustion_threshold(mut self, threshold: f64) -> Self { + self.buffer_exhaustion_threshold = Some(threshold); + self + } +} + +/// Configuration for a single MCP server. +/// +/// MCP (Model Context Protocol) servers expose external tools to the +/// agent. Local servers run as a subprocess over stdio; remote servers +/// speak HTTP or Server-Sent Events. +/// +/// Serialized as a JSON object with a `type` discriminator (`"stdio"` | +/// `"http"` | `"sse"`). +/// +/// # Example +/// +/// ``` +/// # use github_copilot_sdk::types::{McpServerConfig, McpStdioServerConfig, McpHttpServerConfig}; +/// # use std::collections::HashMap; +/// let mut servers = HashMap::new(); +/// servers.insert( +/// "playwright".to_string(), +/// McpServerConfig::Stdio(McpStdioServerConfig { +/// tools: vec!["*".to_string()], +/// command: "npx".to_string(), +/// args: vec!["-y".to_string(), "@playwright/mcp".to_string()], +/// ..Default::default() +/// }), +/// ); +/// servers.insert( +/// "weather".to_string(), +/// McpServerConfig::Http(McpHttpServerConfig { +/// tools: vec!["forecast".to_string()], +/// url: "https://example.com/mcp".to_string(), +/// ..Default::default() +/// }), +/// ); +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "lowercase")] +#[non_exhaustive] +pub enum McpServerConfig { + /// Local MCP server launched as a subprocess and addressed over stdio. + /// On the wire this serializes as `{"type": "stdio", ...}`. The CLI + /// also accepts `"local"` as an alias on input. + #[serde(alias = "local")] + Stdio(McpStdioServerConfig), + /// Remote MCP server addressed over HTTP. + Http(McpHttpServerConfig), + /// Remote MCP server addressed over Server-Sent Events. + Sse(McpHttpServerConfig), +} + +/// Configuration for a local/stdio MCP server. +/// +/// See [`McpServerConfig::Stdio`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpStdioServerConfig { + /// Tools to expose from this server. `["*"]` exposes all; `[]` exposes none. + #[serde(default)] + pub tools: Vec, + /// Optional timeout in milliseconds for tool calls to this server. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, + /// Subprocess executable. + pub command: String, + /// Arguments to pass to the subprocess. + #[serde(default)] + pub args: Vec, + /// Environment variables to set on the subprocess. + /// + /// Interpretation depends on the parent session's + /// `env_value_mode`: `"direct"` (default) treats values as literals; + /// `"indirect"` treats them as env-var names to look up at start time. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub env: HashMap, + /// Working directory for the subprocess. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cwd: Option, +} + +/// Configuration for a remote MCP server (HTTP or SSE). +/// +/// See [`McpServerConfig::Http`] and [`McpServerConfig::Sse`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct McpHttpServerConfig { + /// Tools to expose from this server. `["*"]` exposes all; `[]` exposes none. + #[serde(default)] + pub tools: Vec, + /// Optional timeout in milliseconds for tool calls to this server. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub timeout: Option, + /// Server URL. + pub url: String, + /// Optional HTTP headers to include on every request. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub headers: HashMap, +} + +/// Configures a custom inference provider (BYOK — Bring Your Own Key). +/// +/// Routes session requests through an alternative model provider +/// (OpenAI-compatible, Azure, Anthropic, or local) instead of GitHub +/// Copilot's default routing. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct ProviderConfig { + /// Provider type: `"openai"`, `"azure"`, or `"anthropic"`. Defaults to + /// `"openai"` on the CLI. + #[serde(default, skip_serializing_if = "Option::is_none", rename = "type")] + pub provider_type: Option, + /// API format (openai/azure only): `"completions"` or `"responses"`. + /// Defaults to `"completions"`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub wire_api: Option, + /// API endpoint URL. + pub base_url: String, + /// API key. Optional for local providers like Ollama. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub api_key: Option, + /// Bearer token for authentication. Sets the `Authorization` header + /// directly. Use for services requiring bearer-token auth instead of + /// API key. Takes precedence over `api_key` when both are set. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub bearer_token: Option, + /// Azure-specific options. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub azure: Option, + /// Custom HTTP headers included in outbound provider requests. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub headers: Option>, +} + +impl ProviderConfig { + /// Construct a [`ProviderConfig`] with the required `base_url` set; + /// all other fields default to unset. + pub fn new(base_url: impl Into) -> Self { + Self { + base_url: base_url.into(), + ..Self::default() + } + } + + /// Set the provider type (`"openai"`, `"azure"`, or `"anthropic"`). + pub fn with_provider_type(mut self, provider_type: impl Into) -> Self { + self.provider_type = Some(provider_type.into()); + self + } + + /// Set the API format (`"completions"` or `"responses"`; openai/azure only). + pub fn with_wire_api(mut self, wire_api: impl Into) -> Self { + self.wire_api = Some(wire_api.into()); + self + } + + /// Set the API key. Optional for local providers like Ollama. + pub fn with_api_key(mut self, api_key: impl Into) -> Self { + self.api_key = Some(api_key.into()); + self + } + + /// Set the bearer token used to populate the `Authorization` header. + /// Takes precedence over `api_key` when both are set. + pub fn with_bearer_token(mut self, bearer_token: impl Into) -> Self { + self.bearer_token = Some(bearer_token.into()); + self + } + + /// Set Azure-specific options. + pub fn with_azure(mut self, azure: AzureProviderOptions) -> Self { + self.azure = Some(azure); + self + } + + /// Set the custom HTTP headers attached to outbound provider requests. + pub fn with_headers(mut self, headers: HashMap) -> Self { + self.headers = Some(headers); + self + } +} + +/// Azure-specific provider options. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AzureProviderOptions { + /// Azure API version. Defaults to `"2024-10-21"`. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub api_version: Option, +} + +/// Configuration for creating a new session via the `session.create` RPC. +/// +/// All fields are optional — the CLI applies sensible defaults. +/// +/// # Construction +/// +/// Two equivalent shapes are supported: +/// +/// 1. **Chained builder** (preferred for compile-time-known values): +/// +/// ``` +/// # use github_copilot_sdk::types::SessionConfig; +/// let cfg = SessionConfig::default() +/// .with_client_name("my-app") +/// .with_streaming(true) +/// .with_enable_config_discovery(true); +/// ``` +/// +/// 2. **Direct field assignment** (preferred when forwarding `Option` +/// from upstream code, since `with_` setters take the inner +/// `T`, not `Option`): +/// +/// ``` +/// # use github_copilot_sdk::types::SessionConfig; +/// # let upstream_model: Option = None; +/// # let upstream_system_message: Option = None; +/// let mut cfg = SessionConfig::default() +/// .with_client_name("my-app") +/// .with_streaming(true); +/// cfg.model = upstream_model; +/// cfg.system_message = upstream_system_message; +/// ``` +/// +/// Mixing the two is fine: chain the fields you know at compile time, +/// then assign the `Option` pass-through fields directly. All +/// fields on this struct are `pub`. This pattern matches the +/// `http::request::Parts` / `hyper::Body::Builder` convention in the +/// wider Rust ecosystem. +/// +/// # Field naming across SDKs +/// +/// Rust field names are snake_case (`available_tools`, `system_message`); +/// they round-trip to the camelCase wire protocol via `#[serde(rename_all = +/// "camelCase")]`. When porting code from the TypeScript, Go, Python, or +/// .NET SDKs — or reading the raw JSON-RPC traces — fields appear as +/// `availableTools`, `systemMessage`, etc. +#[derive(Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct SessionConfig { + /// Custom session ID. When unset, the CLI generates one. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, + /// Model to use (e.g. `"gpt-4"`, `"claude-sonnet-4"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + /// Application name sent as `User-Agent` context. + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + /// Reasoning effort level (e.g. `"low"`, `"medium"`, `"high"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, + /// Enable streaming token deltas via `assistant.message_delta` events. + #[serde(skip_serializing_if = "Option::is_none")] + pub streaming: Option, + /// Custom system message configuration. + #[serde(skip_serializing_if = "Option::is_none")] + pub system_message: Option, + /// Client-defined tools to expose to the agent. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + /// Allowlist of built-in tool names the agent may use. + #[serde(skip_serializing_if = "Option::is_none")] + pub available_tools: Option>, + /// Blocklist of built-in tool names the agent must not use. + #[serde(skip_serializing_if = "Option::is_none")] + pub excluded_tools: Option>, + /// MCP server configurations passed through to the CLI. + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + /// How the CLI interprets env values in MCP server configs. + /// `"direct"` = literal values; `"indirect"` = env var names to look up. + #[serde(skip_serializing_if = "Option::is_none")] + pub env_value_mode: Option, + /// When true, the CLI runs config discovery (MCP config files, skills, plugins). + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_config_discovery: Option, + /// Enable the `ask_user` tool for interactive user input. Defaults to + /// `Some(true)` via [`SessionConfig::default`]. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_user_input: Option, + /// Enable `permission.request` JSON-RPC calls from the CLI. Defaults + /// to `Some(true)` via [`SessionConfig::default`]; the default + /// [`DenyAllHandler`](crate::handler::DenyAllHandler) refuses all + /// requests so the wire surface is safe out-of-the-box. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_permission: Option, + /// Enable `exitPlanMode.request` JSON-RPC calls for plan approval. + /// Defaults to `Some(true)` via [`SessionConfig::default`]. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_exit_plan_mode: Option, + /// Enable `autoModeSwitch.request` JSON-RPC calls. When `true`, the CLI + /// asks the handler whether to switch to auto model when an eligible + /// rate limit is hit. Defaults to `Some(true)` via + /// [`SessionConfig::default`]. Without this flag, the CLI surfaces the + /// rate-limit error directly without offering the auto-mode switch. + /// + /// Currently a Rust-only typed handler; cross-SDK parity (Node / + /// Python / Go / .NET) is post-release follow-up work — see + /// [`SessionHandler::on_auto_mode_switch`]. + /// + /// [`SessionHandler::on_auto_mode_switch`]: crate::handler::SessionHandler::on_auto_mode_switch + #[serde(skip_serializing_if = "Option::is_none")] + pub request_auto_mode_switch: Option, + /// Advertise elicitation provider capability. When true, the CLI sends + /// `elicitation.requested` events that the handler can respond to. + /// Defaults to `Some(true)` via [`SessionConfig::default`]. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_elicitation: Option, + /// Skill directory paths passed through to the GitHub Copilot CLI. + #[serde(skip_serializing_if = "Option::is_none")] + pub skill_directories: Option>, + /// Skill names to disable. Skills in this set will not be available + /// even if found in skill directories. + #[serde(skip_serializing_if = "Option::is_none")] + pub disabled_skills: Option>, + /// MCP server names to disable. Servers in this set will not be + /// started or connected. + #[serde(skip_serializing_if = "Option::is_none")] + pub disabled_mcp_servers: Option>, + /// Enable session hooks. When `true`, the CLI sends `hooks.invoke` + /// RPC requests at key lifecycle points (pre/post tool use, prompt + /// submission, session start/end, errors). + #[serde(skip_serializing_if = "Option::is_none")] + pub hooks: Option, + /// Custom agents (sub-agents) configured for this session. + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_agents: Option>, + /// Configures the built-in default agent. Use `excluded_tools` to + /// hide tools from the default agent while keeping them available + /// to custom sub-agents that reference them in their `tools` list. + #[serde(skip_serializing_if = "Option::is_none")] + pub default_agent: Option, + /// Name of the custom agent to activate when the session starts. + /// Must match the `name` of one of the agents in [`Self::custom_agents`]. + #[serde(skip_serializing_if = "Option::is_none")] + pub agent: Option, + /// Configures infinite sessions: persistent workspace + automatic + /// context-window compaction. Enabled by default on the CLI. + #[serde(skip_serializing_if = "Option::is_none")] + pub infinite_sessions: Option, + /// Custom model provider (BYOK). When set, the session routes + /// requests through this provider instead of the default Copilot + /// routing. + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option, + /// Per-property overrides for model capabilities, deep-merged over + /// runtime defaults. + #[serde(skip_serializing_if = "Option::is_none")] + pub model_capabilities: Option, + /// Override the default configuration directory location. When set, + /// the session uses this directory for storing config and state. + #[serde(skip_serializing_if = "Option::is_none")] + pub config_dir: Option, + /// Working directory for the session. Tool operations resolve + /// relative paths against this directory. + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, + /// Per-session GitHub token. Distinct from + /// [`ClientOptions::github_token`](crate::ClientOptions::github_token), + /// which authenticates the CLI process itself; this token determines + /// the GitHub identity used for content exclusion, model routing, and + /// quota checks for *this session*. + #[serde(rename = "gitHubToken", skip_serializing_if = "Option::is_none")] + pub github_token: Option, + /// Forward sub-agent streaming events to this connection. When false, + /// only non-streaming sub-agent events and `subagent.*` lifecycle events + /// are delivered. Defaults to true on the CLI. + #[serde(skip_serializing_if = "Option::is_none")] + pub include_sub_agent_streaming_events: Option, + /// Slash commands registered for this session. When the CLI has a TUI, + /// each command appears as `/name` for the user to invoke and the + /// associated [`CommandHandler`] is called when executed. + #[serde(skip_serializing_if = "Option::is_none", skip_deserializing)] + pub commands: Option>, + /// Custom session filesystem provider for this session. Required when + /// the [`Client`](crate::Client) was started with + /// [`ClientOptions::session_fs`](crate::ClientOptions::session_fs) set. + /// See [`SessionFsProvider`]. + #[serde(skip)] + pub session_fs_provider: Option>, + /// Session-level event handler. The default is + /// [`DenyAllHandler`](crate::handler::DenyAllHandler) — permission + /// requests are denied; other events are no-ops. Use + /// [`with_handler`](Self::with_handler) to install a custom handler. + #[serde(skip)] + pub handler: Option>, + /// Session lifecycle hook handler (pre/post tool use, session + /// start/end, etc.). When set, the SDK auto-enables the wire-level + /// `hooks` flag. Use [`with_hooks`](Self::with_hooks) to install one. + #[serde(skip)] + pub hooks_handler: Option>, + /// System-message transform. When set, the SDK injects the matching + /// `action: "transform"` sections into the system message and routes + /// `systemMessage.transform` RPC callbacks to it during the session. + /// Use [`with_transform`](Self::with_transform) to install one. + #[serde(skip)] + pub transform: Option>, +} + +impl std::fmt::Debug for SessionConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SessionConfig") + .field("session_id", &self.session_id) + .field("model", &self.model) + .field("client_name", &self.client_name) + .field("reasoning_effort", &self.reasoning_effort) + .field("streaming", &self.streaming) + .field("system_message", &self.system_message) + .field("tools", &self.tools) + .field("available_tools", &self.available_tools) + .field("excluded_tools", &self.excluded_tools) + .field("mcp_servers", &self.mcp_servers) + .field("env_value_mode", &self.env_value_mode) + .field("enable_config_discovery", &self.enable_config_discovery) + .field("request_user_input", &self.request_user_input) + .field("request_permission", &self.request_permission) + .field("request_exit_plan_mode", &self.request_exit_plan_mode) + .field("request_auto_mode_switch", &self.request_auto_mode_switch) + .field("request_elicitation", &self.request_elicitation) + .field("skill_directories", &self.skill_directories) + .field("disabled_skills", &self.disabled_skills) + .field("disabled_mcp_servers", &self.disabled_mcp_servers) + .field("hooks", &self.hooks) + .field("custom_agents", &self.custom_agents) + .field("default_agent", &self.default_agent) + .field("agent", &self.agent) + .field("infinite_sessions", &self.infinite_sessions) + .field("provider", &self.provider) + .field("model_capabilities", &self.model_capabilities) + .field("config_dir", &self.config_dir) + .field("working_directory", &self.working_directory) + .field( + "github_token", + &self.github_token.as_ref().map(|_| ""), + ) + .field( + "include_sub_agent_streaming_events", + &self.include_sub_agent_streaming_events, + ) + .field("commands", &self.commands) + .field( + "session_fs_provider", + &self.session_fs_provider.as_ref().map(|_| ""), + ) + .field("handler", &self.handler.as_ref().map(|_| "")) + .field( + "hooks_handler", + &self.hooks_handler.as_ref().map(|_| ""), + ) + .field("transform", &self.transform.as_ref().map(|_| "")) + .finish() + } +} + +impl Default for SessionConfig { + /// Permission and elicitation flows are enabled by default. With + /// Rust's trait-based handlers, the SDK installs `DenyAllHandler` when + /// no handler is provided, so these flags being `Some(true)` means the + /// wire surface advertises the capabilities — and the default handler + /// safely refuses requests. Callers that want the wire surface fully + /// disabled set these explicitly to `Some(false)`. + fn default() -> Self { + Self { + session_id: None, + model: None, + client_name: None, + reasoning_effort: None, + streaming: None, + system_message: None, + tools: None, + available_tools: None, + excluded_tools: None, + mcp_servers: None, + env_value_mode: None, + enable_config_discovery: None, + request_user_input: Some(true), + request_permission: Some(true), + request_exit_plan_mode: Some(true), + request_auto_mode_switch: Some(true), + request_elicitation: Some(true), + skill_directories: None, + disabled_skills: None, + disabled_mcp_servers: None, + hooks: None, + custom_agents: None, + default_agent: None, + agent: None, + infinite_sessions: None, + provider: None, + model_capabilities: None, + config_dir: None, + working_directory: None, + github_token: None, + include_sub_agent_streaming_events: None, + commands: None, + session_fs_provider: None, + handler: None, + hooks_handler: None, + transform: None, + } + } +} + +impl SessionConfig { + /// Install a custom [`SessionHandler`] for this session. + pub fn with_handler(mut self, handler: Arc) -> Self { + self.handler = Some(handler); + self + } + + /// Register slash commands for this session. Each command appears as + /// `/name` in the CLI's TUI; the handler is invoked when the user + /// executes the command. Replaces any commands previously set on this + /// config. See [`CommandDefinition`]. + pub fn with_commands(mut self, commands: Vec) -> Self { + self.commands = Some(commands); + self + } + + /// Install a [`SessionFsProvider`] backing the session's filesystem. + /// Required when the [`Client`](crate::Client) was started with + /// [`ClientOptions::session_fs`](crate::ClientOptions::session_fs). + pub fn with_session_fs_provider(mut self, provider: Arc) -> Self { + self.session_fs_provider = Some(provider); + self + } + + /// Install a [`SessionHooks`] handler. Automatically enables the + /// wire-level `hooks` flag on session creation. + pub fn with_hooks(mut self, hooks: Arc) -> Self { + self.hooks_handler = Some(hooks); + self + } + + /// Install a [`SystemMessageTransform`]. The SDK injects the matching + /// `action: "transform"` sections into the system message and routes + /// `systemMessage.transform` RPC callbacks to it during the session. + pub fn with_transform(mut self, transform: Arc) -> Self { + self.transform = Some(transform); + self + } + + /// Wrap the configured handler so every permission request is + /// auto-approved. Forwards every non-permission event to the inner + /// handler unchanged. + /// + /// If no handler has been installed via [`with_handler`](Self::with_handler), + /// wraps a [`DenyAllHandler`](crate::handler::DenyAllHandler) — useful + /// when you only care about permission policy and want the trait + /// fallback responses for everything else. + /// + /// Order-independent: `with_handler(...).approve_all_permissions()` and + /// `approve_all_permissions().with_handler(...)` are NOT equivalent — + /// the second form discards the wrap because `with_handler` overwrites + /// the handler field. Always call `approve_all_permissions` *after* + /// `with_handler`. + pub fn approve_all_permissions(mut self) -> Self { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::approve_all(inner)); + self + } + + /// Wrap the configured handler so every permission request is + /// auto-denied. See [`approve_all_permissions`](Self::approve_all_permissions) + /// for ordering and default-handler semantics. + pub fn deny_all_permissions(mut self) -> Self { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::deny_all(inner)); + self + } + + /// Wrap the configured handler with a closure-based permission policy: + /// `predicate` is called for each permission request; `true` approves, + /// `false` denies. See + /// [`approve_all_permissions`](Self::approve_all_permissions) for + /// ordering and default-handler semantics. + pub fn approve_permissions_if(mut self, predicate: F) -> Self + where + F: Fn(&crate::types::PermissionRequestData) -> bool + Send + Sync + 'static, + { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::approve_if(inner, predicate)); + self + } + + /// Set a custom session ID (when unset, the CLI generates one). + pub fn with_session_id(mut self, id: impl Into) -> Self { + self.session_id = Some(id.into()); + self + } + + /// Set the model identifier (e.g. `"claude-sonnet-4"`). + pub fn with_model(mut self, model: impl Into) -> Self { + self.model = Some(model.into()); + self + } + + /// Set the application name sent as `User-Agent` context. + pub fn with_client_name(mut self, name: impl Into) -> Self { + self.client_name = Some(name.into()); + self + } + + /// Set the reasoning effort level (e.g. `"low"`, `"medium"`, `"high"`). + pub fn with_reasoning_effort(mut self, effort: impl Into) -> Self { + self.reasoning_effort = Some(effort.into()); + self + } + + /// Enable streaming token deltas via `assistant.message_delta` events. + pub fn with_streaming(mut self, streaming: bool) -> Self { + self.streaming = Some(streaming); + self + } + + /// Set a custom system message configuration. + pub fn with_system_message(mut self, system_message: SystemMessageConfig) -> Self { + self.system_message = Some(system_message); + self + } + + /// Set the client-defined tools to expose to the agent. + pub fn with_tools>(mut self, tools: I) -> Self { + self.tools = Some(tools.into_iter().collect()); + self + } + + /// Set the allowlist of built-in tool names the agent may use. + pub fn with_available_tools(mut self, tools: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.available_tools = Some(tools.into_iter().map(Into::into).collect()); + self + } + + /// Set the blocklist of built-in tool names the agent must not use. + pub fn with_excluded_tools(mut self, tools: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.excluded_tools = Some(tools.into_iter().map(Into::into).collect()); + self + } + + /// Set MCP server configurations passed through to the CLI. + pub fn with_mcp_servers(mut self, servers: HashMap) -> Self { + self.mcp_servers = Some(servers); + self + } + + /// Set how the CLI interprets env values in MCP server configs + /// (`"direct"` literal vs `"indirect"` env var name lookup). + pub fn with_env_value_mode(mut self, mode: impl Into) -> Self { + self.env_value_mode = Some(mode.into()); + self + } + + /// Enable or disable CLI config discovery (MCP config files, skills, plugins). + pub fn with_enable_config_discovery(mut self, enable: bool) -> Self { + self.enable_config_discovery = Some(enable); + self + } + + /// Enable the `ask_user` tool. Defaults to `Some(true)` via [`Self::default`]. + pub fn with_request_user_input(mut self, enable: bool) -> Self { + self.request_user_input = Some(enable); + self + } + + /// Enable `permission.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_permission(mut self, enable: bool) -> Self { + self.request_permission = Some(enable); + self + } + + /// Enable `exitPlanMode.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_exit_plan_mode(mut self, enable: bool) -> Self { + self.request_exit_plan_mode = Some(enable); + self + } + + /// Enable `autoModeSwitch.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_auto_mode_switch(mut self, enable: bool) -> Self { + self.request_auto_mode_switch = Some(enable); + self + } + + /// Advertise elicitation provider capability. Defaults to `Some(true)`. + pub fn with_request_elicitation(mut self, enable: bool) -> Self { + self.request_elicitation = Some(enable); + self + } + + /// Set skill directory paths passed through to the CLI. + pub fn with_skill_directories(mut self, paths: I) -> Self + where + I: IntoIterator, + P: Into, + { + self.skill_directories = Some(paths.into_iter().map(Into::into).collect()); + self + } + + /// Set the names of skills to disable (overrides skill discovery). + pub fn with_disabled_skills(mut self, names: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.disabled_skills = Some(names.into_iter().map(Into::into).collect()); + self + } + + /// Set the names of MCP servers to disable. + pub fn with_disabled_mcp_servers(mut self, names: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.disabled_mcp_servers = Some(names.into_iter().map(Into::into).collect()); + self + } + + /// Set the custom agents (sub-agents) configured for this session. + pub fn with_custom_agents>( + mut self, + agents: I, + ) -> Self { + self.custom_agents = Some(agents.into_iter().collect()); + self + } + + /// Configure the built-in default agent. + pub fn with_default_agent(mut self, agent: DefaultAgentConfig) -> Self { + self.default_agent = Some(agent); + self + } + + /// Activate a named custom agent on session start. Must match the + /// `name` of one of the agents in [`Self::custom_agents`]. + pub fn with_agent(mut self, name: impl Into) -> Self { + self.agent = Some(name.into()); + self + } + + /// Configure infinite sessions (persistent workspace + automatic + /// context-window compaction). + pub fn with_infinite_sessions(mut self, config: InfiniteSessionConfig) -> Self { + self.infinite_sessions = Some(config); + self + } + + /// Configure a custom model provider (BYOK). + pub fn with_provider(mut self, provider: ProviderConfig) -> Self { + self.provider = Some(provider); + self + } + + /// Set per-property overrides for model capabilities. + pub fn with_model_capabilities( + mut self, + capabilities: crate::generated::api_types::ModelCapabilitiesOverride, + ) -> Self { + self.model_capabilities = Some(capabilities); + self + } + + /// Override the default configuration directory location. + pub fn with_config_dir(mut self, dir: impl Into) -> Self { + self.config_dir = Some(dir.into()); + self + } + + /// Set the per-session working directory. Tool operations resolve + /// relative paths against this directory. + pub fn with_working_directory(mut self, dir: impl Into) -> Self { + self.working_directory = Some(dir.into()); + self + } + + /// Set the per-session GitHub token. Distinct from + /// [`ClientOptions::github_token`](crate::ClientOptions::github_token); + /// this token determines the GitHub identity used for content exclusion, + /// model routing, and quota checks for this session only. + pub fn with_github_token(mut self, token: impl Into) -> Self { + self.github_token = Some(token.into()); + self + } + + /// Forward sub-agent streaming events to this connection. Defaults + /// to true on the CLI when unset. + pub fn with_include_sub_agent_streaming_events(mut self, include: bool) -> Self { + self.include_sub_agent_streaming_events = Some(include); + self + } +} + +/// Configuration for resuming an existing session via the `session.resume` RPC. +/// +/// See [`SessionConfig`] for the construction patterns (chained `with_*` +/// builder vs. direct field assignment for `Option` pass-through) and +/// the note on snake_case vs. camelCase field naming. +#[derive(Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct ResumeSessionConfig { + /// ID of the session to resume. + pub session_id: SessionId, + /// Application name sent as User-Agent context. + #[serde(skip_serializing_if = "Option::is_none")] + pub client_name: Option, + /// Enable streaming token deltas. + #[serde(skip_serializing_if = "Option::is_none")] + pub streaming: Option, + /// Re-supply the system message so the agent retains workspace context + /// across CLI process restarts. + #[serde(skip_serializing_if = "Option::is_none")] + pub system_message: Option, + /// Client-defined tools to re-supply on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + /// Blocklist of built-in tool names. + #[serde(skip_serializing_if = "Option::is_none")] + pub excluded_tools: Option>, + /// Re-supply MCP servers so they remain available after app restart. + #[serde(skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + /// How the CLI interprets env values in MCP configs. + #[serde(skip_serializing_if = "Option::is_none")] + pub env_value_mode: Option, + /// Enable config discovery on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub enable_config_discovery: Option, + /// Enable the ask_user tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_user_input: Option, + /// Enable permission request RPCs. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_permission: Option, + /// Enable exit-plan-mode request RPCs. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_exit_plan_mode: Option, + /// Enable auto-mode-switch request RPCs on resume. Defaults to + /// `Some(true)` via [`ResumeSessionConfig::new`]. See + /// [`SessionConfig::request_auto_mode_switch`] for details. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_auto_mode_switch: Option, + /// Advertise elicitation provider capability on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub request_elicitation: Option, + /// Skill directory paths passed through to the GitHub Copilot CLI on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub skill_directories: Option>, + /// Enable session hooks on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub hooks: Option, + /// Custom agents to re-supply on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub custom_agents: Option>, + /// Configures the built-in default agent on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub default_agent: Option, + /// Name of the custom agent to activate. + #[serde(skip_serializing_if = "Option::is_none")] + pub agent: Option, + /// Re-supply infinite session configuration on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub infinite_sessions: Option, + /// Re-supply BYOK provider configuration on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub provider: Option, + /// Per-property model capability overrides on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub model_capabilities: Option, + /// Override the default configuration directory location on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub config_dir: Option, + /// Per-session working directory on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub working_directory: Option, + /// Per-session GitHub token on resume. See + /// [`SessionConfig::github_token`]. + #[serde(rename = "gitHubToken", skip_serializing_if = "Option::is_none")] + pub github_token: Option, + /// Forward sub-agent streaming events to this connection on resume. + #[serde(skip_serializing_if = "Option::is_none")] + pub include_sub_agent_streaming_events: Option, + /// Slash commands registered for this session on resume. See + /// [`SessionConfig::commands`] — commands are not persisted server-side, + /// so the resume payload re-supplies the registration. + #[serde(skip_serializing_if = "Option::is_none", skip_deserializing)] + pub commands: Option>, + /// Custom session filesystem provider. Required on resume when the + /// [`Client`](crate::Client) was started with + /// [`ClientOptions::session_fs`](crate::ClientOptions::session_fs). + /// See [`SessionConfig::session_fs_provider`]. + #[serde(skip)] + pub session_fs_provider: Option>, + /// Force-fail resume if the session does not exist on disk, instead of + /// silently starting a new session. + #[serde(skip_serializing_if = "Option::is_none")] + pub disable_resume: Option, + /// Session-level event handler. See [`SessionConfig::handler`]. + #[serde(skip)] + pub handler: Option>, + /// Session hook handler. See [`SessionConfig::hooks_handler`]. + #[serde(skip)] + pub hooks_handler: Option>, + /// System-message transform. See [`SessionConfig::transform`]. + #[serde(skip)] + pub transform: Option>, +} + +impl std::fmt::Debug for ResumeSessionConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResumeSessionConfig") + .field("session_id", &self.session_id) + .field("client_name", &self.client_name) + .field("streaming", &self.streaming) + .field("system_message", &self.system_message) + .field("tools", &self.tools) + .field("excluded_tools", &self.excluded_tools) + .field("mcp_servers", &self.mcp_servers) + .field("env_value_mode", &self.env_value_mode) + .field("enable_config_discovery", &self.enable_config_discovery) + .field("request_user_input", &self.request_user_input) + .field("request_permission", &self.request_permission) + .field("request_exit_plan_mode", &self.request_exit_plan_mode) + .field("request_auto_mode_switch", &self.request_auto_mode_switch) + .field("request_elicitation", &self.request_elicitation) + .field("skill_directories", &self.skill_directories) + .field("hooks", &self.hooks) + .field("custom_agents", &self.custom_agents) + .field("default_agent", &self.default_agent) + .field("agent", &self.agent) + .field("infinite_sessions", &self.infinite_sessions) + .field("provider", &self.provider) + .field("model_capabilities", &self.model_capabilities) + .field("config_dir", &self.config_dir) + .field("working_directory", &self.working_directory) + .field( + "github_token", + &self.github_token.as_ref().map(|_| ""), + ) + .field( + "include_sub_agent_streaming_events", + &self.include_sub_agent_streaming_events, + ) + .field("commands", &self.commands) + .field( + "session_fs_provider", + &self.session_fs_provider.as_ref().map(|_| ""), + ) + .field("handler", &self.handler.as_ref().map(|_| "")) + .field( + "hooks_handler", + &self.hooks_handler.as_ref().map(|_| ""), + ) + .field("transform", &self.transform.as_ref().map(|_| "")) + .finish() + } +} + +impl ResumeSessionConfig { + /// Construct a `ResumeSessionConfig` with the given session ID and all + /// other fields left unset. Combine with `.with_*` builders or struct + /// update syntax (`..ResumeSessionConfig::new(id)`) to populate the + /// fields you need. + pub fn new(session_id: SessionId) -> Self { + Self { + session_id, + client_name: None, + streaming: None, + system_message: None, + tools: None, + excluded_tools: None, + mcp_servers: None, + env_value_mode: None, + enable_config_discovery: None, + request_user_input: Some(true), + request_permission: Some(true), + request_exit_plan_mode: Some(true), + request_auto_mode_switch: Some(true), + request_elicitation: Some(true), + skill_directories: None, + hooks: None, + custom_agents: None, + default_agent: None, + agent: None, + infinite_sessions: None, + provider: None, + model_capabilities: None, + config_dir: None, + working_directory: None, + github_token: None, + include_sub_agent_streaming_events: None, + commands: None, + session_fs_provider: None, + disable_resume: None, + handler: None, + hooks_handler: None, + transform: None, + } + } + + /// Install a custom [`SessionHandler`] for this session. + pub fn with_handler(mut self, handler: Arc) -> Self { + self.handler = Some(handler); + self + } + + /// Install a [`SessionHooks`] handler. Automatically enables the + /// wire-level `hooks` flag on session resumption. + pub fn with_hooks(mut self, hooks: Arc) -> Self { + self.hooks_handler = Some(hooks); + self + } + + /// Install a [`SystemMessageTransform`]. + pub fn with_transform(mut self, transform: Arc) -> Self { + self.transform = Some(transform); + self + } + + /// Register slash commands for the resumed session. See + /// [`SessionConfig::with_commands`] — commands are not persisted + /// server-side, so the resume payload re-supplies the registration. + pub fn with_commands(mut self, commands: Vec) -> Self { + self.commands = Some(commands); + self + } + + /// Install a [`SessionFsProvider`] backing the resumed session's + /// filesystem. See [`SessionConfig::with_session_fs_provider`]. + pub fn with_session_fs_provider(mut self, provider: Arc) -> Self { + self.session_fs_provider = Some(provider); + self + } + + /// Wrap the configured handler so every permission request is + /// auto-approved. See + /// [`SessionConfig::approve_all_permissions`] for semantics. + pub fn approve_all_permissions(mut self) -> Self { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::approve_all(inner)); + self + } + + /// Wrap the configured handler so every permission request is + /// auto-denied. See + /// [`SessionConfig::deny_all_permissions`] for semantics. + pub fn deny_all_permissions(mut self) -> Self { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::deny_all(inner)); + self + } + + /// Wrap the configured handler with a predicate-based permission policy. + /// See [`SessionConfig::approve_permissions_if`] for semantics. + pub fn approve_permissions_if(mut self, predicate: F) -> Self + where + F: Fn(&crate::types::PermissionRequestData) -> bool + Send + Sync + 'static, + { + let inner = self + .handler + .take() + .unwrap_or_else(|| Arc::new(crate::handler::DenyAllHandler)); + self.handler = Some(crate::permission::approve_if(inner, predicate)); + self + } + + /// Set the application name sent as `User-Agent` context. + pub fn with_client_name(mut self, name: impl Into) -> Self { + self.client_name = Some(name.into()); + self + } + + /// Enable streaming token deltas via `assistant.message_delta` events. + pub fn with_streaming(mut self, streaming: bool) -> Self { + self.streaming = Some(streaming); + self + } + + /// Re-supply the system message so the agent retains workspace context + /// across CLI process restarts. + pub fn with_system_message(mut self, system_message: SystemMessageConfig) -> Self { + self.system_message = Some(system_message); + self + } + + /// Re-supply client-defined tools on resume. + pub fn with_tools>(mut self, tools: I) -> Self { + self.tools = Some(tools.into_iter().collect()); + self + } + + /// Set the blocklist of built-in tool names the agent must not use. + pub fn with_excluded_tools(mut self, tools: I) -> Self + where + I: IntoIterator, + S: Into, + { + self.excluded_tools = Some(tools.into_iter().map(Into::into).collect()); + self + } + + /// Re-supply MCP server configurations on resume. + pub fn with_mcp_servers(mut self, servers: HashMap) -> Self { + self.mcp_servers = Some(servers); + self + } + + /// Set how the CLI interprets env values in MCP configs (`"direct"` / + /// `"indirect"`). + pub fn with_env_value_mode(mut self, mode: impl Into) -> Self { + self.env_value_mode = Some(mode.into()); + self + } + + /// Enable or disable CLI config discovery on resume. + pub fn with_enable_config_discovery(mut self, enable: bool) -> Self { + self.enable_config_discovery = Some(enable); + self + } + + /// Enable the `ask_user` tool. Defaults to `Some(true)` via [`Self::new`]. + pub fn with_request_user_input(mut self, enable: bool) -> Self { + self.request_user_input = Some(enable); + self + } + + /// Enable `permission.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_permission(mut self, enable: bool) -> Self { + self.request_permission = Some(enable); + self + } + + /// Enable `exitPlanMode.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_exit_plan_mode(mut self, enable: bool) -> Self { + self.request_exit_plan_mode = Some(enable); + self + } + + /// Enable `autoModeSwitch.request` JSON-RPC calls. Defaults to `Some(true)`. + pub fn with_request_auto_mode_switch(mut self, enable: bool) -> Self { + self.request_auto_mode_switch = Some(enable); + self + } + + /// Advertise elicitation provider capability on resume. Defaults to `Some(true)`. + pub fn with_request_elicitation(mut self, enable: bool) -> Self { + self.request_elicitation = Some(enable); + self + } + + /// Set skill directory paths passed through to the CLI on resume. + pub fn with_skill_directories(mut self, paths: I) -> Self + where + I: IntoIterator, + P: Into, + { + self.skill_directories = Some(paths.into_iter().map(Into::into).collect()); + self + } + + /// Re-supply custom agents on resume. + pub fn with_custom_agents>( + mut self, + agents: I, + ) -> Self { + self.custom_agents = Some(agents.into_iter().collect()); + self + } + + /// Configure the built-in default agent on resume. + pub fn with_default_agent(mut self, agent: DefaultAgentConfig) -> Self { + self.default_agent = Some(agent); + self + } + + /// Activate a named custom agent on resume. + pub fn with_agent(mut self, name: impl Into) -> Self { + self.agent = Some(name.into()); + self + } + + /// Re-supply infinite session configuration on resume. + pub fn with_infinite_sessions(mut self, config: InfiniteSessionConfig) -> Self { + self.infinite_sessions = Some(config); + self + } + + /// Re-supply BYOK provider configuration on resume. + pub fn with_provider(mut self, provider: ProviderConfig) -> Self { + self.provider = Some(provider); + self + } + + /// Set per-property model capability overrides on resume. + pub fn with_model_capabilities( + mut self, + capabilities: crate::generated::api_types::ModelCapabilitiesOverride, + ) -> Self { + self.model_capabilities = Some(capabilities); + self + } + + /// Override the default configuration directory location on resume. + pub fn with_config_dir(mut self, dir: impl Into) -> Self { + self.config_dir = Some(dir.into()); + self + } + + /// Set the per-session working directory on resume. + pub fn with_working_directory(mut self, dir: impl Into) -> Self { + self.working_directory = Some(dir.into()); + self + } + + /// Set the per-session GitHub token on resume. See + /// [`SessionConfig::github_token`] for distinction from the + /// client-level token. + pub fn with_github_token(mut self, token: impl Into) -> Self { + self.github_token = Some(token.into()); + self + } + + /// Forward sub-agent streaming events to this connection on resume. + pub fn with_include_sub_agent_streaming_events(mut self, include: bool) -> Self { + self.include_sub_agent_streaming_events = Some(include); + self + } + + /// Force-fail resume if the session does not exist on disk, instead + /// of silently starting a new session. + pub fn with_disable_resume(mut self, disable: bool) -> Self { + self.disable_resume = Some(disable); + self + } +} + +/// Controls how the system message is constructed. +/// +/// Use `mode: "append"` (default) to add content after the built-in system +/// message, `"replace"` to substitute it entirely, or `"customize"` for +/// section-level overrides. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct SystemMessageConfig { + /// How content is applied: `"append"` (default), `"replace"`, or `"customize"`. + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// Content string to append or replace. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + /// Section-level overrides (used with `mode: "customize"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub sections: Option>, +} + +impl SystemMessageConfig { + /// Construct an empty [`SystemMessageConfig`]; all fields default to + /// unset. + pub fn new() -> Self { + Self::default() + } + + /// Set the application mode: `"append"` (default), `"replace"`, or + /// `"customize"`. + pub fn with_mode(mut self, mode: impl Into) -> Self { + self.mode = Some(mode.into()); + self + } + + /// Set the system message content (used by `"append"` and `"replace"` + /// modes). + pub fn with_content(mut self, content: impl Into) -> Self { + self.content = Some(content.into()); + self + } + + /// Set the section-level overrides (used with `mode: "customize"`). + pub fn with_sections(mut self, sections: HashMap) -> Self { + self.sections = Some(sections); + self + } +} + +/// An override operation for a single system prompt section. +/// +/// Used within [`SystemMessageConfig::sections`] when `mode` is `"customize"`. +/// The `action` field determines the operation: `"replace"`, `"remove"`, +/// `"append"`, `"prepend"`, or `"transform"`. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SectionOverride { + /// Override action: `"replace"`, `"remove"`, `"append"`, `"prepend"`, or `"transform"`. + #[serde(skip_serializing_if = "Option::is_none")] + pub action: Option, + /// Content for the override operation. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +/// Response from `session.create`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateSessionResult { + /// The CLI-assigned session ID. + pub session_id: SessionId, + /// Workspace directory for the session (infinite sessions). + #[serde(skip_serializing_if = "Option::is_none")] + pub workspace_path: Option, + /// Remote session URL, if the session is running remotely. + #[serde(default, alias = "remote_url")] + pub remote_url: Option, + /// Capabilities negotiated with the CLI for this session. + #[serde(skip_serializing_if = "Option::is_none")] + pub capabilities: Option, +} + +/// Parameters for the `session.sendTelemetry` RPC. +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionTelemetryEvent { + /// Telemetry event kind (for example, `"session_shutdown"`). + pub kind: String, + /// Non-restricted string properties to include with the telemetry event. + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>, + /// Restricted string properties that may contain sensitive data. + #[serde(skip_serializing_if = "Option::is_none")] + pub restricted_properties: Option>, + /// Numeric metrics to include with the telemetry event. + #[serde(skip_serializing_if = "Option::is_none")] + pub metrics: Option>, +} + +/// Severity level for [`Session::log`](crate::session::Session::log) messages. +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum LogLevel { + /// Informational message (default). + #[default] + Info, + /// Warning message. + Warning, + /// Error message. + Error, +} + +/// Options for [`Session::log`](crate::session::Session::log). +/// +/// Pass `None` to `log` for defaults (info level, persisted to the session +/// event log on disk). +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LogOptions { + /// Log severity. `None` lets the server pick (defaults to `info`). + #[serde(skip_serializing_if = "Option::is_none")] + pub level: Option, + /// When `Some(true)`, the message is transient and not persisted to the + /// session event log on disk. `None` lets the server pick. + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral: Option, +} + +impl LogOptions { + /// Set [`level`](Self::level). + pub fn with_level(mut self, level: LogLevel) -> Self { + self.level = Some(level); + self + } + + /// Set [`ephemeral`](Self::ephemeral). + pub fn with_ephemeral(mut self, ephemeral: bool) -> Self { + self.ephemeral = Some(ephemeral); + self + } +} + +/// Options for [`Session::set_model`](crate::session::Session::set_model). +/// +/// Pass `None` to `set_model` to switch model without any overrides. +#[derive(Debug, Clone, Default)] +pub struct SetModelOptions { + /// Reasoning effort for the new model (e.g. `"low"`, `"medium"`, + /// `"high"`, `"xhigh"`). + pub reasoning_effort: Option, + /// Override individual model capabilities resolved by the runtime. Only + /// fields set on the override are applied; the rest fall back to the + /// runtime-resolved values for the model. + pub model_capabilities: Option, +} + +impl SetModelOptions { + /// Set [`reasoning_effort`](Self::reasoning_effort). + pub fn with_reasoning_effort(mut self, effort: impl Into) -> Self { + self.reasoning_effort = Some(effort.into()); + self + } + + /// Set [`model_capabilities`](Self::model_capabilities). + pub fn with_model_capabilities( + mut self, + caps: crate::generated::api_types::ModelCapabilitiesOverride, + ) -> Self { + self.model_capabilities = Some(caps); + self + } +} + +/// Response from the top-level `ping` RPC. +/// +/// The `protocol_version` field is the most commonly-inspected piece — +/// see [`Client::verify_protocol_version`]. +/// +/// [`Client::verify_protocol_version`]: crate::Client::verify_protocol_version +#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PingResponse { + /// The message echoed back by the CLI. + #[serde(default)] + pub message: String, + /// Server-side timestamp (Unix epoch milliseconds). + #[serde(default)] + pub timestamp: i64, + /// The protocol version negotiated by the CLI, if reported. + #[serde(skip_serializing_if = "Option::is_none")] + pub protocol_version: Option, +} + +/// Parameters for the top-level `sendTelemetry` RPC. +#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerTelemetryEvent { + /// Telemetry event kind (for example, `"app.launched"`). + pub kind: String, + /// SDK client name. Non-allowlisted values are hashed in telemetry. + pub client_name: String, + /// Non-restricted string properties to include with the telemetry event. + #[serde(skip_serializing_if = "Option::is_none")] + pub properties: Option>, + /// Restricted string properties that may contain sensitive data. + #[serde(skip_serializing_if = "Option::is_none")] + pub restricted_properties: Option>, + /// Numeric metrics to include with the telemetry event. + #[serde(skip_serializing_if = "Option::is_none")] + pub metrics: Option>, +} + +/// Line range for file attachments. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AttachmentLineRange { + /// First line (1-based). + pub start: u32, + /// Last line (inclusive). + pub end: u32, +} + +/// Cursor position within a file selection. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AttachmentSelectionPosition { + /// Line number (0-based). + pub line: u32, + /// Character offset (0-based). + pub character: u32, +} + +/// Range of selected text within a file. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AttachmentSelectionRange { + /// Start position. + pub start: AttachmentSelectionPosition, + /// End position. + pub end: AttachmentSelectionPosition, +} + +/// Type of GitHub reference attachment. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +#[non_exhaustive] +pub enum GitHubReferenceType { + /// GitHub issue. + Issue, + /// GitHub pull request. + Pr, + /// GitHub discussion. + Discussion, +} + +/// An attachment included with a user message. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde( + tag = "type", + rename_all = "camelCase", + rename_all_fields = "camelCase" +)] +#[non_exhaustive] +pub enum Attachment { + /// A file path, optionally with a line range. + File { + /// Absolute path to the file. + path: PathBuf, + /// Label shown in the UI. + #[serde(skip_serializing_if = "Option::is_none")] + display_name: Option, + /// Optional line range to focus on. + #[serde(skip_serializing_if = "Option::is_none")] + line_range: Option, + }, + /// A directory path. + Directory { + /// Absolute path to the directory. + path: PathBuf, + /// Label shown in the UI. + #[serde(skip_serializing_if = "Option::is_none")] + display_name: Option, + }, + /// A text selection within a file. + Selection { + /// Path to the file containing the selection. + file_path: PathBuf, + /// The selected text content. + text: String, + /// Label shown in the UI. + #[serde(skip_serializing_if = "Option::is_none")] + display_name: Option, + /// Character range of the selection. + selection: AttachmentSelectionRange, + }, + /// Raw binary data (e.g. an image). + Blob { + /// Base64-encoded data. + data: String, + /// MIME type of the data. + mime_type: String, + /// Label shown in the UI. + #[serde(skip_serializing_if = "Option::is_none")] + display_name: Option, + }, + /// A reference to a GitHub issue, PR, or discussion. + #[serde(rename = "github_reference")] + GitHubReference { + /// Issue/PR/discussion number. + number: u64, + /// Title of the referenced item. + title: String, + /// Kind of reference. + reference_type: GitHubReferenceType, + /// Current state (e.g. "open", "closed"). + state: String, + /// URL to the referenced item. + url: String, + }, +} + +impl Attachment { + /// Returns the display name, if set. + pub fn display_name(&self) -> Option<&str> { + match self { + Self::File { display_name, .. } + | Self::Directory { display_name, .. } + | Self::Selection { display_name, .. } + | Self::Blob { display_name, .. } => display_name.as_deref(), + Self::GitHubReference { .. } => None, + } + } + + /// Returns a human-readable label, deriving one from the path if needed. + pub fn label(&self) -> Option { + if let Some(display_name) = self + .display_name() + .map(str::trim) + .filter(|name| !name.is_empty()) + { + return Some(display_name.to_string()); + } + + match self { + Self::GitHubReference { number, title, .. } => Some(if title.trim().is_empty() { + format!("#{}", number) + } else { + title.trim().to_string() + }), + _ => self.derived_display_name(), + } + } + + /// Ensure `display_name` is populated when the variant supports one. + pub fn ensure_display_name(&mut self) { + if self + .display_name() + .map(str::trim) + .is_some_and(|name| !name.is_empty()) + { + return; + } + + let Some(derived_display_name) = self.derived_display_name() else { + return; + }; + + match self { + Self::File { display_name, .. } + | Self::Directory { display_name, .. } + | Self::Selection { display_name, .. } + | Self::Blob { display_name, .. } => *display_name = Some(derived_display_name), + Self::GitHubReference { .. } => {} + } + } + + fn derived_display_name(&self) -> Option { + match self { + Self::File { path, .. } | Self::Directory { path, .. } => { + Some(attachment_name_from_path(path)) + } + Self::Selection { file_path, .. } => Some(attachment_name_from_path(file_path)), + Self::Blob { .. } => Some("attachment".to_string()), + Self::GitHubReference { .. } => None, + } + } +} + +fn attachment_name_from_path(path: &Path) -> String { + path.file_name() + .map(|name| name.to_string_lossy().into_owned()) + .filter(|name| !name.is_empty()) + .unwrap_or_else(|| { + let full = path.to_string_lossy(); + if full.is_empty() { + "attachment".to_string() + } else { + full.into_owned() + } + }) +} + +/// Normalize a list of attachments so every entry has a `display_name`. +pub fn ensure_attachment_display_names(attachments: &mut [Attachment]) { + for attachment in attachments { + attachment.ensure_display_name(); + } +} + +/// Message delivery mode for [`MessageOptions::mode`]. +/// +/// Controls how a prompt is delivered relative to in-flight session work. +/// Wire values: `"enqueue"` and `"immediate"`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[non_exhaustive] +pub enum DeliveryMode { + /// Queue the prompt behind any in-flight work (default). + Enqueue, + /// Interrupt the session and run the prompt immediately. + Immediate, +} + +/// Options for sending a user message to the agent. +/// +/// Used by both [`Session::send`](crate::session::Session::send) and +/// [`Session::send_and_wait`](crate::session::Session::send_and_wait); the +/// `wait_timeout` field is honored only by `send_and_wait` and is ignored by +/// `send`. +/// +/// `MessageOptions` is `#[non_exhaustive]` and constructed via [`MessageOptions::new`] +/// plus the `with_*` chain so future fields can land without breaking callers. +/// For the trivial case, both `&str` and `String` implement `Into`, +/// so: +/// +/// ```no_run +/// # use github_copilot_sdk::session::Session; +/// # async fn run(session: Session) -> Result<(), github_copilot_sdk::Error> { +/// session.send("hello").await?; +/// # Ok(()) } +/// ``` +/// +/// is equivalent to: +/// +/// ```no_run +/// # use github_copilot_sdk::session::Session; +/// # use github_copilot_sdk::types::MessageOptions; +/// # async fn run(session: Session) -> Result<(), github_copilot_sdk::Error> { +/// session.send(MessageOptions::new("hello")).await?; +/// # Ok(()) } +/// ``` +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct MessageOptions { + /// The user prompt to send. + pub prompt: String, + /// Optional message delivery mode for this turn. + /// + /// Controls whether the prompt is queued behind in-flight work + /// ([`DeliveryMode::Enqueue`], default) or interrupts the session and + /// runs immediately ([`DeliveryMode::Immediate`]). + pub mode: Option, + /// Optional attachments to include with the message. + pub attachments: Option>, + /// Maximum time to wait for the session to go idle. Honored only by + /// `send_and_wait`. Defaults to 60 seconds when unset. + pub wait_timeout: Option, + /// Custom HTTP headers to include in outbound model requests for this + /// turn. When `None` or empty, no `requestHeaders` field is sent on + /// the wire. + pub request_headers: Option>, + /// W3C Trace Context `traceparent` header for this turn. + /// + /// Per-turn override that takes precedence over + /// [`ClientOptions::on_get_trace_context`](crate::ClientOptions::on_get_trace_context). + /// When `None`, the SDK falls back to the provider (if configured) + /// before omitting the field. + pub traceparent: Option, + /// W3C Trace Context `tracestate` header for this turn. + /// + /// Per-turn override paired with [`traceparent`](Self::traceparent). + pub tracestate: Option, +} + +impl MessageOptions { + /// Build a new `MessageOptions` with just a prompt. + pub fn new(prompt: impl Into) -> Self { + Self { + prompt: prompt.into(), + mode: None, + attachments: None, + wait_timeout: None, + request_headers: None, + traceparent: None, + tracestate: None, + } + } + + /// Set the message delivery mode for this turn. + /// + /// Pass [`DeliveryMode::Immediate`] to interrupt the session and run + /// the prompt now; the default ([`DeliveryMode::Enqueue`]) queues the + /// prompt behind in-flight work. + pub fn with_mode(mut self, mode: DeliveryMode) -> Self { + self.mode = Some(mode); + self + } + + /// Attach files / selections / blobs to the message. + pub fn with_attachments(mut self, attachments: Vec) -> Self { + self.attachments = Some(attachments); + self + } + + /// Override the default 60-second wait timeout for `send_and_wait`. + pub fn with_wait_timeout(mut self, timeout: Duration) -> Self { + self.wait_timeout = Some(timeout); + self + } + + /// Set custom HTTP headers for outbound model requests for this turn. + pub fn with_request_headers(mut self, headers: HashMap) -> Self { + self.request_headers = Some(headers); + self + } + + /// Set both `traceparent` and `tracestate` from a [`TraceContext`]. + /// Either field may remain `None` if the [`TraceContext`] has no value + /// for it. Use [`with_traceparent`](Self::with_traceparent) or + /// [`with_tracestate`](Self::with_tracestate) to set them individually. + pub fn with_trace_context(mut self, ctx: TraceContext) -> Self { + self.traceparent = ctx.traceparent; + self.tracestate = ctx.tracestate; + self + } + + /// Set the W3C `traceparent` header for this turn. + pub fn with_traceparent(mut self, traceparent: impl Into) -> Self { + self.traceparent = Some(traceparent.into()); + self + } + + /// Set the W3C `tracestate` header for this turn. + pub fn with_tracestate(mut self, tracestate: impl Into) -> Self { + self.tracestate = Some(tracestate.into()); + self + } +} + +impl From<&str> for MessageOptions { + fn from(prompt: &str) -> Self { + Self::new(prompt) + } +} + +impl From for MessageOptions { + fn from(prompt: String) -> Self { + Self::new(prompt) + } +} + +impl From<&String> for MessageOptions { + fn from(prompt: &String) -> Self { + Self::new(prompt.clone()) + } +} + +/// Response from [`Client::get_status`](crate::Client::get_status). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct GetStatusResponse { + /// Package version (e.g. `"1.0.0"`). + pub version: String, + /// Protocol version for SDK compatibility. + pub protocol_version: u32, +} + +/// Response from [`Client::get_auth_status`](crate::Client::get_auth_status). +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct GetAuthStatusResponse { + /// Whether the user is authenticated. + pub is_authenticated: bool, + /// Authentication type (e.g. `"user"`, `"env"`, `"gh-cli"`, `"hmac"`, + /// `"api-key"`, `"token"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub auth_type: Option, + /// GitHub host URL. + #[serde(skip_serializing_if = "Option::is_none")] + pub host: Option, + /// User login name. + #[serde(skip_serializing_if = "Option::is_none")] + pub login: Option, + /// Human-readable status message. + #[serde(skip_serializing_if = "Option::is_none")] + pub status_message: Option, +} + +/// Wrapper for session event notifications received from the CLI. +/// +/// The CLI sends these as JSON-RPC notifications on the `session.event` method. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionEventNotification { + /// The session this event belongs to. + pub session_id: SessionId, + /// The event payload. + pub event: SessionEvent, +} + +/// A single event in a session's timeline. +/// +/// Events form a linked chain via `parent_id`. The `event_type` string +/// identifies the kind (e.g. `"assistant.message_delta"`, `"session.idle"`, +/// `"tool.execution_start"`). Event-specific payload is in `data` as +/// untyped JSON. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionEvent { + /// Unique event ID (UUID v4). + pub id: String, + /// ISO 8601 timestamp. + pub timestamp: String, + /// ID of the preceding event in the chain. + pub parent_id: Option, + /// Transient events that are not persisted to disk. + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral: Option, + /// Debug timestamp: when the CLI received this event (ms since epoch). + #[serde(skip_serializing_if = "Option::is_none")] + pub debug_cli_received_at_ms: Option, + /// Debug timestamp: when the event was forwarded over WebSocket. + #[serde(skip_serializing_if = "Option::is_none")] + pub debug_ws_forwarded_at_ms: Option, + /// Event type string (e.g. `"assistant.message"`, `"session.idle"`). + #[serde(rename = "type")] + pub event_type: String, + /// Event-specific data. Structure depends on `event_type`. + pub data: Value, +} + +impl SessionEvent { + /// Parse the string `event_type` into a typed [`SessionEventType`](crate::generated::SessionEventType) enum. + /// + /// Returns `SessionEventType::Unknown` for unrecognized event types, + /// ensuring forward compatibility with newer CLI versions. + pub fn parsed_type(&self) -> crate::generated::SessionEventType { + use serde::de::IntoDeserializer; + let deserializer: serde::de::value::StrDeserializer<'_, serde::de::value::Error> = + self.event_type.as_str().into_deserializer(); + crate::generated::SessionEventType::deserialize(deserializer) + .unwrap_or(crate::generated::SessionEventType::Unknown) + } + + /// Deserialize the event `data` field into a typed struct. + /// + /// Returns `None` if deserialization fails (e.g. unknown event type + /// or schema mismatch). Prefer typed data accessors for specific + /// event types where you need strongly-typed field access. + pub fn typed_data(&self) -> Option { + serde_json::from_value(self.data.clone()).ok() + } + + /// `model_call` errors are transient — the CLI agent loop continues + /// after them and may succeed on the next turn. These should not be + /// treated as session-ending errors. + pub fn is_transient_error(&self) -> bool { + self.event_type == "session.error" + && self.data.get("errorType").and_then(|v| v.as_str()) == Some("model_call") + } +} + +/// A request from the CLI to invoke a client-defined tool. +/// +/// Received as a JSON-RPC request on the `tool.call` method. The client +/// must respond with a [`ToolResultResponse`]. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub struct ToolInvocation { + /// Session that owns this tool call. + pub session_id: SessionId, + /// Unique ID for this tool call, used to correlate the response. + pub tool_call_id: String, + /// Name of the tool being invoked. + pub tool_name: String, + /// Tool arguments as JSON. + pub arguments: Value, + /// W3C Trace Context `traceparent` header propagated from the CLI's + /// `execute_tool` span. Pass through to OpenTelemetry-aware code so + /// child spans created inside the handler are parented to the CLI + /// span. `None` when the CLI has no trace context for this call. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub traceparent: Option, + /// W3C Trace Context `tracestate` paired with + /// [`traceparent`](Self::traceparent). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tracestate: Option, +} + +impl ToolInvocation { + /// Deserialize this invocation's [`arguments`](Self::arguments) into a + /// strongly-typed parameter struct. + /// + /// Idiomatic way to extract typed parameters when implementing + /// [`ToolHandler`](crate::tool::ToolHandler) directly. Equivalent to + /// `serde_json::from_value(invocation.arguments.clone())` with the SDK's + /// error type. + /// + /// # Example + /// + /// ```rust,no_run + /// # use github_copilot_sdk::{Error, types::ToolInvocation, ToolResult}; + /// # use serde::Deserialize; + /// # #[derive(Deserialize)] struct MyParams { city: String } + /// # async fn example(inv: ToolInvocation) -> Result { + /// let params: MyParams = inv.params()?; + /// // …use `inv.session_id` / `inv.tool_call_id` alongside `params`… + /// # let _ = params; Ok(ToolResult::Text(String::new())) + /// # } + /// ``` + pub fn params(&self) -> Result { + serde_json::from_value(self.arguments.clone()).map_err(crate::Error::from) + } + + /// Returns the propagated [`TraceContext`] for this invocation, or + /// [`TraceContext::default()`] when the CLI sent no headers. + pub fn trace_context(&self) -> TraceContext { + TraceContext { + traceparent: self.traceparent.clone(), + tracestate: self.tracestate.clone(), + } + } +} + +/// Expanded tool result with metadata for the LLM and session log. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolResultExpanded { + /// Result text sent back to the LLM. + pub text_result_for_llm: String, + /// `"success"` or `"failure"`. + pub result_type: String, + /// Optional log message for the session timeline. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_log: Option, + /// Error message, if the tool failed. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// Result of a tool invocation — either a plain text string or an expanded result. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +#[non_exhaustive] +pub enum ToolResult { + /// Simple text result passed directly to the LLM. + Text(String), + /// Structured result with metadata. + Expanded(ToolResultExpanded), +} + +/// JSON-RPC response wrapper for a tool result, sent back to the CLI. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolResultResponse { + /// The tool result payload. + pub result: ToolResult, +} + +/// Metadata for a persisted session, returned by `session.list`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionMetadata { + /// The session's unique identifier. + pub session_id: SessionId, + /// ISO 8601 timestamp when the session was created. + pub start_time: String, + /// ISO 8601 timestamp of the last modification. + pub modified_time: String, + /// Agent-generated session summary. + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, + /// Whether the session is running remotely. + pub is_remote: bool, +} + +/// Response from `session.list`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListSessionsResponse { + /// The list of session metadata entries. + pub sessions: Vec, +} + +/// Filter options for [`Client::list_sessions`](crate::Client::list_sessions). +/// +/// All fields are optional; unset fields don't constrain the result. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionListFilter { + /// Filter by exact `cwd` match. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub cwd: Option, + /// Filter by git root path. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub git_root: Option, + /// Filter by repository in `owner/repo` form. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub repository: Option, + /// Filter by git branch name. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub branch: Option, +} + +/// Response from `session.getMetadata`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetSessionMetadataResponse { + /// The session metadata, or `None` if the session was not found. + #[serde(skip_serializing_if = "Option::is_none")] + pub session: Option, +} + +/// Response from `session.getLastId`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetLastSessionIdResponse { + /// The most recently updated session ID, or `None` if no sessions exist. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, +} + +/// Response from `session.getForeground`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetForegroundSessionResponse { + /// The current foreground session ID, or `None` if no foreground session. + #[serde(skip_serializing_if = "Option::is_none")] + pub session_id: Option, +} + +/// Response from `session.getMessages`. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GetMessagesResponse { + /// Timeline events for the session. + pub events: Vec, +} + +/// Result of an elicitation (interactive UI form) request. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationResult { + /// User's action: `"accept"`, `"decline"`, or `"cancel"`. + pub action: String, + /// Form data submitted by the user (present when action is `"accept"`). + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +/// Elicitation display mode. +/// +/// New modes may be added by the CLI in future protocol versions; the +/// `Unknown` variant keeps deserialization from failing on unrecognised +/// values so the SDK can still surface the request to callers. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[non_exhaustive] +pub enum ElicitationMode { + /// Structured form input rendered by the host. + Form, + /// Browser redirect to a URL. + Url, + /// A mode not yet known to this SDK version. + #[serde(other)] + Unknown, +} + +/// An incoming elicitation request from the CLI (provider side). +/// +/// Received via `elicitation.requested` session event when the session was +/// created with `request_elicitation: true`. The provider should render a +/// form or dialog and return an [`ElicitationResult`]. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ElicitationRequest { + /// Message describing what information is needed from the user. + pub message: String, + /// JSON Schema describing the form fields to present. + #[serde(skip_serializing_if = "Option::is_none")] + pub requested_schema: Option, + /// Elicitation display mode. + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + /// The source that initiated the request (e.g. MCP server name). + #[serde(skip_serializing_if = "Option::is_none")] + pub elicitation_source: Option, + /// URL to open in the user's browser (url mode only). + #[serde(skip_serializing_if = "Option::is_none")] + pub url: Option, +} + +/// Session-level capabilities reported by the CLI after session creation. +/// +/// Capabilities indicate which features the CLI host supports for this session. +/// Updated at runtime via `capabilities.changed` events. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SessionCapabilities { + /// UI capabilities (elicitation support, etc.). + #[serde(skip_serializing_if = "Option::is_none")] + pub ui: Option, +} + +/// UI-specific capabilities for a session. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UiCapabilities { + /// Whether the host supports interactive elicitation dialogs. + #[serde(skip_serializing_if = "Option::is_none")] + pub elicitation: Option, +} + +/// Options for the [`SessionUi::input`](crate::session::SessionUi::input) convenience method. +#[derive(Debug, Clone, Default)] +pub struct InputOptions<'a> { + /// Title label for the input field. + pub title: Option<&'a str>, + /// Descriptive text shown below the field. + pub description: Option<&'a str>, + /// Minimum character length. + pub min_length: Option, + /// Maximum character length. + pub max_length: Option, + /// Semantic format hint. + pub format: Option, + /// Default value pre-populated in the field. + pub default: Option<&'a str>, +} + +/// Semantic format hints for text input fields. +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub enum InputFormat { + /// Email address. + Email, + /// URI. + Uri, + /// Calendar date. + Date, + /// Date and time. + DateTime, +} + +impl InputFormat { + /// Returns the JSON Schema format string for this variant. + pub fn as_str(&self) -> &'static str { + match self { + Self::Email => "email", + Self::Uri => "uri", + Self::Date => "date", + Self::DateTime => "date-time", + } + } +} + +/// Re-exports of generated protocol types that are part of the SDK's +/// public API surface. The canonical definitions live in +/// [`crate::generated::api_types`]; they live here so the crate-root +/// `pub use types::*` surfaces them alongside hand-written SDK types. +pub use crate::generated::api_types::{ + Model, ModelBilling, ModelCapabilities, ModelCapabilitiesLimits, ModelCapabilitiesLimitsVision, + ModelCapabilitiesSupports, ModelList, ModelPolicy, +}; + +/// Permission categories the CLI may request approval for. +/// +/// Wire values are the lower-kebab strings the CLI sends as the `kind` +/// discriminator on a permission request. Marked `#[non_exhaustive]` +/// because the CLI may add new kinds; matches must include a `_` arm. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +#[non_exhaustive] +pub enum PermissionRequestKind { + /// Run a shell command. + Shell, + /// Write to a file. + Write, + /// Read a file. + Read, + /// Open a URL. + Url, + /// Invoke an MCP server tool. + Mcp, + /// Invoke a client-defined custom tool. + CustomTool, + /// Update agent memory. + Memory, + /// Run a hook callback. + Hook, + /// Unrecognized kind. The original wire string is available in + /// [`PermissionRequestData::extra`] under the `kind` key. + #[serde(other)] + Unknown, +} + +/// Data sent by the CLI for permission-related events. +/// +/// Used for both the `permission.request` RPC call (which expects a response) +/// and `permission.requested` notifications (fire-and-forget). Contains the +/// full params object. Note that `requestId` is also available as a separate +/// field on `HandlerEvent::PermissionRequest`. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PermissionRequestData { + /// The permission category being requested. `None` means the CLI did + /// not include a `kind` field. Use this to branch on common cases + /// (shell, write, etc.) without parsing [`extra`](Self::extra). + #[serde(default, skip_serializing_if = "Option::is_none")] + pub kind: Option, + /// The originating tool-call ID, if this permission request is tied + /// to a specific tool invocation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// The full permission request params from the CLI. The shape varies by + /// permission type and CLI version, so we preserve it as `Value`. + #[serde(flatten)] + pub extra: Value, +} + +/// Data sent by the CLI with an `exitPlanMode.request` RPC call. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ExitPlanModeData { + /// Markdown summary of the plan presented to the user. + #[serde(default)] + pub summary: String, + /// Full plan content (e.g. the plan.md body), if available. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub plan_content: Option, + /// Allowed exit actions (e.g. "interactive", "autopilot", "autopilot_fleet"). + #[serde(default)] + pub actions: Vec, + /// Which action the CLI recommends, defaults to "autopilot". + #[serde(default = "default_recommended_action")] + pub recommended_action: String, +} + +fn default_recommended_action() -> String { + "autopilot".to_string() +} + +impl Default for ExitPlanModeData { + fn default() -> Self { + Self { + summary: String::new(), + plan_content: None, + actions: Vec::new(), + recommended_action: default_recommended_action(), + } + } +} + +#[cfg(test)] +mod tests { + use std::path::PathBuf; + + use serde_json::json; + + use super::{ + Attachment, AttachmentLineRange, AttachmentSelectionPosition, AttachmentSelectionRange, + ConnectionState, CustomAgentConfig, DeliveryMode, GitHubReferenceType, + InfiniteSessionConfig, ProviderConfig, ResumeSessionConfig, SessionConfig, SessionId, + SystemMessageConfig, Tool, ensure_attachment_display_names, + }; + + #[test] + fn tool_builder_composes() { + let tool = Tool::new("greet") + .with_description("Say hello") + .with_namespaced_name("hello/greet") + .with_instructions("Pass the user's name") + .with_parameters(json!({ + "type": "object", + "properties": { "name": { "type": "string" } }, + "required": ["name"] + })) + .with_overrides_built_in_tool(true) + .with_skip_permission(true); + assert_eq!(tool.name, "greet"); + assert_eq!(tool.description, "Say hello"); + assert_eq!(tool.namespaced_name.as_deref(), Some("hello/greet")); + assert_eq!(tool.instructions.as_deref(), Some("Pass the user's name")); + assert_eq!(tool.parameters.get("type").unwrap(), &json!("object")); + assert!(tool.overrides_built_in_tool); + assert!(tool.skip_permission); + } + + #[test] + fn tool_with_parameters_handles_non_object_value() { + let tool = Tool::new("noop").with_parameters(json!(null)); + assert!(tool.parameters.is_empty()); + } + + #[test] + fn session_config_default_enables_permission_flow_flags() { + let cfg = SessionConfig::default(); + assert_eq!(cfg.request_user_input, Some(true)); + assert_eq!(cfg.request_permission, Some(true)); + assert_eq!(cfg.request_exit_plan_mode, Some(true)); + assert_eq!(cfg.request_auto_mode_switch, Some(true)); + assert_eq!(cfg.request_elicitation, Some(true)); + } + + #[test] + fn resume_session_config_new_enables_permission_flow_flags() { + let cfg = ResumeSessionConfig::new(SessionId::from("test-id")); + assert_eq!(cfg.request_user_input, Some(true)); + assert_eq!(cfg.request_permission, Some(true)); + assert_eq!(cfg.request_exit_plan_mode, Some(true)); + assert_eq!(cfg.request_auto_mode_switch, Some(true)); + assert_eq!(cfg.request_elicitation, Some(true)); + } + + #[test] + fn session_config_builder_composes() { + use std::collections::HashMap; + + let cfg = SessionConfig::default() + .with_session_id(SessionId::from("sess-1")) + .with_model("claude-sonnet-4") + .with_client_name("test-app") + .with_reasoning_effort("medium") + .with_streaming(true) + .with_tools([Tool::new("greet")]) + .with_available_tools(["bash", "view"]) + .with_excluded_tools(["dangerous"]) + .with_mcp_servers(HashMap::new()) + .with_env_value_mode("direct") + .with_enable_config_discovery(true) + .with_request_user_input(false) + .with_skill_directories([PathBuf::from("/tmp/skills")]) + .with_disabled_skills(["broken-skill"]) + .with_disabled_mcp_servers(["broken-server"]) + .with_agent("researcher") + .with_config_dir(PathBuf::from("/tmp/config")) + .with_working_directory(PathBuf::from("/tmp/work")) + .with_github_token("ghp_test") + .with_include_sub_agent_streaming_events(false); + + assert_eq!(cfg.session_id.as_ref().map(|s| s.as_str()), Some("sess-1")); + assert_eq!(cfg.model.as_deref(), Some("claude-sonnet-4")); + assert_eq!(cfg.client_name.as_deref(), Some("test-app")); + assert_eq!(cfg.reasoning_effort.as_deref(), Some("medium")); + assert_eq!(cfg.streaming, Some(true)); + assert_eq!(cfg.tools.as_ref().map(|t| t.len()), Some(1)); + assert_eq!( + cfg.available_tools.as_deref(), + Some(&["bash".to_string(), "view".to_string()][..]) + ); + assert_eq!( + cfg.excluded_tools.as_deref(), + Some(&["dangerous".to_string()][..]) + ); + assert!(cfg.mcp_servers.is_some()); + assert_eq!(cfg.env_value_mode.as_deref(), Some("direct")); + assert_eq!(cfg.enable_config_discovery, Some(true)); + assert_eq!(cfg.request_user_input, Some(false)); // overrode default + assert_eq!(cfg.request_permission, Some(true)); // default preserved + assert_eq!( + cfg.skill_directories.as_deref(), + Some(&[PathBuf::from("/tmp/skills")][..]) + ); + assert_eq!( + cfg.disabled_skills.as_deref(), + Some(&["broken-skill".to_string()][..]) + ); + assert_eq!(cfg.agent.as_deref(), Some("researcher")); + assert_eq!(cfg.config_dir, Some(PathBuf::from("/tmp/config"))); + assert_eq!(cfg.working_directory, Some(PathBuf::from("/tmp/work"))); + assert_eq!(cfg.github_token.as_deref(), Some("ghp_test")); + assert_eq!(cfg.include_sub_agent_streaming_events, Some(false)); + } + + #[test] + fn resume_session_config_builder_composes() { + use std::collections::HashMap; + + let cfg = ResumeSessionConfig::new(SessionId::from("sess-2")) + .with_client_name("test-app") + .with_streaming(true) + .with_tools([Tool::new("greet")]) + .with_excluded_tools(["dangerous"]) + .with_mcp_servers(HashMap::new()) + .with_env_value_mode("indirect") + .with_enable_config_discovery(true) + .with_request_user_input(false) + .with_skill_directories([PathBuf::from("/tmp/skills")]) + .with_agent("researcher") + .with_config_dir(PathBuf::from("/tmp/config")) + .with_working_directory(PathBuf::from("/tmp/work")) + .with_github_token("ghp_test") + .with_include_sub_agent_streaming_events(true) + .with_disable_resume(true); + + assert_eq!(cfg.session_id.as_str(), "sess-2"); + assert_eq!(cfg.client_name.as_deref(), Some("test-app")); + assert_eq!(cfg.streaming, Some(true)); + assert_eq!(cfg.tools.as_ref().map(|t| t.len()), Some(1)); + assert_eq!( + cfg.excluded_tools.as_deref(), + Some(&["dangerous".to_string()][..]) + ); + assert!(cfg.mcp_servers.is_some()); + assert_eq!(cfg.env_value_mode.as_deref(), Some("indirect")); + assert_eq!(cfg.enable_config_discovery, Some(true)); + assert_eq!(cfg.request_user_input, Some(false)); // overrode default + assert_eq!(cfg.request_permission, Some(true)); // default preserved + assert_eq!( + cfg.skill_directories.as_deref(), + Some(&[PathBuf::from("/tmp/skills")][..]) + ); + assert_eq!(cfg.agent.as_deref(), Some("researcher")); + assert_eq!(cfg.config_dir, Some(PathBuf::from("/tmp/config"))); + assert_eq!(cfg.working_directory, Some(PathBuf::from("/tmp/work"))); + assert_eq!(cfg.github_token.as_deref(), Some("ghp_test")); + assert_eq!(cfg.include_sub_agent_streaming_events, Some(true)); + assert_eq!(cfg.disable_resume, Some(true)); + } + + #[test] + fn custom_agent_config_builder_composes() { + use std::collections::HashMap; + + let cfg = CustomAgentConfig::new("researcher", "You are a research assistant.") + .with_display_name("Research Assistant") + .with_description("Investigates technical questions.") + .with_tools(["bash", "view"]) + .with_mcp_servers(HashMap::new()) + .with_infer(true) + .with_skills(["rust-coding-skill"]); + + assert_eq!(cfg.name, "researcher"); + assert_eq!(cfg.prompt, "You are a research assistant."); + assert_eq!(cfg.display_name.as_deref(), Some("Research Assistant")); + assert_eq!( + cfg.description.as_deref(), + Some("Investigates technical questions.") + ); + assert_eq!( + cfg.tools.as_deref(), + Some(&["bash".to_string(), "view".to_string()][..]) + ); + assert!(cfg.mcp_servers.is_some()); + assert_eq!(cfg.infer, Some(true)); + assert_eq!( + cfg.skills.as_deref(), + Some(&["rust-coding-skill".to_string()][..]) + ); + } + + #[test] + fn infinite_session_config_builder_composes() { + let cfg = InfiniteSessionConfig::new() + .with_enabled(true) + .with_background_compaction_threshold(0.75) + .with_buffer_exhaustion_threshold(0.92); + + assert_eq!(cfg.enabled, Some(true)); + assert_eq!(cfg.background_compaction_threshold, Some(0.75)); + assert_eq!(cfg.buffer_exhaustion_threshold, Some(0.92)); + } + + #[test] + fn provider_config_builder_composes() { + use std::collections::HashMap; + + let mut headers = HashMap::new(); + headers.insert("X-Custom".to_string(), "value".to_string()); + + let cfg = ProviderConfig::new("https://api.example.com") + .with_provider_type("openai") + .with_wire_api("completions") + .with_api_key("sk-test") + .with_bearer_token("bearer-test") + .with_headers(headers); + + assert_eq!(cfg.base_url, "https://api.example.com"); + assert_eq!(cfg.provider_type.as_deref(), Some("openai")); + assert_eq!(cfg.wire_api.as_deref(), Some("completions")); + assert_eq!(cfg.api_key.as_deref(), Some("sk-test")); + assert_eq!(cfg.bearer_token.as_deref(), Some("bearer-test")); + assert_eq!( + cfg.headers + .as_ref() + .and_then(|h| h.get("X-Custom")) + .map(String::as_str), + Some("value"), + ); + } + + #[test] + fn system_message_config_builder_composes() { + use std::collections::HashMap; + + let cfg = SystemMessageConfig::new() + .with_mode("replace") + .with_content("Custom system message.") + .with_sections(HashMap::new()); + + assert_eq!(cfg.mode.as_deref(), Some("replace")); + assert_eq!(cfg.content.as_deref(), Some("Custom system message.")); + assert!(cfg.sections.is_some()); + } + + #[test] + fn delivery_mode_serializes_to_kebab_case_strings() { + assert_eq!( + serde_json::to_string(&DeliveryMode::Enqueue).unwrap(), + "\"enqueue\"" + ); + assert_eq!( + serde_json::to_string(&DeliveryMode::Immediate).unwrap(), + "\"immediate\"" + ); + let parsed: DeliveryMode = serde_json::from_str("\"immediate\"").unwrap(); + assert_eq!(parsed, DeliveryMode::Immediate); + } + + #[test] + fn connection_state_error_serializes_to_match_go() { + let json = serde_json::to_string(&ConnectionState::Error).unwrap(); + assert_eq!(json, "\"error\""); + let parsed: ConnectionState = serde_json::from_str("\"error\"").unwrap(); + assert_eq!(parsed, ConnectionState::Error); + } + + #[test] + fn connection_state_other_variants_serialize_as_lowercase() { + assert_eq!( + serde_json::to_string(&ConnectionState::Disconnected).unwrap(), + "\"disconnected\"" + ); + assert_eq!( + serde_json::to_string(&ConnectionState::Connecting).unwrap(), + "\"connecting\"" + ); + assert_eq!( + serde_json::to_string(&ConnectionState::Connected).unwrap(), + "\"connected\"" + ); + } + + #[test] + fn deserializes_runtime_attachment_variants() { + let attachments: Vec = serde_json::from_value(json!([ + { + "type": "file", + "path": "/tmp/file.rs", + "displayName": "file.rs", + "lineRange": { "start": 7, "end": 12 } + }, + { + "type": "directory", + "path": "/tmp/project", + "displayName": "project" + }, + { + "type": "selection", + "filePath": "/tmp/lib.rs", + "displayName": "lib.rs", + "text": "fn main() {}", + "selection": { + "start": { "line": 1, "character": 2 }, + "end": { "line": 3, "character": 4 } + } + }, + { + "type": "blob", + "data": "Zm9v", + "mimeType": "image/png", + "displayName": "image.png" + }, + { + "type": "github_reference", + "number": 42, + "title": "Fix rendering", + "referenceType": "issue", + "state": "open", + "url": "https://github.com/example/repo/issues/42" + } + ])) + .expect("attachments should deserialize"); + + assert_eq!(attachments.len(), 5); + assert!(matches!( + &attachments[0], + Attachment::File { + path, + display_name, + line_range: Some(AttachmentLineRange { start: 7, end: 12 }), + } if path == &PathBuf::from("/tmp/file.rs") && display_name.as_deref() == Some("file.rs") + )); + assert!(matches!( + &attachments[1], + Attachment::Directory { path, display_name } + if path == &PathBuf::from("/tmp/project") && display_name.as_deref() == Some("project") + )); + assert!(matches!( + &attachments[2], + Attachment::Selection { + file_path, + display_name, + selection: + AttachmentSelectionRange { + start: AttachmentSelectionPosition { line: 1, character: 2 }, + end: AttachmentSelectionPosition { line: 3, character: 4 }, + }, + .. + } if file_path == &PathBuf::from("/tmp/lib.rs") && display_name.as_deref() == Some("lib.rs") + )); + assert!(matches!( + &attachments[3], + Attachment::Blob { + data, + mime_type, + display_name, + } if data == "Zm9v" && mime_type == "image/png" && display_name.as_deref() == Some("image.png") + )); + assert!(matches!( + &attachments[4], + Attachment::GitHubReference { + number: 42, + title, + reference_type: GitHubReferenceType::Issue, + state, + url, + } if title == "Fix rendering" + && state == "open" + && url == "https://github.com/example/repo/issues/42" + )); + } + + #[test] + fn ensures_display_names_for_variants_that_support_them() { + let mut attachments = vec![ + Attachment::File { + path: PathBuf::from("/tmp/file.rs"), + display_name: None, + line_range: None, + }, + Attachment::Selection { + file_path: PathBuf::from("/tmp/src/lib.rs"), + display_name: None, + text: "fn main() {}".to_string(), + selection: AttachmentSelectionRange { + start: AttachmentSelectionPosition { + line: 0, + character: 0, + }, + end: AttachmentSelectionPosition { + line: 0, + character: 10, + }, + }, + }, + Attachment::Blob { + data: "Zm9v".to_string(), + mime_type: "image/png".to_string(), + display_name: None, + }, + Attachment::GitHubReference { + number: 7, + title: "Track regressions".to_string(), + reference_type: GitHubReferenceType::Issue, + state: "open".to_string(), + url: "https://example.com/issues/7".to_string(), + }, + ]; + + ensure_attachment_display_names(&mut attachments); + + assert_eq!(attachments[0].display_name(), Some("file.rs")); + assert_eq!(attachments[1].display_name(), Some("lib.rs")); + assert_eq!(attachments[2].display_name(), Some("attachment")); + assert_eq!(attachments[3].display_name(), None); + assert_eq!( + attachments[3].label(), + Some("Track regressions".to_string()) + ); + } +} + +#[cfg(test)] +mod permission_builder_tests { + use std::sync::Arc; + + use crate::handler::{ + ApproveAllHandler, HandlerEvent, HandlerResponse, PermissionResult, SessionHandler, + }; + use crate::types::{ + PermissionRequestData, RequestId, ResumeSessionConfig, SessionConfig, SessionId, + }; + + fn permission_event() -> HandlerEvent { + HandlerEvent::PermissionRequest { + session_id: SessionId::from("s1"), + request_id: RequestId::new("1"), + data: PermissionRequestData { + extra: serde_json::json!({"tool": "shell"}), + ..Default::default() + }, + } + } + + async fn dispatch(handler: &Arc) -> HandlerResponse { + handler.on_event(permission_event()).await + } + + #[tokio::test] + async fn session_config_approve_all_wraps_existing_handler() { + let cfg = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .approve_all_permissions(); + let handler = cfg.handler.expect("handler should be set"); + match dispatch(&handler).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("expected Approved, got {other:?}"), + } + } + + #[tokio::test] + async fn session_config_approve_all_defaults_to_deny_inner() { + // Without with_handler, the wrap defaults to DenyAllHandler. The + // approve-all wrap intercepts permission events, so they're still + // approved -- the inner handler is consulted only for other events. + let cfg = SessionConfig::default().approve_all_permissions(); + let handler = cfg.handler.expect("handler should be set"); + match dispatch(&handler).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("expected Approved, got {other:?}"), + } + } + + #[tokio::test] + async fn session_config_deny_all_denies() { + let cfg = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .deny_all_permissions(); + let handler = cfg.handler.expect("handler should be set"); + match dispatch(&handler).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("expected Denied, got {other:?}"), + } + } + + #[tokio::test] + async fn session_config_approve_permissions_if_consults_predicate() { + let cfg = SessionConfig::default() + .with_handler(Arc::new(ApproveAllHandler)) + .approve_permissions_if(|data| { + data.extra.get("tool").and_then(|v| v.as_str()) != Some("shell") + }); + let handler = cfg.handler.expect("handler should be set"); + match dispatch(&handler).await { + HandlerResponse::Permission(PermissionResult::Denied) => {} + other => panic!("expected Denied for shell, got {other:?}"), + } + } + + #[tokio::test] + async fn resume_session_config_approve_all_wraps_existing_handler() { + let cfg = ResumeSessionConfig::new(SessionId::from("s1")) + .with_handler(Arc::new(ApproveAllHandler)) + .approve_all_permissions(); + let handler = cfg.handler.expect("handler should be set"); + match dispatch(&handler).await { + HandlerResponse::Permission(PermissionResult::Approved) => {} + other => panic!("expected Approved, got {other:?}"), + } + } +} diff --git a/rust/tests/integration_test.rs b/rust/tests/integration_test.rs new file mode 100644 index 000000000..90e2e1c7a --- /dev/null +++ b/rust/tests/integration_test.rs @@ -0,0 +1,107 @@ +#![allow(clippy::unwrap_used)] + +use std::time::Instant; + +use github_copilot_sdk::resolve::copilot_binary_with_source; +use github_copilot_sdk::{Client, ClientOptions, SDK_PROTOCOL_VERSION}; + +fn default_options() -> ClientOptions { + let mut opts = ClientOptions::default(); + opts.cwd = std::env::current_dir().expect("cwd"); + opts +} + +#[tokio::test] +#[ignore] // requires `copilot` CLI on PATH — run with `cargo test -- --ignored` +async fn start_ping_stop() { + let client = Client::start(default_options()) + .await + .expect("failed to start copilot CLI"); + + // start() calls verify_protocol_version(), so this should be set + let version = client + .protocol_version() + .expect("protocol version not negotiated"); + assert!((2..=SDK_PROTOCOL_VERSION).contains(&version)); + + client.ping(None).await.expect("ping failed"); + client.stop().await.expect("stop failed"); +} + +#[tokio::test] +#[ignore] // requires `copilot` CLI on PATH — run with `cargo test -- --ignored` +async fn force_stop_kills_real_child() { + let client = Client::start(default_options()) + .await + .expect("failed to start copilot CLI"); + + let pid = client.pid().expect("expected a CLI child pid"); + assert!(pid > 0); + + // force_stop is synchronous and must not panic. After it returns, + // pid() should report None because we've taken the child out of the + // mutex. + client.force_stop(); + assert!(client.pid().is_none()); + + // Calling it again should be a no-op rather than panicking. + client.force_stop(); +} + +/// Measures the latency of individual CLI operations that contribute to +/// session creation time. Run with: +/// +/// cargo test -p github-copilot-sdk --test integration_test -- --ignored --nocapture +#[tokio::test] +#[ignore] +async fn cli_operation_latency() { + // Cold start: spawn CLI process + verify protocol version + let t0 = Instant::now(); + let client = Client::start(default_options()) + .await + .expect("cold start failed"); + let cold_start = t0.elapsed(); + + // Warm ping: RPC round-trip on an already-running process + let t1 = Instant::now(); + client.ping(None).await.expect("warm ping failed"); + let warm_ping = t1.elapsed(); + + // list_models: RPC that fetches available models from the CLI + let t2 = Instant::now(); + let models = client.list_models().await.expect("list_models failed"); + let list_models = t2.elapsed(); + + // Second list_models: does the CLI cache internally? + let t2b = Instant::now(); + let _ = client.list_models().await.expect("list_models 2 failed"); + let list_models_2 = t2b.elapsed(); + + client.stop().await.expect("stop first client failed"); + + // Second cold start: measures process spawn cost when the binary is + // already resolved and cached (no extraction overhead) + let t3 = Instant::now(); + let client2 = Client::start(default_options()) + .await + .expect("second cold start failed"); + let second_start = t3.elapsed(); + + client2.stop().await.expect("stop second client failed"); + + let (cli_path, source) = copilot_binary_with_source().expect("copilot binary not found"); + + eprintln!(); + eprintln!("=== CLI operation latency ==="); + eprintln!(" binary: {} ({:?})", cli_path.display(), source); + eprintln!(" cold Client::start: {:>8.1?}", cold_start); + eprintln!(" warm ping(): {:>8.1?}", warm_ping); + eprintln!( + " list_models() ({:>2}): {:>8.1?}", + models.len(), + list_models + ); + eprintln!(" list_models() again: {:>8.1?}", list_models_2); + eprintln!(" second Client::start: {:>8.1?}", second_start); + eprintln!(); +} diff --git a/rust/tests/jsonrpc_test.rs b/rust/tests/jsonrpc_test.rs new file mode 100644 index 000000000..7f7d43213 --- /dev/null +++ b/rust/tests/jsonrpc_test.rs @@ -0,0 +1,412 @@ +#![cfg(feature = "test-support")] +#![allow(clippy::unwrap_used)] + +use github_copilot_sdk::test_support::{JsonRpcClient, JsonRpcNotification, JsonRpcRequest}; +use tokio::io::{AsyncWrite, AsyncWriteExt, duplex}; +use tokio::sync::{broadcast, mpsc}; + +/// Write a Content-Length framed JSON-RPC message to a writer. +async fn write_framed(writer: &mut (impl AsyncWrite + Unpin), body: &[u8]) { + let header = format!("Content-Length: {}\r\n\r\n", body.len()); + writer.write_all(header.as_bytes()).await.unwrap(); + writer.write_all(body).await.unwrap(); + writer.flush().await.unwrap(); +} + +#[tokio::test] +async fn request_response_round_trip() { + // duplex: client_write → server_read, server_write → client_read + let (client_write, mut server_read) = duplex(4096); + let (mut server_write, client_read) = duplex(4096); + + let (notification_tx, _) = broadcast::channel(16); + let (_request_tx, _request_rx) = mpsc::unbounded_channel(); + let request_tx = _request_tx; + + let client = JsonRpcClient::new(client_write, client_read, notification_tx, request_tx); + + // Spawn a task that reads the request from the server side and sends a response. + let server_handle = tokio::spawn(async move { + let mut buf = Vec::new(); + // Read the Content-Length header + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut byte) + .await + .unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + buf.resize(length, 0); + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut buf) + .await + .unwrap(); + + let request: JsonRpcRequest = serde_json::from_slice(&buf).unwrap(); + assert_eq!(request.method, "test.echo"); + assert_eq!(request.jsonrpc, "2.0"); + + // Send response + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": request.id, + "result": { "echoed": true } + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + request.id + }); + + let response = client + .send_request("test.echo", Some(serde_json::json!({"hello": "world"}))) + .await + .unwrap(); + + let request_id = server_handle.await.unwrap(); + assert_eq!(response.id, request_id); + assert!(!response.is_error()); + assert_eq!(response.result.unwrap()["echoed"], serde_json::json!(true)); +} + +#[tokio::test] +async fn notification_broadcasting() { + let (_client_write, _discard) = duplex(4096); + let (mut server_write, client_read) = duplex(4096); + + let (notification_tx, mut notification_rx) = broadcast::channel(16); + let (request_tx, _request_rx) = mpsc::unbounded_channel(); + + let _client = JsonRpcClient::new(_client_write, client_read, notification_tx, request_tx); + + // Server sends a notification (no id field). + let notification = serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.event", + "params": { "session_id": "s1", "event": "started" } + }); + write_framed( + &mut server_write, + &serde_json::to_vec(¬ification).unwrap(), + ) + .await; + + let received: JsonRpcNotification = + tokio::time::timeout(std::time::Duration::from_secs(2), notification_rx.recv()) + .await + .expect("timed out waiting for notification") + .unwrap(); + + assert_eq!(received.method, "session.event"); + assert_eq!(received.params.unwrap()["session_id"], "s1"); +} + +#[tokio::test] +async fn server_request_forwarding() { + let (_client_write, _discard) = duplex(4096); + let (mut server_write, client_read) = duplex(4096); + + let (notification_tx, _) = broadcast::channel(16); + let (request_tx, mut request_rx) = mpsc::unbounded_channel(); + + let _client = JsonRpcClient::new(_client_write, client_read, notification_tx, request_tx); + + // Server sends a request (has both id and method). + let request = serde_json::json!({ + "jsonrpc": "2.0", + "id": 42, + "method": "permission.request", + "params": { "kind": "shell" } + }); + write_framed(&mut server_write, &serde_json::to_vec(&request).unwrap()).await; + + let received: JsonRpcRequest = + tokio::time::timeout(std::time::Duration::from_secs(2), request_rx.recv()) + .await + .expect("timed out waiting for request") + .unwrap(); + + assert_eq!(received.method, "permission.request"); + assert_eq!(received.id, 42); +} + +#[tokio::test] +async fn error_response_round_trip() { + let (client_write, mut server_read) = duplex(4096); + let (mut server_write, client_read) = duplex(4096); + + let (notification_tx, _) = broadcast::channel(16); + let (request_tx, _) = mpsc::unbounded_channel(); + + let client = JsonRpcClient::new(client_write, client_read, notification_tx, request_tx); + + let server_handle = tokio::spawn(async move { + // Read request + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut byte) + .await + .unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + let mut buf = vec![0u8; length]; + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut buf) + .await + .unwrap(); + let request: JsonRpcRequest = serde_json::from_slice(&buf).unwrap(); + + // Send error response + let error_response = serde_json::json!({ + "jsonrpc": "2.0", + "id": request.id, + "error": { "code": -32600, "message": "Invalid Request" } + }); + write_framed( + &mut server_write, + &serde_json::to_vec(&error_response).unwrap(), + ) + .await; + }); + + let response = client.send_request("bad.method", None).await.unwrap(); + server_handle.await.unwrap(); + + assert!(response.is_error()); + let error = response.error.unwrap(); + assert_eq!(error.code, -32600); + assert_eq!(error.message, "Invalid Request"); +} + +#[tokio::test] +async fn read_loop_terminates_on_eof() { + let (client_write, _discard) = duplex(4096); + let (server_write, client_read) = duplex(4096); + + let (notification_tx, _) = broadcast::channel(16); + let (request_tx, _) = mpsc::unbounded_channel(); + + let _client = JsonRpcClient::new(client_write, client_read, notification_tx, request_tx); + + // Drop the server side — the read loop should see EOF and stop. + drop(server_write); + + // Give the read loop time to notice EOF. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; +} + +/// Cancel-safety regression: dropping a `write()` future after the actor has +/// committed to writing must NOT produce a partial frame on the wire. +/// +/// Strategy: spawn a reader task that waits before draining the wire, so +/// the actor's `write_all` blocks waiting for room. Race the caller's +/// future against a sleep; when the sleep wins, the caller's future is +/// dropped while suspended on `ack_rx.await`. Release the reader and +/// verify both frames land on the wire intact. +/// +/// Closes RFD-400 finding #1: `JsonRpcClient::write` was holding a Tokio +/// mutex across `write_all` + `flush`, so caller cancellation mid-frame +/// could desync the transport. The writer-actor refactor moves the I/O +/// onto a dedicated task that owns the writer; caller cancellation drops +/// the ack receiver but does not interrupt the in-flight write. +#[tokio::test] +async fn write_actor_completes_on_caller_cancel() { + use std::sync::Arc; + + use tokio::sync::Notify; + + let (client_write, mut server_read) = duplex(8); + let (_server_write, client_read) = duplex(8); + + let (notification_tx, _) = broadcast::channel(16); + let (request_tx, _) = mpsc::unbounded_channel(); + let client = JsonRpcClient::new(client_write, client_read, notification_tx, request_tx); + + // Reader task that waits for `start` before draining; this gives us + // a window where the actor's write_all is suspended waiting for room. + let start = Arc::new(Notify::new()); + let start_clone = start.clone(); + let reader_task = tokio::spawn(async move { + start_clone.notified().await; + let mut frames = Vec::new(); + for _ in 0..2 { + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut byte) + .await + .unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + let mut body = vec![0u8; length]; + tokio::io::AsyncReadExt::read_exact(&mut server_read, &mut body) + .await + .unwrap(); + let req: JsonRpcRequest = serde_json::from_slice(&body).unwrap(); + frames.push(req); + } + frames + }); + + let frame_a = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: 100, + method: "first.write".to_string(), + params: None, + }; + let frame_b = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: 101, + method: "second.write".to_string(), + params: None, + }; + + // First write: race the future against a sleep. With the reader + // gated, the actor's write_all blocks at the 8-byte buffer boundary, + // so the future stays suspended on `ack_rx.await`. The sleep wins + // after 50ms, dropping the caller's future. The actor still owns the + // write and must complete it once the reader drains. + tokio::select! { + _ = client.write(&frame_a) => panic!("write completed too quickly to test cancellation"), + _ = tokio::time::sleep(std::time::Duration::from_millis(50)) => {} + } + + // Enqueue the second write before releasing the reader. Both frames + // are now in the actor's queue; the actor will drain them in order + // once the reader starts pulling bytes. + let second_handle = tokio::spawn({ + let frame_b = frame_b.clone(); + let client_arc = std::sync::Arc::new(client); + let client_clone = client_arc.clone(); + async move { client_clone.write(&frame_b).await } + }); + + // Release the reader so both frames can flow through the actor. + start.notify_one(); + + let frames = reader_task.await.unwrap(); + second_handle.await.unwrap().unwrap(); + + assert_eq!(frames.len(), 2); + assert_eq!(frames[0].method, "first.write"); + assert_eq!(frames[0].id, 100); + assert_eq!(frames[1].method, "second.write"); + assert_eq!(frames[1].id, 101); +} + +/// Cancel-safety regression: cancelling a `send_request` future before the +/// response arrives must NOT leak the pending-requests entry. The RAII +/// `PendingGuard` removes the entry on drop. +/// +/// Strategy: spawn `send_request`, drop the JoinHandle immediately so the +/// future is cancelled. The CLI eventually sends a response for the +/// cancelled request id; the read loop logs a warning and discards it +/// (the pending entry was already removed by the guard). The next +/// `send_request` should work normally and not collide with the orphan. +/// +/// Closes RFD-400 finding #4. +#[tokio::test] +async fn send_request_cancellation_does_not_leak_pending() { + let (client_write, mut server_read) = duplex(4096); + let (mut server_write, client_read) = duplex(4096); + + let (notification_tx, _) = broadcast::channel(16); + let (request_tx, _) = mpsc::unbounded_channel(); + let client = JsonRpcClient::new(client_write, client_read, notification_tx, request_tx); + let client = std::sync::Arc::new(client); + + // First request: cancel before the server replies. + let cancelled = tokio::spawn({ + let client = client.clone(); + async move { + // Will await the response oneshot; the JoinHandle abort + // below cancels this future. + let _ = client.send_request("first", None).await; + } + }); + + // Read the first request off the wire so we know it was sent. + async fn read_one_method(reader: &mut tokio::io::DuplexStream) -> (u64, String) { + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + tokio::io::AsyncReadExt::read_exact(reader, &mut byte) + .await + .unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + let mut body = vec![0u8; length]; + tokio::io::AsyncReadExt::read_exact(reader, &mut body) + .await + .unwrap(); + let req: JsonRpcRequest = serde_json::from_slice(&body).unwrap(); + (req.id, req.method) + } + + let (first_id, first_method) = read_one_method(&mut server_read).await; + assert_eq!(first_method, "first"); + + // Now cancel the in-flight request. + cancelled.abort(); + let _ = cancelled.await; + + // Send a (late) response for the cancelled id. The read loop should + // log a warning and not blow up. + let stale_resp = serde_json::json!({ + "jsonrpc": "2.0", + "id": first_id, + "result": {"echo": "ignored"} + }); + write_framed(&mut server_write, &serde_json::to_vec(&stale_resp).unwrap()).await; + + // Second request: should succeed normally without collision. + let server_task = tokio::spawn(async move { + let (id, method) = read_one_method(&mut server_read).await; + assert_eq!(method, "second"); + let resp = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": {"ok": true} + }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + }); + + let response = client.send_request("second", None).await.unwrap(); + assert_eq!(response.result.unwrap()["ok"], true); + server_task.await.unwrap(); +} diff --git a/rust/tests/protocol_version_test.rs b/rust/tests/protocol_version_test.rs new file mode 100644 index 000000000..b442f723b --- /dev/null +++ b/rust/tests/protocol_version_test.rs @@ -0,0 +1,93 @@ +#![allow(clippy::unwrap_used)] + +use github_copilot_sdk::Client; +use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt, duplex}; + +async fn write_framed(writer: &mut (impl AsyncWrite + Unpin), body: &[u8]) { + let header = format!("Content-Length: {}\r\n\r\n", body.len()); + writer.write_all(header.as_bytes()).await.unwrap(); + writer.write_all(body).await.unwrap(); + writer.flush().await.unwrap(); +} + +async fn read_framed(reader: &mut (impl tokio::io::AsyncRead + Unpin)) -> serde_json::Value { + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + AsyncReadExt::read_exact(reader, &mut byte).await.unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + let mut buf = vec![0u8; length]; + AsyncReadExt::read_exact(reader, &mut buf).await.unwrap(); + serde_json::from_slice(&buf).unwrap() +} + +/// Verify protocol version against a fake server that responds with `result`. +async fn verify_with_result( + result: serde_json::Value, +) -> (Result<(), github_copilot_sdk::Error>, Option) { + let (client_write, server_read) = duplex(8192); + let (server_write, client_read) = duplex(8192); + let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap(); + + let mut server_read = server_read; + let mut server_write = server_write; + + let verify_handle = tokio::spawn({ + let client = client.clone(); + async move { client.verify_protocol_version().await } + }); + + let req = read_framed(&mut server_read).await; + assert_eq!(req["method"], "ping"); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": req["id"], + "result": result, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let res = tokio::time::timeout(std::time::Duration::from_secs(2), verify_handle) + .await + .unwrap() + .unwrap(); + let version = client.protocol_version(); + (res, version) +} + +#[tokio::test] +async fn accepted_when_version_in_range() { + let (res, version) = verify_with_result(serde_json::json!({ "protocolVersion": 3 })).await; + assert!(res.is_ok()); + assert_eq!(version, Some(3)); +} + +#[tokio::test] +async fn rejected_when_version_out_of_range() { + let (res, version) = verify_with_result(serde_json::json!({ "protocolVersion": 1 })).await; + let err = res.unwrap_err(); + assert!(matches!( + err, + github_copilot_sdk::Error::Protocol(github_copilot_sdk::ProtocolError::VersionMismatch { + server: 1, + .. + }) + )); + assert_eq!(version, None); +} + +#[tokio::test] +async fn succeeds_when_version_missing() { + let (res, version) = verify_with_result(serde_json::json!({ "message": "pong" })).await; + assert!(res.is_ok()); + assert_eq!(version, None); +} diff --git a/rust/tests/session_test.rs b/rust/tests/session_test.rs new file mode 100644 index 000000000..15f28b726 --- /dev/null +++ b/rust/tests/session_test.rs @@ -0,0 +1,3695 @@ +#![allow(clippy::unwrap_used)] + +use std::path::Path; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::time::Duration; + +use async_trait::async_trait; +use github_copilot_sdk::Client; +use github_copilot_sdk::handler::{ + ApproveAllHandler, AutoModeSwitchResponse, ExitPlanModeResult, HandlerEvent, HandlerResponse, + PermissionResult, SessionHandler, UserInputResponse, +}; +use github_copilot_sdk::types::{ + CommandContext, CommandDefinition, CommandHandler, DeliveryMode, MessageOptions, + ServerTelemetryEvent, SessionConfig, SessionId, SessionTelemetryEvent, ToolResult, +}; +use serde_json::Value; +use tokio::io::{AsyncWrite, AsyncWriteExt, duplex}; +use tokio::sync::mpsc; +use tokio::time::timeout; + +const TIMEOUT: Duration = Duration::from_secs(2); +const METHOD_NOT_FOUND: i32 = -32601; + +struct NoopHandler; +#[async_trait] +impl SessionHandler for NoopHandler { + async fn on_event(&self, _event: HandlerEvent) -> HandlerResponse { + HandlerResponse::Ok + } +} + +async fn write_framed(writer: &mut (impl AsyncWrite + Unpin), body: &[u8]) { + let header = format!("Content-Length: {}\r\n\r\n", body.len()); + writer.write_all(header.as_bytes()).await.unwrap(); + writer.write_all(body).await.unwrap(); + writer.flush().await.unwrap(); +} + +async fn read_framed(reader: &mut (impl tokio::io::AsyncRead + Unpin)) -> Value { + let mut header = String::new(); + loop { + let mut byte = [0u8; 1]; + tokio::io::AsyncReadExt::read_exact(reader, &mut byte) + .await + .unwrap(); + header.push(byte[0] as char); + if header.ends_with("\r\n\r\n") { + break; + } + } + let length: usize = header + .trim() + .strip_prefix("Content-Length: ") + .unwrap() + .parse() + .unwrap(); + let mut buf = vec![0u8; length]; + tokio::io::AsyncReadExt::read_exact(reader, &mut buf) + .await + .unwrap(); + serde_json::from_slice(&buf).unwrap() +} + +fn make_client() -> (Client, tokio::io::DuplexStream, tokio::io::DuplexStream) { + let (client_write, server_read) = duplex(8192); + let (server_write, client_read) = duplex(8192); + let client = Client::from_streams(client_read, client_write, std::env::temp_dir()).unwrap(); + (client, server_read, server_write) +} + +struct FakeServer { + read: tokio::io::DuplexStream, + write: tokio::io::DuplexStream, + session_id: String, +} + +impl FakeServer { + async fn read_request(&mut self) -> Value { + read_framed(&mut self.read).await + } + + async fn respond(&mut self, request: &Value, result: Value) { + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": result }); + write_framed(&mut self.write, &serde_json::to_vec(&response).unwrap()).await; + } + + async fn send_notification(&mut self, method: &str, params: Value) { + let notification = serde_json::json!({ + "jsonrpc": "2.0", + "method": method, + "params": params, + }); + write_framed(&mut self.write, &serde_json::to_vec(¬ification).unwrap()).await; + } + + async fn send_event(&mut self, event_type: &str, data: Value) { + self.send_notification( + "session.event", + serde_json::json!({ + "sessionId": self.session_id, + "event": { + "id": format!("evt-{}", rand_id()), + "timestamp": "2025-01-01T00:00:00Z", + "type": event_type, + "data": data, + }, + }), + ) + .await; + } + + async fn send_request(&mut self, id: u64, method: &str, params: Value) { + let request = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "method": method, + "params": params, + }); + write_framed(&mut self.write, &serde_json::to_vec(&request).unwrap()).await; + } + + async fn read_response(&mut self) -> Value { + read_framed(&mut self.read).await + } +} + +async fn create_session_pair( + handler: Arc, +) -> (github_copilot_sdk::session::Session, FakeServer) { + create_session_pair_with_capabilities(handler, serde_json::json!(null)).await +} + +async fn create_session_pair_with_capabilities( + handler: Arc, + capabilities: Value, +) -> (github_copilot_sdk::session::Session, FakeServer) { + let (client, server_read, server_write) = make_client(); + let session_id = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id.clone(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + let handler = handler.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(handler)) + .await + .unwrap() + } + }); + + let create_req = server.read_request().await; + assert_eq!(create_req["method"], "session.create"); + let mut result = serde_json::json!({ + "sessionId": session_id, + "workspacePath": "/tmp/workspace" + }); + if !capabilities.is_null() { + result["capabilities"] = capabilities; + } + server.respond(&create_req, result).await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + (session, server) +} + +fn rand_id() -> u64 { + static COUNTER: AtomicUsize = AtomicUsize::new(0); + COUNTER.fetch_add(1, Ordering::Relaxed) as u64 +} + +#[tokio::test] +async fn session_subscribe_yields_events_observe_only() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + + let mut events = session.subscribe(); + let count = Arc::new(AtomicUsize::new(0)); + let last_type = Arc::new(parking_lot::Mutex::new(String::new())); + let count_clone = count.clone(); + let last_type_clone = last_type.clone(); + let consumer = tokio::spawn(async move { + while let Ok(event) = events.recv().await { + count_clone.fetch_add(1, Ordering::Relaxed); + *last_type_clone.lock() = event.event_type.clone(); + } + }); + + server.send_event("noop.event", serde_json::json!({})).await; + server + .send_event("another.event", serde_json::json!({"k": "v"})) + .await; + + for _ in 0..50 { + if count.load(Ordering::Relaxed) >= 2 { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert_eq!(count.load(Ordering::Relaxed), 2); + assert_eq!(last_type.lock().as_str(), "another.event"); + consumer.abort(); +} + +#[tokio::test] +async fn session_subscribe_drop_stops_delivery() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + + let mut events = session.subscribe(); + let count = Arc::new(AtomicUsize::new(0)); + let count_clone = count.clone(); + let consumer = tokio::spawn(async move { + while let Ok(_event) = events.recv().await { + count_clone.fetch_add(1, Ordering::Relaxed); + } + }); + + server.send_event("first", serde_json::json!({})).await; + for _ in 0..50 { + if count.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert_eq!(count.load(Ordering::Relaxed), 1); + + // Aborting the consumer drops its receiver; further events have no + // effect on the (now-zero) subscriber count. + consumer.abort(); + tokio::time::sleep(Duration::from_millis(20)).await; + + server.send_event("second", serde_json::json!({})).await; + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(count.load(Ordering::Relaxed), 1); +} + +#[tokio::test] +async fn create_session_sends_correct_rpc() { + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session({ + let mut cfg = SessionConfig::default(); + cfg.model = Some("gpt-4".to_string()); + cfg.with_handler(Arc::new(NoopHandler)) + }) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.create"); + assert_eq!(request["params"]["model"], "gpt-4"); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": "s1", "workspacePath": "/ws" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + assert_eq!(session.id(), "s1"); + assert_eq!(session.workspace_path(), Some(Path::new("/ws"))); +} + +#[tokio::test] +async fn send_injects_session_id() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send(MessageOptions::new("hello").with_mode(DeliveryMode::Immediate)) + .await + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.send"); + assert_eq!(request["params"]["sessionId"], server.session_id); + assert_eq!(request["params"]["prompt"], "hello"); + assert_eq!(request["params"]["mode"], "immediate"); + + server.respond(&request, serde_json::json!({})).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn send_serializes_request_headers() { + use std::collections::HashMap; + + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + let mut headers = HashMap::new(); + headers.insert("X-Custom-Tag".to_string(), "value-1".to_string()); + headers.insert("Authorization".to_string(), "Bearer abc".to_string()); + session + .send(MessageOptions::new("hi").with_request_headers(headers)) + .await + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.send"); + assert_eq!(request["params"]["prompt"], "hi"); + let headers = request["params"]["requestHeaders"] + .as_object() + .expect("requestHeaders should be an object"); + assert_eq!(headers["X-Custom-Tag"], "value-1"); + assert_eq!(headers["Authorization"], "Bearer abc"); + assert_eq!(headers.len(), 2); + + server.respond(&request, serde_json::json!({})).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn send_omits_request_headers_when_unset_or_empty() { + use std::collections::HashMap; + + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.send(MessageOptions::new("plain")).await } + }); + let request = server.read_request().await; + assert!( + request["params"].get("requestHeaders").is_none(), + "requestHeaders should be omitted when unset, got: {}", + request["params"] + ); + server.respond(&request, serde_json::json!({})).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send(MessageOptions::new("plain").with_request_headers(HashMap::new())) + .await + } + }); + let request = server.read_request().await; + assert!( + request["params"].get("requestHeaders").is_none(), + "requestHeaders should be omitted for empty map, got: {}", + request["params"] + ); + server.respond(&request, serde_json::json!({})).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn session_rpc_methods_send_correct_method_names() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let cases: Vec<(&str, Option<&str>)> = vec![ + ("session.abort", None), + ("session.plan.delete", None), + ("session.log", Some("message")), + ("session.sendTelemetry", Some("kind")), + ("session.destroy", None), + ]; + + for (expected_method, extra_param_key) in cases { + let s = session.clone(); + let handle = tokio::spawn(async move { + match expected_method { + "session.abort" => s.abort().await.map(|_| ()), + "session.plan.delete" => s.delete_plan().await, + "session.log" => s.log("test msg", None).await, + "session.sendTelemetry" => { + s.send_telemetry(SessionTelemetryEvent { + kind: "sdk_test_event".to_string(), + properties: Some( + [("source".to_string(), "sdk".to_string())] + .into_iter() + .collect(), + ), + restricted_properties: None, + metrics: None, + }) + .await + } + "session.destroy" => s.destroy().await, + _ => unreachable!(), + } + }); + + let request = server.read_request().await; + assert_eq!( + request["method"], expected_method, + "wrong method for {expected_method}" + ); + assert_eq!(request["params"]["sessionId"], server.session_id); + if let Some(key) = extra_param_key { + assert!(!request["params"][key].is_null(), "missing param {key}"); + } + let response = match expected_method { + "session.log" => { + serde_json::json!({ "eventId": "00000000-0000-0000-0000-000000000000" }) + } + _ => serde_json::json!({}), + }; + server.respond(&request, response).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + } +} + +#[tokio::test] +async fn send_telemetry_injects_payload_and_session_id() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send_telemetry(SessionTelemetryEvent { + kind: "sdk_test_event".to_string(), + properties: Some( + [ + ("source".to_string(), "sdk".to_string()), + ("feature".to_string(), "shared-api".to_string()), + ] + .into_iter() + .collect(), + ), + restricted_properties: Some( + [("file_path".to_string(), "/tmp/example.ts".to_string())] + .into_iter() + .collect(), + ), + metrics: Some( + [ + ("count".to_string(), 1.0), + ("duration_ms".to_string(), 12.5), + ] + .into_iter() + .collect(), + ), + }) + .await + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.sendTelemetry"); + assert_eq!(request["params"]["sessionId"], server.session_id); + assert_eq!(request["params"]["kind"], "sdk_test_event"); + assert_eq!(request["params"]["properties"]["source"], "sdk"); + assert_eq!( + request["params"]["restrictedProperties"]["file_path"], + "/tmp/example.ts" + ); + assert_eq!(request["params"]["metrics"]["duration_ms"], 12.5); + + server.respond(&request, serde_json::json!(null)).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn client_rpc_methods_send_correct_method_names() { + let (client, mut server_read, mut server_write) = make_client(); + + // Wire method names per the CLI runtime registration in @github/copilot + // app.js — verified against Node/Go/Python/.NET SDK call sites which all + // use these exact strings. The schema doesn't currently define these as + // typed RPCs (top-level methods, not under any namespace), so call site + // strings are the source of truth. + for expected_method in ["status.get", "auth.getStatus"] { + let c = client.clone(); + let handle = tokio::spawn(async move { + match expected_method { + "status.get" => c.get_status().await.map(|_| ()), + "auth.getStatus" => c.get_auth_status().await.map(|_| ()), + _ => unreachable!(), + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], expected_method); + // Regression-prevention: must not have reverted to the + // hand-authored `getStatus` / `getAuthStatus` names that don't + // exist on the wire. + assert_ne!(request["method"], "getStatus"); + assert_ne!(request["method"], "getAuthStatus"); + let id = request["id"].as_u64().unwrap(); + let result = match expected_method { + "status.get" => serde_json::json!({ "version": "1.0.0", "protocolVersion": 1 }), + "auth.getStatus" => serde_json::json!({ "isAuthenticated": true }), + _ => unreachable!(), + }; + let resp = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": result }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + } +} + +#[tokio::test] +async fn server_send_telemetry_sends_correct_payload() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .send_telemetry(ServerTelemetryEvent { + kind: "app.launched".to_string(), + client_name: "github/autopilot".to_string(), + properties: Some( + [("machine_id".to_string(), "machine-123".to_string())] + .into_iter() + .collect(), + ), + restricted_properties: None, + metrics: Some([("launch_count".to_string(), 1.0)].into_iter().collect()), + }) + .await + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "sendTelemetry"); + assert_eq!(request["params"]["kind"], "app.launched"); + assert_eq!(request["params"]["clientName"], "github/autopilot"); + assert_eq!(request["params"]["properties"]["machine_id"], "machine-123"); + assert_eq!(request["params"]["metrics"]["launch_count"], 1.0); + + let id = request["id"].as_u64().unwrap(); + let resp = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": null }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn server_send_telemetry_falls_back_to_namespaced_method_and_caches_it() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .send_telemetry(ServerTelemetryEvent { + kind: "app.launched".to_string(), + client_name: "github/autopilot".to_string(), + properties: Some( + [("machine_id".to_string(), "machine-123".to_string())] + .into_iter() + .collect(), + ), + restricted_properties: None, + metrics: Some([("launch_count".to_string(), 1.0)].into_iter().collect()), + }) + .await?; + client + .send_telemetry(ServerTelemetryEvent { + kind: "app.closed".to_string(), + client_name: "github/autopilot".to_string(), + properties: None, + restricted_properties: None, + metrics: None, + }) + .await + } + }); + + let first_request = read_framed(&mut server_read).await; + assert_eq!(first_request["method"], "sendTelemetry"); + let first_id = first_request["id"].as_u64().unwrap(); + let first_response = serde_json::json!({ + "jsonrpc": "2.0", + "id": first_id, + "error": { + "code": METHOD_NOT_FOUND, + "message": "Unhandled method sendTelemetry" + } + }); + write_framed( + &mut server_write, + &serde_json::to_vec(&first_response).unwrap(), + ) + .await; + + let second_request = read_framed(&mut server_read).await; + assert_eq!(second_request["method"], "server.sendTelemetry"); + assert_eq!(second_request["params"]["kind"], "app.launched"); + assert_eq!(second_request["params"]["clientName"], "github/autopilot"); + assert_eq!( + second_request["params"]["properties"]["machine_id"], + "machine-123" + ); + assert_eq!(second_request["params"]["metrics"]["launch_count"], 1.0); + + let second_id = second_request["id"].as_u64().unwrap(); + let second_response = serde_json::json!({ "jsonrpc": "2.0", "id": second_id, "result": null }); + write_framed( + &mut server_write, + &serde_json::to_vec(&second_response).unwrap(), + ) + .await; + + let third_request = read_framed(&mut server_read).await; + assert_eq!(third_request["method"], "server.sendTelemetry"); + assert_eq!(third_request["params"]["kind"], "app.closed"); + + let third_id = third_request["id"].as_u64().unwrap(); + let third_response = serde_json::json!({ "jsonrpc": "2.0", "id": third_id, "result": null }); + write_framed( + &mut server_write, + &serde_json::to_vec(&third_response).unwrap(), + ) + .await; + + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn list_sessions_returns_typed_metadata() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.list_sessions(None).await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.list"); + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { + "sessions": [{ + "sessionId": "s1", + "startTime": "2025-01-01T00:00:00Z", + "modifiedTime": "2025-01-01T01:00:00Z", + "summary": "test session", + "isRemote": false, + }] + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let sessions = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(sessions.len(), 1); + assert_eq!(sessions[0].session_id, "s1"); + assert_eq!(sessions[0].summary, Some("test session".to_string())); +} + +#[tokio::test] +async fn list_sessions_serializes_typed_filter() { + use github_copilot_sdk::SessionListFilter; + + let (client, mut server_read, mut server_write) = make_client(); + + let filter = SessionListFilter { + repository: Some("octocat/hello".to_string()), + branch: Some("main".to_string()), + ..Default::default() + }; + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.list_sessions(Some(filter)).await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.list"); + assert_eq!(request["params"]["filter"]["repository"], "octocat/hello"); + assert_eq!(request["params"]["filter"]["branch"], "main"); + // cwd / gitRoot are None and must be omitted from the filter object. + assert!(request["params"]["filter"].get("cwd").is_none()); + assert!(request["params"]["filter"].get("gitRoot").is_none()); + // Regression check: filter must be wrapped under `params.filter`, not + // flattened onto `params` directly. All other SDKs (Node/Python/Go/.NET) + // wrap; flattening is silently ignored by the runtime. + assert!( + request["params"].get("repository").is_none(), + "wire shape is `params.filter.*`, not `params.*` — see Node/Go/Python/.NET" + ); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessions": [] }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + timeout(TIMEOUT, handle).await.unwrap().unwrap(); +} + +#[test] +fn mcp_server_config_roundtrips_through_tagged_enum() { + use std::collections::HashMap; + + use github_copilot_sdk::{McpServerConfig, McpStdioServerConfig}; + + let stdio = McpServerConfig::Stdio(McpStdioServerConfig { + command: "node".to_string(), + args: vec!["server.js".to_string()], + env: HashMap::new(), + cwd: None, + tools: vec!["*".to_string()], + timeout: None, + }); + let json = serde_json::to_value(&stdio).unwrap(); + assert_eq!(json["type"], "stdio"); + assert_eq!(json["command"], "node"); + + // CLI may emit the legacy "local" alias; we accept it on the wire. + let local: McpServerConfig = serde_json::from_value(serde_json::json!({ + "type": "local", + "command": "node", + })) + .unwrap(); + assert!(matches!(local, McpServerConfig::Stdio(_))); + + // SessionConfig.mcp_servers round-trips a typed map. + let mut servers = HashMap::new(); + servers.insert("github".to_string(), stdio.clone()); + let cfg_json = serde_json::to_value(&servers).unwrap(); + assert_eq!(cfg_json["github"]["type"], "stdio"); +} + +#[test] +fn permission_request_data_extracts_typed_kind() { + use github_copilot_sdk::{PermissionRequestData, PermissionRequestKind}; + + let data: PermissionRequestData = serde_json::from_value(serde_json::json!({ + "kind": "shell", + "toolCallId": "t1", + "command": "ls", + })) + .unwrap(); + assert_eq!(data.kind, Some(PermissionRequestKind::Shell)); + assert_eq!(data.tool_call_id, Some("t1".to_string())); + assert_eq!(data.extra["command"], "ls"); + + let custom: PermissionRequestData = serde_json::from_value(serde_json::json!({ + "kind": "custom-tool", + })) + .unwrap(); + assert_eq!(custom.kind, Some(PermissionRequestKind::CustomTool)); + + // Unknown kinds fall through to the catch-all variant rather than failing. + let unknown: PermissionRequestData = serde_json::from_value(serde_json::json!({ + "kind": "future-permission-type", + })) + .unwrap(); + assert_eq!(unknown.kind, Some(PermissionRequestKind::Unknown)); +} + +#[tokio::test] +async fn force_stop_is_idempotent_with_no_child() { + // Stream-based clients have no child process. force_stop should be a + // no-op and safe to call multiple times. + let (client, _server_read, _server_write) = make_client(); + assert_eq!( + client.state(), + github_copilot_sdk::ConnectionState::Connected + ); + client.force_stop(); + assert_eq!( + client.state(), + github_copilot_sdk::ConnectionState::Disconnected + ); + client.force_stop(); + assert_eq!( + client.state(), + github_copilot_sdk::ConnectionState::Disconnected + ); + assert!(client.pid().is_none()); +} + +#[tokio::test] +async fn stop_transitions_state_to_disconnected() { + let (client, _server_read, _server_write) = make_client(); + assert_eq!( + client.state(), + github_copilot_sdk::ConnectionState::Connected + ); + client.stop().await.expect("stop should succeed"); + assert_eq!( + client.state(), + github_copilot_sdk::ConnectionState::Disconnected + ); +} + +#[tokio::test] +async fn lifecycle_subscribe_yields_events_with_filter() { + use github_copilot_sdk::{SessionLifecycleEventMetadata, SessionLifecycleEventType as Type}; + + let (client, _server_read, mut server_write) = make_client(); + + let mut all_events = client.subscribe_lifecycle(); + let mut foreground_events = client.subscribe_lifecycle(); + + let wildcard_count = Arc::new(AtomicUsize::new(0)); + let foreground_count = Arc::new(AtomicUsize::new(0)); + let last_session = Arc::new(parking_lot::Mutex::new(None)); + + let w_count = wildcard_count.clone(); + let w_last = last_session.clone(); + let w_consumer = tokio::spawn(async move { + while let Ok(event) = all_events.recv().await { + w_count.fetch_add(1, Ordering::Relaxed); + *w_last.lock() = Some(event.session_id.clone()); + } + }); + let f_count = foreground_count.clone(); + let f_consumer = tokio::spawn(async move { + while let Ok(event) = foreground_events.recv().await { + if event.event_type == Type::Foreground { + f_count.fetch_add(1, Ordering::Relaxed); + } + } + }); + + let body1 = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.lifecycle", + "params": { "type": "session.created", "sessionId": "s1" }, + })) + .unwrap(); + let body2 = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.lifecycle", + "params": { + "type": "session.foreground", + "sessionId": "s2", + "metadata": { + "startTime": "2025-01-01T00:00:00Z", + "modifiedTime": "2025-01-02T00:00:00Z", + "summary": "hello", + }, + }, + })) + .unwrap(); + let body3 = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.event", + "params": { "sessionId": "ignored", "event": { + "id": "x", "timestamp": "t", "type": "noop", "data": {} + }}, + })) + .unwrap(); + write_framed(&mut server_write, &body1).await; + write_framed(&mut server_write, &body2).await; + write_framed(&mut server_write, &body3).await; + + for _ in 0..50 { + if wildcard_count.load(Ordering::Relaxed) >= 2 { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert_eq!(wildcard_count.load(Ordering::Relaxed), 2); + assert_eq!(foreground_count.load(Ordering::Relaxed), 1); + assert_eq!(last_session.lock().as_deref(), Some("s2")); + w_consumer.abort(); + f_consumer.abort(); + + let meta = SessionLifecycleEventMetadata { + start_time: "t1".into(), + modified_time: "t2".into(), + summary: Some("s".into()), + }; + assert_eq!(meta.summary.as_deref(), Some("s")); +} + +#[tokio::test] +async fn lifecycle_subscribe_drop_stops_delivery() { + let (client, _server_read, mut server_write) = make_client(); + + let mut events = client.subscribe_lifecycle(); + let count = Arc::new(AtomicUsize::new(0)); + let count_clone = count.clone(); + let consumer = tokio::spawn(async move { + while let Ok(_event) = events.recv().await { + count_clone.fetch_add(1, Ordering::Relaxed); + } + }); + + let lifecycle_body = serde_json::to_vec(&serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.lifecycle", + "params": { "type": "session.created", "sessionId": "x" }, + })) + .unwrap(); + + write_framed(&mut server_write, &lifecycle_body).await; + for _ in 0..50 { + if count.load(Ordering::Relaxed) >= 1 { + break; + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + assert_eq!(count.load(Ordering::Relaxed), 1); + + consumer.abort(); + tokio::time::sleep(Duration::from_millis(20)).await; + + write_framed(&mut server_write, &lifecycle_body).await; + tokio::time::sleep(Duration::from_millis(100)).await; + assert_eq!(count.load(Ordering::Relaxed), 1); +} + +#[tokio::test] +async fn delete_session_sends_session_id() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.delete_session(&SessionId::new("s-to-delete")).await } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.delete"); + assert_eq!(request["params"]["sessionId"], "s-to-delete"); + + let id = request["id"].as_u64().unwrap(); + let resp = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn get_last_session_id_returns_none_when_empty() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.get_last_session_id().await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.getLastId"); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let last = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert!(last.is_none()); +} + +#[tokio::test] +async fn get_last_session_id_returns_id_when_set() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.get_last_session_id().await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.getLastId"); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": "s-last" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let last = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(last.as_deref(), Some("s-last")); +} + +#[tokio::test] +async fn get_foreground_session_id_returns_id_when_set() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.get_foreground_session_id().await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.getForeground"); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": "s-fg" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let fg = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(fg.as_deref(), Some("s-fg")); +} + +#[tokio::test] +async fn set_foreground_session_id_sends_session_id() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .set_foreground_session_id(&SessionId::new("s-target")) + .await + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.setForeground"); + assert_eq!(request["params"]["sessionId"], "s-target"); + + let id = request["id"].as_u64().unwrap(); + let resp = serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": {} }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn get_session_metadata_returns_typed_metadata() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .get_session_metadata(&SessionId::new("s1")) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.getMetadata"); + assert_eq!(request["params"]["sessionId"], "s1"); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { + "session": { + "sessionId": "s1", + "startTime": "2025-01-01T00:00:00Z", + "modifiedTime": "2025-01-01T01:00:00Z", + "summary": "loaded session", + "isRemote": false, + } + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let metadata = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + let metadata = metadata.expect("server returned a session"); + assert_eq!(metadata.session_id, "s1"); + assert_eq!(metadata.summary.as_deref(), Some("loaded session")); +} + +#[tokio::test] +async fn get_session_metadata_returns_none_when_missing() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .get_session_metadata(&SessionId::new("missing")) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.getMetadata"); + + let id = request["id"].as_u64().unwrap(); + // Server responds with an empty result object; `session` is absent. + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": {}, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let metadata = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert!(metadata.is_none()); +} + +#[tokio::test] +async fn list_models_returns_typed_model_info() { + let (client, mut server_read, mut server_write) = make_client(); + + let handle = tokio::spawn({ + let client = client.clone(); + async move { client.list_models().await.unwrap() } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "models.list"); + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { + "models": [ + { "id": "gpt-4", "name": "GPT-4", "capabilities": {} }, + { "id": "claude-sonnet-4", "name": "Claude Sonnet", "capabilities": {} }, + ] + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let models = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(models.len(), 2); + assert_eq!(models[0].id, "gpt-4"); + assert_eq!(models[1].name, "Claude Sonnet"); +} + +#[tokio::test] +async fn get_messages_returns_typed_events() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.get_messages().await.unwrap() } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.getMessages"); + server + .respond( + &request, + serde_json::json!({ + "events": [{ + "id": "e1", + "timestamp": "2025-01-01T00:00:00Z", + "type": "user.message", + "data": { "text": "hello" }, + }] + }), + ) + .await; + + let events = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(events.len(), 1); + assert_eq!(events[0].event_type, "user.message"); +} + +#[tokio::test] +async fn set_model_sends_switch_to_request() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.set_model("claude-sonnet-4", None).await.unwrap() } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.model.switchTo"); + assert_eq!(request["params"]["modelId"], "claude-sonnet-4"); + server + .respond( + &request, + serde_json::json!({ "modelId": "claude-sonnet-4" }), + ) + .await; + + timeout(TIMEOUT, handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn get_name_returns_name() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.get_name().await.unwrap() } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.name.get"); + server + .respond(&request, serde_json::json!({ "name": "Fix input flicker" })) + .await; + + assert_eq!( + timeout(TIMEOUT, handle).await.unwrap().unwrap(), + Some("Fix input flicker".to_string()) + ); +} + +#[tokio::test] +async fn set_name_sends_name() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { session.set_name("Fix input flicker").await.unwrap() } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.name.set"); + assert_eq!(request["params"]["name"], "Fix input flicker"); + server.respond(&request, serde_json::json!(null)).await; + + timeout(TIMEOUT, handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn elicitation_returns_typed_result() { + let (session, mut server) = create_session_pair_with_capabilities( + Arc::new(NoopHandler), + serde_json::json!({ "ui": { "elicitation": true } }), + ) + .await; + let session = Arc::new(session); + let schema = serde_json::json!({ + "type": "object", + "properties": { "name": { "type": "string" } }, + }); + + let handle = tokio::spawn({ + let session = session.clone(); + let schema = schema.clone(); + async move { + session + .ui() + .elicitation("Enter your name", schema) + .await + .unwrap() + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.ui.elicitation"); + assert_eq!(request["params"]["message"], "Enter your name"); + assert_eq!(request["params"]["requestedSchema"], schema); + assert!( + request["params"].get("schema").is_none(), + "wire field is `requestedSchema`, not `schema`" + ); + server + .respond( + &request, + serde_json::json!({ "action": "accept", "content": { "name": "Octocat" } }), + ) + .await; + + let result = timeout(TIMEOUT, handle).await.unwrap().unwrap(); + assert_eq!(result.action, "accept"); + assert_eq!(result.content.unwrap()["name"], "Octocat"); +} + +#[tokio::test] +async fn tool_call_dispatches_to_handler() { + struct ToolHandler; + #[async_trait] + impl SessionHandler for ToolHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::ExternalTool { invocation } => { + assert_eq!(invocation.tool_name, "read_file"); + HandlerResponse::ToolResult(ToolResult::Text("file contents here".to_string())) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(ToolHandler)).await; + server + .send_request( + 100, + "tool.call", + serde_json::json!({ + "sessionId": server.session_id, + "toolCallId": "tc-1", + "toolName": "read_file", + "arguments": { "path": "/foo.txt" }, + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 100); + assert_eq!(response["result"]["result"], "file contents here"); +} + +#[tokio::test] +async fn permission_request_dispatches_to_handler() { + struct DenyHandler; + #[async_trait] + impl SessionHandler for DenyHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::PermissionRequest { .. } => { + HandlerResponse::Permission(PermissionResult::Denied) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(DenyHandler)).await; + server + .send_request( + 200, + "permission.request", + serde_json::json!({ + "sessionId": server.session_id, + "requestId": "perm-1", + "kind": "shell", + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 200); + assert_eq!(response["result"]["kind"], "reject"); +} + +#[tokio::test] +async fn user_input_request_dispatches_to_handler() { + struct InputHandler; + #[async_trait] + impl SessionHandler for InputHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::UserInput { question, .. } => { + assert_eq!(question, "Pick a color"); + HandlerResponse::UserInput(Some(UserInputResponse { + answer: "blue".to_string(), + was_freeform: true, + })) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(InputHandler)).await; + server + .send_request( + 300, + "userInput.request", + serde_json::json!({ + "sessionId": server.session_id, + "question": "Pick a color", + "choices": ["red", "blue"], + "allowFreeform": true, + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 300); + assert_eq!(response["result"]["answer"], "blue"); + assert_eq!(response["result"]["wasFreeform"], true); +} + +#[tokio::test] +async fn user_input_requested_notification_does_not_double_dispatch() { + use std::sync::atomic::{AtomicUsize, Ordering}; + // Regression for github/github-app#4249. The CLI sends BOTH a + // `user_input.requested` notification (for observers) AND a + // `userInput.request` JSON-RPC call (the actual prompt) for every + // user-input prompt. Only the JSON-RPC path should reach the + // handler — dispatching from the notification too produced + // duplicate ask_user widgets on the consumer side. + + struct CountingHandler { + invocations: Arc, + } + #[async_trait] + impl SessionHandler for CountingHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + if let HandlerEvent::UserInput { .. } = event { + self.invocations.fetch_add(1, Ordering::SeqCst); + return HandlerResponse::UserInput(Some(UserInputResponse { + answer: "ok".to_string(), + was_freeform: true, + })); + } + HandlerResponse::Ok + } + } + + let invocations = Arc::new(AtomicUsize::new(0)); + let handler = Arc::new(CountingHandler { + invocations: invocations.clone(), + }); + let (_session, mut server) = create_session_pair(handler).await; + + server + .send_event( + "user_input.requested", + serde_json::json!({ + "requestId": "ui-1", + "question": "Allow shell access?", + "choices": ["Yes", "No"], + "allowFreeform": false, + }), + ) + .await; + + // Give the SDK a beat to (incorrectly) auto-dispatch if the + // regression returned. Nothing should arrive on the wire. + let respond_observed = timeout(Duration::from_millis(150), server.read_request()).await; + assert!( + respond_observed.is_err(), + "notification triggered unexpected wire activity: {respond_observed:?}", + ); + assert_eq!( + invocations.load(Ordering::SeqCst), + 0, + "notification path must not invoke the user-input handler", + ); + + // Now drive the JSON-RPC path and confirm the handler still runs once. + server + .send_request( + 301, + "userInput.request", + serde_json::json!({ + "sessionId": server.session_id, + "question": "Pick a color", + "allowFreeform": true, + }), + ) + .await; + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 301); + assert_eq!(response["result"]["answer"], "ok"); + assert_eq!(invocations.load(Ordering::SeqCst), 1); +} + +#[tokio::test] +async fn exit_plan_mode_dispatches_to_handler() { + struct PlanHandler; + #[async_trait] + impl SessionHandler for PlanHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::ExitPlanMode { .. } => { + HandlerResponse::ExitPlanMode(ExitPlanModeResult { + approved: true, + selected_action: Some("autopilot".to_string()), + feedback: None, + }) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(PlanHandler)).await; + server + .send_request( + 400, + "exitPlanMode.request", + serde_json::json!({ "sessionId": server.session_id, "plan": "do the thing" }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["result"]["approved"], true); + assert_eq!(response["result"]["selectedAction"], "autopilot"); +} + +#[tokio::test] +async fn auto_mode_switch_dispatches_to_handler_and_serializes_response() { + use std::sync::atomic::{AtomicUsize, Ordering}; + + struct AutoModeHandler { + calls: Arc, + last_error_code: Arc>>, + last_retry_after: Arc>>, + } + #[async_trait] + impl SessionHandler for AutoModeHandler { + async fn on_auto_mode_switch( + &self, + _session_id: github_copilot_sdk::types::SessionId, + error_code: Option, + retry_after_seconds: Option, + ) -> AutoModeSwitchResponse { + self.calls.fetch_add(1, Ordering::SeqCst); + *self.last_error_code.lock() = error_code; + *self.last_retry_after.lock() = retry_after_seconds; + AutoModeSwitchResponse::YesAlways + } + } + + let calls = Arc::new(AtomicUsize::new(0)); + let last_error_code = Arc::new(parking_lot::Mutex::new(None)); + let last_retry_after = Arc::new(parking_lot::Mutex::new(None)); + let (_session, mut server) = create_session_pair(Arc::new(AutoModeHandler { + calls: calls.clone(), + last_error_code: last_error_code.clone(), + last_retry_after: last_retry_after.clone(), + })) + .await; + + server + .send_request( + 700, + "autoModeSwitch.request", + serde_json::json!({ + "sessionId": server.session_id, + "errorCode": "user_weekly_rate_limited", + "retryAfterSeconds": 3600, + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 700); + assert_eq!(response["result"]["response"], "yes_always"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + assert_eq!( + last_error_code.lock().as_deref(), + Some("user_weekly_rate_limited") + ); + assert_eq!(*last_retry_after.lock(), Some(3600)); +} + +#[tokio::test] +async fn auto_mode_switch_default_handler_replies_no() { + let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + + server + .send_request( + 701, + "autoModeSwitch.request", + serde_json::json!({ + "sessionId": server.session_id, + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["result"]["response"], "no"); +} + +#[tokio::test] +async fn approve_all_handler_approves_permission_and_plan() { + let (_session, mut server) = create_session_pair(Arc::new(ApproveAllHandler)).await; + + server + .send_request( + 500, + "permission.request", + serde_json::json!({ + "sessionId": server.session_id, + "requestId": "perm-auto", + "kind": "shell", + }), + ) + .await; + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["result"]["kind"], "approve-once"); + + server + .send_request( + 501, + "exitPlanMode.request", + serde_json::json!({ "sessionId": server.session_id, "plan": "go" }), + ) + .await; + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["result"]["approved"], true); +} + +#[tokio::test] +async fn session_event_notification_reaches_handler() { + let (event_tx, mut event_rx) = mpsc::unbounded_channel::(); + + struct EventCollector { + tx: mpsc::UnboundedSender, + } + #[async_trait] + impl SessionHandler for EventCollector { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + if let HandlerEvent::SessionEvent { event, .. } = event { + self.tx.send(event.event_type).unwrap(); + } + HandlerResponse::Ok + } + } + + let (_session, mut server) = + create_session_pair(Arc::new(EventCollector { tx: event_tx })).await; + server + .send_event("session.idle", serde_json::json!({})) + .await; + + let event_type = timeout(TIMEOUT, event_rx.recv()).await.unwrap().unwrap(); + assert_eq!(event_type, "session.idle"); +} + +#[tokio::test] +async fn router_routes_to_correct_session() { + let (client, mut server_read, mut server_write) = make_client(); + let (tx1, mut rx1) = mpsc::unbounded_channel::(); + let (tx2, mut rx2) = mpsc::unbounded_channel::(); + + struct Collector { + tx: mpsc::UnboundedSender, + } + #[async_trait] + impl SessionHandler for Collector { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + if let HandlerEvent::SessionEvent { event, .. } = event { + self.tx.send(event.event_type).unwrap(); + } + HandlerResponse::Ok + } + } + + // Create two sessions on the same client + let mut sessions = Vec::new(); + for (tx, sid) in [(tx1, "s-one"), (tx2, "s-two")] { + let h = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session( + SessionConfig::default().with_handler(Arc::new(Collector { tx })), + ) + .await + .unwrap() + } + }); + let req = read_framed(&mut server_read).await; + let id = req["id"].as_u64().unwrap(); + let resp = serde_json::json!({ + "jsonrpc": "2.0", "id": id, + "result": { "sessionId": sid }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&resp).unwrap()).await; + sessions.push(timeout(TIMEOUT, h).await.unwrap().unwrap()); + } + + // Event for s-two should only reach rx2 + let notif = serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.event", + "params": { + "sessionId": "s-two", + "event": { "id": "e1", "timestamp": "2025-01-01T00:00:00Z", "type": "assistant.message", "data": {} }, + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(¬if).unwrap()).await; + assert_eq!( + timeout(TIMEOUT, rx2.recv()).await.unwrap().unwrap(), + "assistant.message" + ); + assert!(rx1.try_recv().is_err()); + + // Event for s-one should only reach rx1 + let notif = serde_json::json!({ + "jsonrpc": "2.0", + "method": "session.event", + "params": { + "sessionId": "s-one", + "event": { "id": "e2", "timestamp": "2025-01-01T00:00:00Z", "type": "session.idle", "data": {} }, + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(¬if).unwrap()).await; + assert_eq!( + timeout(TIMEOUT, rx1.recv()).await.unwrap().unwrap(), + "session.idle" + ); + assert!(rx2.try_recv().is_err()); +} + +#[tokio::test] +async fn send_and_wait_returns_last_assistant_message_on_idle() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send_and_wait( + MessageOptions::new("hello").with_wait_timeout(Duration::from_secs(5)), + ) + .await + } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.send"); + server.respond(&request, serde_json::json!({})).await; + + server + .send_event( + "assistant.message", + serde_json::json!({ "message": "Hello back!" }), + ) + .await; + server + .send_event("session.idle", serde_json::json!({})) + .await; + + let result = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + let event = result.expect("should have captured assistant.message"); + assert_eq!(event.event_type, "assistant.message"); + assert_eq!(event.data["message"], "Hello back!"); +} + +#[tokio::test] +async fn send_and_wait_returns_error_on_session_error() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send_and_wait( + MessageOptions::new("fail").with_wait_timeout(Duration::from_secs(5)), + ) + .await + } + }); + + let request = server.read_request().await; + server.respond(&request, serde_json::json!({})).await; + server + .send_event( + "session.error", + serde_json::json!({ "message": "something went wrong" }), + ) + .await; + + let err = timeout(TIMEOUT, handle) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!( + matches!(err, github_copilot_sdk::Error::Session(github_copilot_sdk::SessionError::AgentError(ref msg)) if msg.contains("something went wrong")) + ); +} + +#[tokio::test] +async fn send_and_wait_times_out() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send_and_wait( + MessageOptions::new("hello").with_wait_timeout(Duration::from_millis(100)), + ) + .await + } + }); + + let request = server.read_request().await; + server.respond(&request, serde_json::json!({})).await; + + let err = timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(matches!( + err, + github_copilot_sdk::Error::Session(github_copilot_sdk::SessionError::Timeout(_)) + )); +} + +/// Cancel-safety regression: an outer `tokio::time::timeout` around +/// `send_and_wait` must NOT leak the `idle_waiter` slot. After the outer +/// timeout fires and drops the future, subsequent `send` and +/// `send_and_wait` calls must succeed without `SendWhileWaiting`. +/// +/// Closes RFD-400 review finding #2. +#[tokio::test] +async fn send_and_wait_outer_cancellation_clears_waiter() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + // First call: wrap in outer timeout much shorter than the inner + // wait_timeout. The outer timeout expires, dropping the + // send_and_wait future before the idle/error event arrives. + let handle = tokio::spawn({ + let session = session.clone(); + async move { + tokio::time::timeout( + Duration::from_millis(50), + session.send_and_wait( + MessageOptions::new("first").with_wait_timeout(Duration::from_secs(60)), + ), + ) + .await + } + }); + + let request = server.read_request().await; + server.respond(&request, serde_json::json!({})).await; + + // Outer timeout fires → Err(Elapsed) returned, future is dropped. + let outer_result = timeout(Duration::from_secs(2), handle) + .await + .unwrap() + .unwrap(); + assert!(outer_result.is_err(), "outer timeout should have elapsed"); + + // The WaiterGuard's Drop should have cleared the slot. A subsequent + // `send` must NOT return SendWhileWaiting. + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { session.send("second").await } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.send"); + assert_eq!(request["params"]["prompt"], "second"); + server + .respond( + &request, + serde_json::json!({ "messageId": "msg-after-cancel" }), + ) + .await; + + let result = timeout(TIMEOUT, send_handle).await.unwrap().unwrap(); + assert_eq!(result.unwrap(), "msg-after-cancel"); +} + +/// Cancel-safety regression: explicitly dropping the JoinHandle of an +/// in-flight `send_and_wait` must clear the waiter slot via WaiterGuard's +/// Drop. The next `send` must succeed. +/// +/// Closes RFD-400 review finding #2. +#[tokio::test] +async fn send_and_wait_drop_clears_waiter() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + // Start a send_and_wait, let it install the waiter, then abort the + // task before any idle/error event arrives. + let handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send_and_wait( + MessageOptions::new("aborted").with_wait_timeout(Duration::from_secs(60)), + ) + .await + } + }); + + // Drain the session.send RPC so we know the waiter is installed. + let request = server.read_request().await; + server.respond(&request, serde_json::json!({})).await; + + // Now abort the in-flight send_and_wait. The WaiterGuard drops as + // the future unwinds, clearing the slot. + handle.abort(); + let _ = handle.await; + + // Give the runtime a moment to run the drop. + tokio::task::yield_now().await; + + // Next `send` must succeed — no SendWhileWaiting. + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { session.send("after-abort").await } + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.send"); + assert_eq!(request["params"]["prompt"], "after-abort"); + server + .respond( + &request, + serde_json::json!({ "messageId": "msg-after-abort" }), + ) + .await; + + let result = timeout(TIMEOUT, send_handle).await.unwrap().unwrap(); + assert_eq!(result.unwrap(), "msg-after-abort"); +} + +/// Cancel-safety regression: `Session::stop_event_loop` must NOT abort +/// the event-loop task mid-handler. An in-flight handler (here a slow +/// `userInput.request` callback) must run to completion before the loop +/// exits — the CLI receives the response on the wire before the session +/// tears down. +/// +/// Closes RFD-400 review finding #3. +#[tokio::test] +async fn stop_event_loop_completes_in_flight_handler() { + struct SlowHandler; + #[async_trait] + impl SessionHandler for SlowHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::UserInput { .. } => { + // Sleep so stop_event_loop has a chance to fire while + // the handler is mid-flight. The loop must wait for + // this to return rather than abort it. + tokio::time::sleep(Duration::from_millis(150)).await; + HandlerResponse::UserInput(Some(UserInputResponse { + answer: "completed".to_string(), + was_freeform: false, + })) + } + _ => HandlerResponse::Ok, + } + } + } + + let (session, mut server) = create_session_pair(Arc::new(SlowHandler)).await; + let session = Arc::new(session); + + server + .send_request( + 900, + "userInput.request", + serde_json::json!({ + "sessionId": server.session_id, + "question": "slow", + "choices": null, + "allowFreeform": true, + }), + ) + .await; + + // Give the loop a moment to dispatch into the handler. + tokio::time::sleep(Duration::from_millis(20)).await; + + // Now request shutdown. The loop is parked in handle_request awaiting + // the slow handler. `notify_one()` buffers the signal until the loop + // re-enters its select, which can only happen after the handler + // returns and the response is sent on the wire. + let stop_handle = tokio::spawn({ + let session = session.clone(); + async move { session.stop_event_loop().await } + }); + + // Verify the handler's response lands on the wire BEFORE the loop + // exits — i.e. stop_event_loop did not abort mid-handler. + let response = timeout(Duration::from_secs(2), server.read_response()) + .await + .unwrap(); + assert_eq!(response["id"], 900); + assert_eq!(response["result"]["answer"], "completed"); + + // stop_event_loop completes after the handler returns and the loop + // observes the buffered shutdown signal on its next select iteration. + timeout(Duration::from_secs(2), stop_handle) + .await + .unwrap() + .unwrap(); +} + +/// Cancel-safety regression: dropping a Session does NOT abort the event +/// loop mid-handler. The loop sees the buffered shutdown signal on its +/// next select iteration and exits cleanly. This is the Drop equivalent +/// of stop_event_loop_completes_in_flight_handler; closes RFD-400 review +/// finding #3 for the implicit-drop path that used to call +/// `JoinHandle::abort()`. +#[tokio::test] +async fn drop_session_does_not_abort_handler() { + use std::sync::atomic::{AtomicBool, Ordering}; + + let handler_completed = Arc::new(AtomicBool::new(false)); + + struct CompletionHandler { + completed: Arc, + } + #[async_trait] + impl SessionHandler for CompletionHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::UserInput { .. } => { + tokio::time::sleep(Duration::from_millis(100)).await; + self.completed.store(true, Ordering::SeqCst); + HandlerResponse::UserInput(Some(UserInputResponse { + answer: "done".to_string(), + was_freeform: false, + })) + } + _ => HandlerResponse::Ok, + } + } + } + + let (session, mut server) = create_session_pair(Arc::new(CompletionHandler { + completed: handler_completed.clone(), + })) + .await; + + server + .send_request( + 901, + "userInput.request", + serde_json::json!({ + "sessionId": server.session_id, + "question": "drop-test", + "choices": null, + "allowFreeform": true, + }), + ) + .await; + + tokio::time::sleep(Duration::from_millis(20)).await; + drop(session); + + let response = timeout(Duration::from_secs(2), server.read_response()) + .await + .unwrap(); + assert_eq!(response["id"], 901); + assert_eq!(response["result"]["answer"], "done"); + assert!( + handler_completed.load(Ordering::SeqCst), + "handler must run to completion despite Session being dropped" + ); +} + +#[tokio::test] +async fn elicitation_requested_dispatches_to_handler_and_responds() { + use github_copilot_sdk::types::ElicitationResult; + + struct ElicitHandler; + #[async_trait] + impl SessionHandler for ElicitHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::ElicitationRequest { request, .. } => { + assert_eq!(request.message, "Enter your name"); + HandlerResponse::Elicitation(ElicitationResult { + action: "accept".to_string(), + content: Some(serde_json::json!({ "name": "Alice" })), + }) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(ElicitHandler)).await; + + // CLI broadcasts elicitation.requested as a session event notification + server + .send_event( + "elicitation.requested", + serde_json::json!({ + "requestId": "elicit-1", + "message": "Enter your name", + "requestedSchema": { + "type": "object", + "properties": { "name": { "type": "string" } }, + "required": ["name"] + }, + "mode": "form", + }), + ) + .await; + + // The SDK should call session.ui.handlePendingElicitation RPC + let rpc_call = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(rpc_call["method"], "session.ui.handlePendingElicitation"); + assert_eq!(rpc_call["params"]["requestId"], "elicit-1"); + assert_eq!(rpc_call["params"]["result"]["action"], "accept"); + assert_eq!(rpc_call["params"]["result"]["content"]["name"], "Alice"); +} + +#[tokio::test] +async fn elicitation_requested_cancels_on_handler_error() { + struct FailHandler; + #[async_trait] + impl SessionHandler for FailHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + // Return Ok instead of Elicitation — SDK should treat as cancel + HandlerEvent::ElicitationRequest { .. } => HandlerResponse::Ok, + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(FailHandler)).await; + server + .send_event( + "elicitation.requested", + serde_json::json!({ + "requestId": "elicit-2", + "message": "Pick something", + }), + ) + .await; + + let rpc_call = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(rpc_call["method"], "session.ui.handlePendingElicitation"); + assert_eq!(rpc_call["params"]["result"]["action"], "cancel"); +} + +#[tokio::test] +async fn external_tool_requested_dispatches_to_handler_and_responds() { + struct ExternalToolHandler; + #[async_trait] + impl SessionHandler for ExternalToolHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::ExternalTool { invocation } => { + assert_eq!(invocation.tool_name, "run_tests"); + assert_eq!(invocation.tool_call_id, "tc-ext-1"); + assert_eq!(invocation.arguments["suite"], "unit"); + HandlerResponse::ToolResult(ToolResult::Text("all tests passed".to_string())) + } + _ => HandlerResponse::Ok, + } + } + } + + let (_session, mut server) = create_session_pair(Arc::new(ExternalToolHandler)).await; + + server + .send_event( + "external_tool.requested", + serde_json::json!({ + "requestId": "req-ext-1", + "sessionId": server.session_id, + "toolCallId": "tc-ext-1", + "toolName": "run_tests", + "arguments": { "suite": "unit" }, + }), + ) + .await; + + let rpc_call = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(rpc_call["method"], "session.tools.handlePendingToolCall"); + assert_eq!(rpc_call["params"]["requestId"], "req-ext-1"); + assert_eq!(rpc_call["params"]["result"], "all tests passed"); +} + +#[tokio::test] +async fn capabilities_captured_from_create_response() { + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { + "sessionId": "cap-session", + "capabilities": { + "ui": { "elicitation": true } + } + }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + let caps = session.capabilities(); + assert_eq!(caps.ui.as_ref().unwrap().elicitation, Some(true)); +} + +#[tokio::test] +async fn capabilities_changed_event_updates_session() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + + // Initially no capabilities (create_session_pair doesn't send them) + assert!(session.capabilities().ui.is_none()); + + // CLI sends capabilities.changed event + server + .send_event( + "capabilities.changed", + serde_json::json!({ + "ui": { "elicitation": true } + }), + ) + .await; + + // Poll until the event loop processes the notification + let caps = timeout(TIMEOUT, async { + loop { + let caps = session.capabilities(); + if caps.ui.is_some() { + return caps; + } + tokio::time::sleep(Duration::from_millis(5)).await; + } + }) + .await + .expect("capabilities should update within timeout"); + + assert_eq!(caps.ui.as_ref().unwrap().elicitation, Some(true)); +} + +#[tokio::test] +async fn request_elicitation_sent_in_create_params() { + let (client, mut server_read, mut server_write) = make_client(); + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + + let request = read_framed(&mut server_read).await; + assert_eq!(request["method"], "session.create"); + assert_eq!(request["params"]["requestElicitation"], true); + + let id = request["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "result": { "sessionId": "s-elicit" }, + }); + write_framed(&mut server_write, &serde_json::to_vec(&response).unwrap()).await; + timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); +} + +#[tokio::test] +async fn elicitation_methods_fail_without_capability() { + let (session, _server) = create_session_pair(Arc::new(NoopHandler)).await; + + // Session created without capabilities — elicitation should fail + let err = session + .ui() + .elicitation("test", serde_json::json!({})) + .await + .unwrap_err(); + assert!(matches!( + err, + github_copilot_sdk::Error::Session( + github_copilot_sdk::SessionError::ElicitationNotSupported + ) + )); + + let err = session.ui().confirm("ok?").await.unwrap_err(); + assert!(matches!( + err, + github_copilot_sdk::Error::Session( + github_copilot_sdk::SessionError::ElicitationNotSupported + ) + )); +} + +async fn create_session_pair_with_hooks( + handler: Arc, + hooks: Arc, +) -> (github_copilot_sdk::session::Session, FakeServer) { + let (client, server_read, server_write) = make_client(); + let session_id = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id.clone(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + let handler = handler.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_handler(handler) + .with_hooks(hooks), + ) + .await + .unwrap() + } + }); + + let create_req = server.read_request().await; + assert_eq!(create_req["method"], "session.create"); + // Verify hooks: true is auto-set in the config + assert_eq!(create_req["params"]["hooks"], true); + server + .respond( + &create_req, + serde_json::json!({ + "sessionId": session_id, + "workspacePath": "/tmp/workspace" + }), + ) + .await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + (session, server) +} + +#[tokio::test] +async fn hooks_invoke_dispatches_to_session_hooks() { + use github_copilot_sdk::hooks::{HookEvent, HookOutput, PreToolUseOutput, SessionHooks}; + + struct PolicyHooks; + #[async_trait] + impl SessionHooks for PolicyHooks { + async fn on_hook(&self, event: HookEvent) -> HookOutput { + match event { + HookEvent::PreToolUse { input, .. } => { + if input.tool_name == "rm" { + HookOutput::PreToolUse(PreToolUseOutput { + permission_decision: Some("deny".to_string()), + permission_decision_reason: Some("destructive".to_string()), + ..Default::default() + }) + } else { + HookOutput::None + } + } + _ => HookOutput::None, + } + } + } + + let (_session, mut server) = + create_session_pair_with_hooks(Arc::new(NoopHandler), Arc::new(PolicyHooks)).await; + + // Send a hooks.invoke request for a denied tool + server + .send_request( + 300, + "hooks.invoke", + serde_json::json!({ + "sessionId": server.session_id, + "hookType": "preToolUse", + "input": { + "timestamp": 1234567890, + "cwd": "/tmp", + "toolName": "rm", + "toolArgs": { "path": "/" } + } + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 300); + assert_eq!(response["result"]["output"]["permissionDecision"], "deny"); + assert_eq!( + response["result"]["output"]["permissionDecisionReason"], + "destructive" + ); +} + +#[tokio::test] +async fn hooks_invoke_returns_empty_for_unregistered_hook() { + use github_copilot_sdk::hooks::SessionHooks; + + struct EmptyHooks; + #[async_trait] + impl SessionHooks for EmptyHooks {} + + let (_session, mut server) = + create_session_pair_with_hooks(Arc::new(NoopHandler), Arc::new(EmptyHooks)).await; + + server + .send_request( + 301, + "hooks.invoke", + serde_json::json!({ + "sessionId": server.session_id, + "hookType": "sessionEnd", + "input": { + "timestamp": 1234567890, + "cwd": "/tmp", + "reason": "complete" + } + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 301); + assert_eq!(response["result"]["output"], serde_json::json!({})); +} + +async fn create_session_pair_with_transforms( + handler: Arc, + transforms: Arc, +) -> (github_copilot_sdk::session::Session, FakeServer) { + let (client, server_read, server_write) = make_client(); + let session_id = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id.clone(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + let handler = handler.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_handler(handler) + .with_transform(transforms), + ) + .await + .unwrap() + } + }); + + let create_req = server.read_request().await; + assert_eq!(create_req["method"], "session.create"); + // Verify transforms inject customize mode and section overrides + assert_eq!(create_req["params"]["systemMessage"]["mode"], "customize"); + server + .respond( + &create_req, + serde_json::json!({ + "sessionId": session_id, + "workspacePath": "/tmp/workspace" + }), + ) + .await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + (session, server) +} + +#[tokio::test] +async fn system_message_transform_dispatches_to_transform() { + use github_copilot_sdk::transforms::{SystemMessageTransform, TransformContext}; + + struct AppendTransform; + #[async_trait] + impl SystemMessageTransform for AppendTransform { + fn section_ids(&self) -> Vec { + vec!["instructions".to_string()] + } + + async fn transform_section( + &self, + _section_id: &str, + content: &str, + _ctx: TransformContext, + ) -> Option { + Some(format!("{content}\nAlways be concise.")) + } + } + + let (_session, mut server) = + create_session_pair_with_transforms(Arc::new(NoopHandler), Arc::new(AppendTransform)).await; + + server + .send_request( + 400, + "systemMessage.transform", + serde_json::json!({ + "sessionId": server.session_id, + "sections": { + "instructions": { "content": "You are helpful." } + } + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 400); + assert_eq!( + response["result"]["sections"]["instructions"]["content"], + "You are helpful.\nAlways be concise." + ); +} + +#[tokio::test] +async fn system_message_transform_returns_error_for_missing_sections() { + use github_copilot_sdk::transforms::{SystemMessageTransform, TransformContext}; + + struct DummyTransform; + #[async_trait] + impl SystemMessageTransform for DummyTransform { + fn section_ids(&self) -> Vec { + vec!["instructions".to_string()] + } + + async fn transform_section( + &self, + _section_id: &str, + _content: &str, + _ctx: TransformContext, + ) -> Option { + None + } + } + + let (_session, mut server) = + create_session_pair_with_transforms(Arc::new(NoopHandler), Arc::new(DummyTransform)).await; + + // Send request with no sections parameter + server + .send_request( + 401, + "systemMessage.transform", + serde_json::json!({ + "sessionId": server.session_id, + }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 401); + assert_eq!(response["error"]["code"], -32602); +} + +#[tokio::test] +async fn list_workspace_files_uses_plural_method_name() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let s = session.clone(); + let handle = tokio::spawn(async move { s.list_workspace_files().await }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.workspaces.listFiles"); + assert_eq!(request["params"]["sessionId"], server.session_id); + server + .respond( + &request, + serde_json::json!({ "files": ["a.txt", "subdir/b.txt"] }), + ) + .await; + + let files = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + assert_eq!(files, vec!["a.txt".to_string(), "subdir/b.txt".to_string()]); +} + +#[tokio::test] +async fn read_workspace_file_uses_plural_method_name_and_forwards_path() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let s = session.clone(); + let handle = + tokio::spawn(async move { s.read_workspace_file(Path::new("notes/plan.md")).await }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.workspaces.readFile"); + assert_eq!(request["params"]["sessionId"], server.session_id); + assert_eq!(request["params"]["path"], "notes/plan.md"); + server + .respond(&request, serde_json::json!({ "content": "hello" })) + .await; + + let content = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + assert_eq!(content, "hello"); +} + +#[tokio::test] +async fn create_workspace_file_uses_plural_method_name_and_forwards_payload() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let s = session.clone(); + let handle = tokio::spawn(async move { + s.create_workspace_file(Path::new("notes/plan.md"), "body") + .await + }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.workspaces.createFile"); + assert_eq!(request["params"]["sessionId"], server.session_id); + assert_eq!(request["params"]["path"], "notes/plan.md"); + assert_eq!(request["params"]["content"], "body"); + server.respond(&request, serde_json::json!({})).await; + + timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); +} + +#[tokio::test] +async fn rpc_namespace_session_agent_list_dispatches_correctly() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let s = session.clone(); + let handle = tokio::spawn(async move { s.rpc().agent().list().await }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.agent.list"); + assert_eq!(request["params"]["sessionId"], server.session_id); + server + .respond(&request, serde_json::json!({ "agents": [] })) + .await; + + let result = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + assert!(result.agents.is_empty()); +} + +#[tokio::test] +async fn rpc_namespace_session_tasks_list_dispatches_correctly() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let s = session.clone(); + let handle = tokio::spawn(async move { s.rpc().tasks().list().await }); + + let request = server.read_request().await; + assert_eq!(request["method"], "session.tasks.list"); + assert_eq!(request["params"]["sessionId"], server.session_id); + server + .respond(&request, serde_json::json!({ "tasks": [] })) + .await; + + let result = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + assert!(result.tasks.is_empty()); +} + +#[tokio::test] +async fn rpc_namespace_client_models_list_dispatches_correctly() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let client = session.client().clone(); + let handle = tokio::spawn(async move { client.rpc().models().list().await }); + + let request = server.read_request().await; + assert_eq!(request["method"], "models.list"); + server + .respond(&request, serde_json::json!({ "models": [] })) + .await; + + let result = timeout(TIMEOUT, handle).await.unwrap().unwrap().unwrap(); + assert!(result.models.is_empty()); +} + +#[tokio::test] +async fn client_stop_sends_session_destroy_for_each_active_session() { + // One client, two registered sessions. Client::stop must send + // session.destroy for each before returning Ok. + let (client, server_read, server_write) = make_client(); + let session_id_a = format!("test-session-{}", rand_id()); + let session_id_b = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id_a.clone(), + }; + + // Spawn both create_session calls. + let create_a = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + let create_a_req = server.read_request().await; + assert_eq!(create_a_req["method"], "session.create"); + server + .respond( + &create_a_req, + serde_json::json!({ "sessionId": session_id_a, "workspacePath": "/tmp/ws-a" }), + ) + .await; + let _session_a = timeout(TIMEOUT, create_a).await.unwrap(); + + let create_b = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + let create_b_req = server.read_request().await; + assert_eq!(create_b_req["method"], "session.create"); + server + .respond( + &create_b_req, + serde_json::json!({ "sessionId": session_id_b, "workspacePath": "/tmp/ws-b" }), + ) + .await; + let _session_b = timeout(TIMEOUT, create_b).await.unwrap(); + + // Drive Client::stop and respond to each destroy in turn. + let stop_handle = tokio::spawn({ + let client = client.clone(); + async move { client.stop().await } + }); + + let mut destroyed = Vec::new(); + for _ in 0..2 { + let req = server.read_request().await; + assert_eq!(req["method"], "session.destroy"); + destroyed.push(req["params"]["sessionId"].as_str().unwrap().to_string()); + server.respond(&req, serde_json::json!(null)).await; + } + destroyed.sort(); + let mut expected = [session_id_a.clone(), session_id_b.clone()]; + expected.sort(); + assert_eq!(destroyed, expected); + + let stop_result = timeout(TIMEOUT, stop_handle).await.unwrap().unwrap(); + assert!(stop_result.is_ok(), "stop returned errors: {stop_result:?}"); +} + +#[tokio::test] +async fn client_stop_aggregates_session_destroy_errors() { + // session.destroy fails on the wire — Client::stop returns + // StopErrors carrying the failure rather than short-circuiting. + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let client = session.client().clone(); + + let stop_handle = tokio::spawn(async move { client.stop().await }); + + let req = server.read_request().await; + assert_eq!(req["method"], "session.destroy"); + let id = req["id"].as_u64().unwrap(); + let response = serde_json::json!({ + "jsonrpc": "2.0", + "id": id, + "error": { "code": -32000, "message": "session gone" }, + }); + write_framed(&mut server.write, &serde_json::to_vec(&response).unwrap()).await; + + let stop_result = timeout(TIMEOUT, stop_handle).await.unwrap().unwrap(); + let errors = stop_result.expect_err("expected aggregated errors"); + assert_eq!(errors.errors().len(), 1); + let msg = errors.to_string(); + assert!(msg.contains("session gone"), "unexpected message: {msg}"); +} + +#[test] +fn session_config_serializes_bucket_b_fields() { + use std::path::PathBuf; + + use github_copilot_sdk::{SessionConfig, SessionId}; + + let cfg = { + let mut cfg = SessionConfig::default(); + cfg.session_id = Some(SessionId::from("custom-id")); + cfg.config_dir = Some(PathBuf::from("/tmp/cfg")); + cfg.working_directory = Some(PathBuf::from("/tmp/work")); + cfg.github_token = Some("ghs_secret".to_string()); + cfg.include_sub_agent_streaming_events = Some(false); + cfg + }; + let json = serde_json::to_value(&cfg).unwrap(); + assert_eq!(json["sessionId"], "custom-id"); + assert_eq!(json["configDir"], "/tmp/cfg"); + assert_eq!(json["workingDirectory"], "/tmp/work"); + assert_eq!(json["gitHubToken"], "ghs_secret"); + assert_eq!(json["includeSubAgentStreamingEvents"], false); + + // Debug never leaks the token. + let debug = format!("{cfg:?}"); + assert!(!debug.contains("ghs_secret"), "leaked token: {debug}"); + assert!(debug.contains(""), "missing redaction: {debug}"); + + // Unset fields are omitted on the wire. + let empty = serde_json::to_value(SessionConfig::default()).unwrap(); + assert!(empty.get("sessionId").is_none()); + assert!(empty.get("gitHubToken").is_none()); +} + +#[test] +fn resume_session_config_serializes_bucket_b_fields() { + use std::path::PathBuf; + + use github_copilot_sdk::{ResumeSessionConfig, SessionId}; + + let mut cfg = ResumeSessionConfig::new(SessionId::from("sess-1")); + cfg.working_directory = Some(PathBuf::from("/tmp/work")); + cfg.config_dir = Some(PathBuf::from("/tmp/cfg")); + cfg.github_token = Some("ghs_secret".to_string()); + cfg.include_sub_agent_streaming_events = Some(true); + let json = serde_json::to_value(&cfg).unwrap(); + assert_eq!(json["sessionId"], "sess-1"); + assert_eq!(json["workingDirectory"], "/tmp/work"); + assert_eq!(json["configDir"], "/tmp/cfg"); + assert_eq!(json["gitHubToken"], "ghs_secret"); + assert_eq!(json["includeSubAgentStreamingEvents"], true); + + let debug = format!("{cfg:?}"); + assert!(!debug.contains("ghs_secret"), "leaked token: {debug}"); +} + +// ===================================================================== +// Slash commands (§ 4.1) +// ===================================================================== + +struct CountingCommandHandler { + last_ctx: Arc>>, + error_to_return: Option, +} + +#[async_trait] +impl CommandHandler for CountingCommandHandler { + async fn on_command(&self, ctx: CommandContext) -> Result<(), github_copilot_sdk::Error> { + *self.last_ctx.lock() = Some(ctx); + if let Some(message) = &self.error_to_return { + Err(github_copilot_sdk::Error::Session( + github_copilot_sdk::SessionError::AgentError(message.clone()), + )) + } else { + Ok(()) + } + } +} + +async fn create_session_pair_with_commands( + handler: Arc, + commands: Vec, +) -> (github_copilot_sdk::session::Session, FakeServer, Value) { + let (client, server_read, server_write) = make_client(); + let session_id = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id.clone(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + let handler = handler.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_handler(handler) + .with_commands(commands), + ) + .await + .unwrap() + } + }); + + let create_req = server.read_request().await; + assert_eq!(create_req["method"], "session.create"); + server + .respond( + &create_req, + serde_json::json!({ + "sessionId": session_id, + "workspacePath": "/tmp/workspace" + }), + ) + .await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + (session, server, create_req) +} + +#[tokio::test] +async fn create_serializes_commands_strips_handler() { + let last_ctx = Arc::new(parking_lot::Mutex::new(None)); + let commands = vec![ + CommandDefinition::new( + "deploy", + Arc::new(CountingCommandHandler { + last_ctx: last_ctx.clone(), + error_to_return: None, + }), + ) + .with_description("Deploy to production"), + CommandDefinition::new( + "rollback", + Arc::new(CountingCommandHandler { + last_ctx: last_ctx.clone(), + error_to_return: None, + }), + ), + ]; + + let (_session, _server, create_req) = + create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + + let wire = create_req["params"]["commands"] + .as_array() + .expect("commands should be an array"); + assert_eq!(wire.len(), 2); + + let deploy = &wire[0]; + assert_eq!(deploy["name"], "deploy"); + assert_eq!(deploy["description"], "Deploy to production"); + assert!( + deploy.get("handler").is_none(), + "wire payload must not include handler, got: {deploy}" + ); + let deploy_keys: Vec<&String> = deploy.as_object().unwrap().keys().collect(); + assert_eq!(deploy_keys.len(), 2, "got keys: {deploy_keys:?}"); + + let rollback = &wire[1]; + assert_eq!(rollback["name"], "rollback"); + assert!( + rollback.get("description").is_none(), + "description should be omitted when None, got: {rollback}" + ); + assert!(rollback.get("handler").is_none()); + let rollback_keys: Vec<&String> = rollback.as_object().unwrap().keys().collect(); + assert_eq!(rollback_keys.len(), 1, "got keys: {rollback_keys:?}"); +} + +#[tokio::test] +async fn command_execute_dispatches_to_registered_handler_and_acks_success() { + let last_ctx = Arc::new(parking_lot::Mutex::new(None)); + let commands = vec![CommandDefinition::new( + "deploy", + Arc::new(CountingCommandHandler { + last_ctx: last_ctx.clone(), + error_to_return: None, + }), + )]; + + let (session, mut server, _) = + create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + + server + .send_event( + "command.execute", + serde_json::json!({ + "requestId": "req-deploy-1", + "command": "/deploy production", + "commandName": "deploy", + "args": "production", + }), + ) + .await; + + let ack = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!( + ack["method"], "session.commands.handlePendingCommand", + "expected handlePendingCommand RPC, got: {ack}" + ); + assert_eq!( + ack["params"]["sessionId"].as_str(), + Some(session.id().as_ref()) + ); + assert_eq!(ack["params"]["requestId"], "req-deploy-1"); + assert!( + ack["params"].get("error").is_none(), + "success ack should omit error, got: {ack}" + ); + + server + .respond(&ack, serde_json::json!({ "success": true })) + .await; + + let ctx = last_ctx + .lock() + .clone() + .expect("handler should have been invoked"); + assert_eq!(ctx.command, "/deploy production"); + assert_eq!(ctx.command_name, "deploy"); + assert_eq!(ctx.args, "production"); + assert_eq!(ctx.session_id.as_ref(), session.id().as_ref()); +} + +#[tokio::test] +async fn command_execute_unknown_command_acks_with_error() { + let (session, mut server, _) = + create_session_pair_with_commands(Arc::new(NoopHandler), vec![]).await; + + server + .send_event( + "command.execute", + serde_json::json!({ + "requestId": "req-unknown-1", + "command": "/missing", + "commandName": "missing", + "args": "", + }), + ) + .await; + + let ack = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(ack["method"], "session.commands.handlePendingCommand"); + assert_eq!(ack["params"]["requestId"], "req-unknown-1"); + assert_eq!( + ack["params"]["error"], "Unknown command: missing", + "got: {ack}" + ); + server + .respond(&ack, serde_json::json!({ "success": false })) + .await; + drop(session); +} + +#[tokio::test] +async fn command_execute_handler_error_propagates_to_ack() { + let last_ctx = Arc::new(parking_lot::Mutex::new(None)); + let commands = vec![CommandDefinition::new( + "fail", + Arc::new(CountingCommandHandler { + last_ctx: last_ctx.clone(), + error_to_return: Some("deploy failed: dry-run rejected".to_string()), + }), + )]; + + let (_session, mut server, _) = + create_session_pair_with_commands(Arc::new(NoopHandler), commands).await; + + server + .send_event( + "command.execute", + serde_json::json!({ + "requestId": "req-fail-1", + "command": "/fail", + "commandName": "fail", + "args": "", + }), + ) + .await; + + let ack = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(ack["method"], "session.commands.handlePendingCommand"); + assert_eq!(ack["params"]["requestId"], "req-fail-1"); + let error_msg = ack["params"]["error"] + .as_str() + .expect("ack should include error"); + assert!( + error_msg.contains("deploy failed: dry-run rejected"), + "expected handler error in ack, got: {error_msg}" + ); + server + .respond(&ack, serde_json::json!({ "success": false })) + .await; +} + +// SessionFsProvider tests -------------------------------------------------- + +use github_copilot_sdk::session_fs::{ + DirEntry, DirEntryKind, FileInfo, FsError, SessionFsConventions, SessionFsProvider, +}; + +struct RecordingFsProvider { + files: parking_lot::Mutex>, +} + +impl RecordingFsProvider { + fn new() -> Self { + Self { + files: parking_lot::Mutex::new(std::collections::HashMap::new()), + } + } + + fn with_file(self, path: &str, content: &str) -> Self { + self.files + .lock() + .insert(path.to_string(), content.to_string()); + self + } +} + +#[async_trait] +impl SessionFsProvider for RecordingFsProvider { + async fn read_file(&self, path: &str) -> Result { + self.files + .lock() + .get(path) + .cloned() + .ok_or_else(|| FsError::NotFound(path.to_string())) + } + + async fn write_file( + &self, + path: &str, + content: &str, + _mode: Option, + ) -> Result<(), FsError> { + self.files + .lock() + .insert(path.to_string(), content.to_string()); + Ok(()) + } + + async fn stat(&self, path: &str) -> Result { + let files = self.files.lock(); + let content = files + .get(path) + .ok_or_else(|| FsError::NotFound(path.to_string()))?; + Ok(FileInfo::new( + true, + false, + content.len() as i64, + "2025-01-01T00:00:00Z", + "2025-01-01T00:00:00Z", + )) + } + + async fn readdir_with_types(&self, _path: &str) -> Result, FsError> { + Ok(vec![ + DirEntry::new("README.md", DirEntryKind::File), + DirEntry::new("src", DirEntryKind::Directory), + ]) + } + + async fn rm(&self, path: &str, _recursive: bool, force: bool) -> Result<(), FsError> { + let mut files = self.files.lock(); + if files.remove(path).is_none() && !force { + return Err(FsError::NotFound(path.to_string())); + } + Ok(()) + } +} + +async fn create_session_pair_with_fs_provider( + handler: Arc, + provider: Arc, +) -> (github_copilot_sdk::session::Session, FakeServer) { + let (client, server_read, server_write) = make_client(); + let session_id = format!("test-session-{}", rand_id()); + + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: session_id.clone(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + let handler = handler.clone(); + async move { + client + .create_session( + SessionConfig::default() + .with_handler(handler) + .with_session_fs_provider(provider), + ) + .await + .unwrap() + } + }); + + let create_req = server.read_request().await; + assert_eq!(create_req["method"], "session.create"); + server + .respond( + &create_req, + serde_json::json!({ + "sessionId": session_id, + "workspacePath": "/tmp/workspace" + }), + ) + .await; + + let session = timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + (session, server) +} + +#[tokio::test] +async fn session_fs_dispatches_read_file_to_provider() { + let provider = Arc::new(RecordingFsProvider::new().with_file("/foo.txt", "hello world")); + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + + server + .send_request( + 42, + "sessionFs.readFile", + serde_json::json!({ "sessionId": server.session_id, "path": "/foo.txt" }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 42); + assert_eq!(response["result"]["content"], "hello world"); + assert!(response["result"].get("error").is_none() || response["result"]["error"].is_null()); +} + +#[tokio::test] +async fn session_fs_maps_not_found_to_enoent() { + let provider = Arc::new(RecordingFsProvider::new()); + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + + server + .send_request( + 7, + "sessionFs.readFile", + serde_json::json!({ "sessionId": server.session_id, "path": "/missing.txt" }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 7); + let error = &response["result"]["error"]; + assert_eq!(error["code"], "ENOENT"); + assert!(error["message"].as_str().unwrap().contains("missing.txt")); +} + +#[tokio::test] +async fn session_fs_maps_other_to_unknown() { + struct AlwaysFails; + #[async_trait] + impl SessionFsProvider for AlwaysFails { + async fn stat(&self, _path: &str) -> Result { + Err(FsError::Other("backing store unavailable".to_string())) + } + } + + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), Arc::new(AlwaysFails)).await; + + server + .send_request( + 8, + "sessionFs.stat", + serde_json::json!({ "sessionId": server.session_id, "path": "/x" }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + let error = &response["result"]["error"]; + assert_eq!(error["code"], "UNKNOWN"); + assert!( + error["message"] + .as_str() + .unwrap() + .contains("backing store unavailable") + ); +} + +#[tokio::test] +async fn session_fs_dispatches_write_file_with_mode() { + let provider = Arc::new(RecordingFsProvider::new()); + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider.clone()).await; + + server + .send_request( + 10, + "sessionFs.writeFile", + serde_json::json!({ "sessionId": server.session_id, "path": "/out.txt", "content": "abc", "mode": 420 }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 10); + assert!(response["result"].get("error").is_none() || response["result"]["error"].is_null()); + assert_eq!(provider.files.lock().get("/out.txt").unwrap(), "abc"); +} + +#[tokio::test] +async fn session_fs_dispatches_readdir_with_types() { + let provider = Arc::new(RecordingFsProvider::new()); + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + + server + .send_request( + 11, + "sessionFs.readdirWithTypes", + serde_json::json!({ "sessionId": server.session_id, "path": "/dir" }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + let entries = response["result"]["entries"].as_array().unwrap(); + assert_eq!(entries.len(), 2); + assert_eq!(entries[0]["name"], "README.md"); + assert_eq!(entries[0]["type"], "file"); + assert_eq!(entries[1]["name"], "src"); + assert_eq!(entries[1]["type"], "directory"); +} + +#[tokio::test] +async fn session_fs_dispatches_rm_with_force() { + let provider = Arc::new(RecordingFsProvider::new()); + let (_session, mut server) = + create_session_pair_with_fs_provider(Arc::new(NoopHandler), provider).await; + + server + .send_request( + 12, + "sessionFs.rm", + serde_json::json!({ "sessionId": server.session_id, "path": "/missing", "force": true, "recursive": false }), + ) + .await; + + let response = timeout(TIMEOUT, server.read_response()).await.unwrap(); + assert_eq!(response["id"], 12); + assert!(response["result"].get("error").is_none() || response["result"]["error"].is_null()); +} + +#[tokio::test] +async fn validate_session_fs_config_rejects_empty_initial_cwd() { + let cfg = github_copilot_sdk::session_fs::SessionFsConfig::new( + "", + "/state", + SessionFsConventions::Posix, + ); + let opts = { + let mut opts = github_copilot_sdk::ClientOptions::default(); + opts.session_fs = Some(cfg); + opts + }; + let err = github_copilot_sdk::Client::start(opts).await.err(); + let err_string = format!("{err:?}"); + assert!( + err_string.contains("initial_cwd") || err_string.contains("InvalidSessionFsConfig"), + "got: {err_string}" + ); +} + +#[tokio::test] +async fn create_session_errors_when_provider_required_but_missing() { + // Without a CLI we can't exercise the configured-but-missing-provider path + // through Client::start; the unit-level behavior is covered by the + // SessionError::SessionFsProviderRequired variant being constructible. + // This test asserts the error type's display formatting is stable. + let err = github_copilot_sdk::SessionError::SessionFsProviderRequired; + assert!(format!("{err}").contains("session_fs")); +} + +// ---------- 4.3 trace context tests ---------- + +struct StaticTraceProvider { + ctx: github_copilot_sdk::types::TraceContext, + calls: Arc, +} + +#[async_trait] +impl github_copilot_sdk::types::TraceContextProvider for StaticTraceProvider { + async fn get_trace_context(&self) -> github_copilot_sdk::types::TraceContext { + self.calls.fetch_add(1, Ordering::Relaxed); + self.ctx.clone() + } +} + +fn make_client_with_trace_provider( + provider: Arc, +) -> (Client, tokio::io::DuplexStream, tokio::io::DuplexStream) { + let (client_write, server_read) = duplex(8192); + let (server_write, client_read) = duplex(8192); + let client = Client::from_streams_with_trace_provider( + client_read, + client_write, + std::env::temp_dir(), + provider, + ) + .unwrap(); + (client, server_read, server_write) +} + +#[tokio::test] +async fn on_get_trace_context_called_on_session_create() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = Arc::new(StaticTraceProvider { + ctx: github_copilot_sdk::types::TraceContext::from_traceparent("00-aaaa-bbbb-01") + .with_tracestate("vendor=value"), + calls: calls.clone(), + }); + let (client, server_read, server_write) = make_client_with_trace_provider(provider); + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: "trace-create".to_string(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + + let req = server.read_request().await; + assert_eq!(req["method"], "session.create"); + assert_eq!(req["params"]["traceparent"], "00-aaaa-bbbb-01"); + assert_eq!(req["params"]["tracestate"], "vendor=value"); + server + .respond( + &req, + serde_json::json!({"sessionId": "trace-create", "workspacePath": "/tmp/ws"}), + ) + .await; + timeout(TIMEOUT, create_handle).await.unwrap().unwrap(); + assert_eq!(calls.load(Ordering::Relaxed), 1); +} + +#[tokio::test] +async fn on_get_trace_context_called_on_session_resume() { + use github_copilot_sdk::types::ResumeSessionConfig; + let calls = Arc::new(AtomicUsize::new(0)); + let provider = Arc::new(StaticTraceProvider { + ctx: github_copilot_sdk::types::TraceContext::from_traceparent("00-resume-trace-01"), + calls: calls.clone(), + }); + let (client, server_read, server_write) = make_client_with_trace_provider(provider); + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: "trace-resume".to_string(), + }; + + let resume_handle = tokio::spawn({ + let client = client.clone(); + async move { + let cfg = ResumeSessionConfig::new(SessionId::from("trace-resume")) + .with_handler(Arc::new(NoopHandler)); + client.resume_session(cfg).await.unwrap() + } + }); + + // resume sends `session.resume` then `session.skills.reload`. + let req = server.read_request().await; + assert_eq!(req["method"], "session.resume"); + assert_eq!(req["params"]["traceparent"], "00-resume-trace-01"); + assert!( + req["params"].get("tracestate").is_none(), + "tracestate should be omitted when None" + ); + server + .respond( + &req, + serde_json::json!({"sessionId": "trace-resume", "workspacePath": "/tmp/ws"}), + ) + .await; + let reload_req = server.read_request().await; + assert_eq!(reload_req["method"], "session.skills.reload"); + server.respond(&reload_req, serde_json::json!({})).await; + + timeout(TIMEOUT, resume_handle).await.unwrap().unwrap(); + assert_eq!(calls.load(Ordering::Relaxed), 1); +} + +#[tokio::test] +async fn on_get_trace_context_called_on_session_send() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = Arc::new(StaticTraceProvider { + ctx: github_copilot_sdk::types::TraceContext::from_traceparent("00-send-trace-01"), + calls: calls.clone(), + }); + let (client, server_read, server_write) = make_client_with_trace_provider(provider); + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: "trace-send".to_string(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + let create_req = server.read_request().await; + server + .respond( + &create_req, + serde_json::json!({"sessionId": "trace-send", "workspacePath": "/tmp/ws"}), + ) + .await; + let session = Arc::new(timeout(TIMEOUT, create_handle).await.unwrap().unwrap()); + + // Provider was called once for create; reset by reading the count baseline. + let baseline = calls.load(Ordering::Relaxed); + assert_eq!(baseline, 1, "create_session should call the provider once"); + + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { session.send(MessageOptions::new("hi")).await } + }); + let send_req = server.read_request().await; + assert_eq!(send_req["method"], "session.send"); + assert_eq!(send_req["params"]["traceparent"], "00-send-trace-01"); + server.respond(&send_req, serde_json::json!({})).await; + timeout(TIMEOUT, send_handle) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(calls.load(Ordering::Relaxed), baseline + 1); +} + +#[tokio::test] +async fn message_options_trace_context_overrides_callback() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = Arc::new(StaticTraceProvider { + ctx: github_copilot_sdk::types::TraceContext::from_traceparent("00-callback-01"), + calls: calls.clone(), + }); + let (client, server_read, server_write) = make_client_with_trace_provider(provider); + let mut server = FakeServer { + read: server_read, + write: server_write, + session_id: "trace-override".to_string(), + }; + + let create_handle = tokio::spawn({ + let client = client.clone(); + async move { + client + .create_session(SessionConfig::default().with_handler(Arc::new(NoopHandler))) + .await + .unwrap() + } + }); + let create_req = server.read_request().await; + server + .respond( + &create_req, + serde_json::json!({"sessionId": "trace-override", "workspacePath": "/tmp/ws"}), + ) + .await; + let session = Arc::new(timeout(TIMEOUT, create_handle).await.unwrap().unwrap()); + + let baseline = calls.load(Ordering::Relaxed); + + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send( + MessageOptions::new("hi") + .with_traceparent("00-override-01") + .with_tracestate("vendor=override"), + ) + .await + } + }); + let send_req = server.read_request().await; + assert_eq!(send_req["params"]["traceparent"], "00-override-01"); + assert_eq!(send_req["params"]["tracestate"], "vendor=override"); + server.respond(&send_req, serde_json::json!({})).await; + timeout(TIMEOUT, send_handle) + .await + .unwrap() + .unwrap() + .unwrap(); + + // Callback must NOT have been invoked when MessageOptions carried an override. + assert_eq!( + calls.load(Ordering::Relaxed), + baseline, + "callback should be skipped when MessageOptions carries trace headers" + ); +} + +#[tokio::test] +async fn message_options_trace_context_used_without_callback() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { + session + .send(MessageOptions::new("hi").with_traceparent("00-direct-01")) + .await + } + }); + let req = server.read_request().await; + assert_eq!(req["method"], "session.send"); + assert_eq!(req["params"]["traceparent"], "00-direct-01"); + assert!( + req["params"].get("tracestate").is_none(), + "tracestate should be omitted when only traceparent is set" + ); + server.respond(&req, serde_json::json!({})).await; + timeout(TIMEOUT, send_handle) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +#[tokio::test] +async fn tool_invocation_carries_trace_context_from_event() { + use github_copilot_sdk::handler::{HandlerEvent, HandlerResponse, SessionHandler}; + + struct CapturingHandler { + captured: parking_lot::Mutex, Option)>>, + signal: tokio::sync::Notify, + } + + #[async_trait] + impl SessionHandler for CapturingHandler { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + if let HandlerEvent::ExternalTool { invocation } = event { + *self.captured.lock() = Some(( + invocation.traceparent.clone(), + invocation.tracestate.clone(), + )); + self.signal.notify_one(); + return HandlerResponse::ToolResult(ToolResult::Text("ok".into())); + } + HandlerResponse::Ok + } + } + + let handler = Arc::new(CapturingHandler { + captured: parking_lot::Mutex::new(None), + signal: tokio::sync::Notify::new(), + }); + let (_session, mut server) = create_session_pair(handler.clone()).await; + + server + .send_event( + "external_tool.requested", + serde_json::json!({ + "requestId": "req-1", + "sessionId": server.session_id, + "toolCallId": "tc-1", + "toolName": "calc", + "arguments": {"x": 1}, + "traceparent": "00-tool-01", + "tracestate": "vendor=tool", + }), + ) + .await; + + // Drain the handlePendingToolCall RPC the dispatcher sends after the handler runs. + let pending = timeout(TIMEOUT, server.read_request()).await.unwrap(); + assert_eq!(pending["method"], "session.tools.handlePendingToolCall"); + + timeout(TIMEOUT, handler.signal.notified()).await.unwrap(); + let captured = handler.captured.lock().clone(); + assert_eq!( + captured, + Some((Some("00-tool-01".into()), Some("vendor=tool".into()))), + ); +} + +#[tokio::test] +async fn wire_omits_trace_fields_when_unset() { + let (session, mut server) = create_session_pair(Arc::new(NoopHandler)).await; + let session = Arc::new(session); + + let send_handle = tokio::spawn({ + let session = session.clone(); + async move { session.send(MessageOptions::new("hi")).await } + }); + let req = server.read_request().await; + assert!(req["params"].get("traceparent").is_none()); + assert!(req["params"].get("tracestate").is_none()); + server.respond(&req, serde_json::json!({})).await; + timeout(TIMEOUT, send_handle) + .await + .unwrap() + .unwrap() + .unwrap(); +} diff --git a/scripts/codegen/package.json b/scripts/codegen/package.json index a2df5dded..c42713d84 100644 --- a/scripts/codegen/package.json +++ b/scripts/codegen/package.json @@ -3,11 +3,12 @@ "private": true, "type": "module", "scripts": { - "generate": "tsx typescript.ts && tsx csharp.ts && tsx python.ts && tsx go.ts", + "generate": "tsx typescript.ts && tsx csharp.ts && tsx python.ts && tsx go.ts && tsx rust.ts", "generate:ts": "tsx typescript.ts", "generate:csharp": "tsx csharp.ts", "generate:python": "tsx python.ts", - "generate:go": "tsx go.ts" + "generate:go": "tsx go.ts", + "generate:rust": "tsx rust.ts" }, "dependencies": { "json-schema": "^0.4.0", diff --git a/scripts/codegen/rust.ts b/scripts/codegen/rust.ts new file mode 100644 index 000000000..e4d9a74ff --- /dev/null +++ b/scripts/codegen/rust.ts @@ -0,0 +1,1374 @@ +/** + * Rust code generator for the Copilot protocol JSON Schemas. + * + * Reads api.schema.json and session-events.schema.json, emits idiomatic Rust + * types to rust/src/generated/. + * + * Usage: npx tsx scripts/codegen/rust.ts + */ + +import { execFile } from "child_process"; +import fs from "fs/promises"; +import path from "path"; +import { promisify } from "util"; +import type { JSONSchema7, JSONSchema7Definition } from "json-schema"; +import { + type ApiSchema, + type DefinitionCollections, + EXCLUDED_EVENT_TYPES, + REPO_ROOT, + type RpcMethod, + collectDefinitionCollections, + collectDefinitions, + getApiSchemaPath, + getRpcSchemaTypeName, + getSessionEventsSchemaPath, + isObjectSchema, + isRpcMethod, + isSchemaDeprecated, + isVoidSchema, + postProcessSchema, + refTypeName, + resolveObjectSchema, + resolveRef, + resolveSchema, +} from "./utils.js"; + +const execFileAsync = promisify(execFile); + +const GENERATED_DIR = path.join(REPO_ROOT, "rust/src/generated"); + +/** + * JSON property names that should be emitted as a hand-authored newtype rather + * than `String`. The newtype is `#[serde(transparent)]`, so the wire format is + * unchanged. Add new entries sparingly — these only fire when a schema field + * has type `string` and an exact-match name in this map. + */ +const STRING_NEWTYPE_OVERRIDES: Record = { + sessionId: "SessionId", + remoteSessionId: "SessionId", + requestId: "RequestId", +}; + +// ── Naming helpers ────────────────────────────────────────────────────────── + +function toPascalCase(s: string): string { + return s + .split(/[._\-\s]+/) + .map((w) => w.charAt(0).toUpperCase() + w.slice(1)) + .join(""); +} + +function toSnakeCase(s: string): string { + return s + .replace(/([A-Z])/g, "_$1") + .replace(/^_/, "") + .replace(/[.\-\s]+/g, "_") + .toLowerCase() + .replace(/_+/g, "_"); +} + +/** Convert a JSON property name (camelCase) to a Rust field name (snake_case). */ +function toRustFieldName(jsonName: string): string { + return toSnakeCase(jsonName); +} + +/** Convert snake_case back to camelCase (matches serde's rename_all = "camelCase"). */ +function snakeToCamelCase(snake: string): string { + return snake.replace(/_([a-z0-9])/g, (_, c: string) => c.toUpperCase()); +} + +/** + * Rust reserved keywords that need raw identifier syntax (r#). + */ +const RUST_KEYWORDS = new Set([ + "as", + "async", + "await", + "break", + "const", + "continue", + "crate", + "dyn", + "else", + "enum", + "extern", + "false", + "fn", + "for", + "if", + "impl", + "in", + "let", + "loop", + "match", + "mod", + "move", + "mut", + "pub", + "ref", + "return", + "self", + "Self", + "static", + "struct", + "super", + "trait", + "true", + "type", + "unsafe", + "use", + "where", + "while", + "yield", +]); + +function safeRustFieldName(name: string): string { + const snake = toRustFieldName(name); + return RUST_KEYWORDS.has(snake) ? `r#${snake}` : snake; +} + +// ── Codegen context ───────────────────────────────────────────────────────── + +interface RustCodegenCtx { + /** Accumulated struct definitions. */ + structs: string[]; + /** Accumulated enum definitions. */ + enums: string[]; + /** Track generated type names to avoid duplicates. */ + generatedNames: Set; + /** Schema definitions for $ref resolution. */ + definitions?: DefinitionCollections; +} + +function stripOption(typeName: string): string { + return typeName.startsWith("Option<") && typeName.endsWith(">") + ? typeName.slice("Option<".length, -1) + : typeName; +} + +function getUnionVariants(schema: JSONSchema7): JSONSchema7[] | null { + if (schema.anyOf) return schema.anyOf as JSONSchema7[]; + if (schema.oneOf) return schema.oneOf as JSONSchema7[]; + return null; +} + +function tryEmitRustDiscriminatedUnion( + schema: JSONSchema7, + parentTypeName: string, + jsonPropName: string, + ctx: RustCodegenCtx, +): string | null { + const variants = getUnionVariants(schema); + if (!variants) return null; + + const nonNull = variants.filter((variant) => variant.type !== "null"); + if (nonNull.length <= 1) return null; + + const enumName = + (typeof schema.title === "string" && schema.title) || + parentTypeName + toPascalCase(jsonPropName); + + const resolvedVariants = nonNull.map((variant) => { + if (variant.$ref && typeof variant.$ref === "string") { + const resolved = resolveRef(variant.$ref, ctx.definitions); + return { + schema: (resolved ?? variant) as JSONSchema7, + typeName: toPascalCase(refTypeName(variant.$ref, ctx.definitions)), + }; + } + + const resolved = + resolveObjectSchema(variant, ctx.definitions) ?? + resolveSchema(variant, ctx.definitions) ?? + variant; + const kindConst = (resolved.properties?.kind as JSONSchema7 | undefined) + ?.const; + const typeName = + (typeof resolved.title === "string" && resolved.title) || + (typeof kindConst === "string" + ? `${enumName}${toPascalCase(kindConst)}` + : `${enumName}Variant`); + + return { + schema: resolved as JSONSchema7, + typeName, + }; + }); + + const isDiscriminated = resolvedVariants.every( + ({ schema: variantSchema }) => { + if (!isObjectSchema(variantSchema) || !variantSchema.properties) + return false; + const kind = variantSchema.properties.kind as JSONSchema7 | undefined; + return typeof kind?.const === "string"; + }, + ); + if (!isDiscriminated) return null; + + if (ctx.generatedNames.has(enumName)) { + return enumName; + } + ctx.generatedNames.add(enumName); + + for (const { schema: variantSchema, typeName } of resolvedVariants) { + if (isObjectSchema(variantSchema)) { + emitRustStruct(typeName, variantSchema, ctx); + } + } + + const lines: string[] = []; + if (schema.description) { + for (const line of schema.description.split(/\r?\n/)) { + lines.push(`/// ${line}`); + } + } + lines.push("#[derive(Debug, Clone, Serialize, Deserialize)]"); + lines.push("#[serde(untagged)]"); + lines.push(`pub enum ${enumName} {`); + + for (const { schema: variantSchema, typeName } of resolvedVariants) { + const kind = ((variantSchema.properties?.kind as JSONSchema7 | undefined) + ?.const ?? typeName) as string; + lines.push(` ${toPascalCase(kind)}(${stripOption(typeName)}),`); + } + + lines.push("}"); + ctx.enums.push(lines.join("\n")); + return enumName; +} + +function makeCtx(definitions?: DefinitionCollections): RustCodegenCtx { + return { + structs: [], + enums: [], + generatedNames: new Set(), + definitions, + }; +} + +// ── Type resolution ───────────────────────────────────────────────────────── + +/** + * Map a JSON Schema to a Rust type string. Emits nested type definitions as + * side effects into ctx. + */ +function resolveRustType( + propSchema: JSONSchema7, + parentTypeName: string, + jsonPropName: string, + isRequired: boolean, + ctx: RustCodegenCtx, +): string { + const nestedName = parentTypeName + toPascalCase(jsonPropName); + + // $ref — resolve and recurse + if (propSchema.$ref && typeof propSchema.$ref === "string") { + const typeName = toPascalCase( + refTypeName(propSchema.$ref, ctx.definitions), + ); + const resolved = resolveRef(propSchema.$ref, ctx.definitions); + if (resolved) { + if (resolved.enum) { + emitRustStringEnum( + typeName, + resolved.enum as string[], + ctx, + resolved.description, + ); + return wrapOption(typeName, isRequired); + } + if (isObjectSchema(resolved)) { + emitRustStruct(typeName, resolved, ctx); + return wrapOption(typeName, isRequired); + } + return resolveRustType( + resolved, + parentTypeName, + jsonPropName, + isRequired, + ctx, + ); + } + return wrapOption(typeName, isRequired); + } + + // anyOf — nullable pattern or union + if (propSchema.anyOf) { + const discriminatedUnion = tryEmitRustDiscriminatedUnion( + propSchema, + parentTypeName, + jsonPropName, + ctx, + ); + if (discriminatedUnion) { + return wrapOption(discriminatedUnion, isRequired); + } + + const nonNull = (propSchema.anyOf as JSONSchema7[]).filter( + (s) => s.type !== "null", + ); + const hasNull = (propSchema.anyOf as JSONSchema7[]).some( + (s) => s.type === "null", + ); + + if (nonNull.length === 1) { + const innerType = resolveRustType( + nonNull[0], + parentTypeName, + jsonPropName, + true, + ctx, + ); + if (isRequired && !hasNull) return innerType; + return wrapOption(innerType, false); + } + + if (nonNull.length > 1) { + // Multi-type union — use serde_json::Value as escape hatch + return wrapOption("serde_json::Value", isRequired); + } + } + + // oneOf — treat like anyOf for now + if (propSchema.oneOf) { + const discriminatedUnion = tryEmitRustDiscriminatedUnion( + propSchema, + parentTypeName, + jsonPropName, + ctx, + ); + if (discriminatedUnion) { + return wrapOption(discriminatedUnion, isRequired); + } + + const nonNull = (propSchema.oneOf as JSONSchema7[]).filter( + (s) => s.type !== "null", + ); + if (nonNull.length === 1) { + const innerType = resolveRustType( + nonNull[0], + parentTypeName, + jsonPropName, + true, + ctx, + ); + return wrapOption(innerType, isRequired); + } + return wrapOption("serde_json::Value", isRequired); + } + + // allOf — merge and treat as object + if (propSchema.allOf) { + const merged = resolveObjectSchema(propSchema, ctx.definitions); + if (merged && isObjectSchema(merged)) { + const structName = (propSchema.title as string) || nestedName; + emitRustStruct(structName, merged, ctx); + return wrapOption(structName, isRequired); + } + } + + // enum + if (propSchema.enum && Array.isArray(propSchema.enum)) { + const enumName = (propSchema.title as string) || nestedName; + emitRustStringEnum( + enumName, + propSchema.enum as string[], + ctx, + propSchema.description, + ); + return wrapOption(enumName, isRequired); + } + + // const — just a string + if (propSchema.const !== undefined) { + if (typeof propSchema.const === "string") { + const enumName = (propSchema.title as string) || nestedName; + emitRustConstStringEnum( + enumName, + propSchema.const, + ctx, + propSchema.description, + ); + return wrapOption(enumName, isRequired); + } + return wrapOption("serde_json::Value", isRequired); + } + + const schemaType = propSchema.type; + + // Type arrays like ["string", "null"] + if (Array.isArray(schemaType)) { + const nonNullTypes = (schemaType as string[]).filter((t) => t !== "null"); + if (nonNullTypes.length === 1) { + const inner = resolveRustType( + { ...propSchema, type: nonNullTypes[0] as JSONSchema7["type"] }, + parentTypeName, + jsonPropName, + true, + ctx, + ); + return wrapOption(inner, false); + } + return wrapOption("serde_json::Value", isRequired); + } + + // Primitive types + if (schemaType === "string") { + const newtype = STRING_NEWTYPE_OVERRIDES[jsonPropName]; + if (newtype) return wrapOption(newtype, isRequired); + return wrapOption("String", isRequired); + } + if (schemaType === "number") return wrapOption("f64", isRequired); + if (schemaType === "integer") return wrapOption("i64", isRequired); + if (schemaType === "boolean") return wrapOption("bool", isRequired); + + // Array + if (schemaType === "array") { + const items = propSchema.items as JSONSchema7 | undefined; + if (items) { + const itemType = resolveRustType( + items, + parentTypeName, + `${jsonPropName}Item`, + true, + ctx, + ); + return wrapOption(`Vec<${itemType}>`, isRequired); + } + return wrapOption("Vec", isRequired); + } + + // Object + if (schemaType === "object" || (propSchema.properties && !schemaType)) { + if ( + propSchema.properties && + Object.keys(propSchema.properties).length > 0 + ) { + const structName = (propSchema.title as string) || nestedName; + emitRustStruct(structName, propSchema, ctx); + return wrapOption(structName, isRequired); + } + if (propSchema.additionalProperties) { + if ( + typeof propSchema.additionalProperties === "object" && + Object.keys(propSchema.additionalProperties as Record) + .length > 0 + ) { + const ap = propSchema.additionalProperties as JSONSchema7; + if (ap.type === "object" && ap.properties) { + const valueName = (ap.title as string) || `${nestedName}Value`; + emitRustStruct(valueName, ap, ctx); + return wrapOption(`HashMap`, isRequired); + } + const valueType = resolveRustType( + ap, + parentTypeName, + `${jsonPropName}Value`, + true, + ctx, + ); + return wrapOption(`HashMap`, isRequired); + } + return wrapOption("HashMap", isRequired); + } + return wrapOption("serde_json::Value", isRequired); + } + + // Fallback + return wrapOption("serde_json::Value", isRequired); +} + +function wrapOption(rustType: string, isRequired: boolean): string { + if (isRequired) return rustType; + // Don't double-wrap Option, Vec, or HashMap (they're already nullable-ish) + if ( + rustType.startsWith("Option<") || + rustType.startsWith("Vec<") || + rustType.startsWith("HashMap<") + ) { + return rustType; + } + return `Option<${rustType}>`; +} + +// ── Struct emission ───────────────────────────────────────────────────────── + +function emitRustStruct( + typeName: string, + schema: JSONSchema7, + ctx: RustCodegenCtx, + description?: string, +): void { + if (ctx.generatedNames.has(typeName)) return; + ctx.generatedNames.add(typeName); + + const required = new Set(schema.required || []); + const lines: string[] = []; + const desc = description || schema.description; + if (desc) { + for (const line of desc.split(/\r?\n/)) { + lines.push(`/// ${line}`); + } + } + if (isSchemaDeprecated(schema)) { + lines.push("#[deprecated]"); + } + lines.push("#[derive(Debug, Clone, Serialize, Deserialize)]"); + lines.push(`#[serde(rename_all = "camelCase")]`); + lines.push(`pub struct ${typeName} {`); + + for (const [propName, propSchema] of Object.entries( + schema.properties || {}, + )) { + if (typeof propSchema !== "object") continue; + const prop = propSchema as JSONSchema7; + const isReq = required.has(propName); + const rustField = safeRustFieldName(propName); + const rustType = resolveRustType(prop, typeName, propName, isReq, ctx); + + if (prop.description) { + for (const line of prop.description.split(/\r?\n/)) { + lines.push(` /// ${line}`); + } + } + if (isSchemaDeprecated(prop)) { + lines.push(" #[deprecated]"); + } + + // Determine if an explicit rename is needed. `rename_all = "camelCase"` on + // the struct converts snake_case fields to camelCase automatically, so we + // only need an explicit rename when that automatic conversion doesn't produce + // the original JSON property name. + const snakeField = toRustFieldName(propName); + const autoRename = snakeToCamelCase(snakeField); + const needsRename = autoRename !== propName; + const isOptionType = rustType.startsWith("Option<"); + const needsSkip = !isReq && isOptionType; + + if (needsSkip && needsRename) { + lines.push( + ` #[serde(rename = "${propName}", skip_serializing_if = "Option::is_none")]`, + ); + } else if (needsSkip) { + lines.push(` #[serde(skip_serializing_if = "Option::is_none")]`); + } else if (!isReq && !isOptionType && needsRename) { + lines.push(` #[serde(rename = "${propName}", default)]`); + } else if (!isReq && !isOptionType) { + lines.push(" #[serde(default)]"); + } else if (needsRename) { + lines.push(` #[serde(rename = "${propName}")]`); + } + + lines.push(` pub ${rustField}: ${rustType},`); + } + + lines.push("}"); + ctx.structs.push(lines.join("\n")); +} + +// ── Enum emission ─────────────────────────────────────────────────────────── + +function emitRustStringEnum( + enumName: string, + values: string[], + ctx: RustCodegenCtx, + description?: string, +): void { + if (ctx.generatedNames.has(enumName)) return; + ctx.generatedNames.add(enumName); + + const lines: string[] = []; + if (description) { + for (const line of description.split(/\r?\n/)) { + lines.push(`/// ${line}`); + } + } + lines.push("#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]"); + lines.push(`pub enum ${enumName} {`); + + for (const value of values) { + const variantName = toPascalCase(value); + if (variantName !== value) { + lines.push(` #[serde(rename = "${value}")]`); + } + lines.push(` ${variantName},`); + } + + // Add a catch-all for forward compatibility + lines.push(" /// Unknown variant for forward compatibility."); + lines.push(" #[serde(other)]"); + lines.push(" Unknown,"); + + lines.push("}"); + ctx.enums.push(lines.join("\n")); +} + +function emitRustConstStringEnum( + enumName: string, + value: string, + ctx: RustCodegenCtx, + description?: string, +): void { + if (ctx.generatedNames.has(enumName)) return; + ctx.generatedNames.add(enumName); + + const lines: string[] = []; + if (description) { + for (const line of description.split(/\r?\n/)) { + lines.push(`/// ${line}`); + } + } + lines.push("#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]"); + lines.push(`pub enum ${enumName} {`); + const variantName = toPascalCase(value); + if (variantName !== value) { + lines.push(` #[serde(rename = "${value}")]`); + } + lines.push(` ${variantName},`); + lines.push("}"); + ctx.enums.push(lines.join("\n")); +} + +// ── Session events generation ─────────────────────────────────────────────── + +interface EventVariant { + /** The event type string, e.g. "session.start" */ + typeName: string; + /** PascalCase variant name, e.g. "SessionStart" */ + variantName: string; + /** Data struct name, e.g. "SessionStartData" */ + dataClassName: string; + /** Schema for the data field */ + dataSchema: JSONSchema7; + /** Description of the event */ + description?: string; +} + +function extractEventVariants(schema: JSONSchema7): EventVariant[] { + const definitionCollections = collectDefinitionCollections( + schema as Record, + ); + const sessionEvent = + resolveSchema( + { $ref: "#/definitions/SessionEvent" }, + definitionCollections, + ) ?? resolveSchema({ $ref: "#/$defs/SessionEvent" }, definitionCollections); + if (!sessionEvent?.anyOf) + throw new Error("Schema must have SessionEvent definition with anyOf"); + + return (sessionEvent.anyOf as JSONSchema7[]) + .map((variant) => { + const resolvedVariant = + resolveObjectSchema(variant as JSONSchema7, definitionCollections) ?? + resolveSchema(variant as JSONSchema7, definitionCollections) ?? + (variant as JSONSchema7); + if (typeof resolvedVariant !== "object" || !resolvedVariant.properties) { + throw new Error("Invalid variant"); + } + const typeSchema = resolvedVariant.properties.type as JSONSchema7; + const typeName = typeSchema?.const as string; + if (!typeName) throw new Error("Variant must have type.const"); + + const dataSchema = + resolveObjectSchema( + resolvedVariant.properties.data as JSONSchema7, + definitionCollections, + ) ?? + resolveSchema( + resolvedVariant.properties.data as JSONSchema7, + definitionCollections, + ) ?? + ((resolvedVariant.properties.data as JSONSchema7) || {}); + + return { + typeName, + variantName: toPascalCase(typeName), + dataClassName: `${toPascalCase(typeName)}Data`, + dataSchema, + description: resolvedVariant.description || dataSchema.description, + }; + }) + .filter((v) => !EXCLUDED_EVENT_TYPES.has(v.typeName)); +} + +function generateSessionEventsCode(schema: JSONSchema7): string { + const variants = extractEventVariants(schema); + const ctx = makeCtx( + collectDefinitionCollections(schema as Record), + ); + + // Generate per-event data structs + for (const variant of variants) { + emitRustStruct( + variant.dataClassName, + variant.dataSchema, + ctx, + variant.description, + ); + } + + // Build the SessionEventType enum + const typeEnumLines: string[] = []; + typeEnumLines.push("/// Identifies the kind of session event."); + typeEnumLines.push( + "#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]", + ); + typeEnumLines.push("pub enum SessionEventType {"); + for (const variant of variants) { + typeEnumLines.push(` #[serde(rename = "${variant.typeName}")]`); + typeEnumLines.push(` ${variant.variantName},`); + } + typeEnumLines.push(" /// Unknown event type for forward compatibility."); + typeEnumLines.push(" #[serde(other)]"); + typeEnumLines.push(" Unknown,"); + typeEnumLines.push("}"); + + // Build the SessionEventData enum (adjacently tagged by type/data) + const dataEnumLines: string[] = []; + dataEnumLines.push( + "/// Typed session event data, discriminated by the event `type` field.", + ); + dataEnumLines.push("///"); + dataEnumLines.push( + "/// Use with [`TypedSessionEvent`] for fully typed event handling.", + ); + dataEnumLines.push("#[derive(Debug, Clone, Serialize, Deserialize)]"); + dataEnumLines.push(`#[serde(tag = "type", content = "data")]`); + dataEnumLines.push("pub enum SessionEventData {"); + for (const variant of variants) { + dataEnumLines.push(` #[serde(rename = "${variant.typeName}")]`); + dataEnumLines.push(` ${variant.variantName}(${variant.dataClassName}),`); + } + dataEnumLines.push("}"); + + // Build TypedSessionEvent that combines common fields with typed data + const typedEventLines: string[] = []; + typedEventLines.push("/// A session event with typed data payload."); + typedEventLines.push("///"); + typedEventLines.push( + "/// The common event fields (id, timestamp, parentId, ephemeral) are", + ); + typedEventLines.push( + "/// available directly. The event-specific data is in the `payload` field", + ); + typedEventLines.push("/// as a [`SessionEventData`] enum."); + typedEventLines.push("#[derive(Debug, Clone, Serialize, Deserialize)]"); + typedEventLines.push(`#[serde(rename_all = "camelCase")]`); + typedEventLines.push("pub struct TypedSessionEvent {"); + typedEventLines.push(" /// Unique event identifier (UUID v4)."); + typedEventLines.push(" pub id: String,"); + typedEventLines.push( + " /// ISO 8601 timestamp when the event was created.", + ); + typedEventLines.push(" pub timestamp: String,"); + typedEventLines.push(" /// ID of the preceding event in the chain."); + typedEventLines.push(` #[serde(skip_serializing_if = "Option::is_none")]`); + typedEventLines.push(" pub parent_id: Option,"); + typedEventLines.push( + " /// When true, the event is transient and not persisted.", + ); + typedEventLines.push(` #[serde(skip_serializing_if = "Option::is_none")]`); + typedEventLines.push(" pub ephemeral: Option,"); + typedEventLines.push( + " /// The typed event payload (discriminated by event type).", + ); + typedEventLines.push(" #[serde(flatten)]"); + typedEventLines.push(" pub payload: SessionEventData,"); + typedEventLines.push("}"); + + // Assemble file + const out: string[] = []; + out.push( + "//! Auto-generated from session-events.schema.json — do not edit manually.", + ); + out.push(""); + out.push("use std::collections::HashMap;"); + out.push(""); + out.push("use serde::{Deserialize, Serialize};"); + out.push(""); + out.push("use crate::types::{RequestId, SessionId};"); + out.push(""); + + // SessionEventType enum + out.push(typeEnumLines.join("\n")); + out.push(""); + + // SessionEventData enum + out.push(dataEnumLines.join("\n")); + out.push(""); + + // TypedSessionEvent struct + out.push(typedEventLines.join("\n")); + out.push(""); + + // Per-event data structs + for (const block of ctx.structs) { + out.push(block); + out.push(""); + } + + // Supporting enums + for (const block of ctx.enums) { + out.push(block); + out.push(""); + } + + return out.join("\n"); +} + +// ── API types generation ──────────────────────────────────────────────────── + +function collectRpcMethods( + node: Record, + prefix = "", +): RpcMethod[] { + const methods: RpcMethod[] = []; + for (const [key, value] of Object.entries(node)) { + if (isRpcMethod(value)) { + methods.push(value); + } else if (typeof value === "object" && value !== null) { + methods.push( + ...collectRpcMethods( + value as Record, + prefix ? `${prefix}.${key}` : key, + ), + ); + } + } + return methods; +} + +function rustParamsTypeName(method: RpcMethod): string { + return getRpcSchemaTypeName( + method.params, + `${toPascalCase(method.rpcMethod)}Params`, + ); +} + +function rustResultTypeName(method: RpcMethod): string { + return getRpcSchemaTypeName( + method.result, + `${toPascalCase(method.rpcMethod)}Result`, + ); +} + +function generateApiTypesCode(apiSchema: ApiSchema): string { + const definitions = collectDefinitions(apiSchema as Record); + const defCollections = collectDefinitionCollections( + apiSchema as Record, + ); + const ctx = makeCtx(defCollections); + + // Generate shared definitions (structs & enums) + for (const [name, def] of Object.entries(definitions)) { + if (typeof def !== "object" || def === null) continue; + const schema = def as JSONSchema7; + + if (schema.enum && Array.isArray(schema.enum)) { + emitRustStringEnum( + name, + schema.enum as string[], + ctx, + schema.description, + ); + } else if (isObjectSchema(schema)) { + emitRustStruct(name, schema, ctx, schema.description); + } + } + + // Collect all RPC methods and generate request/response types + const allMethods: RpcMethod[] = []; + for (const group of [ + apiSchema.server, + apiSchema.session, + apiSchema.clientSession, + ]) { + if (group) { + allMethods.push(...collectRpcMethods(group as Record)); + } + } + + // RPC method name constants + const methodConstLines: string[] = []; + methodConstLines.push("/// JSON-RPC method name constants."); + methodConstLines.push("pub mod rpc_methods {"); + + for (const method of allMethods) { + const constName = method.rpcMethod.replace(/\./g, "_").toUpperCase(); + methodConstLines.push(` /// \`${method.rpcMethod}\``); + methodConstLines.push( + ` pub const ${constName}: &str = "${method.rpcMethod}";`, + ); + } + methodConstLines.push("}"); + + // Generate param/result types for each method + for (const method of allMethods) { + if ( + method.params && + isObjectSchema(method.params) && + !isVoidSchema(method.params) + ) { + const paramsName = rustParamsTypeName(method); + emitRustStruct(paramsName, method.params, ctx, method.params.description); + } + if (method.result && !isVoidSchema(method.result)) { + const resultName = rustResultTypeName(method); + const resolved = resolveSchema(method.result, defCollections); + if (resolved) { + if (resolved.enum && Array.isArray(resolved.enum)) { + // Already generated from definitions + } else if (isObjectSchema(resolved)) { + emitRustStruct(resultName, resolved, ctx, resolved.description); + } + } + } + } + + // Assemble file + const out: string[] = []; + out.push("//! Auto-generated from api.schema.json — do not edit manually."); + out.push(""); + out.push("#![allow(clippy::large_enum_variant)]"); + out.push(""); + out.push("use std::collections::HashMap;"); + out.push(""); + out.push("use serde::{Deserialize, Serialize};"); + out.push(""); + out.push("use crate::types::{RequestId, SessionId};"); + out.push(""); + + // Method constants + out.push(methodConstLines.join("\n")); + out.push(""); + + // Shared definition types first, then RPC types + for (const block of ctx.structs) { + out.push(block); + out.push(""); + } + + for (const block of ctx.enums) { + out.push(block); + out.push(""); + } + + return out.join("\n"); +} + +// ── Typed RPC namespace generation ────────────────────────────────────────── + +interface NamespaceNode { + name: string; + typeName: string; + methods: RpcMethod[]; + children: Map; +} + +function newNamespaceNode(name: string, typeName: string): NamespaceNode { + return { name, typeName, methods: [], children: new Map() }; +} + +/** + * Build a namespace tree from a list of methods. `groupOf(method)` returns the + * dotted group path (e.g. "mcp.config" for "mcp.config.list" / "workspaces" + * for "workspaces.listFiles"); the last segment of `rpcMethod` is the leaf + * method name. + */ +function buildNamespaceTree( + rootTypeName: string, + methods: RpcMethod[], + stripPrefix: string, +): NamespaceNode { + const root = newNamespaceNode("", rootTypeName); + for (const method of methods) { + const trimmed = stripPrefix && method.rpcMethod.startsWith(stripPrefix) + ? method.rpcMethod.slice(stripPrefix.length) + : method.rpcMethod; + const segments = trimmed.split("."); + const groupSegments = segments.slice(0, -1); + let node = root; + for (const seg of groupSegments) { + let child = node.children.get(seg); + if (!child) { + const childTypeName = `${node.typeName}${toPascalCase(seg)}`; + child = newNamespaceNode(seg, childTypeName); + node.children.set(seg, child); + } + node = child; + } + node.methods.push(method); + } + return root; +} + +/** + * Determine if a method has typed params. Returns `{ hasParams, typeName }`. + * Handles `$ref`-based, title-bearing, and inline params uniformly: + * + * - Resolves `$ref` to its definition. + * - For session methods, ignores `sessionId` (the namespace injects it). + * - Returns `hasParams=false` when the resolved property set (after the + * sessionId filter for session methods) is empty. + * - The type name comes from `$ref` (preferred), then the resolved + * definition's `title`, then the inline params `title`. + */ +function getMethodParamsInfo( + method: RpcMethod, + defCollections: DefinitionCollections, + isSession: boolean, +): { hasParams: boolean; typeName: string | null } { + if (!method.params) return { hasParams: false, typeName: null }; + const inline = method.params as JSONSchema7 & { $ref?: string }; + const resolved = resolveSchema(inline, defCollections); + if (!resolved) return { hasParams: false, typeName: null }; + + let typeName: string | null = null; + if (typeof inline.$ref === "string") { + typeName = refTypeName(inline.$ref, defCollections); + } else if (typeof resolved.title === "string") { + typeName = resolved.title; + } else if (typeof inline.title === "string") { + typeName = inline.title; + } + + const allProps = Object.keys(resolved.properties || {}); + const props = isSession + ? allProps.filter((p) => p !== "sessionId") + : allProps; + if (props.length === 0) return { hasParams: false, typeName: null }; + if (!typeName) return { hasParams: false, typeName: null }; + return { hasParams: true, typeName }; +} + +function rpcMethodConstName(method: RpcMethod): string { + return method.rpcMethod.replace(/\./g, "_").toUpperCase(); +} + +function emitNamespaceStruct( + out: string[], + node: NamespaceNode, + holderType: string, + holderField: string, + isSession: boolean, + defCollections: DefinitionCollections, + docPrefix: string, +): void { + const lifetimes = "<'a>"; + out.push(`/// ${docPrefix}`); + out.push(`#[derive(Clone, Copy)]`); + out.push(`pub struct ${node.typeName}${lifetimes} {`); + out.push(` pub(crate) ${holderField}: &'a ${holderType},`); + out.push(`}`); + out.push(""); + + out.push(`impl${lifetimes} ${node.typeName}${lifetimes} {`); + + // Sub-namespace accessors + const childNames = Array.from(node.children.keys()).sort(); + for (const childName of childNames) { + const child = node.children.get(childName)!; + const accessor = toSnakeCase(childName); + const desc = isSession + ? `\`session.${accessorPath(node, childName, isSession)}.*\`` + : `\`${accessorPath(node, childName, isSession)}.*\``; + out.push(` /// ${desc} sub-namespace.`); + out.push( + ` pub fn ${accessor}(&self) -> ${child.typeName}<'a> {`, + ); + out.push(` ${child.typeName} { ${holderField}: self.${holderField} }`); + out.push(` }`); + out.push(""); + } + + // Leaf methods + for (const method of node.methods) { + emitNamespaceMethod(out, method, holderField, isSession, defCollections); + } + + out.push(`}`); + out.push(""); + + // Recursively emit child structs + for (const childName of childNames) { + const child = node.children.get(childName)!; + const childDoc = isSession + ? `\`session.${accessorPath(node, childName, isSession)}.*\` RPCs.` + : `\`${accessorPath(node, childName, isSession)}.*\` RPCs.`; + emitNamespaceStruct( + out, + child, + holderType, + holderField, + isSession, + defCollections, + childDoc, + ); + } +} + +function accessorPath(parent: NamespaceNode, child: string, _isSession: boolean): string { + // Build wire-style dotted path from the namespace tree's "name" chain plus child. + // `parent.name === ""` for root; we accumulate by retrieving parent name only. + // (We don't track full ancestry here; this is just for doc strings — we + // fall back to the child name alone when at the root.) + if (!parent.name) return child; + return `${parent.name}.${child}`; +} + +function getResultTypeName( + method: RpcMethod, + defCollections: DefinitionCollections, +): string | null { + const result = method.result as (JSONSchema7 & { $ref?: string }) | null; + if (!result || isVoidSchema(result)) return null; + if (typeof result.$ref === "string") { + return refTypeName(result.$ref, defCollections); + } + if (typeof result.title === "string") return result.title; + return `${toPascalCase(method.rpcMethod)}Result`; +} + +function emitNamespaceMethod( + out: string[], + method: RpcMethod, + holderField: string, + isSession: boolean, + defCollections: DefinitionCollections, +): void { + const wireMethod = method.rpcMethod; + const constName = rpcMethodConstName(method); + const lastSegment = wireMethod.split(".").pop()!; + const fnName = toSnakeCase(lastSegment); + + const paramsInfo = getMethodParamsInfo(method, defCollections, isSession); + const hasParams = paramsInfo.hasParams; + const paramsTypeName = paramsInfo.typeName; + + const resultTypeName = getResultTypeName(method, defCollections); + const returnType = resultTypeName ? resultTypeName : "()"; + const resultIsVoid = resultTypeName === null; + + const docs: string[] = []; + docs.push(` /// Wire method: \`${wireMethod}\`.`); + if (method.deprecated) docs.push(` #[deprecated]`); + const stability = method.stability; + if (stability && stability !== "stable") { + docs.push(` /// Stability: \`${stability}\`.`); + } + + const paramArg = hasParams ? `, params: ${paramsTypeName}` : ""; + + out.push(...docs); + out.push( + ` pub async fn ${fnName}(&self${paramArg}) -> Result<${returnType}, Error> {`, + ); + + // Build the params Value sent over the wire. + if (isSession) { + if (hasParams) { + out.push(` let mut wire_params = serde_json::to_value(params)?;`); + out.push( + ` wire_params["sessionId"] = serde_json::Value::String(self.session.id().to_string());`, + ); + } else { + out.push( + ` let wire_params = serde_json::json!({ "sessionId": self.session.id() });`, + ); + } + out.push( + ` let _value = self.session.client().call(rpc_methods::${constName}, Some(wire_params)).await?;`, + ); + } else { + if (hasParams) { + out.push(` let wire_params = serde_json::to_value(params)?;`); + } else { + out.push(` let wire_params = serde_json::json!({});`); + } + out.push( + ` let _value = self.client.call(rpc_methods::${constName}, Some(wire_params)).await?;`, + ); + } + + if (resultIsVoid) { + out.push(` Ok(())`); + } else { + out.push(` Ok(serde_json::from_value(_value)?)`); + } + out.push(` }`); + out.push(""); +} + +function generateRpcCode(apiSchema: ApiSchema): string { + const defCollections = collectDefinitionCollections( + apiSchema as unknown as Record, + ); + + const serverMethods = apiSchema.server + ? collectRpcMethods(apiSchema.server as Record) + : []; + const sessionMethods = apiSchema.session + ? collectRpcMethods(apiSchema.session as Record) + : []; + + const clientRoot = buildNamespaceTree("ClientRpc", serverMethods, ""); + const sessionRoot = buildNamespaceTree( + "SessionRpc", + sessionMethods, + "session.", + ); + + const out: string[] = []; + out.push( + "//! Auto-generated typed JSON-RPC namespace — do not edit manually.", + ); + out.push("//!"); + out.push( + "//! Generated from `api.schema.json` by `scripts/codegen/rust.ts`. The", + ); + out.push( + "//! [`ClientRpc`] and [`SessionRpc`] view structs let callers reach every", + ); + out.push( + "//! protocol method through a typed namespace tree, so wire method names", + ); + out.push( + "//! and request/response shapes live in exactly one place — this file.", + ); + out.push(""); + out.push("#![allow(missing_docs)]"); + out.push("#![allow(clippy::too_many_arguments)]"); + out.push(""); + out.push("use super::api_types::*;"); + out.push("use super::api_types::rpc_methods;"); + out.push("use crate::session::Session;"); + out.push("use crate::{Client, Error};"); + out.push(""); + + emitNamespaceStruct( + out, + clientRoot, + "Client", + "client", + false, + defCollections, + "Typed view over the [`Client`]'s server-level RPC namespace.", + ); + emitNamespaceStruct( + out, + sessionRoot, + "Session", + "session", + true, + defCollections, + "Typed view over a [`Session`]'s RPC namespace.", + ); + + return out.join("\n"); +} + +// ── mod.rs generation ─────────────────────────────────────────────────────── + +function generateModRs(): string { + const lines: string[] = []; + lines.push("//! Auto-generated protocol types — do not edit manually."); + lines.push("//!"); + lines.push( + "//! Generated from the Copilot protocol JSON Schemas by `scripts/codegen/rust.ts`.", + ); + lines.push("#![allow(missing_docs)]"); + lines.push("#![allow(rustdoc::bare_urls)]"); + lines.push(""); + lines.push("pub mod api_types;"); + lines.push("pub mod rpc;"); + lines.push("pub mod session_events;"); + lines.push(""); + lines.push( + "// Re-export session event types at the module root — no conflicts with", + ); + lines.push( + "// hand-written types. API types are kept namespaced under `api_types::`", + ); + lines.push( + "// because some names (Tool, ModelCapabilities, etc.) overlap with the", + ); + lines.push("// hand-written SDK API types in `types.rs`."); + lines.push("pub use session_events::*;"); + lines.push(""); + return lines.join("\n"); +} + +// ── Format with rustfmt ───────────────────────────────────────────────────── + +async function rustfmt(filePath: string): Promise { + try { + await execFileAsync("rustfmt", ["--edition", "2021", filePath]); + } catch (e: unknown) { + const error = e as { stderr?: string }; + console.warn( + `rustfmt warning for ${path.basename(filePath)}: ${error.stderr || e}`, + ); + } +} + +// ── Main ──────────────────────────────────────────────────────────────────── + +async function generate(): Promise { + console.log("Loading schemas..."); + + const sessionEventsSchemaPath = await getSessionEventsSchemaPath(); + const apiSchemaPath = await getApiSchemaPath(process.argv[2]); + + const sessionEventsRaw = JSON.parse( + await fs.readFile(sessionEventsSchemaPath, "utf-8"), + ); + const apiRaw = JSON.parse( + await fs.readFile(apiSchemaPath, "utf-8"), + ) as ApiSchema; + + const sessionEventsSchema = postProcessSchema( + sessionEventsRaw as JSONSchema7, + ); + const apiSchema = postProcessSchema( + apiRaw as JSONSchema7, + ) as unknown as ApiSchema; + + // Ensure output directory exists + await fs.mkdir(GENERATED_DIR, { recursive: true }); + + // Generate session events + console.log("Generating session_events.rs..."); + const sessionEventsCode = generateSessionEventsCode(sessionEventsSchema); + const sessionEventsPath = path.join(GENERATED_DIR, "session_events.rs"); + await fs.writeFile(sessionEventsPath, sessionEventsCode, "utf-8"); + await rustfmt(sessionEventsPath); + + // Generate API types + console.log("Generating api_types.rs..."); + const apiTypesCode = generateApiTypesCode(apiSchema); + const apiTypesPath = path.join(GENERATED_DIR, "api_types.rs"); + await fs.writeFile(apiTypesPath, apiTypesCode, "utf-8"); + await rustfmt(apiTypesPath); + + // Generate typed RPC namespace + console.log("Generating rpc.rs..."); + const rpcCode = generateRpcCode(apiSchema); + const rpcPath = path.join(GENERATED_DIR, "rpc.rs"); + await fs.writeFile(rpcPath, rpcCode, "utf-8"); + await rustfmt(rpcPath); + + // Generate mod.rs + console.log("Generating mod.rs..."); + const modRsCode = generateModRs(); + const modRsPath = path.join(GENERATED_DIR, "mod.rs"); + await fs.writeFile(modRsPath, modRsCode, "utf-8"); + await rustfmt(modRsPath); + + console.log(`Done! Generated files in ${GENERATED_DIR}`); +} + +generate().catch((err) => { + console.error("Code generation failed:", err); + process.exit(1); +}); diff --git a/test/scenarios/RUST_COVERAGE.md b/test/scenarios/RUST_COVERAGE.md new file mode 100644 index 000000000..f0c61979f --- /dev/null +++ b/test/scenarios/RUST_COVERAGE.md @@ -0,0 +1,61 @@ +# Rust scenario coverage + +Rust SDK scenario samples live alongside the TypeScript / Python / Go / C# samples under +`test/scenarios/*//rust/`. The monorepo's `scenario-builds.yml` workflow +auto-discovers any `*/rust/Cargo.toml` under `test/scenarios/` and verifies it builds. + +## Coverage + +| Category | Scenario | Status | +|----------------|-------------------------|--------| +| `transport/` | `stdio` | ✅ | +| `transport/` | `tcp` | ✅ | +| `transport/` | `external` | ❌ deferred (needs `from_streams`-style sample) | +| `sessions/` | `streaming` | ✅ | +| `sessions/` | `session-resume` | ✅ | +| `sessions/` | `infinite-sessions` | ✅ | +| `sessions/` | `concurrent-sessions` | ✅ | +| `sessions/` | `multi-user-*` | ❌ deferred (multi-client orchestration) | +| `modes/` | `default` | ✅ | +| `modes/` | non-default | ❌ deferred (plan mode, read-only) | +| `tools/` | `no-tools` | ✅ | +| `tools/` | `mcp-servers` | ✅ | +| `tools/` | `skills` | ✅ | +| `tools/` | `tool-filtering` | ✅ | +| `tools/` | `custom-agents` | ✅ | +| `tools/` | `tool-overrides` | ✅ | +| `tools/` | `virtual-filesystem` | ❌ deferred (needs `VirtualFilesystem` hook port) | +| `callbacks/` | `hooks` | ✅ | +| `callbacks/` | `permissions` | ✅ | +| `callbacks/` | `user-input` | ✅ | +| `prompts/` | `system-message` | ✅ | +| `prompts/` | `reasoning-effort` | ✅ | +| `prompts/` | `attachments` | ✅ | +| `bundling/` | * | ❌ app-level concern, not an SDK gap | +| `auth/` | * | ❌ deferred (GitHub-App / token-exchange) | + +## Remaining gaps + +- `transport/external` — needs a sample using an externally-managed CLI process (parity with Node's `from_streams`). +- `tools/virtual-filesystem` — depends on a future `VirtualFilesystem` hook port. +- `modes/*` (non-default) — plan-mode and read-only-mode samples. +- `sessions/multi-user-*` — multi-client orchestration. +- `auth/*` — GitHub-App / token-exchange sample programs. +- `bundling/*` — process bundling is application-level, not an SDK concern. + +## Running the samples locally + +Each scenario's `verify.sh` runs the Rust build + run phase alongside the other +languages. With a token in place (`GITHUB_TOKEN`, or `gh auth login`): + +```sh +cd test/scenarios/transport/stdio && ./verify.sh +``` + +To build all Rust scenario samples without running them (what CI does): + +```sh +for d in $(find test/scenarios -path '*/rust/Cargo.toml'); do + (cd "$(dirname "$d")" && cargo build --quiet) || echo "FAILED: $d" +done +``` diff --git a/test/scenarios/callbacks/hooks/rust/Cargo.toml b/test/scenarios/callbacks/hooks/rust/Cargo.toml new file mode 100644 index 000000000..4c16a91b5 --- /dev/null +++ b/test/scenarios/callbacks/hooks/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "hooks-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +async-trait = "0.1" +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } diff --git a/test/scenarios/callbacks/hooks/rust/src/main.rs b/test/scenarios/callbacks/hooks/rust/src/main.rs new file mode 100644 index 000000000..179765d2f --- /dev/null +++ b/test/scenarios/callbacks/hooks/rust/src/main.rs @@ -0,0 +1,131 @@ +//! Session hooks — intercept lifecycle events (session start/end, pre/post +//! tool use, user prompt, errors) and log every firing. + +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::hooks::{ + ErrorOccurredInput, ErrorOccurredOutput, HookContext, PostToolUseInput, PostToolUseOutput, + PreToolUseInput, PreToolUseOutput, SessionEndInput, SessionEndOutput, SessionHooks, + SessionStartInput, SessionStartOutput, UserPromptSubmittedInput, UserPromptSubmittedOutput, +}; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions}; +use tokio::sync::Mutex; + +struct HookLogger { + log: Arc>>, +} + +impl HookLogger { + async fn append(&self, entry: String) { + self.log.lock().await.push(entry); + } +} + +#[async_trait] +impl SessionHooks for HookLogger { + async fn on_session_start( + &self, + _input: SessionStartInput, + _ctx: HookContext, + ) -> Option { + self.append("onSessionStart".to_string()).await; + None + } + + async fn on_session_end( + &self, + _input: SessionEndInput, + _ctx: HookContext, + ) -> Option { + self.append("onSessionEnd".to_string()).await; + None + } + + async fn on_pre_tool_use( + &self, + input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + self.append(format!("onPreToolUse:{}", input.tool_name)) + .await; + let mut out = PreToolUseOutput::default(); + out.permission_decision = Some("allow".to_string()); + Some(out) + } + + async fn on_post_tool_use( + &self, + input: PostToolUseInput, + _ctx: HookContext, + ) -> Option { + self.append(format!("onPostToolUse:{}", input.tool_name)) + .await; + None + } + + async fn on_user_prompt_submitted( + &self, + input: UserPromptSubmittedInput, + _ctx: HookContext, + ) -> Option { + self.append("onUserPromptSubmitted".to_string()).await; + let mut out = UserPromptSubmittedOutput::default(); + out.modified_prompt = Some(input.prompt); + Some(out) + } + + async fn on_error_occurred( + &self, + input: ErrorOccurredInput, + _ctx: HookContext, + ) -> Option { + self.append(format!("onErrorOccurred:{}", input.error)) + .await; + None + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let hook_log = Arc::new(Mutex::new(Vec::::new())); + let hooks = Arc::new(HookLogger { + log: hook_log.clone(), + }); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + let config = config + .with_handler(Arc::new(ApproveAllHandler)) + .with_hooks(hooks); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait( + "List the files in the current directory using the glob tool with pattern '*.md'.", + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + println!("\n--- Hook execution log ---"); + let log = hook_log.lock().await; + for entry in log.iter() { + println!(" {entry}"); + } + println!("\nTotal hooks fired: {}", log.len()); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/callbacks/hooks/verify.sh b/test/scenarios/callbacks/hooks/verify.sh index 8157fed78..e6f706e61 100755 --- a/test/scenarios/callbacks/hooks/verify.sh +++ b/test/scenarios/callbacks/hooks/verify.sh @@ -120,6 +120,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o hooks-go . 2>&1" # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -137,6 +139,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./hooks-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/callbacks/permissions/rust/Cargo.toml b/test/scenarios/callbacks/permissions/rust/Cargo.toml new file mode 100644 index 000000000..a30a94162 --- /dev/null +++ b/test/scenarios/callbacks/permissions/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "permissions-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +async-trait = "0.1" +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } diff --git a/test/scenarios/callbacks/permissions/rust/src/main.rs b/test/scenarios/callbacks/permissions/rust/src/main.rs new file mode 100644 index 000000000..214620e35 --- /dev/null +++ b/test/scenarios/callbacks/permissions/rust/src/main.rs @@ -0,0 +1,91 @@ +//! Permission callback — log every `permission.request` from the CLI and +//! approve all of them. + +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{PermissionResult, SessionHandler}; +use github_copilot_sdk::hooks::{HookContext, PreToolUseInput, PreToolUseOutput, SessionHooks}; +use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionConfig, SessionId}; +use github_copilot_sdk::{Client, ClientOptions}; +use tokio::sync::Mutex; + +struct PermissionLogger { + log: Arc>>, +} + +#[async_trait] +impl SessionHandler for PermissionLogger { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + data: PermissionRequestData, + ) -> PermissionResult { + let tool_name = data + .extra + .get("tool") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + self.log.lock().await.push(format!("approved:{tool_name}")); + PermissionResult::Approved + } +} + +struct AllowAllHooks; + +#[async_trait] +impl SessionHooks for AllowAllHooks { + async fn on_pre_tool_use( + &self, + _input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + let mut out = PreToolUseOutput::default(); + out.permission_decision = Some("allow".to_string()); + Some(out) + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let permission_log = Arc::new(Mutex::new(Vec::::new())); + let handler = Arc::new(PermissionLogger { + log: permission_log.clone(), + }); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + let config = config + .with_handler(handler) + .with_hooks(Arc::new(AllowAllHooks)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait( + "List the files in the current directory using glob with pattern '*.md'.", + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + println!("\n--- Permission request log ---"); + let log = permission_log.lock().await; + for entry in log.iter() { + println!(" {entry}"); + } + println!("\nTotal permission requests: {}", log.len()); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/callbacks/permissions/verify.sh b/test/scenarios/callbacks/permissions/verify.sh index bc4af1f6a..e63438a6e 100755 --- a/test/scenarios/callbacks/permissions/verify.sh +++ b/test/scenarios/callbacks/permissions/verify.sh @@ -114,6 +114,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o permissions-go . # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -131,6 +133,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./permissions-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/callbacks/user-input/rust/Cargo.toml b/test/scenarios/callbacks/user-input/rust/Cargo.toml new file mode 100644 index 000000000..83430f128 --- /dev/null +++ b/test/scenarios/callbacks/user-input/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "user-input-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +async-trait = "0.1" +tokio = { version = "1", features = ["macros", "rt-multi-thread", "sync"] } diff --git a/test/scenarios/callbacks/user-input/rust/src/main.rs b/test/scenarios/callbacks/user-input/rust/src/main.rs new file mode 100644 index 000000000..b7fea906e --- /dev/null +++ b/test/scenarios/callbacks/user-input/rust/src/main.rs @@ -0,0 +1,103 @@ +//! User-input callback — answer the agent's `ask_user` prompts and log +//! every question. + +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{PermissionResult, SessionHandler, UserInputResponse}; +use github_copilot_sdk::hooks::{HookContext, PreToolUseInput, PreToolUseOutput, SessionHooks}; +use github_copilot_sdk::types::{PermissionRequestData, RequestId, SessionConfig, SessionId}; +use github_copilot_sdk::{Client, ClientOptions}; +use tokio::sync::Mutex; + +struct InputResponder { + log: Arc>>, +} + +#[async_trait] +impl SessionHandler for InputResponder { + async fn on_permission_request( + &self, + _session_id: SessionId, + _request_id: RequestId, + _data: PermissionRequestData, + ) -> PermissionResult { + PermissionResult::Approved + } + + async fn on_user_input( + &self, + _session_id: SessionId, + question: String, + _choices: Option>, + _allow_freeform: Option, + ) -> Option { + self.log + .lock() + .await + .push(format!("question: {question}")); + Some(UserInputResponse { + answer: "Paris".to_string(), + was_freeform: true, + }) + } +} + +struct AllowAllHooks; + +#[async_trait] +impl SessionHooks for AllowAllHooks { + async fn on_pre_tool_use( + &self, + _input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + let mut out = PreToolUseOutput::default(); + out.permission_decision = Some("allow".to_string()); + Some(out) + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let input_log = Arc::new(Mutex::new(Vec::::new())); + let handler = Arc::new(InputResponder { + log: input_log.clone(), + }); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.request_user_input = Some(true); + let config = config + .with_handler(handler) + .with_hooks(Arc::new(AllowAllHooks)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait( + "I want to learn about a city. Use the ask_user tool to ask me \ + which city I'm interested in. Then tell me about that city.", + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + println!("\n--- User input log ---"); + let log = input_log.lock().await; + for entry in log.iter() { + println!(" {entry}"); + } + println!("\nTotal user input requests: {}", log.len()); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/callbacks/user-input/verify.sh b/test/scenarios/callbacks/user-input/verify.sh index 4550a4c1f..5e35eb67c 100755 --- a/test/scenarios/callbacks/user-input/verify.sh +++ b/test/scenarios/callbacks/user-input/verify.sh @@ -114,6 +114,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o user-input-go . 2 # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -131,6 +133,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./user-input-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/modes/default/rust/Cargo.toml b/test/scenarios/modes/default/rust/Cargo.toml new file mode 100644 index 000000000..d3483ec64 --- /dev/null +++ b/test/scenarios/modes/default/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "default-mode-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/modes/default/rust/src/main.rs b/test/scenarios/modes/default/rust/src/main.rs new file mode 100644 index 000000000..ba890997d --- /dev/null +++ b/test/scenarios/modes/default/rust/src/main.rs @@ -0,0 +1,36 @@ +//! Default agent mode — the agent has access to built-in tools (grep, view, etc.) +//! and can use them to complete a task. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + let session = client.create_session(config).await?; + + let response = session + .send_and_wait( + "Use the grep tool to search for the word 'SDK' in README.md and show the matching lines.", + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("Response: {content}"); + } + } + + println!("Default mode test complete"); + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/modes/default/verify.sh b/test/scenarios/modes/default/verify.sh index 9d9b78578..e8811d0d9 100755 --- a/test/scenarios/modes/default/verify.sh +++ b/test/scenarios/modes/default/verify.sh @@ -107,6 +107,9 @@ check "Python (syntax)" bash -c "python3 -c \"import ast; ast.parse(open('$SCRI # Go: build check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o default-go . 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" + # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" @@ -125,6 +128,9 @@ run_with_timeout "Python (run)" bash -c "cd '$SCRIPT_DIR/python' && python3 main # Go: run run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./default-go" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" + # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" diff --git a/test/scenarios/prompts/attachments/README.md b/test/scenarios/prompts/attachments/README.md index 76b76751d..2bdb551fb 100644 --- a/test/scenarios/prompts/attachments/README.md +++ b/test/scenarios/prompts/attachments/README.md @@ -34,12 +34,14 @@ Demonstrates sending **file attachments** alongside a prompt using the Copilot S | TypeScript | `attachments: [{ type: "file", path: sampleFile }]` | | Python | `"attachments": [{"type": "file", "path": sample_file}]` | | Go | `Attachments: []copilot.Attachment{{Type: "file", Path: sampleFile}}` | +| Rust | `Attachment::File { path, display_name: None, line_range: None }` | | Language | Blob Attachment Syntax | |----------|------------------------| | TypeScript | `attachments: [{ type: "blob", data: base64Data, mimeType: "image/png" }]` | | Python | `"attachments": [{"type": "blob", "data": base64_data, "mimeType": "image/png"}]` | | Go | `Attachments: []copilot.Attachment{{Type: copilot.AttachmentTypeBlob, Data: &data, MIMEType: &mime}}` | +| Rust | `Attachment::Blob { data, mime_type, display_name: None }` | ## Sample Data diff --git a/test/scenarios/prompts/attachments/rust/Cargo.toml b/test/scenarios/prompts/attachments/rust/Cargo.toml new file mode 100644 index 000000000..e87952f14 --- /dev/null +++ b/test/scenarios/prompts/attachments/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "attachments-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/prompts/attachments/rust/src/main.rs b/test/scenarios/prompts/attachments/rust/src/main.rs new file mode 100644 index 000000000..9ba9cc176 --- /dev/null +++ b/test/scenarios/prompts/attachments/rust/src/main.rs @@ -0,0 +1,58 @@ +//! File attachments — send a prompt alongside a file attachment so the +//! model can reference the file's content in its response. + +use std::path::PathBuf; +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{Attachment, MessageOptions, SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +const SYSTEM_PROMPT: &str = + "You are a helpful assistant. Answer questions about attached files concisely."; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some(SYSTEM_PROMPT.to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(Vec::new()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + // CARGO_MANIFEST_DIR resolves to .../prompts/attachments/rust at compile time. + let sample_file: PathBuf = [env!("CARGO_MANIFEST_DIR"), "..", "sample-data.txt"] + .iter() + .collect(); + let sample_file = sample_file.canonicalize().unwrap_or(sample_file); + + let response = session + .send_and_wait( + MessageOptions::new("What languages are listed in the attached file?").with_attachments( + vec![Attachment::File { + path: sample_file, + display_name: None, + line_range: None, + }], + ), + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/prompts/attachments/verify.sh b/test/scenarios/prompts/attachments/verify.sh index cf4a91977..41b4f108c 100755 --- a/test/scenarios/prompts/attachments/verify.sh +++ b/test/scenarios/prompts/attachments/verify.sh @@ -110,6 +110,9 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o attachments-go . # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" + echo "══════════════════════════════════════" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" echo "══════════════════════════════════════" @@ -127,6 +130,9 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./attachments-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" + echo "══════════════════════════════════════" echo " Results: $PASS passed, $FAIL failed" echo "══════════════════════════════════════" diff --git a/test/scenarios/prompts/reasoning-effort/rust/Cargo.toml b/test/scenarios/prompts/reasoning-effort/rust/Cargo.toml new file mode 100644 index 000000000..c48db3c98 --- /dev/null +++ b/test/scenarios/prompts/reasoning-effort/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "reasoning-effort-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/prompts/reasoning-effort/rust/src/main.rs b/test/scenarios/prompts/reasoning-effort/rust/src/main.rs new file mode 100644 index 000000000..bf1ab9720 --- /dev/null +++ b/test/scenarios/prompts/reasoning-effort/rust/src/main.rs @@ -0,0 +1,40 @@ +//! Reasoning effort — set the model's reasoning depth via +//! `SessionConfig::reasoning_effort`. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some("You are a helpful assistant. Answer concisely.".to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-opus-4.6".to_string()); + config.reasoning_effort = Some("low".to_string()); + config.available_tools = Some(Vec::new()); + config.system_message = Some(sysmsg); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("Reasoning effort: low"); + println!("Response: {content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/prompts/reasoning-effort/verify.sh b/test/scenarios/prompts/reasoning-effort/verify.sh index fe528229e..4d32e4d87 100755 --- a/test/scenarios/prompts/reasoning-effort/verify.sh +++ b/test/scenarios/prompts/reasoning-effort/verify.sh @@ -110,6 +110,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o reasoning-effort- # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -127,6 +129,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./reasoning-effort-g # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/prompts/system-message/rust/Cargo.toml b/test/scenarios/prompts/system-message/rust/Cargo.toml new file mode 100644 index 000000000..0d153f9cc --- /dev/null +++ b/test/scenarios/prompts/system-message/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "system-message-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/prompts/system-message/rust/src/main.rs b/test/scenarios/prompts/system-message/rust/src/main.rs new file mode 100644 index 000000000..4218a389b --- /dev/null +++ b/test/scenarios/prompts/system-message/rust/src/main.rs @@ -0,0 +1,40 @@ +//! Custom system message — replace the built-in prompt entirely. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +const PIRATE_PROMPT: &str = "You are a pirate. Always respond in pirate speak. Say 'Arrr!' \ +in every response. Use nautical terms and pirate slang throughout."; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some(PIRATE_PROMPT.to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(Vec::new()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/prompts/system-message/verify.sh b/test/scenarios/prompts/system-message/verify.sh index c2699768b..d1f60e5c4 100755 --- a/test/scenarios/prompts/system-message/verify.sh +++ b/test/scenarios/prompts/system-message/verify.sh @@ -109,6 +109,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o system-message-go # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" @@ -127,6 +129,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./system-message-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" diff --git a/test/scenarios/sessions/concurrent-sessions/rust/Cargo.toml b/test/scenarios/sessions/concurrent-sessions/rust/Cargo.toml new file mode 100644 index 000000000..a6de4e273 --- /dev/null +++ b/test/scenarios/sessions/concurrent-sessions/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "concurrent-sessions-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs b/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs new file mode 100644 index 000000000..43932b613 --- /dev/null +++ b/test/scenarios/sessions/concurrent-sessions/rust/src/main.rs @@ -0,0 +1,53 @@ +//! Concurrent sessions — two sessions on a single client running in +//! parallel with different system prompts. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +const PIRATE_PROMPT: &str = "You are a pirate. Always say Arrr!"; +const ROBOT_PROMPT: &str = "You are a robot. Always say BEEP BOOP!"; + +fn make_config(system: &str) -> SessionConfig { + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some(system.to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(Vec::new()); + config.with_handler(Arc::new(ApproveAllHandler)) +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let session1 = client.create_session(make_config(PIRATE_PROMPT)).await?; + let session2 = client.create_session(make_config(ROBOT_PROMPT)).await?; + + let (r1, r2) = tokio::join!( + session1.send_and_wait("What is the capital of France?"), + session2.send_and_wait("What is the capital of France?"), + ); + + if let Some(event) = r1? { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("Session 1 (pirate): {content}"); + } + } + if let Some(event) = r2? { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("Session 2 (robot): {content}"); + } + } + + session1.destroy().await?; + session2.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/sessions/concurrent-sessions/verify.sh b/test/scenarios/sessions/concurrent-sessions/verify.sh index be4e3d309..25e6fab18 100755 --- a/test/scenarios/sessions/concurrent-sessions/verify.sh +++ b/test/scenarios/sessions/concurrent-sessions/verify.sh @@ -138,6 +138,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o concurrent-sessio # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -155,6 +157,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./concurrent-session # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/sessions/infinite-sessions/rust/Cargo.toml b/test/scenarios/sessions/infinite-sessions/rust/Cargo.toml new file mode 100644 index 000000000..1f23af8a6 --- /dev/null +++ b/test/scenarios/sessions/infinite-sessions/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "infinite-sessions-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/sessions/infinite-sessions/rust/src/main.rs b/test/scenarios/sessions/infinite-sessions/rust/src/main.rs new file mode 100644 index 000000000..0c0f06814 --- /dev/null +++ b/test/scenarios/sessions/infinite-sessions/rust/src/main.rs @@ -0,0 +1,55 @@ +//! Infinite sessions — explicit `InfiniteSessionConfig` thresholds and a +//! sequence of three turns to exercise the persistent workspace. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{InfiniteSessionConfig, SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = + Some("You are a helpful assistant. Answer concisely in one sentence.".to_string()); + + let mut infinite = InfiniteSessionConfig::default(); + infinite.enabled = Some(true); + infinite.background_compaction_threshold = Some(0.80); + infinite.buffer_exhaustion_threshold = Some(0.95); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.available_tools = Some(Vec::new()); + config.system_message = Some(sysmsg); + config.infinite_sessions = Some(infinite); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let prompts = [ + "What is the capital of France?", + "What is the capital of Japan?", + "What is the capital of Brazil?", + ]; + + for prompt in prompts { + let response = session.send_and_wait(prompt).await?; + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("Q: {prompt}"); + println!("A: {content}\n"); + } + } + } + + println!("Infinite sessions test complete — all messages processed successfully"); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/sessions/infinite-sessions/verify.sh b/test/scenarios/sessions/infinite-sessions/verify.sh index fe4de01e4..367901f28 100755 --- a/test/scenarios/sessions/infinite-sessions/verify.sh +++ b/test/scenarios/sessions/infinite-sessions/verify.sh @@ -116,6 +116,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o infinite-sessions # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -133,6 +135,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./infinite-sessions- # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/sessions/session-resume/rust/Cargo.toml b/test/scenarios/sessions/session-resume/rust/Cargo.toml new file mode 100644 index 000000000..ed6207260 --- /dev/null +++ b/test/scenarios/sessions/session-resume/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "session-resume-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/sessions/session-resume/rust/src/main.rs b/test/scenarios/sessions/session-resume/rust/src/main.rs new file mode 100644 index 000000000..10cd4fa62 --- /dev/null +++ b/test/scenarios/sessions/session-resume/rust/src/main.rs @@ -0,0 +1,46 @@ +//! Session resume — create a session, plant a memory, then resume by ID +//! and verify the agent recalls it. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{ResumeSessionConfig, SessionConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.available_tools = Some(Vec::new()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + let session = client.create_session(config).await?; + + session + .send_and_wait("Remember this: the secret word is PINEAPPLE.") + .await?; + + let session_id = session.id().clone(); + // Note: do NOT destroy — `resume_session` needs the session to persist. + + let resume_config = + ResumeSessionConfig::new(session_id).with_handler(Arc::new(ApproveAllHandler)); + let resumed = client.resume_session(resume_config).await?; + println!("Session resumed"); + + let response = resumed + .send_and_wait("What was the secret word I told you?") + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + resumed.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/sessions/session-resume/verify.sh b/test/scenarios/sessions/session-resume/verify.sh index 02cc14d5a..07a5992e9 100755 --- a/test/scenarios/sessions/session-resume/verify.sh +++ b/test/scenarios/sessions/session-resume/verify.sh @@ -117,6 +117,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o session-resume-go # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" @@ -135,6 +137,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./session-resume-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" diff --git a/test/scenarios/sessions/streaming/rust/Cargo.toml b/test/scenarios/sessions/streaming/rust/Cargo.toml new file mode 100644 index 000000000..31acc381b --- /dev/null +++ b/test/scenarios/sessions/streaming/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "streaming-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +async-trait = "0.1" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/sessions/streaming/rust/src/main.rs b/test/scenarios/sessions/streaming/rust/src/main.rs new file mode 100644 index 000000000..f5cf23764 --- /dev/null +++ b/test/scenarios/sessions/streaming/rust/src/main.rs @@ -0,0 +1,66 @@ +//! Streaming session — count `assistant.message_delta` events while waiting +//! for the final response. + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use async_trait::async_trait; +use github_copilot_sdk::handler::{HandlerEvent, HandlerResponse, PermissionResult, SessionHandler}; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions}; + +struct StreamCounter { + chunks: Arc, +} + +#[async_trait] +impl SessionHandler for StreamCounter { + async fn on_event(&self, event: HandlerEvent) -> HandlerResponse { + match event { + HandlerEvent::SessionEvent { event, .. } => { + if event.event_type == "assistant.message_delta" { + self.chunks.fetch_add(1, Ordering::Relaxed); + } + HandlerResponse::Ok + } + HandlerEvent::PermissionRequest { .. } => { + HandlerResponse::Permission(PermissionResult::Approved) + } + _ => HandlerResponse::Ok, + } + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let chunks = Arc::new(AtomicUsize::new(0)); + let handler = Arc::new(StreamCounter { + chunks: chunks.clone(), + }); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.streaming = Some(true); + let config = config.with_handler(handler); + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + println!( + "\nStreaming chunks received: {}", + chunks.load(Ordering::Relaxed) + ); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/sessions/streaming/verify.sh b/test/scenarios/sessions/streaming/verify.sh index 070ef059b..828f42a43 100755 --- a/test/scenarios/sessions/streaming/verify.sh +++ b/test/scenarios/sessions/streaming/verify.sh @@ -114,6 +114,9 @@ check "Python (syntax)" bash -c "python3 -c \"import ast; ast.parse(open('$SCRI # Go: build check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o streaming-go . 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" + # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" @@ -132,6 +135,9 @@ run_with_timeout "Python (run)" bash -c "cd '$SCRIPT_DIR/python' && python3 main # Go: run run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./streaming-go" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" + # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" diff --git a/test/scenarios/tools/custom-agents/rust/Cargo.toml b/test/scenarios/tools/custom-agents/rust/Cargo.toml new file mode 100644 index 000000000..6d536052c --- /dev/null +++ b/test/scenarios/tools/custom-agents/rust/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "custom-agents-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust", features = ["derive"] } +schemars = "1" +serde = { version = "1", features = ["derive"] } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/custom-agents/rust/src/main.rs b/test/scenarios/tools/custom-agents/rust/src/main.rs new file mode 100644 index 000000000..e707770bc --- /dev/null +++ b/test/scenarios/tools/custom-agents/rust/src/main.rs @@ -0,0 +1,82 @@ +//! Custom agents — define a sub-agent ("researcher") with its own prompt +//! and tool allowlist, alongside a client-defined `analyze-codebase` tool. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ToolHandlerRouter, define_tool}; +use github_copilot_sdk::types::{CustomAgentConfig, DefaultAgentConfig, SessionConfig, ToolResult}; +use github_copilot_sdk::{Client, ClientOptions}; +use schemars::JsonSchema; +use serde::Deserialize; + +#[derive(Deserialize, JsonSchema)] +#[schemars(description = "Parameters for analyze-codebase")] +struct AnalyzeParams { + /// the analysis query + query: String, +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let analyze_codebase = define_tool( + "analyze-codebase", + "Performs deep analysis of the codebase", + |_inv, params: AnalyzeParams| async move { + Ok(ToolResult::Text(format!( + "Analysis result for: {}", + params.query + ))) + }, + ); + + let router = ToolHandlerRouter::new(vec![analyze_codebase], Arc::new(ApproveAllHandler)); + let tools = router.tools(); + + let mut researcher = CustomAgentConfig::default(); + researcher.name = "researcher".to_string(); + researcher.display_name = Some("Research Agent".to_string()); + researcher.description = Some( + "A research agent that can only read and search files, not modify them".to_string(), + ); + researcher.tools = Some(vec![ + "grep".to_string(), + "glob".to_string(), + "view".to_string(), + "analyze-codebase".to_string(), + ]); + researcher.prompt = + "You are a research assistant. You can search and read files but cannot modify \ + anything. When asked about your capabilities, list the tools you have access to." + .to_string(); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.tools = Some(tools); + config.default_agent = Some(DefaultAgentConfig { + excluded_tools: Some(vec!["analyze-codebase".to_string()]), + }); + config.custom_agents = Some(vec![researcher]); + let config = config.with_handler(Arc::new(router)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait( + "What custom agents are available? Describe the researcher agent and its capabilities.", + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/custom-agents/verify.sh b/test/scenarios/tools/custom-agents/verify.sh index 826f9df9d..4d295b47f 100755 --- a/test/scenarios/tools/custom-agents/verify.sh +++ b/test/scenarios/tools/custom-agents/verify.sh @@ -109,6 +109,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o custom-agents-go # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" @@ -127,6 +129,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./custom-agents-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" diff --git a/test/scenarios/tools/mcp-servers/rust/Cargo.toml b/test/scenarios/tools/mcp-servers/rust/Cargo.toml new file mode 100644 index 000000000..84c40e3be --- /dev/null +++ b/test/scenarios/tools/mcp-servers/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "mcp-servers-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +serde_json = "1" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/mcp-servers/rust/src/main.rs b/test/scenarios/tools/mcp-servers/rust/src/main.rs new file mode 100644 index 000000000..fd76147a1 --- /dev/null +++ b/test/scenarios/tools/mcp-servers/rust/src/main.rs @@ -0,0 +1,68 @@ +//! MCP servers — configure an MCP server from env and pass it through to +//! the CLI via `SessionConfig::mcp_servers`. Build-only when +//! `MCP_SERVER_CMD` is unset. + +use std::collections::HashMap; +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{ + McpServerConfig, McpStdioServerConfig, SessionConfig, SystemMessageConfig, +}; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mcp_cmd = std::env::var("MCP_SERVER_CMD").ok(); + let mcp_args_env = std::env::var("MCP_SERVER_ARGS").ok(); + let mcp_servers = mcp_cmd.as_ref().map(|cmd| { + let args: Vec = mcp_args_env + .as_deref() + .map(|s| s.split(' ').map(str::to_string).collect()) + .unwrap_or_default(); + let stdio = McpStdioServerConfig { + tools: vec!["*".to_string()], + command: cmd.clone(), + args, + ..Default::default() + }; + let mut map = HashMap::new(); + map.insert("example".to_string(), McpServerConfig::Stdio(stdio)); + map + }); + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = + Some("You are a helpful assistant. Answer questions concisely.".to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(Vec::new()); + config.mcp_servers = mcp_servers; + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + if mcp_cmd.is_some() { + println!("\nMCP servers configured: example"); + } else { + println!("\nNo MCP servers configured (set MCP_SERVER_CMD to test with a real server)"); + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/mcp-servers/verify.sh b/test/scenarios/tools/mcp-servers/verify.sh index b087e0625..abde4508e 100755 --- a/test/scenarios/tools/mcp-servers/verify.sh +++ b/test/scenarios/tools/mcp-servers/verify.sh @@ -105,6 +105,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o mcp-servers-go . # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" @@ -123,6 +125,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./mcp-servers-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" diff --git a/test/scenarios/tools/no-tools/rust/Cargo.toml b/test/scenarios/tools/no-tools/rust/Cargo.toml new file mode 100644 index 000000000..461469946 --- /dev/null +++ b/test/scenarios/tools/no-tools/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "no-tools-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/no-tools/rust/src/main.rs b/test/scenarios/tools/no-tools/rust/src/main.rs new file mode 100644 index 000000000..691ac47ed --- /dev/null +++ b/test/scenarios/tools/no-tools/rust/src/main.rs @@ -0,0 +1,44 @@ +//! No-tools session — replace the system prompt and empty the available tools +//! list so the agent cannot execute code, read files, or call any built-ins. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +const SYSTEM_PROMPT: &str = "You are a minimal assistant with no tools available. +You cannot execute code, read files, edit files, search, or perform any actions. +You can only respond with text based on your training data. +If asked about your capabilities or tools, clearly state that you have no tools available."; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some(SYSTEM_PROMPT.to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(Vec::new()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + let session = client.create_session(config).await?; + + let response = session + .send_and_wait("Use the bash tool to run 'echo hello'.") + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/no-tools/verify.sh b/test/scenarios/tools/no-tools/verify.sh index 1223c7dcc..286796b70 100755 --- a/test/scenarios/tools/no-tools/verify.sh +++ b/test/scenarios/tools/no-tools/verify.sh @@ -107,6 +107,9 @@ check "Python (syntax)" bash -c "python3 -c \"import ast; ast.parse(open('$SCRI # Go: build check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o no-tools-go . 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" + # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" @@ -125,6 +128,9 @@ run_with_timeout "Python (run)" bash -c "cd '$SCRIPT_DIR/python' && python3 main # Go: run run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./no-tools-go" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" + # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" diff --git a/test/scenarios/tools/skills/rust/Cargo.toml b/test/scenarios/tools/skills/rust/Cargo.toml new file mode 100644 index 000000000..c2de4b20e --- /dev/null +++ b/test/scenarios/tools/skills/rust/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "skills-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +async-trait = "0.1" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/skills/rust/src/main.rs b/test/scenarios/tools/skills/rust/src/main.rs new file mode 100644 index 000000000..845704fac --- /dev/null +++ b/test/scenarios/tools/skills/rust/src/main.rs @@ -0,0 +1,62 @@ +//! Skills — point the CLI at a directory of user-defined skills via +//! `SessionConfig::skill_directories`. + +use std::path::PathBuf; +use std::sync::Arc; + +use async_trait::async_trait; +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::hooks::{HookContext, PreToolUseInput, PreToolUseOutput, SessionHooks}; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions}; + +struct AllowAllHooks; + +#[async_trait] +impl SessionHooks for AllowAllHooks { + async fn on_pre_tool_use( + &self, + _input: PreToolUseInput, + _ctx: HookContext, + ) -> Option { + let mut out = PreToolUseOutput::default(); + out.permission_decision = Some("allow".to_string()); + Some(out) + } +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + // CARGO_MANIFEST_DIR resolves to .../tools/skills/rust at compile time. + let skills_dir: PathBuf = [env!("CARGO_MANIFEST_DIR"), "..", "sample-skills"] + .iter() + .collect(); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.skill_directories = Some(vec![skills_dir]); + let config = config + .with_handler(Arc::new(ApproveAllHandler)) + .with_hooks(Arc::new(AllowAllHooks)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait("Use the greeting skill to greet someone named Alice.") + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + println!("\nSkill directories configured successfully"); + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/skills/verify.sh b/test/scenarios/tools/skills/verify.sh index fb13fcb16..6d1881173 100755 --- a/test/scenarios/tools/skills/verify.sh +++ b/test/scenarios/tools/skills/verify.sh @@ -108,6 +108,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o skills-go . 2>&1" # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" echo " Phase 2: E2E Run (timeout ${TIMEOUT}s each)" @@ -125,6 +127,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./skills-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" echo " Results: $PASS passed, $FAIL failed" diff --git a/test/scenarios/tools/tool-filtering/rust/Cargo.toml b/test/scenarios/tools/tool-filtering/rust/Cargo.toml new file mode 100644 index 000000000..88e38073d --- /dev/null +++ b/test/scenarios/tools/tool-filtering/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "tool-filtering-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/tool-filtering/rust/src/main.rs b/test/scenarios/tools/tool-filtering/rust/src/main.rs new file mode 100644 index 000000000..edc203550 --- /dev/null +++ b/test/scenarios/tools/tool-filtering/rust/src/main.rs @@ -0,0 +1,47 @@ +//! Tool filtering — restrict the agent to a subset of built-in tools via +//! `SessionConfig::available_tools`. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::{SessionConfig, SystemMessageConfig}; +use github_copilot_sdk::{Client, ClientOptions}; + +const SYSTEM_PROMPT: &str = "You are a helpful assistant. You have access to a limited set \ +of tools. When asked about your tools, list exactly which tools you have available."; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut sysmsg = SystemMessageConfig::default(); + sysmsg.mode = Some("replace".to_string()); + sysmsg.content = Some(SYSTEM_PROMPT.to_string()); + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.system_message = Some(sysmsg); + config.available_tools = Some(vec![ + "grep".to_string(), + "glob".to_string(), + "view".to_string(), + ]); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait("What tools do you have available? List each one by name.") + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/tool-filtering/verify.sh b/test/scenarios/tools/tool-filtering/verify.sh index 058b7129e..d73377718 100755 --- a/test/scenarios/tools/tool-filtering/verify.sh +++ b/test/scenarios/tools/tool-filtering/verify.sh @@ -119,6 +119,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o tool-filtering-go # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" @@ -137,6 +139,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./tool-filtering-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" diff --git a/test/scenarios/tools/tool-overrides/rust/Cargo.toml b/test/scenarios/tools/tool-overrides/rust/Cargo.toml new file mode 100644 index 000000000..f3b9d6aef --- /dev/null +++ b/test/scenarios/tools/tool-overrides/rust/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "tool-overrides-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust", features = ["derive"] } +schemars = "1" +serde = { version = "1", features = ["derive"] } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/tools/tool-overrides/rust/src/main.rs b/test/scenarios/tools/tool-overrides/rust/src/main.rs new file mode 100644 index 000000000..ce002a27d --- /dev/null +++ b/test/scenarios/tools/tool-overrides/rust/src/main.rs @@ -0,0 +1,61 @@ +//! Tool overrides — replace the built-in `grep` tool with a custom +//! implementation that returns a distinct marker. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::tool::{ToolHandlerRouter, define_tool}; +use github_copilot_sdk::types::{SessionConfig, ToolResult}; +use github_copilot_sdk::{Client, ClientOptions}; +use schemars::JsonSchema; +use serde::Deserialize; + +#[derive(Deserialize, JsonSchema)] +#[schemars(description = "Parameters for custom grep")] +struct GrepParams { + /// Search query + query: String, +} + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let grep_tool = define_tool( + "grep", + "A custom grep implementation that overrides the built-in", + |_inv, params: GrepParams| async move { + Ok(ToolResult::Text(format!("CUSTOM_GREP_RESULT: {}", params.query))) + }, + ); + + let router = ToolHandlerRouter::new(vec![grep_tool], Arc::new(ApproveAllHandler)); + let mut tools = router.tools(); + for t in tools.iter_mut() { + if t.name == "grep" { + t.overrides_built_in_tool = true; + } + } + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + config.tools = Some(tools); + let config = config.with_handler(Arc::new(router)); + + let session = client.create_session(config).await?; + + let response = session + .send_and_wait("Use grep to search for the word 'hello'") + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/tools/tool-overrides/verify.sh b/test/scenarios/tools/tool-overrides/verify.sh index b7687de50..cf9b34d51 100755 --- a/test/scenarios/tools/tool-overrides/verify.sh +++ b/test/scenarios/tools/tool-overrides/verify.sh @@ -109,6 +109,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o tool-overrides-go # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" @@ -127,6 +129,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./tool-overrides-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" diff --git a/test/scenarios/transport/stdio/README.md b/test/scenarios/transport/stdio/README.md index 5178935cc..7de2457ec 100644 --- a/test/scenarios/transport/stdio/README.md +++ b/test/scenarios/transport/stdio/README.md @@ -23,6 +23,7 @@ Each sample follows the same flow: | `typescript/` | `@github/copilot-sdk` | TypeScript (Node.js) | | `python/` | `github-copilot-sdk` | Python | | `go/` | `github.com/github/copilot-sdk/go` | Go | +| `rust/` | `copilot-sdk` | Rust | ## Prerequisites diff --git a/test/scenarios/transport/stdio/rust/Cargo.toml b/test/scenarios/transport/stdio/rust/Cargo.toml new file mode 100644 index 000000000..aa22474c0 --- /dev/null +++ b/test/scenarios/transport/stdio/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "stdio-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/transport/stdio/rust/src/main.rs b/test/scenarios/transport/stdio/rust/src/main.rs new file mode 100644 index 000000000..156b3587d --- /dev/null +++ b/test/scenarios/transport/stdio/rust/src/main.rs @@ -0,0 +1,30 @@ +//! Stdio transport — spawn the CLI as a child and exchange JSON-RPC over its stdio. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let mut opts = ClientOptions::default(); + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/transport/stdio/verify.sh b/test/scenarios/transport/stdio/verify.sh index 9a5b11b17..f9f004675 100755 --- a/test/scenarios/transport/stdio/verify.sh +++ b/test/scenarios/transport/stdio/verify.sh @@ -104,6 +104,9 @@ check "Python (syntax)" bash -c "python3 -c \"import ast; ast.parse(open('$SCRI # Go: build check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o stdio-go . 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" + # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" @@ -122,6 +125,9 @@ run_with_timeout "Python (run)" bash -c "cd '$SCRIPT_DIR/python' && python3 main # Go: run run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./stdio-go" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" + # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" diff --git a/test/scenarios/transport/tcp/rust/Cargo.toml b/test/scenarios/transport/tcp/rust/Cargo.toml new file mode 100644 index 000000000..fe5d19a91 --- /dev/null +++ b/test/scenarios/transport/tcp/rust/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "tcp-rust" +version = "0.0.0" +edition = "2024" +publish = false + +[dependencies] +github-copilot-sdk = { path = "../../../../../rust" } +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } diff --git a/test/scenarios/transport/tcp/rust/src/main.rs b/test/scenarios/transport/tcp/rust/src/main.rs new file mode 100644 index 000000000..49691c1b2 --- /dev/null +++ b/test/scenarios/transport/tcp/rust/src/main.rs @@ -0,0 +1,43 @@ +//! TCP transport — connect to an externally-running CLI server. Reads +//! `COPILOT_CLI_URL` (default `localhost:3000`) for `host:port`. + +use std::sync::Arc; + +use github_copilot_sdk::handler::ApproveAllHandler; +use github_copilot_sdk::types::SessionConfig; +use github_copilot_sdk::{Client, ClientOptions, Transport}; + +#[tokio::main] +async fn main() -> Result<(), github_copilot_sdk::Error> { + let cli_url = + std::env::var("COPILOT_CLI_URL").unwrap_or_else(|_| "localhost:3000".to_string()); + let (host, port_str) = cli_url + .split_once(':') + .expect("COPILOT_CLI_URL must be 'host:port'"); + let port: u16 = port_str.parse().expect("COPILOT_CLI_URL port must be u16"); + + let mut opts = ClientOptions::default(); + opts.transport = Transport::External { + host: host.to_string(), + port, + }; + opts.github_token = std::env::var("GITHUB_TOKEN").ok(); + let client = Client::start(opts).await?; + + let mut config = SessionConfig::default(); + config.model = Some("claude-haiku-4.5".to_string()); + let config = config.with_handler(Arc::new(ApproveAllHandler)); + + let session = client.create_session(config).await?; + + let response = session.send_and_wait("What is the capital of France?").await?; + + if let Some(event) = response { + if let Some(content) = event.data.get("content").and_then(|c| c.as_str()) { + println!("{content}"); + } + } + + session.destroy().await?; + Ok(()) +} diff --git a/test/scenarios/transport/tcp/verify.sh b/test/scenarios/transport/tcp/verify.sh index 711e0959a..fd30b98f9 100755 --- a/test/scenarios/transport/tcp/verify.sh +++ b/test/scenarios/transport/tcp/verify.sh @@ -163,6 +163,8 @@ check "Go (build)" bash -c "cd '$SCRIPT_DIR/go' && go build -o tcp-go . 2>&1" # C#: build check "C# (build)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet build --nologo -v q 2>&1" +# Rust: build +check "Rust (build)" bash -c "cd '$SCRIPT_DIR/rust' && cargo build --quiet 2>&1" echo "══════════════════════════════════════" @@ -181,6 +183,8 @@ run_with_timeout "Go (run)" bash -c "cd '$SCRIPT_DIR/go' && ./tcp-go" # C#: run run_with_timeout "C# (run)" bash -c "cd '$SCRIPT_DIR/csharp' && dotnet run --no-build 2>&1" +# Rust: run +run_with_timeout "Rust (run)" bash -c "cd '$SCRIPT_DIR/rust' && cargo run --quiet 2>&1" echo "══════════════════════════════════════" diff --git a/test/scenarios/verify.sh b/test/scenarios/verify.sh index 543c93d2b..7b6b066a0 100755 --- a/test/scenarios/verify.sh +++ b/test/scenarios/verify.sh @@ -43,12 +43,13 @@ TOTAL=${#VERIFY_SCRIPTS[@]} # ── SDK icon helpers ──────────────────────────────────────────────── sdk_icons() { local log="$1" - local ts py go cs + local ts py go cs rs ts="$(sdk_status "$log" "TypeScript")" py="$(sdk_status "$log" "Python")" go="$(sdk_status "$log" "Go ")" cs="$(sdk_status "$log" "C#")" - printf "TS %s PY %s GO %s C# %s" "$ts" "$py" "$go" "$cs" + rs="$(sdk_status "$log" "Rust")" + printf "TS %s PY %s GO %s C# %s RS %s" "$ts" "$py" "$go" "$cs" "$rs" } sdk_status() {