diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 756e3bcb..94a7bc33 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -34,6 +34,10 @@ This is a multi-language SDK repository. Install the tools for the SDK(s) you pl 1. Install [.NET 8.0+](https://dotnet.microsoft.com/download) 1. Install dependencies: `cd dotnet && dotnet restore` +### Rust SDK +1. Install [Rust 1.75+](https://www.rust-lang.org/tools/install) (2021 edition) +1. Install dependencies: `cd rust && cargo build` + ## Submitting a pull request 1. [Fork][fork] and clone the repository @@ -60,11 +64,13 @@ just test-nodejs # Node.js tests just test-python # Python tests just test-go # Go tests just test-dotnet # .NET tests +just test-rust # Rust tests just lint-nodejs # Node.js linting just lint-python # Python linting just lint-go # Go linting just lint-dotnet # .NET linting +just lint-rust # Rust linting ``` Or run commands directly in each SDK directory: @@ -81,6 +87,9 @@ cd go && go test ./... && golangci-lint run ./... # .NET cd dotnet && dotnet test test/GitHub.Copilot.SDK.Test.csproj + +# Rust +cd rust && cargo test && cargo clippy -- -D warnings ``` Here are a few things you can do that will increase the likelihood of your pull request being accepted: diff --git a/README.md b/README.md index cf437522..3d9d2524 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ All SDKs are in technical preview and may change in breaking ways as we move tow | **Python** | [`./python/`](./python/README.md) | `pip install github-copilot-sdk` | | **Go** | [`./go/`](./go/README.md) | `go get github.com/github/copilot-sdk/go` | | **.NET** | [`./dotnet/`](./dotnet/README.md) | `dotnet add package GitHub.Copilot.SDK` | +| **Rust** | [`./rust/`](./rust/README.md) | `cargo add copilot-sdk` | See the individual SDK READMEs for installation, usage examples, and API reference. diff --git a/justfile b/justfile index 8b1af30c..97dd00f1 100644 --- a/justfile +++ b/justfile @@ -3,13 +3,13 @@ default: @just --list # Format all code across all languages -format: format-go format-python format-nodejs format-dotnet +format: format-go format-python format-nodejs format-dotnet format-rust # Lint all code across all languages -lint: lint-go lint-python lint-nodejs lint-dotnet +lint: lint-go lint-python lint-nodejs lint-dotnet lint-rust # Run tests for all languages -test: test-go test-python test-nodejs test-dotnet +test: test-go test-python test-nodejs test-dotnet test-rust # Format Go code format-go: @@ -71,6 +71,21 @@ test-dotnet: @echo "=== Testing .NET code ===" @cd dotnet && dotnet test test/GitHub.Copilot.SDK.Test.csproj +# Format Rust code +format-rust: + @echo "=== Formatting Rust code ===" + @cd rust && cargo fmt + +# Lint Rust code +lint-rust: + @echo "=== Linting Rust code ===" + @cd rust && cargo clippy -- -D warnings + +# Test Rust code +test-rust: + @echo "=== Testing Rust code ===" + @cd rust && cargo test + # Install all dependencies install: @echo "=== Installing dependencies ===" diff --git a/rust/.gitignore b/rust/.gitignore new file mode 100644 index 00000000..49b80a8d --- /dev/null +++ b/rust/.gitignore @@ -0,0 +1,5 @@ +# Rust build artifacts +/target/ + +# Cargo lock file (optional for libraries) +# Cargo.lock diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 00000000..d3bd9fe4 --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,836 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + +[[package]] +name = "bumpalo" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" + +[[package]] +name = "bytes" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" + +[[package]] +name = "cc" +version = "1.2.53" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "755d2fce177175ffca841e9a06afdb2c4ab0f593d53b4dee48147dfaade85932" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "chrono" +version = "0.4.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fac4744fb15ae8337dc853fee7fb3f4e48c0fbaa23d0afe49c447b4fab126118" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + +[[package]] +name = "copilot-sdk" +version = "0.1.0" +dependencies = [ + "chrono", + "futures", + "once_cell", + "regex", + "schemars", + "serde", + "serde_json", + "tempfile", + "thiserror", + "tokio", + "tokio-test", + "uuid", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "find-msvc-tools" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8591b0bcc8a98a64310a2fae1bb3e9b8564dd10e381e6e28010fde8e8e8568db" + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" + +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + +[[package]] +name = "futures-task" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988" + +[[package]] +name = "futures-util" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + +[[package]] +name = "iana-time-zone" +version = "0.1.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33e57f83510bb73707521ebaffa789ec8caf86f9657cad665b092b581d40e9fb" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "js-sys" +version = "0.3.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.180" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bcc35a38544a891a5f7c865aca548a982ccb3b8650a5b06d0fd33a10283c56fc" + +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "memchr" +version = "2.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "proc-macro2" +version = "1.0.105" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "535d180e0ecab6268a3e718bb9fd44db66bbbc256257165fc699dadf70d16fe7" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc74d9a594b72ae6656596548f56f667211f8a97b3d4c3d467150794690dc40a" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "regex" +version = "1.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5276caf25ac86c8d810222b3dbb938e512c55c6831a10f3e6ed1c93b84041f1c" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" + +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "schemars" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" +dependencies = [ + "dyn-clone", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "0.8.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e265784ad618884abaea0600a9adf15393368d840e0222d101a072f3f7534d" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + +[[package]] +name = "socket2" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17129e116933cf371d018bb80ae557e889637989d8638274fb25622827b03881" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + +[[package]] +name = "syn" +version = "2.0.114" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4d107df263a3013ef9b1879b0df87d706ff80f65a86ea879bd9c31f9b307c2a" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tempfile" +version = "3.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio" +version = "1.49.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + +[[package]] +name = "tokio-test" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f6d24790a10a7af737693a3e8f1d03faef7e6ca0cc99aae5066f533766de545" +dependencies = [ + "futures-core", + "tokio", + "tokio-stream", +] + +[[package]] +name = "unicode-ident" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" + +[[package]] +name = "uuid" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" +dependencies = [ + "getrandom", + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "wasip2" +version = "1.0.2+wasi-0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9517f9239f02c069db75e65f174b3da828fe5f5b945c4dd26bd25d89c03ebcf5" +dependencies = [ + "wit-bindgen", +] + +[[package]] +name = "wasm-bindgen" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wit-bindgen" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" + +[[package]] +name = "zmij" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfcd145825aace48cff44a8844de64bf75feec3080e0aa5cdbde72961ae51a65" diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 00000000..b319a976 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "copilot-sdk" +version = "0.1.0" +edition = "2021" +license = "MIT" +description = "Rust SDK for programmatic access to the GitHub Copilot CLI" +repository = "https://github.com/github/copilot-sdk" +readme = "README.md" +keywords = ["copilot", "github", "ai", "sdk"] +categories = ["api-bindings", "development-tools"] +rust-version = "1.75" + +[dependencies] +tokio = { version = "1", features = ["rt-multi-thread", "process", "io-util", "net", "sync", "time", "macros"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +uuid = { version = "1", features = ["v4"] } +thiserror = "2" +schemars = { version = "0.8", features = ["derive"] } +chrono = { version = "0.4", features = ["serde"] } +regex = "1" +futures = "0.3" +once_cell = "1" + +[dev-dependencies] +tokio-test = "0.4" +tempfile = "3" + +[[example]] +name = "basic" +path = "examples/basic.rs" diff --git a/rust/README.md b/rust/README.md new file mode 100644 index 00000000..6088aac6 --- /dev/null +++ b/rust/README.md @@ -0,0 +1,474 @@ +# Copilot CLI SDK for Rust + +A Rust SDK for programmatic access to the GitHub Copilot CLI. + +> **Note:** This SDK is in technical preview and may change in breaking ways. + +## Installation + +```bash +cargo add copilot-sdk +cargo add tokio --features full +``` + +Or add to your `Cargo.toml`: + +```toml +[dependencies] +copilot-sdk = "0.1" +tokio = { version = "1", features = ["full"] } +``` + +## Quick Start + +```rust +use copilot_sdk::{CopilotClient, ClientOptions, SessionConfig, SessionEvent, MessageOptions, SessionEventType}; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Create client (returns Result) + let client = CopilotClient::new(Some(ClientOptions { + log_level: Some("error".to_string()), + ..Default::default() + }))?; + + // Start the client + client.start().await?; + + // Create a session + let session = client.create_session(Some(SessionConfig { + model: Some("gpt-5".to_string()), + ..Default::default() + })).await?; + + // Set up event handler (receives Arc) + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let tx_clone = tx.clone(); + + session.on(Arc::new(move |event: Arc| { + if event.event_type == SessionEventType::AssistantMessage { + if let Some(content) = &event.data.content { + println!("{}", content); + } + } + if event.event_type == SessionEventType::SessionIdle { + let _ = tx_clone.try_send(()); + } + })); + + // Send a message + session.send(MessageOptions { + prompt: "What is 2+2?".to_string(), + ..Default::default() + }).await?; + + // Wait for completion + rx.recv().await; + + // Clean up + session.destroy().await?; + client.stop().await; + + Ok(()) +} +``` + +## API Reference + +### CopilotClient + +The main client for interacting with the Copilot CLI server. + +#### Constructor + +```rust +CopilotClient::new(options: Option) -> Result +``` + +Creates a new client. Returns `Result` to handle invalid configuration errors. + +**ClientOptions:** + +| Field | Type | Default | Description | +|-------|------|---------|-------------| +| `cli_path` | `Option` | `"copilot"` | Path to CLI executable (or `COPILOT_CLI_PATH` env var) | +| `cli_url` | `Option` | `None` | URL of existing CLI server (e.g., `"localhost:8080"`, `"http://127.0.0.1:9000"`, or `"8080"`). When provided, the client will not spawn a CLI process. | +| `cwd` | `Option` | `None` | Working directory for CLI process | +| `port` | `Option` | `0` | Server port for TCP mode (0 = random) | +| `use_stdio` | `Option` | `true` | Use stdio transport instead of TCP | +| `log_level` | `Option` | `"info"` | Log level for CLI server | +| `auto_start` | `Option` | `true` | Auto-start server on first use | +| `auto_restart` | `Option` | `true` | Auto-restart on crash | +| `env` | `Option>` | `None` | Environment variables for CLI process | + +#### Methods + +##### `start() -> Result<()>` + +Start the CLI server and establish connection. + +##### `stop() -> Vec` + +Stop the CLI server and close all sessions. Returns a list of any errors encountered during cleanup. + +##### `force_stop()` + +Forcefully stop without graceful cleanup. Use when `stop()` takes too long. + +##### `create_session(config: Option) -> Result>` + +Create a new conversation session. + +**SessionConfig:** + +| Field | Type | Description | +|-------|------|-------------| +| `session_id` | `Option` | Custom session ID | +| `model` | `Option` | Model to use (`"gpt-5"`, `"claude-sonnet-4.5"`, etc.) | +| `tools` | `Vec` | Custom tools exposed to the CLI | +| `streaming` | `Option` | Enable streaming responses | +| `system_message` | `Option` | System message customization | +| `provider` | `Option` | Custom model provider | +| `mcp_servers` | `Option>` | MCP server configurations | +| `available_tools` | `Option>` | Allowlist of available tools | +| `excluded_tools` | `Option>` | Tools to exclude | + +##### `resume_session(session_id: &str, config: Option) -> Result>` + +Resume an existing session. + +##### `get_state() -> ConnectionState` + +Get current connection state (`Disconnected`, `Connecting`, `Connected`, `Error`). + +##### `ping(message: Option<&str>) -> Result` + +Ping the server to verify connectivity. + +##### `list_sessions() -> Result>` + +List all available sessions. + +##### `delete_session(session_id: &str) -> Result<()>` + +Delete a session and its data from disk. + +--- + +### CopilotSession + +Represents a single conversation session. + +#### Methods + +##### `send(options: MessageOptions) -> Result` + +Send a message to the session. Returns immediately after the message is queued; use event handlers or `send_and_wait()` to wait for completion. + +**MessageOptions:** + +| Field | Type | Description | +|-------|------|-------------| +| `prompt` | `String` | The message/prompt to send | +| `attachments` | `Option>` | File attachments | +| `mode` | `Option` | Delivery mode (`"enqueue"` or `"immediate"`) | + +Returns the message ID. + +##### `send_and_wait(options: MessageOptions, timeout: Option) -> Result>` + +Send a message and wait until the session becomes idle. Returns the final assistant message event, or `None` if none was received. + +##### `on(handler: SessionEventHandler) -> impl FnOnce()` + +Subscribe to session events. Returns an unsubscribe function. + +**Important:** The handler receives `Arc` (not `SessionEvent`) to avoid expensive clones when dispatching to multiple handlers. + +```rust +let unsubscribe = session.on(Arc::new(|event: Arc| { + println!("Event: {:?}", event.event_type); +})); + +// Later... +unsubscribe(); +``` + +##### `abort() -> Result<()>` + +Abort the currently processing message. + +##### `get_messages() -> Result>` + +Get all events/messages from this session's history. + +##### `destroy() -> Result<()>` + +Destroy the session and free resources. + +--- + +## Tools + +Expose your own functionality to Copilot by attaching tools to a session. + +### Using `define_tool` (Recommended) + +Use `define_tool` for type-safe tools with automatic JSON schema generation: + +```rust +use copilot_sdk::{define_tool, SessionConfig, ToolInvocation}; +use schemars::JsonSchema; +use serde::Deserialize; + +#[derive(Deserialize, JsonSchema)] +struct LookupIssueParams { + /// Issue identifier + id: String, +} + +let lookup_issue = define_tool::( + "lookup_issue", + "Fetch issue details from our tracker", + |params, _inv| async move { + let issue = fetch_issue(¶ms.id).await?; + Ok(issue.summary) + }, +); + +let session = client.create_session(Some(SessionConfig { + model: Some("gpt-5".to_string()), + tools: vec![lookup_issue], + ..Default::default() +})).await?; +``` + +### Using `ToolBuilder` + +For more control over the JSON schema: + +```rust +use copilot_sdk::{ToolBuilder, ToolResult}; +use serde_json::json; + +let lookup_issue = ToolBuilder::new("lookup_issue") + .description("Fetch issue details from our tracker") + .parameters(json!({ + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "Issue identifier" + } + }, + "required": ["id"] + })) + .handler(|inv| async move { + let id = inv.arguments.get("id") + .and_then(|v| v.as_str()) + .unwrap_or(""); + let issue = fetch_issue(id).await?; + Ok(ToolResult::success(issue.summary)) + }); + +let session = client.create_session(Some(SessionConfig { + model: Some("gpt-5".to_string()), + tools: vec![lookup_issue], + ..Default::default() +})).await?; +``` + +When the model selects a tool, the SDK automatically runs your handler and responds to the CLI's `tool.call` with the result. + +--- + +## Event Types + +Sessions emit various events during processing: + +| Event Type | Description | +|------------|-------------| +| `UserMessage` | User message added | +| `AssistantMessage` | Complete assistant response | +| `AssistantMessageDelta` | Streaming response chunk | +| `AssistantReasoning` | Complete reasoning content | +| `AssistantReasoningDelta` | Streaming reasoning chunk | +| `ToolExecutionStart` | Tool execution started | +| `ToolExecutionComplete` | Tool execution completed | +| `SessionIdle` | Session finished processing | +| `SessionError` | Error occurred | +| `SessionStart` | Session started | + +See [`SessionEventType`](src/generated/session_events.rs) for the full list. + +--- + +## Streaming + +Enable streaming to receive assistant response chunks as they're generated: + +```rust +use copilot_sdk::{CopilotClient, SessionConfig, SessionEvent, MessageOptions, SessionEventType}; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = CopilotClient::new(None)?; + client.start().await?; + + let session = client.create_session(Some(SessionConfig { + model: Some("gpt-5".to_string()), + streaming: Some(true), + ..Default::default() + })).await?; + + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + let tx_clone = tx.clone(); + + session.on(Arc::new(move |event: Arc| { + match event.event_type { + SessionEventType::AssistantMessageDelta => { + // Streaming message chunk - print incrementally + if let Some(delta) = &event.data.delta_content { + print!("{}", delta); + } + } + SessionEventType::AssistantReasoningDelta => { + // Streaming reasoning chunk (if model supports reasoning) + if let Some(delta) = &event.data.delta_content { + print!("{}", delta); + } + } + SessionEventType::AssistantMessage => { + // Final message - complete content + println!("\n--- Final message ---"); + if let Some(content) = &event.data.content { + println!("{}", content); + } + } + SessionEventType::AssistantReasoning => { + // Final reasoning content (if model supports reasoning) + println!("--- Reasoning ---"); + if let Some(content) = &event.data.content { + println!("{}", content); + } + } + SessionEventType::SessionIdle => { + let _ = tx_clone.try_send(()); + } + _ => {} + } + })); + + session.send(MessageOptions { + prompt: "Tell me a short story".to_string(), + ..Default::default() + }).await?; + + rx.recv().await; + + session.destroy().await?; + client.stop().await; + + Ok(()) +} +``` + +When `streaming: Some(true)`: + +- `AssistantMessageDelta` events contain `delta_content` with incremental text +- `AssistantReasoningDelta` events contain reasoning chunks (model-dependent) +- Accumulate `delta_content` values to build the response progressively +- Final `AssistantMessage` and `AssistantReasoning` events contain complete content + +Note: Final events are always sent regardless of streaming setting. + +--- + +## Transport Modes + +### stdio (Default) + +Communicates with CLI via stdin/stdout pipes. Recommended for most use cases. + +```rust +let client = CopilotClient::new(None)?; // Uses stdio by default +``` + +### TCP + +Communicates with CLI via TCP socket. Useful for distributed scenarios. + +```rust +let client = CopilotClient::new(Some(ClientOptions { + use_stdio: Some(false), + port: Some(3000), + ..Default::default() +}))?; +``` + +### External Server + +Connect to an already-running CLI server: + +```rust +let client = CopilotClient::new(Some(ClientOptions { + cli_url: Some("localhost:8080".to_string()), + ..Default::default() +}))?; +``` + +--- + +## Error Handling + +The SDK uses `CopilotError` for all error types: + +```rust +use copilot_sdk::{CopilotClient, CopilotError}; + +match CopilotClient::new(None) { + Ok(client) => { + // Use client... + } + Err(CopilotError::InvalidConfig(msg)) => { + eprintln!("Configuration error: {}", msg); + } + Err(e) => { + eprintln!("Error: {}", e); + } +} +``` + +Common error types: + +| Error | Description | +|-------|-------------| +| `InvalidConfig` | Invalid client configuration | +| `NotConnected` | Client not connected | +| `Connection` | Connection error | +| `Process` | CLI process error | +| `Timeout` | Operation timed out | +| `JsonRpc` | JSON-RPC error from server | +| `Session` | Session-related error | + +--- + +## Environment Variables + +- `COPILOT_CLI_PATH` - Path to the Copilot CLI executable + +--- + +## Requirements + +- Rust 1.75+ (2021 edition) +- Tokio async runtime +- GitHub Copilot CLI installed and accessible + +--- + +## License + +MIT diff --git a/rust/e2e/client_test.rs b/rust/e2e/client_test.rs new file mode 100644 index 00000000..7903546f --- /dev/null +++ b/rust/e2e/client_test.rs @@ -0,0 +1,129 @@ +//! Client E2E tests. + +use copilot_sdk::{ClientOptions, ConnectionState, CopilotClient}; + +mod testharness; +use testharness::cli_path; + +/// Skip test if CLI is not available. +macro_rules! require_cli { + () => { + if cli_path().is_none() { + eprintln!("Skipping test: CLI not found. Run 'npm install' in the nodejs directory first."); + return; + } + }; +} + +#[tokio::test] +async fn test_start_and_connect_using_stdio() { + require_cli!(); + + let cli = cli_path().unwrap(); + let client = CopilotClient::new(Some(ClientOptions { + cli_path: Some(cli), + use_stdio: Some(true), + ..Default::default() + })); + + // Start the client + let result = client.start().await; + assert!(result.is_ok(), "Failed to start client: {:?}", result.err()); + + // Verify state + assert_eq!(client.get_state().await, ConnectionState::Connected); + + // Ping the server + let pong = client.ping(Some("test message")).await; + assert!(pong.is_ok(), "Failed to ping: {:?}", pong.err()); + + let pong = pong.unwrap(); + assert_eq!(pong.message, "pong: test message"); + assert!(pong.timestamp >= 0); + + // Stop the client + let errors = client.stop().await; + assert!(errors.is_empty(), "Expected no errors on stop, got: {:?}", errors); + + // Verify disconnected state + assert_eq!(client.get_state().await, ConnectionState::Disconnected); +} + +#[tokio::test] +async fn test_start_and_connect_using_tcp() { + require_cli!(); + + let cli = cli_path().unwrap(); + let client = CopilotClient::new(Some(ClientOptions { + cli_path: Some(cli), + use_stdio: Some(false), + ..Default::default() + })); + + // Start the client + let result = client.start().await; + assert!(result.is_ok(), "Failed to start client: {:?}", result.err()); + + // Verify state + assert_eq!(client.get_state().await, ConnectionState::Connected); + + // Ping the server + let pong = client.ping(Some("test message")).await; + assert!(pong.is_ok(), "Failed to ping: {:?}", pong.err()); + + let pong = pong.unwrap(); + assert_eq!(pong.message, "pong: test message"); + assert!(pong.timestamp >= 0); + + // Stop the client + let errors = client.stop().await; + assert!(errors.is_empty(), "Expected no errors on stop, got: {:?}", errors); + + // Verify disconnected state + assert_eq!(client.get_state().await, ConnectionState::Disconnected); +} + +#[tokio::test] +async fn test_force_stop_without_cleanup() { + require_cli!(); + + let cli = cli_path().unwrap(); + let client = CopilotClient::new(Some(ClientOptions { + cli_path: Some(cli), + ..Default::default() + })); + + // Create a session + let session = client.create_session(None).await; + assert!(session.is_ok(), "Failed to create session: {:?}", session.err()); + + // Force stop + client.force_stop().await; + + // Verify disconnected state + assert_eq!(client.get_state().await, ConnectionState::Disconnected); +} + +#[tokio::test] +async fn test_auto_start_on_create_session() { + require_cli!(); + + let cli = cli_path().unwrap(); + let client = CopilotClient::new(Some(ClientOptions { + cli_path: Some(cli), + auto_start: Some(true), + ..Default::default() + })); + + // Don't call start() - it should auto-start + assert_eq!(client.get_state().await, ConnectionState::Disconnected); + + // Create a session - this should auto-start + let session = client.create_session(None).await; + assert!(session.is_ok(), "Failed to create session: {:?}", session.err()); + + // Should now be connected + assert_eq!(client.get_state().await, ConnectionState::Connected); + + client.force_stop().await; +} diff --git a/rust/e2e/mod.rs b/rust/e2e/mod.rs new file mode 100644 index 00000000..eace9248 --- /dev/null +++ b/rust/e2e/mod.rs @@ -0,0 +1,3 @@ +//! End-to-end tests for the Copilot SDK. + +pub mod testharness; diff --git a/rust/e2e/testharness/context.rs b/rust/e2e/testharness/context.rs new file mode 100644 index 00000000..1e9baf45 --- /dev/null +++ b/rust/e2e/testharness/context.rs @@ -0,0 +1,121 @@ +//! Test context for E2E tests. + +use copilot_sdk::{ClientOptions, CopilotClient}; +use std::env; +use std::path::PathBuf; +use std::sync::OnceLock; +use tempfile::TempDir; + +static CLI_PATH: OnceLock> = OnceLock::new(); + +/// Get the path to the Copilot CLI. +/// +/// Checks the `COPILOT_CLI_PATH` environment variable first, then looks for the CLI +/// in the sibling nodejs directory's node_modules. +pub fn cli_path() -> Option { + CLI_PATH + .get_or_init(|| { + // Check environment variable first + if let Ok(path) = env::var("COPILOT_CLI_PATH") { + if !path.is_empty() { + return Some(path); + } + } + + // Look for CLI in sibling nodejs directory's node_modules + let mut path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); + path.push("../nodejs/node_modules/@github/copilot/index.js"); + + if let Ok(abs_path) = path.canonicalize() { + if abs_path.exists() { + return abs_path.to_str().map(|s| s.to_string()); + } + } + + None + }) + .clone() +} + +/// Test context for E2E tests. +/// +/// Provides isolated directories and configuration for testing. +pub struct TestContext { + /// Path to the Copilot CLI. + pub cli_path: String, + /// Temporary home directory. + pub home_dir: TempDir, + /// Temporary work directory. + pub work_dir: TempDir, +} + +impl TestContext { + /// Create a new test context. + /// + /// # Panics + /// + /// Panics if the CLI is not found. + pub fn new() -> Self { + let cli = cli_path().expect( + "CLI not found. Run 'npm install' in the nodejs directory first, or set COPILOT_CLI_PATH.", + ); + + let home_dir = TempDir::new().expect("Failed to create temp home dir"); + let work_dir = TempDir::new().expect("Failed to create temp work dir"); + + Self { + cli_path: cli, + home_dir, + work_dir, + } + } + + /// Get environment variables configured for isolated testing. + pub fn env(&self) -> Vec<(String, String)> { + let mut env = Vec::new(); + + // Copy current environment + for (key, value) in std::env::vars() { + env.push((key, value)); + } + + // Add overrides for isolated testing + env.push(( + "XDG_CONFIG_HOME".to_string(), + self.home_dir.path().to_str().unwrap().to_string(), + )); + env.push(( + "XDG_STATE_HOME".to_string(), + self.home_dir.path().to_str().unwrap().to_string(), + )); + + env + } + + /// Create a CopilotClient configured for this test context. + pub fn new_client(&self) -> CopilotClient { + CopilotClient::new(Some(ClientOptions { + cli_path: Some(self.cli_path.clone()), + cwd: Some(self.work_dir.path().to_str().unwrap().to_string()), + env: Some(self.env()), + ..Default::default() + })) + } +} + +impl Default for TestContext { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cli_path_returns_some_or_none() { + // This test just ensures the function doesn't panic + let _ = cli_path(); + } +} diff --git a/rust/e2e/testharness/mod.rs b/rust/e2e/testharness/mod.rs new file mode 100644 index 00000000..bf7347a2 --- /dev/null +++ b/rust/e2e/testharness/mod.rs @@ -0,0 +1,5 @@ +//! Test harness for E2E tests. + +mod context; + +pub use context::*; diff --git a/rust/examples/basic.rs b/rust/examples/basic.rs new file mode 100644 index 00000000..ff8563c1 --- /dev/null +++ b/rust/examples/basic.rs @@ -0,0 +1,100 @@ +//! Basic example demonstrating the Copilot SDK. +//! +//! This example shows how to: +//! - Create a client and connect to the CLI server +//! - Create a session +//! - Subscribe to events +//! - Send a message and wait for a response +//! +//! # Running +//! +//! ```bash +//! cargo run --example basic +//! ``` + +use copilot_sdk::{ + CopilotClient, MessageOptions, SessionConfig, SessionEvent, SessionEventType, +}; +use std::sync::Arc; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("Starting Copilot SDK basic example...\n"); + + // Create a client with default options + // This will spawn the CLI server automatically when needed + let client = CopilotClient::new(None)?; + + // Start the client (connects to CLI server) + println!("Connecting to Copilot CLI server..."); + client.start().await?; + println!("Connected!\n"); + + // Ping the server to verify connectivity + let pong = client.ping(Some("hello")).await?; + println!("Ping response: {}", pong.message); + println!("Protocol version: {:?}\n", pong.protocol_version); + + // Create a session + println!("Creating session..."); + let session = client + .create_session(Some(SessionConfig { + model: Some("gpt-4".to_string()), + ..Default::default() + })) + .await?; + println!("Session created: {}\n", session.session_id()); + + // Subscribe to events (handler receives Arc) + let _unsubscribe = session.on(Arc::new(|event: Arc| { + match event.event_type { + SessionEventType::AssistantMessage => { + if let Some(content) = &event.data.content { + println!("Assistant: {}", content); + } + } + SessionEventType::SessionError => { + if let Some(message) = &event.data.message { + eprintln!("Error: {}", message); + } + } + SessionEventType::SessionIdle => { + println!("\n[Session idle]"); + } + _ => { + // Log other event types + println!("[Event: {}]", event.event_type); + } + } + })); + + // Send a message and wait for response + println!("Sending message...\n"); + let response = session + .send_and_wait( + MessageOptions { + prompt: "What is 2 + 2? Answer briefly.".to_string(), + ..Default::default() + }, + None, // Use default timeout + ) + .await?; + + if let Some(event) = response { + if let Some(content) = event.data.content { + println!("\nFinal response: {}", content); + } + } + + // Clean up + println!("\nStopping client..."); + let errors = client.stop().await; + if !errors.is_empty() { + for err in errors { + eprintln!("Cleanup error: {}", err); + } + } + + println!("Done!"); + Ok(()) +} diff --git a/rust/src/client.rs b/rust/src/client.rs new file mode 100644 index 00000000..546440cd --- /dev/null +++ b/rust/src/client.rs @@ -0,0 +1,1055 @@ +//! CopilotClient implementation for managing the Copilot CLI connection. + +use crate::error::{CopilotError, Result}; +use crate::generated::SessionEvent; +use crate::jsonrpc::JsonRpcClient; +use crate::session::CopilotSession; +use crate::tool::{ToolInvocation, ToolResult}; +use crate::types::{ + ClientOptions, ConnectionState, PingResponse, ProviderConfig, + ResumeSessionConfig, SessionConfig, SessionMetadata, get_sdk_protocol_version, +}; +use once_cell::sync::Lazy; +use regex::Regex; +use serde_json::{json, Value}; +use std::collections::HashMap; +use std::env; +use std::process::Stdio; +use std::sync::{Arc, Weak}; +use tokio::io::{BufReader, BufWriter}; +use tokio::net::TcpStream; +use tokio::process::{Child, Command}; +use tokio::sync::{RwLock, Semaphore}; + +/// Maximum number of concurrent event dispatch tasks to prevent unbounded task spawning. +const MAX_CONCURRENT_EVENT_TASKS: usize = 100; + +/// Semaphore to limit concurrent event dispatch tasks. +static EVENT_SEMAPHORE: Lazy = Lazy::new(|| Semaphore::new(MAX_CONCURRENT_EVENT_TASKS)); + +/// Client for interacting with the Copilot CLI server. +/// +/// The client manages the connection to the CLI server and provides methods +/// for creating and managing sessions. +/// +/// # Example +/// +/// ```ignore +/// use copilot_sdk::{CopilotClient, ClientOptions}; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box> { +/// let client = CopilotClient::new(None)?; +/// client.start().await?; +/// +/// let session = client.create_session(None).await?; +/// // Use the session... +/// +/// client.stop().await; +/// Ok(()) +/// } +/// ``` +pub struct CopilotClient { + options: ClientOptions, + process: RwLock>, + rpc_client: RwLock>>, + state: RwLock, + sessions: Arc>>>, + actual_port: RwLock, + actual_host: RwLock, + is_external_server: bool, + auto_start: bool, + auto_restart: bool, + /// Weak self-reference for use in callbacks. Set by `start_arc()`. + self_ref: RwLock>>, +} + +impl CopilotClient { + /// Create a new CopilotClient with the given options. + /// + /// If options is None, default options are used (spawns CLI server using stdio). + /// + /// # Errors + /// + /// Returns `CopilotError::InvalidConfig` if the configuration is invalid: + /// - `cli_url` is provided along with `use_stdio` or `cli_path` (mutually exclusive) + /// - `cli_url` has an invalid format + pub fn new(options: Option) -> Result { + let mut opts = ClientOptions { + cli_path: Some("copilot".to_string()), + use_stdio: Some(true), + log_level: Some("info".to_string()), + ..Default::default() + }; + + let mut is_external_server = false; + let mut auto_start = true; + let mut auto_restart = true; + let mut actual_host = "localhost".to_string(); + let mut actual_port = 0u16; + + if let Some(user_opts) = options { + // Validate mutually exclusive options + if user_opts.cli_url.is_some() + && (user_opts.use_stdio.unwrap_or(false) || user_opts.cli_path.is_some()) + { + return Err(CopilotError::InvalidConfig( + "cli_url is mutually exclusive with use_stdio and cli_path".to_string(), + )); + } + + // Parse cli_url if provided + if let Some(ref url) = user_opts.cli_url { + let (host, port) = parse_cli_url(url)?; + actual_host = host; + actual_port = port; + is_external_server = true; + opts.use_stdio = Some(false); + opts.cli_url = user_opts.cli_url; + } + + if let Some(path) = user_opts.cli_path { + opts.cli_path = Some(path); + } + if let Some(cwd) = user_opts.cwd { + opts.cwd = Some(cwd); + } + if let Some(port) = user_opts.port { + opts.port = Some(port); + opts.use_stdio = Some(false); + } + if let Some(log_level) = user_opts.log_level { + opts.log_level = Some(log_level); + } + if let Some(env) = user_opts.env { + opts.env = Some(env); + } + if let Some(auto) = user_opts.auto_start { + auto_start = auto; + } + if let Some(auto) = user_opts.auto_restart { + auto_restart = auto; + } + } + + // Check environment variable for CLI path + if let Ok(cli_path) = env::var("COPILOT_CLI_PATH") { + opts.cli_path = Some(cli_path); + } + + Ok(Self { + options: opts, + process: RwLock::new(None), + rpc_client: RwLock::new(None), + state: RwLock::new(ConnectionState::Disconnected), + sessions: Arc::new(RwLock::new(HashMap::new())), + actual_port: RwLock::new(actual_port), + actual_host: RwLock::new(actual_host), + is_external_server, + auto_start, + auto_restart, + self_ref: RwLock::new(None), + }) + } + + /// Start the CLI server and establish a connection. + pub async fn start(&self) -> Result<()> { + { + let state = self.state.read().await; + if *state == ConnectionState::Connected { + return Ok(()); + } + } + + { + let mut state = self.state.write().await; + *state = ConnectionState::Connecting; + } + + // Only start CLI server process if not connecting to external server + if !self.is_external_server { + if let Err(e) = self.start_cli_server().await { + let mut state = self.state.write().await; + *state = ConnectionState::Error; + return Err(e); + } + } + + // Connect to the server + if let Err(e) = self.connect_to_server().await { + let mut state = self.state.write().await; + *state = ConnectionState::Error; + return Err(e); + } + + // Verify protocol version + if let Err(e) = self.verify_protocol_version().await { + let mut state = self.state.write().await; + *state = ConnectionState::Error; + return Err(e); + } + + { + let mut state = self.state.write().await; + *state = ConnectionState::Connected; + } + + Ok(()) + } + + /// Stop the CLI server and close all active sessions. + /// + /// Returns a vector of errors encountered during cleanup. + pub async fn stop(&self) -> Vec { + let mut errors = Vec::new(); + + // Destroy all active sessions + let sessions: Vec> = { + let sessions = self.sessions.read().await; + sessions.values().cloned().collect() + }; + + for session in sessions { + if let Err(e) = session.destroy().await { + errors.push(CopilotError::Session(format!( + "failed to destroy session {}: {}", + session.session_id(), + e + ))); + } + } + + { + let mut sessions = self.sessions.write().await; + sessions.clear(); + } + + // Kill CLI process (only if we spawned it) + if !self.is_external_server { + let mut process = self.process.write().await; + if let Some(mut child) = process.take() { + if let Err(e) = child.kill().await { + errors.push(CopilotError::Process(format!( + "failed to kill CLI process: {}", + e + ))); + } + } + } + + // Close JSON-RPC client + { + let mut client = self.rpc_client.write().await; + if let Some(rpc) = client.take() { + rpc.stop().await; + } + } + + { + let mut state = self.state.write().await; + *state = ConnectionState::Disconnected; + } + + if !self.is_external_server { + let mut port = self.actual_port.write().await; + *port = 0; + } + + errors + } + + /// Forcefully stop the CLI server without graceful cleanup. + pub async fn force_stop(&self) { + // Clear sessions immediately + { + let mut sessions = self.sessions.write().await; + sessions.clear(); + } + + // Kill CLI process (only if we spawned it) + if !self.is_external_server { + let mut process = self.process.write().await; + if let Some(mut child) = process.take() { + let _ = child.kill().await; + } + } + + // Close JSON-RPC client + { + let mut client = self.rpc_client.write().await; + if let Some(rpc) = client.take() { + rpc.stop().await; + } + } + + { + let mut state = self.state.write().await; + *state = ConnectionState::Disconnected; + } + + if !self.is_external_server { + let mut port = self.actual_port.write().await; + *port = 0; + } + } + + /// Create a new session. + pub async fn create_session( + &self, + config: Option, + ) -> Result> { + self.ensure_connected().await?; + + let config = config.unwrap_or_default(); + let mut params = json!({}); + + if let Some(ref model) = config.model { + params["model"] = json!(model); + } + if let Some(ref session_id) = config.session_id { + params["sessionId"] = json!(session_id); + } + if !config.tools.is_empty() { + let tool_defs: Vec = config + .tools + .iter() + .filter(|t| !t.name.is_empty()) + .map(|t| { + let mut def = json!({ + "name": t.name, + "description": t.description, + }); + if let Some(ref params) = t.parameters { + def["parameters"] = params.clone(); + } + def + }) + .collect(); + if !tool_defs.is_empty() { + params["tools"] = json!(tool_defs); + } + } + if let Some(ref sys_msg) = config.system_message { + let mut system_message = json!({}); + if let Some(ref mode) = sys_msg.mode { + system_message["mode"] = json!(mode); + } + if let Some(ref content) = sys_msg.content { + system_message["content"] = json!(content); + } + params["systemMessage"] = system_message; + } + if let Some(ref available) = config.available_tools { + params["availableTools"] = json!(available); + } + if let Some(ref excluded) = config.excluded_tools { + params["excludedTools"] = json!(excluded); + } + if let Some(streaming) = config.streaming { + params["streaming"] = json!(streaming); + } + if let Some(ref provider) = config.provider { + params["provider"] = build_provider_params(provider); + } + if let Some(ref mcp_servers) = config.mcp_servers { + params["mcpServers"] = json!(mcp_servers); + } + if let Some(ref custom_agents) = config.custom_agents { + params["customAgents"] = json!(custom_agents); + } + if let Some(ref config_dir) = config.config_dir { + params["configDir"] = json!(config_dir); + } + if let Some(ref skill_dirs) = config.skill_directories { + params["skillDirectories"] = json!(skill_dirs); + } + if let Some(ref disabled) = config.disabled_skills { + params["disabledSkills"] = json!(disabled); + } + + let rpc = self.get_rpc_client().await?; + let result = rpc.request("session.create", params).await?; + + let session_id = result + .get("sessionId") + .and_then(|v| v.as_str()) + .ok_or_else(|| CopilotError::InvalidResponse("missing sessionId".to_string()))? + .to_string(); + + let session = Arc::new(CopilotSession::new(session_id.clone(), rpc.clone())); + session.register_tools(config.tools).await; + + { + let mut sessions = self.sessions.write().await; + sessions.insert(session_id, session.clone()); + } + + Ok(session) + } + + /// Resume an existing session. + pub async fn resume_session( + &self, + session_id: &str, + config: Option, + ) -> Result> { + self.ensure_connected().await?; + + let config = config.unwrap_or_default(); + let mut params = json!({ + "sessionId": session_id, + }); + + if !config.tools.is_empty() { + let tool_defs: Vec = config + .tools + .iter() + .filter(|t| !t.name.is_empty()) + .map(|t| { + let mut def = json!({ + "name": t.name, + "description": t.description, + }); + if let Some(ref params) = t.parameters { + def["parameters"] = params.clone(); + } + def + }) + .collect(); + if !tool_defs.is_empty() { + params["tools"] = json!(tool_defs); + } + } + if let Some(ref provider) = config.provider { + params["provider"] = build_provider_params(provider); + } + if let Some(streaming) = config.streaming { + params["streaming"] = json!(streaming); + } + if let Some(ref mcp_servers) = config.mcp_servers { + params["mcpServers"] = json!(mcp_servers); + } + if let Some(ref custom_agents) = config.custom_agents { + params["customAgents"] = json!(custom_agents); + } + if let Some(ref skill_dirs) = config.skill_directories { + params["skillDirectories"] = json!(skill_dirs); + } + if let Some(ref disabled) = config.disabled_skills { + params["disabledSkills"] = json!(disabled); + } + + let rpc = self.get_rpc_client().await?; + let result = rpc.request("session.resume", params).await?; + + let resumed_session_id = result + .get("sessionId") + .and_then(|v| v.as_str()) + .ok_or_else(|| CopilotError::InvalidResponse("missing sessionId".to_string()))? + .to_string(); + + let session = Arc::new(CopilotSession::new(resumed_session_id.clone(), rpc.clone())); + session.register_tools(config.tools).await; + + { + let mut sessions = self.sessions.write().await; + sessions.insert(resumed_session_id, session.clone()); + } + + Ok(session) + } + + /// Get the current connection state. + pub async fn get_state(&self) -> ConnectionState { + *self.state.read().await + } + + /// Ping the server to verify connectivity. + pub async fn ping(&self, message: Option<&str>) -> Result { + let rpc = self.get_rpc_client().await?; + + let mut params = json!({}); + if let Some(msg) = message { + params["message"] = json!(msg); + } + + let result = rpc.request("ping", params).await?; + + Ok(PingResponse { + message: result + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(), + timestamp: result + .get("timestamp") + .and_then(|v| v.as_i64()) + .unwrap_or(0), + protocol_version: result + .get("protocolVersion") + .and_then(|v| v.as_i64()) + .map(|v| v as i32), + }) + } + + /// Delete a session by ID. + pub async fn delete_session(&self, session_id: &str) -> Result<()> { + let rpc = self.get_rpc_client().await?; + + let params = json!({ + "sessionId": session_id, + }); + + rpc.request("session.delete", params).await?; + + { + let mut sessions = self.sessions.write().await; + sessions.remove(session_id); + } + + Ok(()) + } + + /// List all sessions. + pub async fn list_sessions(&self) -> Result> { + let rpc = self.get_rpc_client().await?; + + let result = rpc.request("session.list", json!({})).await?; + + let sessions = result + .get("sessions") + .and_then(|v| v.as_array()) + .map(|arr| { + arr.iter() + .filter_map(|v| serde_json::from_value(v.clone()).ok()) + .collect() + }) + .unwrap_or_default(); + + Ok(sessions) + } + + /// Start the CLI server with auto-reconnect support. + /// + /// This method is similar to [`start()`](Self::start) but requires the client to be + /// wrapped in an `Arc`. When `auto_restart` is enabled in the client options, this + /// method sets up automatic reconnection when the connection is lost. + /// + /// # Example + /// + /// ```ignore + /// use copilot_sdk::{CopilotClient, ClientOptions}; + /// use std::sync::Arc; + /// + /// let client = Arc::new(CopilotClient::new(Some(ClientOptions { + /// auto_restart: Some(true), + /// ..Default::default() + /// }))?); + /// + /// client.start_arc().await?; + /// // Connection will automatically reconnect if lost + /// ``` + pub async fn start_arc(self: &Arc) -> Result<()> { + // Store weak self-reference for use in callbacks + { + let mut self_ref = self.self_ref.write().await; + *self_ref = Some(Arc::downgrade(self)); + } + + // Call the regular start method + self.start().await?; + + // Set up disconnect handler if auto_restart is enabled + if self.auto_restart { + self.setup_disconnect_handler(Arc::clone(self)).await; + } + + Ok(()) + } + + /// Set up the disconnect handler on the RPC client. + /// + /// This spawns a background task that listens for disconnect events + /// and handles reconnection. + async fn setup_disconnect_handler(&self, client_arc: Arc) { + // Get RPC client, ensuring the guard is dropped before the next await + let rpc = { + let guard = self.rpc_client.read().await; + guard.clone() + }; + + let Some(rpc) = rpc else { + return; + }; + + // Create a channel for disconnect notifications + let (tx, rx) = tokio::sync::mpsc::channel::<()>(1); + + // Set up the callback to send on the channel (sync operation) + rpc.set_on_disconnect(Arc::new(move || { + // try_send is non-blocking and won't fail if receiver is ready + let _ = tx.try_send(()); + })) + .await; + + // Spawn the handler task with the Arc reference + Self::spawn_disconnect_handler(client_arc, rx); + } + + /// Spawn a task to handle disconnect events. + /// + /// This is a separate function to ensure the spawned future doesn't capture + /// any non-Send types from the calling context. + fn spawn_disconnect_handler( + client: Arc, + mut rx: tokio::sync::mpsc::Receiver<()>, + ) { + tokio::spawn(async move { + // Wait for disconnect notification + if rx.recv().await.is_some() { + client.handle_disconnect().await; + } + }); + } + + /// Handle a disconnection event. + /// + /// Called when the RPC connection is lost. Triggers reconnection if + /// `auto_restart` is enabled and the client was in Connected state. + async fn handle_disconnect(&self) { + let should_reconnect = { + let state = self.state.read().await; + self.auto_restart && *state == ConnectionState::Connected + }; + + if should_reconnect { + self.reconnect().await; + } + } + + /// Attempt to reconnect to the server. + /// + /// Notifies all active sessions of the disconnection, stops the current + /// connection, and attempts to establish a new one. + async fn reconnect(&self) { + // Notify sessions of disconnection + self.invalidate_sessions().await; + + // Update state + { + let mut state = self.state.write().await; + *state = ConnectionState::Disconnected; + } + + // Stop the current connection + let _ = self.stop().await; + + // Attempt to restart + // Use the stored weak reference to call start_arc if available + let self_ref = self.self_ref.read().await.clone(); + if let Some(weak_self) = self_ref { + if let Some(arc_self) = weak_self.upgrade() { + if let Err(e) = arc_self.start_arc().await { + eprintln!("Reconnection failed: {}", e); + let mut state = arc_self.state.write().await; + *state = ConnectionState::Error; + } + return; + } + } + + // Fallback to regular start if no Arc reference available + if let Err(e) = self.start().await { + eprintln!("Reconnection failed: {}", e); + let mut state = self.state.write().await; + *state = ConnectionState::Error; + } + } + + /// Notify all sessions that the connection has been lost. + async fn invalidate_sessions(&self) { + // Collect sessions first to avoid holding the lock across await points + let sessions: Vec> = { + let guard = self.sessions.read().await; + guard.values().cloned().collect() + }; + + for session in sessions { + session.dispatch_error("Connection lost").await; + } + } + + // Private methods + + async fn ensure_connected(&self) -> Result<()> { + let state = self.state.read().await; + if *state == ConnectionState::Connected { + return Ok(()); + } + drop(state); + + if self.auto_start { + self.start().await + } else { + Err(CopilotError::NotConnected) + } + } + + async fn get_rpc_client(&self) -> Result> { + let client = self.rpc_client.read().await; + client.clone().ok_or(CopilotError::NotConnected) + } + + async fn start_cli_server(&self) -> Result<()> { + let cli_path = self + .options + .cli_path + .as_deref() + .unwrap_or("copilot") + .to_string(); + let log_level = self + .options + .log_level + .as_deref() + .unwrap_or("info") + .to_string(); + + let mut args = vec![ + "--server".to_string(), + "--log-level".to_string(), + log_level, + ]; + + let use_stdio = self.options.use_stdio.unwrap_or(true); + if use_stdio { + args.push("--stdio".to_string()); + } else if let Some(port) = self.options.port { + args.push("--port".to_string()); + args.push(port.to_string()); + } + + // Determine command - if CLI path is a .js file, run with node + let (command, final_args) = if cli_path.ends_with(".js") { + let mut new_args = vec![cli_path]; + new_args.extend(args); + ("node".to_string(), new_args) + } else { + (cli_path, args) + }; + + let mut cmd = Command::new(&command); + cmd.args(&final_args) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()); + + if let Some(ref cwd) = self.options.cwd { + cmd.current_dir(cwd); + } + + if let Some(ref env_vars) = self.options.env { + for (key, value) in env_vars { + cmd.env(key, value); + } + } + + let mut child = cmd + .spawn() + .map_err(|e| CopilotError::Process(format!("failed to start CLI server: {}", e)))?; + + if use_stdio { + // For stdio mode, create JSON-RPC client from stdin/stdout + let stdin = child + .stdin + .take() + .ok_or_else(|| CopilotError::Process("failed to get stdin".to_string()))?; + let stdout = child + .stdout + .take() + .ok_or_else(|| CopilotError::Process("failed to get stdout".to_string()))?; + + let reader = BufReader::new(stdout); + let writer = BufWriter::new(stdin); + + let rpc = Arc::new(JsonRpcClient::new(reader, writer)); + self.setup_notification_handler(&rpc).await; + + { + let mut client = self.rpc_client.write().await; + *client = Some(rpc); + } + + { + let mut process = self.process.write().await; + *process = Some(child); + } + + Ok(()) + } else { + // For TCP mode, wait for port announcement + let stdout = child + .stdout + .take() + .ok_or_else(|| CopilotError::Process("failed to get stdout".to_string()))?; + + let mut reader = BufReader::new(stdout); + let port_regex = Regex::new(r"listening on port (\d+)").unwrap(); + + let mut line = String::new(); + let timeout = tokio::time::timeout( + std::time::Duration::from_secs(10), + async { + loop { + line.clear(); + use tokio::io::AsyncBufReadExt; + if reader.read_line(&mut line).await.is_err() { + break Err(CopilotError::Process("failed to read from CLI".to_string())); + } + if let Some(caps) = port_regex.captures(&line) { + if let Some(port_str) = caps.get(1) { + if let Ok(port) = port_str.as_str().parse::() { + break Ok(port); + } + } + } + } + }, + ) + .await; + + let port = match timeout { + Ok(Ok(port)) => port, + Ok(Err(e)) => return Err(e), + Err(_) => { + return Err(CopilotError::Timeout); + } + }; + + { + let mut actual_port = self.actual_port.write().await; + *actual_port = port; + } + + { + let mut process = self.process.write().await; + *process = Some(child); + } + + Ok(()) + } + } + + async fn connect_to_server(&self) -> Result<()> { + let use_stdio = self.options.use_stdio.unwrap_or(true); + if use_stdio && !self.is_external_server { + // Already connected via stdio in start_cli_server + return Ok(()); + } + + // Connect via TCP + self.connect_via_tcp().await + } + + async fn connect_via_tcp(&self) -> Result<()> { + let port = *self.actual_port.read().await; + if port == 0 { + return Err(CopilotError::Connection( + "server port not available".to_string(), + )); + } + + let host = self.actual_host.read().await.clone(); + let addr = format!("{}:{}", host, port); + + let stream = tokio::time::timeout( + std::time::Duration::from_secs(10), + TcpStream::connect(&addr), + ) + .await + .map_err(|_| CopilotError::Timeout)? + .map_err(|e| { + CopilotError::Connection(format!("failed to connect to CLI server at {}: {}", addr, e)) + })?; + + let (reader, writer) = stream.into_split(); + let reader = BufReader::new(reader); + let writer = BufWriter::new(writer); + + let rpc = Arc::new(JsonRpcClient::new(reader, writer)); + self.setup_notification_handler(&rpc).await; + + { + let mut client = self.rpc_client.write().await; + *client = Some(rpc); + } + + Ok(()) + } + + async fn setup_notification_handler(&self, rpc: &Arc) { + let sessions = self.sessions.clone(); + + // Set up notification handler for session events + rpc.set_notification_handler(Arc::new(move |method, params| { + if method == "session.event" { + if let Some(session_id) = params.get("sessionId").and_then(|v| v.as_str()) { + if let Some(event_value) = params.get("event") { + if let Ok(event) = serde_json::from_value::(event_value.clone()) + { + let sessions = sessions.clone(); + let session_id = session_id.to_string(); + // Use semaphore to limit concurrent event dispatch tasks + tokio::spawn(async move { + let _permit = EVENT_SEMAPHORE.acquire().await.unwrap(); + let sessions = sessions.read().await; + if let Some(session) = sessions.get(&session_id) { + session.dispatch_event(event).await; + } + }); + } + } + } + } + })) + .await; + + // Set up tool call handler + let sessions_for_tools = self.sessions.clone(); + rpc.set_request_handler( + "tool.call", + Arc::new(move |params| { + let sessions = sessions_for_tools.clone(); + + let session_id = params + .get("sessionId") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let tool_call_id = params + .get("toolCallId") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let tool_name = params + .get("toolName") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let arguments = params.get("arguments").cloned().unwrap_or(Value::Null); + + Box::pin(async move { + if session_id.is_empty() || tool_call_id.is_empty() || tool_name.is_empty() { + return Err(CopilotError::InvalidResponse( + "invalid tool call payload".to_string(), + )); + } + + let sessions = sessions.read().await; + let session = sessions.get(&session_id); + + let result = if let Some(session) = session { + let inv = ToolInvocation { + session_id: session_id.clone(), + tool_call_id, + tool_name: tool_name.clone(), + arguments, + }; + session.execute_tool(&tool_name, inv).await + } else { + Ok(ToolResult::unsupported(&tool_name)) + }; + + match result { + Ok(tool_result) => Ok(json!({ "result": tool_result })), + Err(e) => Ok(json!({ "result": ToolResult::failure(e.to_string()) })), + } + }) + }), + ) + .await; + } + + async fn verify_protocol_version(&self) -> Result<()> { + let expected_version = get_sdk_protocol_version(); + let ping_result = self.ping(None).await?; + + match ping_result.protocol_version { + None => Err(CopilotError::ProtocolVersionNotReported { + expected: expected_version, + }), + Some(version) if version != expected_version => { + Err(CopilotError::ProtocolVersionMismatch { + expected: expected_version, + actual: version, + }) + } + Some(_) => Ok(()), + } + } +} + +/// Parse a CLI URL into host and port components. +fn parse_cli_url(url: &str) -> Result<(String, u16)> { + // Remove protocol if present + let clean_url = Regex::new(r"^https?://") + .unwrap() + .replace(url, "") + .to_string(); + + // Check if it's just a port number + if Regex::new(r"^\d+$").unwrap().is_match(&clean_url) { + let port: u16 = clean_url.parse().map_err(|_| { + CopilotError::InvalidConfig(format!("Invalid port in cli_url: {}", url)) + })?; + return Ok(("localhost".to_string(), port)); + } + + // Parse host:port format + let parts: Vec<&str> = clean_url.splitn(2, ':').collect(); + if parts.len() != 2 { + return Err(CopilotError::InvalidConfig(format!( + "Invalid cli_url format: {}. Expected 'host:port', 'http://host:port', or 'port'", + url + ))); + } + + let host = if parts[0].is_empty() { + "localhost".to_string() + } else { + parts[0].to_string() + }; + + let port: u16 = parts[1].parse().map_err(|_| { + CopilotError::InvalidConfig(format!("Invalid port in cli_url: {}", url)) + })?; + + Ok((host, port)) +} + +/// Build provider params for JSON-RPC. +fn build_provider_params(provider: &ProviderConfig) -> Value { + let mut params = json!({}); + + if let Some(ref t) = provider.provider_type { + params["type"] = json!(t); + } + if let Some(ref w) = provider.wire_api { + params["wireApi"] = json!(w); + } + params["baseUrl"] = json!(provider.base_url); + if let Some(ref k) = provider.api_key { + params["apiKey"] = json!(k); + } + if let Some(ref t) = provider.bearer_token { + params["bearerToken"] = json!(t); + } + if let Some(ref azure) = provider.azure { + let mut azure_params = json!({}); + if let Some(ref v) = azure.api_version { + azure_params["apiVersion"] = json!(v); + } + params["azure"] = azure_params; + } + + params +} diff --git a/rust/src/error.rs b/rust/src/error.rs new file mode 100644 index 00000000..e138b39b --- /dev/null +++ b/rust/src/error.rs @@ -0,0 +1,332 @@ +//! Error types for the Copilot SDK. +//! +//! This module provides error types used throughout the SDK for handling various +//! failure scenarios including JSON-RPC errors, connection issues, process failures, +//! and configuration problems. +//! +//! # Error Types +//! +//! The main error type is [`CopilotError`], which encompasses all possible errors +//! that can occur when using the SDK: +//! +//! - **Protocol Errors**: JSON-RPC errors, version mismatches +//! - **Connection Errors**: Network issues, connection failures +//! - **Process Errors**: CLI server startup/crash failures +//! - **Session Errors**: Session creation/management failures +//! - **Configuration Errors**: Invalid client configuration +//! +//! # Example +//! +//! ```ignore +//! use copilot_sdk::{CopilotClient, CopilotError}; +//! +//! match CopilotClient::new(None) { +//! Ok(client) => { +//! // Use client... +//! } +//! Err(CopilotError::InvalidConfig(msg)) => { +//! eprintln!("Configuration error: {}", msg); +//! } +//! Err(e) => { +//! eprintln!("Error: {}", e); +//! } +//! } +//! ``` + +use thiserror::Error; + +/// Main error type for the Copilot SDK. +/// +/// This enum represents all possible errors that can occur when using the SDK. +/// Each variant includes contextual information to help diagnose the issue. +/// +/// # Variants +/// +/// - [`JsonRpc`](Self::JsonRpc) - Error received from the JSON-RPC server +/// - [`Connection`](Self::Connection) - Network or connection failure +/// - [`Process`](Self::Process) - CLI server process failure +/// - [`Session`](Self::Session) - Session-related error +/// - [`NotConnected`](Self::NotConnected) - Client not connected +/// - [`ProtocolVersionMismatch`](Self::ProtocolVersionMismatch) - SDK/server version incompatibility +/// - [`Timeout`](Self::Timeout) - Operation timed out +/// - [`InvalidResponse`](Self::InvalidResponse) - Malformed server response +/// - [`ToolExecution`](Self::ToolExecution) - Tool handler failure +/// - [`Serialization`](Self::Serialization) - JSON serialization failure +/// - [`Io`](Self::Io) - I/O operation failure +/// - [`ClientStopped`](Self::ClientStopped) - Client has been stopped +/// - [`InvalidConfig`](Self::InvalidConfig) - Invalid configuration provided +#[derive(Error, Debug)] +pub enum CopilotError { + /// JSON-RPC error received from the server. + /// + /// This error is returned when the JSON-RPC server responds with an error. + /// The error code and message are provided by the server. + /// + /// # Fields + /// + /// - `code` - Standard JSON-RPC error code (see [`JsonRpcError`] constants) + /// - `message` - Human-readable error description from the server + /// - `data` - Optional additional error data provided by the server + #[error("JSON-RPC error {code}: {message}")] + JsonRpc { + /// The JSON-RPC error code. Standard codes are defined in [`JsonRpcError`]. + code: i32, + /// Human-readable error message from the server. + message: String, + /// Optional additional error data provided by the server. + data: Option, + }, + + /// Connection error. + /// + /// Returned when there is a network or transport-level failure + /// communicating with the CLI server. + #[error("Connection error: {0}")] + Connection(String), + + /// Process error (CLI server failed to start or crashed). + /// + /// Returned when the CLI server process fails to start, crashes, + /// or exits unexpectedly. + #[error("Process error: {0}")] + Process(String), + + /// Session error. + /// + /// Returned when a session-related operation fails, such as + /// creating, resuming, or destroying a session. + #[error("Session error: {0}")] + Session(String), + + /// Client not connected. + /// + /// Returned when attempting to use a client that hasn't been + /// connected yet. Call [`CopilotClient::start()`](crate::CopilotClient::start) + /// before using other methods. + #[error("Client not connected. Call start() first")] + NotConnected, + + /// Protocol version mismatch. + /// + /// Returned when the SDK's protocol version doesn't match the server's + /// version. This usually means either the SDK or CLI needs to be updated. + /// + /// # Fields + /// + /// - `expected` - The protocol version the SDK expects + /// - `actual` - The protocol version reported by the server + #[error("SDK protocol version mismatch: SDK expects version {expected}, but server reports version {actual}. Please update your SDK or server to ensure compatibility")] + ProtocolVersionMismatch { + /// The protocol version expected by this SDK. + expected: i32, + /// The protocol version reported by the server. + actual: i32, + }, + + /// Protocol version not reported by server. + /// + /// Returned when the server doesn't report a protocol version. + /// This usually indicates an older server that needs to be updated. + /// + /// # Fields + /// + /// - `expected` - The protocol version the SDK expects + #[error("SDK protocol version mismatch: SDK expects version {expected}, but server does not report a protocol version. Please update your server to ensure compatibility")] + ProtocolVersionNotReported { + /// The protocol version expected by this SDK. + expected: i32, + }, + + /// Timeout waiting for response. + /// + /// Returned when an operation takes longer than the configured + /// timeout duration. + #[error("Timeout waiting for response")] + Timeout, + + /// Invalid response from server. + /// + /// Returned when the server sends a response that cannot be + /// parsed or is malformed. + #[error("Invalid response: {0}")] + InvalidResponse(String), + + /// Tool execution error. + /// + /// Returned when a tool handler fails during execution. + /// The error message contains details about what went wrong. + #[error("Tool execution error: {0}")] + ToolExecution(String), + + /// Serialization error. + /// + /// Returned when JSON serialization or deserialization fails. + #[error("Serialization error: {0}")] + Serialization(#[from] serde_json::Error), + + /// IO error. + /// + /// Returned when an I/O operation fails, such as reading from + /// or writing to the CLI server process. + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + /// Client stopped. + /// + /// Returned when attempting to use a client that has been stopped. + #[error("Client stopped")] + ClientStopped, + + /// Invalid configuration. + /// + /// Returned when the client configuration is invalid, such as + /// providing mutually exclusive options. + /// + /// # Example + /// + /// ```ignore + /// // This will return InvalidConfig because cli_url and use_stdio are + /// // mutually exclusive + /// let client = CopilotClient::new(Some(ClientOptions { + /// cli_url: Some("localhost:8080".to_string()), + /// use_stdio: Some(true), + /// ..Default::default() + /// })); + /// ``` + #[error("Invalid configuration: {0}")] + InvalidConfig(String), +} + +/// JSON-RPC error representation. +/// +/// This struct represents a JSON-RPC 2.0 error object, containing an error code, +/// message, and optional additional data. It can be converted into a [`CopilotError`]. +/// +/// # Standard Error Codes +/// +/// The JSON-RPC 2.0 specification defines several standard error codes: +/// +/// | Code | Constant | Description | +/// |------|----------|-------------| +/// | -32700 | [`PARSE_ERROR`](Self::PARSE_ERROR) | Invalid JSON was received | +/// | -32600 | [`INVALID_REQUEST`](Self::INVALID_REQUEST) | Invalid JSON-RPC request | +/// | -32601 | [`METHOD_NOT_FOUND`](Self::METHOD_NOT_FOUND) | Method does not exist | +/// | -32602 | [`INVALID_PARAMS`](Self::INVALID_PARAMS) | Invalid method parameters | +/// | -32603 | [`INTERNAL_ERROR`](Self::INTERNAL_ERROR) | Internal JSON-RPC error | +/// +/// # Example +/// +/// ``` +/// use copilot_sdk::JsonRpcError; +/// +/// let error = JsonRpcError::new(JsonRpcError::METHOD_NOT_FOUND, "Method not found") +/// .with_data(serde_json::json!({"method": "unknown.method"})); +/// +/// assert_eq!(error.code, -32601); +/// assert_eq!(error.message, "Method not found"); +/// ``` +#[derive(Debug, Clone)] +pub struct JsonRpcError { + /// The JSON-RPC error code. + /// + /// Standard error codes are available as constants on this type + /// (e.g., [`Self::PARSE_ERROR`], [`Self::METHOD_NOT_FOUND`]). + /// Server-specific error codes may also be used. + pub code: i32, + + /// Human-readable error message describing what went wrong. + pub message: String, + + /// Optional additional error data. + /// + /// This can contain any additional information about the error + /// that might be useful for debugging or error handling. + pub data: Option, +} + +impl JsonRpcError { + /// Create a new JSON-RPC error with the given code and message. + /// + /// # Arguments + /// + /// * `code` - The JSON-RPC error code + /// * `message` - Human-readable error message + /// + /// # Example + /// + /// ``` + /// use copilot_sdk::JsonRpcError; + /// + /// let error = JsonRpcError::new(-32600, "Invalid Request"); + /// assert_eq!(error.code, JsonRpcError::INVALID_REQUEST); + /// ``` + pub fn new(code: i32, message: impl Into) -> Self { + Self { + code, + message: message.into(), + data: None, + } + } + + /// Add additional data to the error. + /// + /// Returns `self` for method chaining. + /// + /// # Arguments + /// + /// * `data` - Additional error data as a JSON value + /// + /// # Example + /// + /// ``` + /// use copilot_sdk::JsonRpcError; + /// use serde_json::json; + /// + /// let error = JsonRpcError::new(-32602, "Invalid params") + /// .with_data(json!({"expected": "string", "got": "number"})); + /// ``` + pub fn with_data(mut self, data: serde_json::Value) -> Self { + self.data = Some(data); + self + } + + /// Parse error: Invalid JSON was received by the server. + /// + /// An error occurred on the server while parsing the JSON text. + pub const PARSE_ERROR: i32 = -32700; + + /// Invalid Request: The JSON sent is not a valid JSON-RPC Request object. + pub const INVALID_REQUEST: i32 = -32600; + + /// Method not found: The method does not exist or is not available. + pub const METHOD_NOT_FOUND: i32 = -32601; + + /// Invalid params: Invalid method parameter(s). + pub const INVALID_PARAMS: i32 = -32602; + + /// Internal error: Internal JSON-RPC error. + pub const INTERNAL_ERROR: i32 = -32603; +} + +impl std::fmt::Display for JsonRpcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "JSON-RPC Error {}: {}", self.code, self.message) + } +} + +impl std::error::Error for JsonRpcError {} + +impl From for CopilotError { + fn from(err: JsonRpcError) -> Self { + CopilotError::JsonRpc { + code: err.code, + message: err.message, + data: err.data, + } + } +} + +/// Result type alias for Copilot operations. +/// +/// This is a convenience alias for `std::result::Result`. +pub type Result = std::result::Result; diff --git a/rust/src/generated/mod.rs b/rust/src/generated/mod.rs new file mode 100644 index 00000000..30d8b7a6 --- /dev/null +++ b/rust/src/generated/mod.rs @@ -0,0 +1,5 @@ +//! Generated types for the Copilot SDK. + +pub mod session_events; + +pub use session_events::*; diff --git a/rust/src/generated/session_events.rs b/rust/src/generated/session_events.rs new file mode 100644 index 00000000..e0ed705e --- /dev/null +++ b/rust/src/generated/session_events.rs @@ -0,0 +1,878 @@ +// AUTO-GENERATED FILE - DO NOT EDIT +// +// Generated from: @github/copilot/session-events.schema.json +// Generated by: scripts/generate-session-types.ts +// +// To update these types: +// 1. Update the schema in copilot-agent-runtime +// 2. Run: npm run generate:session-types + +//! Session event types for the Copilot SDK. +//! +//! This module contains types representing events emitted by Copilot sessions +//! during conversation processing. Events cover the full lifecycle of a session, +//! including message handling, tool execution, and session state changes. +//! +//! # Event Flow +//! +//! A typical session flow emits events in this order: +//! +//! 1. [`SessionEventType::SessionStart`] - Session created +//! 2. [`SessionEventType::UserMessage`] - User message received +//! 3. [`SessionEventType::AssistantTurnStart`] - Assistant begins processing +//! 4. [`SessionEventType::AssistantMessageDelta`] - Streaming response chunks (if enabled) +//! 5. [`SessionEventType::AssistantMessage`] - Complete assistant response +//! 6. [`SessionEventType::AssistantTurnEnd`] - Assistant finished processing +//! 7. [`SessionEventType::SessionIdle`] - Session ready for next message +//! +//! # Tool Execution +//! +//! When the assistant invokes tools: +//! +//! 1. [`SessionEventType::ToolExecutionStart`] - Tool execution begins +//! 2. [`SessionEventType::ToolExecutionPartialResult`] - Intermediate results (optional) +//! 3. [`SessionEventType::ToolExecutionComplete`] - Tool execution finished +//! +//! # Example +//! +//! ```ignore +//! use copilot_sdk::{SessionEvent, SessionEventType}; +//! +//! fn handle_event(event: &SessionEvent) { +//! match event.event_type { +//! SessionEventType::AssistantMessage => { +//! if let Some(content) = &event.data.content { +//! println!("Assistant: {}", content); +//! } +//! } +//! SessionEventType::SessionError => { +//! if let Some(msg) = &event.data.message { +//! eprintln!("Error: {}", msg); +//! } +//! } +//! _ => {} +//! } +//! } +//! ``` + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; + +/// Session event from the Copilot CLI. +/// +/// This is the primary event structure emitted by sessions during processing. +/// Each event has a type, timestamp, unique ID, and type-specific data. +/// +/// # Fields +/// +/// - `event_type` - The type of event (see [`SessionEventType`]) +/// - `id` - Unique identifier for this event +/// - `timestamp` - When the event occurred +/// - `data` - Event-specific data (see [`SessionEventData`]) +/// - `parent_id` - ID of parent event (for nested events) +/// - `ephemeral` - Whether this event should be persisted +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionEvent { + /// Event-specific data payload. + /// + /// This is a flat structure containing all possible fields for all event types. + /// Only fields relevant to the specific `event_type` will be populated. + pub data: SessionEventData, + + /// Whether the event is ephemeral (not persisted to session history). + /// + /// Ephemeral events are typically used for streaming deltas and + /// intermediate progress updates. + #[serde(skip_serializing_if = "Option::is_none")] + pub ephemeral: Option, + + /// Unique identifier for this event. + /// + /// Can be used to track specific events or correlate related events. + pub id: String, + + /// Parent event ID (if this event is a child of another event). + /// + /// Used to establish event hierarchies, such as tool execution events + /// being children of the message that triggered them. + #[serde(rename = "parentId")] + pub parent_id: Option, + + /// Timestamp when this event occurred. + pub timestamp: DateTime, + + /// The type of this event. + /// + /// Determines which fields in `data` are populated and how the event + /// should be interpreted. + #[serde(rename = "type")] + pub event_type: SessionEventType, +} + +/// Session event types. +/// +/// Each variant represents a different type of event that can occur during +/// a Copilot session. The event type determines which fields in +/// [`SessionEventData`] will be populated. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub enum SessionEventType { + /// Message processing was aborted. + #[serde(rename = "abort")] + Abort, + + /// Assistant expressed an intent or plan. + #[serde(rename = "assistant.intent")] + AssistantIntent, + + /// Complete assistant message. + /// + /// Contains the full response content in `data.content`. + #[serde(rename = "assistant.message")] + AssistantMessage, + + /// Streaming assistant message chunk. + /// + /// Contains incremental content in `data.delta_content`. + /// Only emitted when streaming is enabled. + #[serde(rename = "assistant.message_delta")] + AssistantMessageDelta, + + /// Complete assistant reasoning (chain-of-thought). + /// + /// Contains the full reasoning content in `data.content`. + /// Only available for models that support reasoning. + #[serde(rename = "assistant.reasoning")] + AssistantReasoning, + + /// Streaming assistant reasoning chunk. + /// + /// Contains incremental reasoning in `data.delta_content`. + /// Only emitted when streaming is enabled and model supports reasoning. + #[serde(rename = "assistant.reasoning_delta")] + AssistantReasoningDelta, + + /// Assistant finished processing this turn. + #[serde(rename = "assistant.turn_end")] + AssistantTurnEnd, + + /// Assistant started processing this turn. + #[serde(rename = "assistant.turn_start")] + AssistantTurnStart, + + /// Token usage information for this turn. + /// + /// Contains `data.input_tokens`, `data.output_tokens`, etc. + #[serde(rename = "assistant.usage")] + AssistantUsage, + + /// Hook execution completed. + #[serde(rename = "hook.end")] + HookEnd, + + /// Hook execution started. + #[serde(rename = "hook.start")] + HookStart, + + /// Pending messages queue was modified. + #[serde(rename = "pending_messages.modified")] + PendingMessagesModified, + + /// Context compaction completed. + #[serde(rename = "session.compaction_complete")] + SessionCompactionComplete, + + /// Context compaction started. + #[serde(rename = "session.compaction_start")] + SessionCompactionStart, + + /// Session error occurred. + /// + /// Contains error details in `data.message`, `data.error_type`, `data.stack`. + #[serde(rename = "session.error")] + SessionError, + + /// Session handed off to another destination. + #[serde(rename = "session.handoff")] + SessionHandoff, + + /// Session is idle and ready for new messages. + /// + /// This event indicates the session has finished processing and is + /// waiting for the next user message. + #[serde(rename = "session.idle")] + SessionIdle, + + /// Informational message about the session. + #[serde(rename = "session.info")] + SessionInfo, + + /// Model was changed during the session. + /// + /// Contains `data.previous_model` and `data.new_model`. + #[serde(rename = "session.model_change")] + SessionModelChange, + + /// Session was resumed from saved state. + #[serde(rename = "session.resume")] + SessionResume, + + /// Session was started. + /// + /// This is typically the first event emitted after creating a session. + #[serde(rename = "session.start")] + SessionStart, + + /// Context was truncated to fit within token limits. + #[serde(rename = "session.truncation")] + SessionTruncation, + + /// Usage information for the session. + #[serde(rename = "session.usage_info")] + SessionUsageInfo, + + /// Subagent completed execution. + #[serde(rename = "subagent.completed")] + SubagentCompleted, + + /// Subagent execution failed. + #[serde(rename = "subagent.failed")] + SubagentFailed, + + /// Subagent was selected for execution. + #[serde(rename = "subagent.selected")] + SubagentSelected, + + /// Subagent started execution. + #[serde(rename = "subagent.started")] + SubagentStarted, + + /// System message was added. + #[serde(rename = "system.message")] + SystemMessage, + + /// Tool execution completed. + /// + /// Contains tool result in `data.result`, tool name in `data.tool_name`. + #[serde(rename = "tool.execution_complete")] + ToolExecutionComplete, + + /// Partial tool execution result. + /// + /// Contains intermediate output in `data.partial_output`. + #[serde(rename = "tool.execution_partial_result")] + ToolExecutionPartialResult, + + /// Tool execution started. + /// + /// Contains tool name in `data.tool_name`, arguments in `data.arguments`. + #[serde(rename = "tool.execution_start")] + ToolExecutionStart, + + /// Tool was requested by user. + #[serde(rename = "tool.user_requested")] + ToolUserRequested, + + /// User message was added to the session. + /// + /// Contains message content in `data.content`, attachments in `data.attachments`. + #[serde(rename = "user.message")] + UserMessage, +} + +impl std::fmt::Display for SessionEventType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SessionEventType::Abort => write!(f, "abort"), + SessionEventType::AssistantIntent => write!(f, "assistant.intent"), + SessionEventType::AssistantMessage => write!(f, "assistant.message"), + SessionEventType::AssistantMessageDelta => write!(f, "assistant.message_delta"), + SessionEventType::AssistantReasoning => write!(f, "assistant.reasoning"), + SessionEventType::AssistantReasoningDelta => write!(f, "assistant.reasoning_delta"), + SessionEventType::AssistantTurnEnd => write!(f, "assistant.turn_end"), + SessionEventType::AssistantTurnStart => write!(f, "assistant.turn_start"), + SessionEventType::AssistantUsage => write!(f, "assistant.usage"), + SessionEventType::HookEnd => write!(f, "hook.end"), + SessionEventType::HookStart => write!(f, "hook.start"), + SessionEventType::PendingMessagesModified => write!(f, "pending_messages.modified"), + SessionEventType::SessionCompactionComplete => write!(f, "session.compaction_complete"), + SessionEventType::SessionCompactionStart => write!(f, "session.compaction_start"), + SessionEventType::SessionError => write!(f, "session.error"), + SessionEventType::SessionHandoff => write!(f, "session.handoff"), + SessionEventType::SessionIdle => write!(f, "session.idle"), + SessionEventType::SessionInfo => write!(f, "session.info"), + SessionEventType::SessionModelChange => write!(f, "session.model_change"), + SessionEventType::SessionResume => write!(f, "session.resume"), + SessionEventType::SessionStart => write!(f, "session.start"), + SessionEventType::SessionTruncation => write!(f, "session.truncation"), + SessionEventType::SessionUsageInfo => write!(f, "session.usage_info"), + SessionEventType::SubagentCompleted => write!(f, "subagent.completed"), + SessionEventType::SubagentFailed => write!(f, "subagent.failed"), + SessionEventType::SubagentSelected => write!(f, "subagent.selected"), + SessionEventType::SubagentStarted => write!(f, "subagent.started"), + SessionEventType::SystemMessage => write!(f, "system.message"), + SessionEventType::ToolExecutionComplete => write!(f, "tool.execution_complete"), + SessionEventType::ToolExecutionPartialResult => write!(f, "tool.execution_partial_result"), + SessionEventType::ToolExecutionStart => write!(f, "tool.execution_start"), + SessionEventType::ToolUserRequested => write!(f, "tool.user_requested"), + SessionEventType::UserMessage => write!(f, "user.message"), + } + } +} + +/// Session event data - a flat structure containing all possible fields. +/// +/// This structure contains all possible fields for all event types. +/// Only fields relevant to the specific [`SessionEventType`] will be populated; +/// all others will be `None`. +/// +/// # Common Fields by Event Type +/// +/// | Event Type | Relevant Fields | +/// |------------|-----------------| +/// | `AssistantMessage` | `content`, `turn_id`, `message_id` | +/// | `AssistantMessageDelta` | `delta_content`, `message_id` | +/// | `UserMessage` | `content`, `attachments` | +/// | `SessionError` | `message`, `error_type`, `stack`, `error` | +/// | `ToolExecutionStart` | `tool_name`, `tool_call_id`, `arguments` | +/// | `ToolExecutionComplete` | `tool_name`, `tool_call_id`, `result` | +/// | `AssistantUsage` | `input_tokens`, `output_tokens`, `model`, `cost` | +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct SessionEventData { + /// Context information (working directory, git info, etc.). + #[serde(skip_serializing_if = "Option::is_none")] + pub context: Option, + + /// Copilot CLI version. + #[serde(rename = "copilotVersion", skip_serializing_if = "Option::is_none")] + pub copilot_version: Option, + + /// Event producer identifier. + #[serde(skip_serializing_if = "Option::is_none")] + pub producer: Option, + + /// Currently selected model. + #[serde(rename = "selectedModel", skip_serializing_if = "Option::is_none")] + pub selected_model: Option, + + /// Session identifier. + #[serde(rename = "sessionId", skip_serializing_if = "Option::is_none")] + pub session_id: Option, + + /// Session start timestamp. + #[serde(rename = "startTime", skip_serializing_if = "Option::is_none")] + pub start_time: Option>, + + /// Protocol version number. + #[serde(skip_serializing_if = "Option::is_none")] + pub version: Option, + + /// Total event count in session. + #[serde(rename = "eventCount", skip_serializing_if = "Option::is_none")] + pub event_count: Option, + + /// Session resume timestamp. + #[serde(rename = "resumeTime", skip_serializing_if = "Option::is_none")] + pub resume_time: Option>, + + /// Error type identifier. + #[serde(rename = "errorType", skip_serializing_if = "Option::is_none")] + pub error_type: Option, + + /// Human-readable message content or error message. + #[serde(skip_serializing_if = "Option::is_none")] + pub message: Option, + + /// Error stack trace. + #[serde(skip_serializing_if = "Option::is_none")] + pub stack: Option, + + /// Info type for informational events. + #[serde(rename = "infoType", skip_serializing_if = "Option::is_none")] + pub info_type: Option, + + /// New model (for model change events). + #[serde(rename = "newModel", skip_serializing_if = "Option::is_none")] + pub new_model: Option, + + /// Previous model (for model change events). + #[serde(rename = "previousModel", skip_serializing_if = "Option::is_none")] + pub previous_model: Option, + + /// Handoff timestamp. + #[serde(rename = "handoffTime", skip_serializing_if = "Option::is_none")] + pub handoff_time: Option>, + + /// Remote session ID (for handoffs). + #[serde(rename = "remoteSessionId", skip_serializing_if = "Option::is_none")] + pub remote_session_id: Option, + + /// Repository information. + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, + + /// Source type (local or remote). + #[serde(rename = "sourceType", skip_serializing_if = "Option::is_none")] + pub source_type: Option, + + /// Summary text. + #[serde(skip_serializing_if = "Option::is_none")] + pub summary: Option, + + /// Number of messages removed during context truncation. + #[serde(rename = "messagesRemovedDuringTruncation", skip_serializing_if = "Option::is_none")] + pub messages_removed_during_truncation: Option, + + /// Who performed the action. + #[serde(rename = "performedBy", skip_serializing_if = "Option::is_none")] + pub performed_by: Option, + + /// Message count after truncation. + #[serde(rename = "postTruncationMessagesLength", skip_serializing_if = "Option::is_none")] + pub post_truncation_messages_length: Option, + + /// Token count after truncation. + #[serde(rename = "postTruncationTokensInMessages", skip_serializing_if = "Option::is_none")] + pub post_truncation_tokens_in_messages: Option, + + /// Message count before truncation. + #[serde(rename = "preTruncationMessagesLength", skip_serializing_if = "Option::is_none")] + pub pre_truncation_messages_length: Option, + + /// Token count before truncation. + #[serde(rename = "preTruncationTokensInMessages", skip_serializing_if = "Option::is_none")] + pub pre_truncation_tokens_in_messages: Option, + + /// Maximum token limit. + #[serde(rename = "tokenLimit", skip_serializing_if = "Option::is_none")] + pub token_limit: Option, + + /// Tokens removed during truncation. + #[serde(rename = "tokensRemovedDuringTruncation", skip_serializing_if = "Option::is_none")] + pub tokens_removed_during_truncation: Option, + + /// Current token count. + #[serde(rename = "currentTokens", skip_serializing_if = "Option::is_none")] + pub current_tokens: Option, + + /// Current message count. + #[serde(rename = "messagesLength", skip_serializing_if = "Option::is_none")] + pub messages_length: Option, + + /// Token usage for compaction operation. + #[serde(rename = "compactionTokensUsed", skip_serializing_if = "Option::is_none")] + pub compaction_tokens_used: Option, + + /// Detailed error information. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + + /// Messages removed during compaction. + #[serde(rename = "messagesRemoved", skip_serializing_if = "Option::is_none")] + pub messages_removed: Option, + + /// Token count after compaction. + #[serde(rename = "postCompactionTokens", skip_serializing_if = "Option::is_none")] + pub post_compaction_tokens: Option, + + /// Message count before compaction. + #[serde(rename = "preCompactionMessagesLength", skip_serializing_if = "Option::is_none")] + pub pre_compaction_messages_length: Option, + + /// Token count before compaction. + #[serde(rename = "preCompactionTokens", skip_serializing_if = "Option::is_none")] + pub pre_compaction_tokens: Option, + + /// Whether the operation succeeded. + #[serde(skip_serializing_if = "Option::is_none")] + pub success: Option, + + /// Summary content from compaction. + #[serde(rename = "summaryContent", skip_serializing_if = "Option::is_none")] + pub summary_content: Option, + + /// Tokens removed during operation. + #[serde(rename = "tokensRemoved", skip_serializing_if = "Option::is_none")] + pub tokens_removed: Option, + + /// Message attachments (files, directories). + #[serde(skip_serializing_if = "Option::is_none")] + pub attachments: Option>, + + /// Text content of message or response. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + + /// Event source identifier. + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + + /// Transformed content (after processing). + #[serde(rename = "transformedContent", skip_serializing_if = "Option::is_none")] + pub transformed_content: Option, + + /// Turn identifier within the session. + #[serde(rename = "turnId", skip_serializing_if = "Option::is_none")] + pub turn_id: Option, + + /// Assistant's stated intent. + #[serde(skip_serializing_if = "Option::is_none")] + pub intent: Option, + + /// Reasoning chain identifier. + #[serde(rename = "reasoningId", skip_serializing_if = "Option::is_none")] + pub reasoning_id: Option, + + /// Incremental content for streaming events. + #[serde(rename = "deltaContent", skip_serializing_if = "Option::is_none")] + pub delta_content: Option, + + /// Message identifier. + #[serde(rename = "messageId", skip_serializing_if = "Option::is_none")] + pub message_id: Option, + + /// Parent tool call ID (for nested tool calls). + #[serde(rename = "parentToolCallId", skip_serializing_if = "Option::is_none")] + pub parent_tool_call_id: Option, + + /// List of tool requests from the assistant. + #[serde(rename = "toolRequests", skip_serializing_if = "Option::is_none")] + pub tool_requests: Option>, + + /// Total response size in bytes. + #[serde(rename = "totalResponseSizeBytes", skip_serializing_if = "Option::is_none")] + pub total_response_size_bytes: Option, + + /// API call identifier for tracking. + #[serde(rename = "apiCallId", skip_serializing_if = "Option::is_none")] + pub api_call_id: Option, + + /// Tokens read from cache. + #[serde(rename = "cacheReadTokens", skip_serializing_if = "Option::is_none")] + pub cache_read_tokens: Option, + + /// Tokens written to cache. + #[serde(rename = "cacheWriteTokens", skip_serializing_if = "Option::is_none")] + pub cache_write_tokens: Option, + + /// Cost of the operation (in USD). + #[serde(skip_serializing_if = "Option::is_none")] + pub cost: Option, + + /// Duration of operation in milliseconds. + #[serde(skip_serializing_if = "Option::is_none")] + pub duration: Option, + + /// Who initiated the operation. + #[serde(skip_serializing_if = "Option::is_none")] + pub initiator: Option, + + /// Number of input tokens used. + #[serde(rename = "inputTokens", skip_serializing_if = "Option::is_none")] + pub input_tokens: Option, + + /// Model name used for the operation. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + + /// Number of output tokens generated. + #[serde(rename = "outputTokens", skip_serializing_if = "Option::is_none")] + pub output_tokens: Option, + + /// Provider-specific call identifier. + #[serde(rename = "providerCallId", skip_serializing_if = "Option::is_none")] + pub provider_call_id: Option, + + /// Quota usage snapshots by model. + #[serde(rename = "quotaSnapshots", skip_serializing_if = "Option::is_none")] + pub quota_snapshots: Option>, + + /// Reason for the event or action. + #[serde(skip_serializing_if = "Option::is_none")] + pub reason: Option, + + /// Tool call arguments (JSON value). + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + + /// Tool call identifier. + #[serde(rename = "toolCallId", skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + + /// Name of the tool being executed. + #[serde(rename = "toolName", skip_serializing_if = "Option::is_none")] + pub tool_name: Option, + + /// Partial/intermediate tool output. + #[serde(rename = "partialOutput", skip_serializing_if = "Option::is_none")] + pub partial_output: Option, + + /// Whether the tool was requested by the user. + #[serde(rename = "isUserRequested", skip_serializing_if = "Option::is_none")] + pub is_user_requested: Option, + + /// Tool execution result. + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + + /// Tool-specific telemetry data. + #[serde(rename = "toolTelemetry", skip_serializing_if = "Option::is_none")] + pub tool_telemetry: Option>, + + /// Subagent description. + #[serde(rename = "agentDescription", skip_serializing_if = "Option::is_none")] + pub agent_description: Option, + + /// Subagent display name. + #[serde(rename = "agentDisplayName", skip_serializing_if = "Option::is_none")] + pub agent_display_name: Option, + + /// Subagent name identifier. + #[serde(rename = "agentName", skip_serializing_if = "Option::is_none")] + pub agent_name: Option, + + /// List of available tools. + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// Hook invocation identifier. + #[serde(rename = "hookInvocationId", skip_serializing_if = "Option::is_none")] + pub hook_invocation_id: Option, + + /// Type of hook being executed. + #[serde(rename = "hookType", skip_serializing_if = "Option::is_none")] + pub hook_type: Option, + + /// Input data for the operation. + #[serde(skip_serializing_if = "Option::is_none")] + pub input: Option, + + /// Output data from the operation. + #[serde(skip_serializing_if = "Option::is_none")] + pub output: Option, + + /// Additional metadata. + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + + /// Name identifier. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + + /// Message role (system, developer, etc.). + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, +} + +/// File or directory attachment in a message. +/// +/// Attachments allow users to include files or directories as context +/// for their messages. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Attachment { + /// Display name shown to the user. + #[serde(rename = "displayName")] + pub display_name: String, + + /// File system path to the attachment. + pub path: String, + + /// Type of attachment (file or directory). + #[serde(rename = "type")] + pub attachment_type: AttachmentType, +} + +/// Type of attachment. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum AttachmentType { + /// Directory attachment. + Directory, + /// File attachment. + File, +} + +/// Token usage information for compaction operations. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CompactionTokensUsed { + /// Tokens read from cache during compaction. + #[serde(rename = "cachedInput")] + pub cached_input: f64, + + /// Input tokens used for compaction. + pub input: f64, + + /// Output tokens generated during compaction. + pub output: f64, +} + +/// Context information - can be a simple string or detailed object. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ContextUnion { + /// Simple string context. + String(String), + /// Detailed context object with working directory and git info. + Object(ContextClass), +} + +/// Detailed context information. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ContextClass { + /// Current git branch (if in a git repository). + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + + /// Current working directory. + pub cwd: String, + + /// Git repository root directory. + #[serde(rename = "gitRoot", skip_serializing_if = "Option::is_none")] + pub git_root: Option, + + /// Repository identifier. + #[serde(skip_serializing_if = "Option::is_none")] + pub repository: Option, +} + +/// Error information - can be a simple message or detailed object. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ErrorUnion { + /// Simple error message string. + String(String), + /// Detailed error object with code and stack trace. + Object(ErrorClass), +} + +/// Detailed error information. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ErrorClass { + /// Error code identifier. + #[serde(skip_serializing_if = "Option::is_none")] + pub code: Option, + + /// Human-readable error message. + pub message: String, + + /// Error stack trace (if available). + #[serde(skip_serializing_if = "Option::is_none")] + pub stack: Option, +} + +/// Additional metadata for system messages. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Metadata { + /// Prompt version identifier. + #[serde(rename = "promptVersion", skip_serializing_if = "Option::is_none")] + pub prompt_version: Option, + + /// Template variables used in the prompt. + #[serde(skip_serializing_if = "Option::is_none")] + pub variables: Option>, +} + +/// Quota usage snapshot for a model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QuotaSnapshot { + /// Number of entitlement requests. + #[serde(rename = "entitlementRequests")] + pub entitlement_requests: f64, + + /// Whether user has unlimited entitlement. + #[serde(rename = "isUnlimitedEntitlement")] + pub is_unlimited_entitlement: bool, + + /// Overage amount. + pub overage: f64, + + /// Whether overage is allowed when quota is exhausted. + #[serde(rename = "overageAllowedWithExhaustedQuota")] + pub overage_allowed_with_exhausted_quota: bool, + + /// Percentage of quota remaining. + #[serde(rename = "remainingPercentage")] + pub remaining_percentage: f64, + + /// When the quota resets. + #[serde(rename = "resetDate", skip_serializing_if = "Option::is_none")] + pub reset_date: Option>, + + /// Whether usage is allowed when quota is exhausted. + #[serde(rename = "usageAllowedWithExhaustedQuota")] + pub usage_allowed_with_exhausted_quota: bool, + + /// Number of requests used. + #[serde(rename = "usedRequests")] + pub used_requests: f64, +} + +/// Git repository information. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Repository { + /// Current branch name. + #[serde(skip_serializing_if = "Option::is_none")] + pub branch: Option, + + /// Repository name. + pub name: String, + + /// Repository owner. + pub owner: String, +} + +/// Tool execution result data. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResultData { + /// Result content (usually displayed to the assistant). + pub content: String, +} + +/// Tool request from the assistant. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolRequest { + /// Arguments to pass to the tool (as JSON). + pub arguments: Value, + + /// Tool name. + pub name: String, + + /// Unique identifier for this tool call. + #[serde(rename = "toolCallId")] + pub tool_call_id: String, + + /// Type of tool request. + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub request_type: Option, +} + +/// Type of tool request. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum ToolRequestType { + /// Custom tool type. + Custom, + /// Standard function tool. + Function, +} + +/// Source type for session events. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum SourceType { + /// Event originated locally. + Local, + /// Event originated from a remote source. + Remote, +} + +/// Message role. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + /// Developer-provided message. + Developer, + /// System-generated message. + System, +} diff --git a/rust/src/jsonrpc.rs b/rust/src/jsonrpc.rs new file mode 100644 index 00000000..bf054236 --- /dev/null +++ b/rust/src/jsonrpc.rs @@ -0,0 +1,708 @@ +//! JSON-RPC 2.0 client implementation. +//! +//! This module provides a JSON-RPC 2.0 client implementation for communicating with the +//! Copilot CLI server using either stdio or TCP transport with Content-Length framing. +//! +//! The client handles: +//! - Sending requests and receiving responses +//! - Receiving and dispatching notifications +//! - Handling incoming requests from the server (for tool calls) +//! +//! # Protocol +//! +//! The client uses the Language Server Protocol (LSP) framing format: +//! - Messages are preceded by a `Content-Length: \r\n\r\n` header +//! - Message bodies are JSON-encoded according to JSON-RPC 2.0 +//! +//! # Security +//! +//! Messages larger than [`MAX_MESSAGE_SIZE`] (100 MB) are rejected to prevent +//! denial-of-service attacks via unbounded memory allocation. +//! +//! # Example +//! +//! ```ignore +//! use copilot_sdk::jsonrpc::JsonRpcClient; +//! use serde_json::json; +//! +//! // Create client from stdio streams +//! let client = JsonRpcClient::new(stdin, stdout); +//! +//! // Send a request +//! let result = client.request("ping", json!({"message": "hello"})).await?; +//! ``` + +use crate::error::{CopilotError, JsonRpcError, Result}; + +/// Maximum allowed message size (100 MB) to prevent DoS attacks via unbounded memory allocation. +/// +/// Any incoming message with a `Content-Length` header exceeding this value will be rejected +/// with an I/O error. +pub const MAX_MESSAGE_SIZE: usize = 100 * 1024 * 1024; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::sync::{mpsc, oneshot, RwLock}; + +/// JSON-RPC 2.0 request message. +/// +/// Represents a request sent to the server that expects a response. +/// +/// # JSON Format +/// +/// ```json +/// { +/// "jsonrpc": "2.0", +/// "id": "unique-request-id", +/// "method": "method.name", +/// "params": { ... } +/// } +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcRequest { + /// JSON-RPC protocol version, always "2.0". + pub jsonrpc: String, + + /// Unique identifier for this request. Used to match responses to requests. + pub id: Value, + + /// The method name to invoke on the server. + pub method: String, + + /// Optional parameters for the method call. + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +/// JSON-RPC 2.0 response message. +/// +/// Represents a response from the server to a request. +/// +/// # JSON Format +/// +/// Success response: +/// ```json +/// { +/// "jsonrpc": "2.0", +/// "id": "request-id", +/// "result": { ... } +/// } +/// ``` +/// +/// Error response: +/// ```json +/// { +/// "jsonrpc": "2.0", +/// "id": "request-id", +/// "error": { "code": -32600, "message": "..." } +/// } +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcResponse { + /// JSON-RPC protocol version, always "2.0". + pub jsonrpc: String, + + /// Request ID this response corresponds to. + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + + /// Result value on success. + #[serde(skip_serializing_if = "Option::is_none")] + pub result: Option, + + /// Error information on failure. + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, +} + +/// JSON-RPC 2.0 error object in a response. +/// +/// Contains error information when a request fails. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcErrorResponse { + /// Error code indicating the type of error. + /// Standard codes are defined in [`crate::JsonRpcError`]. + pub code: i32, + + /// Human-readable error message. + pub message: String, + + /// Optional additional error data. + #[serde(skip_serializing_if = "Option::is_none")] + pub data: Option, +} + +/// JSON-RPC 2.0 notification message. +/// +/// Represents a one-way message that does not expect a response. +/// +/// # JSON Format +/// +/// ```json +/// { +/// "jsonrpc": "2.0", +/// "method": "notification.name", +/// "params": { ... } +/// } +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct JsonRpcNotification { + /// JSON-RPC protocol version, always "2.0". + pub jsonrpc: String, + + /// The notification method name. + pub method: String, + + /// Optional parameters for the notification. + #[serde(skip_serializing_if = "Option::is_none")] + pub params: Option, +} + +/// Notification handler function type. +/// +/// Called when the server sends a notification. Receives the method name +/// and parameters. +/// +/// # Arguments +/// +/// - First argument: Method name (e.g., `"session.event"`) +/// - Second argument: Parameters as a JSON value +pub type NotificationHandler = Arc; + +/// Request handler function type for incoming server requests. +/// +/// Called when the server sends a request (e.g., tool calls). +/// Returns a future that resolves to the response value or an error. +/// +/// # Arguments +/// +/// The handler receives the request parameters as a JSON value. +/// +/// # Returns +/// +/// A pinned boxed future that resolves to `Result`. +pub type RequestHandler = Arc< + dyn Fn(Value) -> Pin> + Send>> + Send + Sync +>; + +/// Internal message types for the write loop. +enum WriteMessage { + /// Send the given bytes to the writer. + Send(Vec), + /// Stop the write loop. + Stop, +} + +/// Disconnect handler function type. +/// +/// Called when the connection is lost (read loop exits due to EOF or error). +pub type DisconnectHandler = Arc; + +/// JSON-RPC client for stdio/TCP transport with Content-Length framing. +/// +/// This client handles bidirectional JSON-RPC 2.0 communication over async streams. +/// It supports: +/// +/// - Sending requests and receiving responses +/// - Sending notifications (one-way messages) +/// - Receiving notifications from the server +/// - Handling incoming requests from the server (e.g., tool calls) +/// +/// # Thread Safety +/// +/// The client is `Send + Sync` and can be safely shared across tasks. +/// +/// # Example +/// +/// ```ignore +/// use copilot_sdk::jsonrpc::JsonRpcClient; +/// use serde_json::json; +/// +/// let client = JsonRpcClient::new(reader, writer); +/// +/// // Set up notification handler +/// client.set_notification_handler(Arc::new(|method, params| { +/// println!("Notification: {} {:?}", method, params); +/// })).await; +/// +/// // Send a request +/// let result = client.request("session.create", json!({ +/// "model": "gpt-5" +/// })).await?; +/// ``` +pub struct JsonRpcClient { + write_tx: mpsc::Sender, + pending_requests: Arc>>>, + notification_handler: Arc>>, + request_handlers: Arc>>, + running: Arc, + on_disconnect: Arc>>, +} + +impl JsonRpcClient { + /// Create a new JSON-RPC client from async read/write streams. + /// + /// This spawns two background tasks: + /// - A write loop that sends outgoing messages + /// - A read loop that receives and dispatches incoming messages + /// + /// # Arguments + /// + /// * `reader` - Async reader for incoming messages + /// * `writer` - Async writer for outgoing messages + /// + /// # Example + /// + /// ```ignore + /// // From stdio + /// let client = JsonRpcClient::new(tokio::io::stdin(), tokio::io::stdout()); + /// + /// // From TCP + /// let (reader, writer) = stream.into_split(); + /// let client = JsonRpcClient::new(reader, writer); + /// ``` + pub fn new(reader: R, writer: W) -> Self + where + R: tokio::io::AsyncRead + Unpin + Send + 'static, + W: tokio::io::AsyncWrite + Unpin + Send + 'static, + { + let (write_tx, write_rx) = mpsc::channel::(100); + let pending_requests: Arc>>> = + Arc::new(RwLock::new(HashMap::new())); + let notification_handler: Arc>> = + Arc::new(RwLock::new(None)); + let request_handlers: Arc>> = + Arc::new(RwLock::new(HashMap::new())); + let running = Arc::new(std::sync::atomic::AtomicBool::new(true)); + let on_disconnect: Arc>> = Arc::new(RwLock::new(None)); + + let client = Self { + write_tx: write_tx.clone(), + pending_requests: pending_requests.clone(), + notification_handler: notification_handler.clone(), + request_handlers: request_handlers.clone(), + running: running.clone(), + on_disconnect: on_disconnect.clone(), + }; + + // Spawn write loop + let running_write = running.clone(); + tokio::spawn(async move { + Self::write_loop(writer, write_rx, running_write).await; + }); + + // Spawn read loop + let write_tx_for_read = write_tx; + tokio::spawn(async move { + Self::read_loop( + reader, + pending_requests, + notification_handler, + request_handlers, + running, + write_tx_for_read, + on_disconnect, + ) + .await; + }); + + client + } + + /// Write loop - sends messages to the writer with Content-Length framing. + async fn write_loop( + mut writer: W, + mut write_rx: mpsc::Receiver, + running: Arc, + ) where + W: tokio::io::AsyncWrite + Unpin, + { + while running.load(std::sync::atomic::Ordering::SeqCst) { + match write_rx.recv().await { + Some(WriteMessage::Send(data)) => { + // Write Content-Length header + message + let header = format!("Content-Length: {}\r\n\r\n", data.len()); + if writer.write_all(header.as_bytes()).await.is_err() { + break; + } + if writer.write_all(&data).await.is_err() { + break; + } + if writer.flush().await.is_err() { + break; + } + } + Some(WriteMessage::Stop) | None => break, + } + } + } + + /// Read loop - reads messages from the reader and dispatches them. + async fn read_loop( + reader: R, + pending_requests: Arc>>>, + notification_handler: Arc>>, + request_handlers: Arc>>, + running: Arc, + write_tx: mpsc::Sender, + on_disconnect: Arc>>, + ) where + R: tokio::io::AsyncRead + Unpin, + { + let mut reader = BufReader::new(reader); + + while running.load(std::sync::atomic::Ordering::SeqCst) { + // Read Content-Length header + let content_length = match Self::read_content_length(&mut reader).await { + Ok(Some(len)) => len, + Ok(None) => continue, + Err(_) => break, + }; + + if content_length == 0 { + continue; + } + + // Read message body + let mut body = vec![0u8; content_length]; + if reader.read_exact(&mut body).await.is_err() { + break; + } + + // Parse message + let message: Value = match serde_json::from_slice(&body) { + Ok(v) => v, + Err(_) => continue, + }; + + // Determine message type and dispatch + let has_id = message.get("id").is_some(); + let has_method = message.get("method").is_some(); + + if has_id && has_method { + // Request from server (e.g., tool.call) + Self::handle_server_request( + message, + request_handlers.clone(), + write_tx.clone(), + ) + .await; + } else if has_id { + // Response to our request + Self::handle_response(message, pending_requests.clone()).await; + } else if has_method { + // Notification from server + Self::handle_notification(message, notification_handler.clone()).await; + } + } + + // Invoke disconnect callback when read loop exits + if let Some(callback) = on_disconnect.read().await.as_ref() { + callback(); + } + } + + /// Read the Content-Length header from the stream. + /// + /// Returns the content length, or an error if the message is too large + /// or the connection is closed. + async fn read_content_length(reader: &mut BufReader) -> std::io::Result> + where + R: tokio::io::AsyncRead + Unpin, + { + let mut content_length = 0usize; + + loop { + let mut line = String::new(); + let bytes_read = reader.read_line(&mut line).await?; + if bytes_read == 0 { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "EOF", + )); + } + + // Check for blank line (end of headers) + if line == "\r\n" || line == "\n" { + break; + } + + // Parse Content-Length + if let Some(len_str) = line.strip_prefix("Content-Length: ") { + if let Ok(len) = len_str.trim().parse::() { + if len > MAX_MESSAGE_SIZE { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + format!("Message size {} exceeds maximum {}", len, MAX_MESSAGE_SIZE), + )); + } + content_length = len; + } + } + } + + Ok(Some(content_length)) + } + + /// Handle a response to one of our requests. + async fn handle_response( + message: Value, + pending_requests: Arc>>>, + ) { + let response: JsonRpcResponse = match serde_json::from_value(message) { + Ok(r) => r, + Err(_) => return, + }; + + let id = match &response.id { + Some(Value::String(s)) => s.clone(), + _ => return, + }; + + let sender = { + let mut pending = pending_requests.write().await; + pending.remove(&id) + }; + + if let Some(sender) = sender { + let _ = sender.send(response); + } + } + + /// Handle a notification from the server. + async fn handle_notification( + message: Value, + notification_handler: Arc>>, + ) { + let notification: JsonRpcNotification = match serde_json::from_value(message) { + Ok(n) => n, + Err(_) => return, + }; + + let handler = notification_handler.read().await; + if let Some(handler) = handler.as_ref() { + let params = notification.params.unwrap_or(Value::Null); + handler(notification.method, params); + } + } + + /// Handle a request from the server (e.g., tool.call). + async fn handle_server_request( + message: Value, + request_handlers: Arc>>, + write_tx: mpsc::Sender, + ) { + let request: JsonRpcRequest = match serde_json::from_value(message) { + Ok(r) => r, + Err(_) => return, + }; + + let handlers = request_handlers.read().await; + let handler = handlers.get(&request.method).cloned(); + drop(handlers); + + let response = if let Some(handler) = handler { + let params = request.params.unwrap_or(Value::Null); + match handler(params).await { + Ok(result) => JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: Some(request.id), + result: Some(result), + error: None, + }, + Err(e) => { + let (code, message) = match e { + CopilotError::JsonRpc { code, message, .. } => (code, message), + _ => (JsonRpcError::INTERNAL_ERROR, e.to_string()), + }; + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: Some(request.id), + result: None, + error: Some(JsonRpcErrorResponse { + code, + message, + data: None, + }), + } + } + } + } else { + JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: Some(request.id), + result: None, + error: Some(JsonRpcErrorResponse { + code: JsonRpcError::METHOD_NOT_FOUND, + message: format!("Method not found: {}", request.method), + data: None, + }), + } + }; + + // Send response + if let Ok(data) = serde_json::to_vec(&response) { + let _ = write_tx.send(WriteMessage::Send(data)).await; + } + } + + /// Set the notification handler for incoming server notifications. + /// + /// Only one handler can be active at a time. Setting a new handler + /// replaces the previous one. + /// + /// # Arguments + /// + /// * `handler` - Function called for each notification + pub async fn set_notification_handler(&self, handler: NotificationHandler) { + let mut h = self.notification_handler.write().await; + *h = Some(handler); + } + + /// Set the disconnect handler called when the connection is lost. + /// + /// The callback is invoked when the read loop exits due to EOF or error. + /// Only one handler can be active at a time. Setting a new handler + /// replaces the previous one. + /// + /// # Arguments + /// + /// * `handler` - Function called when disconnected + pub async fn set_on_disconnect(&self, handler: DisconnectHandler) { + let mut h = self.on_disconnect.write().await; + *h = Some(handler); + } + + /// Set a request handler for a specific method. + /// + /// Used to handle incoming requests from the server, such as tool calls. + /// + /// # Arguments + /// + /// * `method` - The method name to handle (e.g., `"tool.call"`) + /// * `handler` - Async function to handle the request + pub async fn set_request_handler(&self, method: &str, handler: RequestHandler) { + let mut handlers = self.request_handlers.write().await; + handlers.insert(method.to_string(), handler); + } + + /// Remove a request handler for a specific method. + /// + /// # Arguments + /// + /// * `method` - The method name to stop handling + pub async fn remove_request_handler(&self, method: &str) { + let mut handlers = self.request_handlers.write().await; + handlers.remove(method); + } + + /// Send a JSON-RPC request and wait for the response. + /// + /// This method sends a request to the server and blocks until a response + /// is received or the connection is closed. + /// + /// # Arguments + /// + /// * `method` - The method name to call + /// * `params` - Parameters for the method call + /// + /// # Returns + /// + /// The result value from the server, or an error if the request failed. + /// + /// # Errors + /// + /// - [`CopilotError::JsonRpc`] - Server returned an error + /// - [`CopilotError::ClientStopped`] - Connection was closed + /// - [`CopilotError::Serialization`] - Failed to serialize request + pub async fn request(&self, method: &str, params: Value) -> Result { + let request_id = uuid::Uuid::new_v4().to_string(); + + // Create response channel + let (tx, rx) = oneshot::channel(); + { + let mut pending = self.pending_requests.write().await; + pending.insert(request_id.clone(), tx); + } + + // Build request + let request = JsonRpcRequest { + jsonrpc: "2.0".to_string(), + id: Value::String(request_id.clone()), + method: method.to_string(), + params: Some(params), + }; + + // Send request + let data = serde_json::to_vec(&request)?; + self.write_tx + .send(WriteMessage::Send(data)) + .await + .map_err(|_| CopilotError::ClientStopped)?; + + // Wait for response + let response = rx.await.map_err(|_| CopilotError::ClientStopped)?; + + // Clean up pending request + { + let mut pending = self.pending_requests.write().await; + pending.remove(&request_id); + } + + // Handle response + if let Some(error) = response.error { + return Err(CopilotError::JsonRpc { + code: error.code, + message: error.message, + data: error.data, + }); + } + + Ok(response.result.unwrap_or(Value::Null)) + } + + /// Send a JSON-RPC notification (no response expected). + /// + /// Notifications are one-way messages that don't expect a response + /// from the server. + /// + /// # Arguments + /// + /// * `method` - The notification method name + /// * `params` - Parameters for the notification + /// + /// # Errors + /// + /// - [`CopilotError::ClientStopped`] - Connection was closed + /// - [`CopilotError::Serialization`] - Failed to serialize notification + pub async fn notify(&self, method: &str, params: Value) -> Result<()> { + let notification = JsonRpcNotification { + jsonrpc: "2.0".to_string(), + method: method.to_string(), + params: Some(params), + }; + + let data = serde_json::to_vec(¬ification)?; + self.write_tx + .send(WriteMessage::Send(data)) + .await + .map_err(|_| CopilotError::ClientStopped)?; + + Ok(()) + } + + /// Stop the client and close the connection. + /// + /// This signals the read and write loops to terminate and closes + /// the underlying connection. + pub async fn stop(&self) { + self.running + .store(false, std::sync::atomic::Ordering::SeqCst); + let _ = self.write_tx.send(WriteMessage::Stop).await; + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs new file mode 100644 index 00000000..da608c2a --- /dev/null +++ b/rust/src/lib.rs @@ -0,0 +1,131 @@ +//! Rust SDK for programmatic access to the GitHub Copilot CLI. +//! +//! This crate provides a Rust interface for interacting with the Copilot CLI server, +//! creating and managing conversation sessions, and integrating custom tools. +//! +//! # Quick Start +//! +//! ```ignore +//! use copilot_sdk::{CopilotClient, SessionConfig, MessageOptions}; +//! +//! #[tokio::main] +//! async fn main() -> Result<(), Box> { +//! // Create a client (spawns CLI server automatically) +//! let client = CopilotClient::new(None); +//! client.start().await?; +//! +//! // Create a session +//! let session = client.create_session(Some(SessionConfig { +//! model: Some("gpt-4".to_string()), +//! ..Default::default() +//! })).await?; +//! +//! // Subscribe to events +//! let _unsubscribe = session.on(std::sync::Arc::new(|event| { +//! if event.event_type == copilot_sdk::SessionEventType::AssistantMessage { +//! if let Some(content) = &event.data.content { +//! println!("Assistant: {}", content); +//! } +//! } +//! })); +//! +//! // Send a message +//! session.send(MessageOptions { +//! prompt: "Hello, Copilot!".to_string(), +//! ..Default::default() +//! }).await?; +//! +//! // Clean up +//! client.stop().await; +//! +//! Ok(()) +//! } +//! ``` +//! +//! # Custom Tools +//! +//! You can define custom tools that the assistant can invoke: +//! +//! ```ignore +//! use copilot_sdk::{define_tool, SessionConfig}; +//! use schemars::JsonSchema; +//! use serde::Deserialize; +//! +//! #[derive(Deserialize, JsonSchema)] +//! struct GetWeatherParams { +//! city: String, +//! } +//! +//! let tool = define_tool::( +//! "get_weather", +//! "Get weather for a city", +//! |params, _inv| async move { +//! Ok(format!("Weather in {}: 22 degrees", params.city)) +//! }, +//! ); +//! +//! let session = client.create_session(Some(SessionConfig { +//! tools: vec![tool], +//! ..Default::default() +//! })).await?; +//! ``` + +#![warn(missing_docs)] +#![warn(rustdoc::missing_crate_level_docs)] + +pub mod client; +pub mod error; +pub mod generated; +pub mod jsonrpc; +pub mod session; +pub mod tool; +pub mod types; + +// Re-export main types at crate root +pub use client::CopilotClient; +pub use error::{CopilotError, JsonRpcError, Result}; +pub use generated::{SessionEvent, SessionEventData, SessionEventType}; +pub use session::{CopilotSession, SessionEventHandler, UnsubscribeFn}; +pub use tool::{ + define_tool, IntoToolResult, Tool, ToolBinaryResult, ToolBuilder, ToolHandler, ToolInvocation, + ToolResult, +}; +pub use types::{ + Attachment, AttachmentType, AzureProviderOptions, ClientOptions, ConnectionState, + CustomAgentConfig, McpLocalServerConfig, McpRemoteServerConfig, McpServerConfig, + MessageOptions, PermissionInvocation, PermissionRequest, PermissionRequestResult, PingResponse, + ProviderConfig, ResumeSessionConfig, SessionConfig, SessionMetadata, SystemMessageConfig, + get_sdk_protocol_version, SDK_PROTOCOL_VERSION, +}; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_sdk_protocol_version() { + assert_eq!(get_sdk_protocol_version(), SDK_PROTOCOL_VERSION); + assert_eq!(SDK_PROTOCOL_VERSION, 1); + } + + #[test] + fn test_connection_state_display() { + assert_eq!(ConnectionState::Disconnected.to_string(), "disconnected"); + assert_eq!(ConnectionState::Connecting.to_string(), "connecting"); + assert_eq!(ConnectionState::Connected.to_string(), "connected"); + assert_eq!(ConnectionState::Error.to_string(), "error"); + } + + #[test] + fn test_session_event_type_display() { + assert_eq!( + SessionEventType::AssistantMessage.to_string(), + "assistant.message" + ); + assert_eq!(SessionEventType::SessionIdle.to_string(), "session.idle"); + assert_eq!( + SessionEventType::ToolExecutionStart.to_string(), + "tool.execution_start" + ); + } +} diff --git a/rust/src/session.rs b/rust/src/session.rs new file mode 100644 index 00000000..40932c2e --- /dev/null +++ b/rust/src/session.rs @@ -0,0 +1,365 @@ +//! CopilotSession implementation for managing conversation sessions. + +use crate::error::{CopilotError, Result}; +use crate::generated::{SessionEvent, SessionEventType}; +use crate::jsonrpc::JsonRpcClient; +use crate::tool::{Tool, ToolHandler, ToolInvocation, ToolResult}; +use crate::types::MessageOptions; +use futures::FutureExt; +use serde_json::json; +use std::collections::HashMap; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::{mpsc, RwLock}; + +/// Callback type for session events. +/// +/// Takes an `Arc` to avoid expensive clones in hot paths when +/// dispatching events to multiple handlers. +pub type SessionEventHandler = Arc) + Send + Sync>; + +/// Unsubscribe function returned by `on()`. +pub type UnsubscribeFn = Box; + +struct EventHandler { + id: u64, + handler: SessionEventHandler, +} + +/// A session for conversing with the Copilot CLI. +/// +/// Sessions maintain conversation state, handle events, and manage tool execution. +/// Sessions are created via [`CopilotClient::create_session()`](crate::CopilotClient::create_session) or resumed via +/// [`CopilotClient::resume_session()`](crate::CopilotClient::resume_session). +/// +/// # Example +/// +/// ```ignore +/// use copilot_sdk::{CopilotClient, MessageOptions}; +/// use std::sync::Arc; +/// +/// #[tokio::main] +/// async fn main() -> Result<(), Box> { +/// let client = CopilotClient::new(None)?; +/// let session = client.create_session(None).await?; +/// +/// // Subscribe to events +/// let _unsubscribe = session.on(Arc::new(|event| { +/// println!("Event: {:?}", event.event_type); +/// })); +/// +/// // Send a message +/// let message_id = session.send(MessageOptions { +/// prompt: "Hello!".to_string(), +/// ..Default::default() +/// }).await?; +/// +/// Ok(()) +/// } +/// ``` +pub struct CopilotSession { + session_id: String, + rpc_client: Arc, + handlers: Arc>>, + next_handler_id: AtomicU64, + tool_handlers: RwLock>, + destroyed: AtomicBool, +} + +impl CopilotSession { + /// Create a new session wrapper. + /// + /// Note: This is primarily for internal use. Use `CopilotClient::create_session` + /// to create sessions with proper initialization. + pub(crate) fn new(session_id: String, rpc_client: Arc) -> Self { + Self { + session_id, + rpc_client, + handlers: Arc::new(std::sync::Mutex::new(Vec::new())), + next_handler_id: AtomicU64::new(0), + tool_handlers: RwLock::new(HashMap::new()), + destroyed: AtomicBool::new(false), + } + } + + /// Get the session ID. + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// Send a message to this session. + /// + /// The message is processed asynchronously. Subscribe to events via [`Self::on()`] + /// to receive streaming responses and other session events. + /// + /// Returns the message ID of the response. + pub async fn send(&self, options: MessageOptions) -> Result { + let mut params = json!({ + "sessionId": self.session_id, + "prompt": options.prompt, + }); + + if let Some(ref attachments) = options.attachments { + params["attachments"] = json!(attachments); + } + if let Some(ref mode) = options.mode { + params["mode"] = json!(mode); + } + + let result = self.rpc_client.request("session.send", params).await?; + + let message_id = result + .get("messageId") + .and_then(|v| v.as_str()) + .ok_or_else(|| CopilotError::InvalidResponse("missing messageId".to_string()))? + .to_string(); + + Ok(message_id) + } + + /// Send a message and wait for the session to become idle. + /// + /// This is a convenience method that combines [`Self::send()`] with waiting for + /// the `session.idle` event. + /// + /// Events are still delivered to handlers registered via [`Self::on()`] while waiting. + /// + /// Returns the final assistant message event, or None if none was received. + pub async fn send_and_wait( + &self, + options: MessageOptions, + timeout: Option, + ) -> Result> { + let timeout = timeout.unwrap_or(Duration::from_secs(60)); + + let (idle_tx, mut idle_rx) = mpsc::channel::<()>(1); + let (error_tx, mut error_rx) = mpsc::channel::(1); + let last_assistant_message: Arc>> = Arc::new(RwLock::new(None)); + + let last_msg = last_assistant_message.clone(); + let idle_tx_clone = idle_tx.clone(); + let error_tx_clone = error_tx.clone(); + + let unsubscribe = self.on(Arc::new(move |event: Arc| { + let last_msg = last_msg.clone(); + let idle_tx = idle_tx_clone.clone(); + let error_tx = error_tx_clone.clone(); + + tokio::spawn(async move { + match event.event_type { + SessionEventType::AssistantMessage => { + let mut last = last_msg.write().await; + // Clone the inner SessionEvent from the Arc for storage + *last = Some((*event).clone()); + } + SessionEventType::SessionIdle => { + let _ = idle_tx.send(()).await; + } + SessionEventType::SessionError => { + let msg = event + .data + .message + .clone() + .unwrap_or_else(|| "session error".to_string()); + let _ = error_tx.send(msg).await; + } + _ => {} + } + }); + })); + + // Send the message + self.send(options).await?; + + // Wait for idle, error, or timeout + let result = tokio::select! { + _ = idle_rx.recv() => { + let last = last_assistant_message.read().await; + Ok(last.clone()) + } + Some(err) = error_rx.recv() => { + Err(CopilotError::Session(format!("session error: {}", err))) + } + _ = tokio::time::sleep(timeout) => { + Err(CopilotError::Timeout) + } + }; + + // Unsubscribe + unsubscribe(); + + result + } + + /// Subscribe to events from this session. + /// + /// Events include assistant messages, tool executions, errors, and session state + /// changes. Multiple handlers can be registered and will all receive events. + /// + /// Returns a function that can be called to unsubscribe the handler. + pub fn on(&self, handler: SessionEventHandler) -> impl FnOnce() + Send { + let id = self.next_handler_id.fetch_add(1, Ordering::SeqCst); + + // Use synchronous mutex lock - no async runtime needed + { + let mut h = self.handlers.lock().unwrap(); + h.push(EventHandler { id, handler }); + } + + // Return unsubscribe closure + let handlers = self.handlers.clone(); + move || { + let mut h = handlers.lock().unwrap(); + h.retain(|h| h.id != id); + } + } + + /// Get all events and messages from this session's history. + pub async fn get_messages(&self) -> Result> { + let params = json!({ + "sessionId": self.session_id, + }); + + let result = self.rpc_client.request("session.getMessages", params).await?; + + let events_raw = result + .get("events") + .and_then(|v| v.as_array()) + .ok_or_else(|| CopilotError::InvalidResponse("missing events".to_string()))?; + + let events: Vec = events_raw + .iter() + .filter_map(|v| serde_json::from_value(v.clone()).ok()) + .collect(); + + Ok(events) + } + + /// Abort the currently processing message in this session. + pub async fn abort(&self) -> Result<()> { + let params = json!({ + "sessionId": self.session_id, + }); + + self.rpc_client.request("session.abort", params).await?; + Ok(()) + } + + /// Destroy this session and release all associated resources. + /// + /// After calling this method, the session can no longer be used. + pub async fn destroy(&self) -> Result<()> { + // Mark as destroyed first to prevent any new events from being dispatched + self.destroyed.store(true, Ordering::SeqCst); + + let params = json!({ + "sessionId": self.session_id, + }); + + self.rpc_client.request("session.destroy", params).await?; + + // Clear handlers (using sync mutex) + { + let mut handlers = self.handlers.lock().unwrap(); + handlers.clear(); + } + + // Clear tool handlers + { + let mut tool_handlers = self.tool_handlers.write().await; + tool_handlers.clear(); + } + + Ok(()) + } + + /// Register tools for this session. + pub(crate) async fn register_tools(&self, tools: Vec) { + let mut handlers = self.tool_handlers.write().await; + handlers.clear(); + for tool in tools { + if !tool.name.is_empty() { + handlers.insert(tool.name.clone(), tool.handler.clone()); + } + } + } + + /// Get a tool handler by name. + pub(crate) async fn get_tool_handler(&self, name: &str) -> Option { + let handlers = self.tool_handlers.read().await; + handlers.get(name).cloned() + } + + /// Execute a tool. + pub(crate) async fn execute_tool( + &self, + tool_name: &str, + invocation: ToolInvocation, + ) -> Result { + let handler = self.get_tool_handler(tool_name).await; + + match handler { + Some(handler) => { + // Execute the tool handler with async catch_unwind + let tool_name_for_err = tool_name.to_string(); + match std::panic::AssertUnwindSafe(handler(invocation)) + .catch_unwind() + .await + { + Ok(result) => result, + Err(_) => Ok(ToolResult::failure(format!( + "tool panic: {}", + tool_name_for_err + ))), + } + } + None => Ok(ToolResult::unsupported(tool_name)), + } + } + + /// Dispatch an event to all registered handlers. + pub(crate) async fn dispatch_event(&self, event: SessionEvent) { + // Don't dispatch events if the session has been destroyed + if self.destroyed.load(Ordering::SeqCst) { + return; + } + + let handlers: Vec = { + let h = self.handlers.lock().unwrap(); + h.iter().map(|h| h.handler.clone()).collect() + }; + + // Wrap event in Arc once, then clone the Arc for each handler (cheap) + let event = Arc::new(event); + + for handler in handlers { + // Don't let panics crash the dispatcher + let event_clone = Arc::clone(&event); + let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + handler(event_clone); + })); + } + } + + /// Dispatch an error event to all registered handlers. + /// + /// Used internally to notify session handlers when the connection is lost. + pub(crate) async fn dispatch_error(&self, message: &str) { + use crate::generated::SessionEventData; + use chrono::Utc; + + let event = SessionEvent { + event_type: SessionEventType::SessionError, + id: uuid::Uuid::new_v4().to_string(), + timestamp: Utc::now(), + parent_id: None, + ephemeral: None, + data: SessionEventData { + message: Some(message.to_string()), + ..Default::default() + }, + }; + self.dispatch_event(event).await; + } +} diff --git a/rust/src/tool.rs b/rust/src/tool.rs new file mode 100644 index 00000000..9c3364c0 --- /dev/null +++ b/rust/src/tool.rs @@ -0,0 +1,340 @@ +//! Tool definition helpers for the Copilot SDK. + +use crate::error::Result; +use schemars::JsonSchema; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; +use serde_json::Value; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; + +/// Information about a tool invocation. +#[derive(Debug, Clone)] +pub struct ToolInvocation { + /// Session ID. + pub session_id: String, + /// Unique ID for this tool call. + pub tool_call_id: String, + /// Name of the tool being called. + pub tool_name: String, + /// Raw arguments as JSON value. + pub arguments: Value, +} + +/// Result of a tool execution. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResult { + /// Text result for the LLM. + #[serde(rename = "textResultForLlm")] + pub text_result_for_llm: String, + + /// Binary results for the LLM. + #[serde(rename = "binaryResultsForLlm", skip_serializing_if = "Option::is_none")] + pub binary_results_for_llm: Option>, + + /// Result type: "success" or "failure". + #[serde(rename = "resultType")] + pub result_type: String, + + /// Error message (for failures). + #[serde(skip_serializing_if = "Option::is_none")] + pub error: Option, + + /// Session log (optional). + #[serde(rename = "sessionLog", skip_serializing_if = "Option::is_none")] + pub session_log: Option, + + /// Tool telemetry data. + #[serde(rename = "toolTelemetry", skip_serializing_if = "Option::is_none")] + pub tool_telemetry: Option, +} + +impl ToolResult { + /// Create a successful result with text. + pub fn success(text: impl Into) -> Self { + Self { + text_result_for_llm: text.into(), + binary_results_for_llm: None, + result_type: "success".to_string(), + error: None, + session_log: None, + tool_telemetry: None, + } + } + + /// Create a failure result. + pub fn failure(error: impl Into) -> Self { + Self { + text_result_for_llm: "Invoking this tool produced an error. Detailed information is not available.".to_string(), + binary_results_for_llm: None, + result_type: "failure".to_string(), + error: Some(error.into()), + session_log: None, + tool_telemetry: None, + } + } + + /// Create a result for an unsupported tool. + pub fn unsupported(tool_name: &str) -> Self { + Self { + text_result_for_llm: format!("Tool '{}' is not supported by this client instance.", tool_name), + binary_results_for_llm: None, + result_type: "failure".to_string(), + error: Some(format!("tool '{}' not supported", tool_name)), + session_log: None, + tool_telemetry: None, + } + } +} + +/// Binary result for tools. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolBinaryResult { + /// Base64-encoded data. + pub data: String, + /// MIME type. + #[serde(rename = "mimeType")] + pub mime_type: String, + /// Result type. + #[serde(rename = "type")] + pub result_type: String, + /// Optional description. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, +} + +/// Type alias for async tool handlers. +pub type ToolHandler = Arc< + dyn Fn(ToolInvocation) -> Pin> + Send>> + + Send + + Sync, +>; + +/// A tool definition that can be exposed to Copilot. +#[derive(Clone)] +pub struct Tool { + /// Tool name. + pub name: String, + /// Tool description. + pub description: String, + /// JSON Schema for parameters. + pub parameters: Option, + /// Tool handler function. + pub handler: ToolHandler, +} + +impl std::fmt::Debug for Tool { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Tool") + .field("name", &self.name) + .field("description", &self.description) + .field("parameters", &self.parameters) + .finish() + } +} + +/// Define a tool with automatic JSON schema generation from a typed handler. +/// +/// # Example +/// +/// ```ignore +/// use schemars::JsonSchema; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize, JsonSchema)] +/// struct GetWeatherParams { +/// city: String, +/// } +/// +/// let tool = define_tool::( +/// "get_weather", +/// "Get weather for a city", +/// |params, _inv| async move { +/// Ok(format!("Weather in {}: 22 degrees", params.city)) +/// }, +/// ); +/// ``` +pub fn define_tool(name: &str, description: &str, handler: F) -> Tool +where + P: DeserializeOwned + JsonSchema + Send + 'static, + F: Fn(P, ToolInvocation) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + R: IntoToolResult + 'static, +{ + // Generate JSON schema for the parameters + let schema = schemars::schema_for!(P); + let parameters = serde_json::to_value(schema).ok(); + + let handler = Arc::new(handler); + + let wrapped_handler: ToolHandler = Arc::new(move |inv: ToolInvocation| { + let handler = handler.clone(); + Box::pin(async move { + // Parse arguments into typed struct + let params: P = serde_json::from_value(inv.arguments.clone()) + .map_err(|e| crate::error::CopilotError::ToolExecution( + format!("Failed to parse arguments: {}", e) + ))?; + + let result = handler(params, inv).await?; + result.into_tool_result() + }) + }); + + Tool { + name: name.to_string(), + description: description.to_string(), + parameters, + handler: wrapped_handler, + } +} + +/// Trait for converting values into ToolResult. +/// +/// This trait enables flexible return types from tool handlers. Instead of +/// always returning `ToolResult`, handlers can return simpler types like +/// `String`, `&str`, `()`, or `serde_json::Value`, and they will be +/// automatically converted to successful `ToolResult` values. +/// +/// # Built-in Implementations +/// +/// | Type | Result | +/// |------|--------| +/// | `ToolResult` | Passed through unchanged | +/// | `String` | Success with the string as content | +/// | `&str` | Success with the string as content | +/// | `()` | Success with empty content | +/// | `serde_json::Value` | Success with JSON serialized as string | +/// +/// # Example +/// +/// ```ignore +/// // These tool handlers are all valid: +/// +/// // Return a String +/// |params, _inv| async move { Ok("Done!".to_string()) } +/// +/// // Return a ToolResult for more control +/// |params, _inv| async move { Ok(ToolResult::success("Done!")) } +/// +/// // Return nothing (empty success) +/// |params, _inv| async move { Ok(()) } +/// +/// // Return JSON +/// |params, _inv| async move { Ok(serde_json::json!({"status": "ok"})) } +/// ``` +pub trait IntoToolResult { + /// Convert this value into a [`ToolResult`]. + /// + /// # Returns + /// + /// A `Result` containing the converted `ToolResult`, or an error if + /// conversion fails (e.g., JSON serialization error for `Value` types). + fn into_tool_result(self) -> Result; +} + +impl IntoToolResult for ToolResult { + fn into_tool_result(self) -> Result { + Ok(self) + } +} + +impl IntoToolResult for String { + fn into_tool_result(self) -> Result { + Ok(ToolResult::success(self)) + } +} + +impl IntoToolResult for &str { + fn into_tool_result(self) -> Result { + Ok(ToolResult::success(self)) + } +} + +impl IntoToolResult for () { + fn into_tool_result(self) -> Result { + Ok(ToolResult::success("")) + } +} + +impl IntoToolResult for Value { + fn into_tool_result(self) -> Result { + let json = serde_json::to_string(&self)?; + Ok(ToolResult::success(json)) + } +} + +/// Builder for creating tools manually without automatic schema generation. +pub struct ToolBuilder { + name: String, + description: String, + parameters: Option, +} + +impl ToolBuilder { + /// Create a new tool builder. + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + description: String::new(), + parameters: None, + } + } + + /// Set the tool description. + pub fn description(mut self, description: impl Into) -> Self { + self.description = description.into(); + self + } + + /// Set the JSON schema for parameters. + pub fn parameters(mut self, parameters: Value) -> Self { + self.parameters = Some(parameters); + self + } + + /// Build the tool with an async handler. + pub fn handler(self, handler: F) -> Tool + where + F: Fn(ToolInvocation) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + let handler = Arc::new(handler); + Tool { + name: self.name, + description: self.description, + parameters: self.parameters, + handler: Arc::new(move |inv| { + let handler = handler.clone(); + Box::pin(async move { handler(inv).await }) + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tool_result_success() { + let result = ToolResult::success("Hello"); + assert_eq!(result.result_type, "success"); + assert_eq!(result.text_result_for_llm, "Hello"); + assert!(result.error.is_none()); + } + + #[test] + fn test_tool_result_failure() { + let result = ToolResult::failure("Something went wrong"); + assert_eq!(result.result_type, "failure"); + assert!(result.error.is_some()); + } + + #[test] + fn test_tool_result_unsupported() { + let result = ToolResult::unsupported("unknown_tool"); + assert_eq!(result.result_type, "failure"); + assert!(result.text_result_for_llm.contains("unknown_tool")); + } +} diff --git a/rust/src/types.rs b/rust/src/types.rs new file mode 100644 index 00000000..4369b174 --- /dev/null +++ b/rust/src/types.rs @@ -0,0 +1,483 @@ +//! Core type definitions for the Copilot SDK. +//! +//! This module contains all the configuration types, enums, and data structures +//! used throughout the SDK for client configuration, session management, and +//! message handling. +//! +//! # Main Types +//! +//! - [`ClientOptions`] - Configuration for creating a [`CopilotClient`](crate::CopilotClient) +//! - [`SessionConfig`] - Configuration for creating a new session +//! - [`MessageOptions`] - Options for sending messages to a session +//! - [`ConnectionState`] - Current connection state of the client + +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// Connection state of the client. +/// +/// Represents the current state of the connection between the SDK and the +/// Copilot CLI server. Use [`CopilotClient::get_state()`](crate::CopilotClient::get_state) +/// to retrieve the current state. +/// +/// # Example +/// +/// ```ignore +/// use copilot_sdk::{CopilotClient, ConnectionState}; +/// +/// let client = CopilotClient::new(None)?; +/// assert_eq!(client.get_state(), ConnectionState::Disconnected); +/// +/// client.start().await?; +/// assert_eq!(client.get_state(), ConnectionState::Connected); +/// ``` +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConnectionState { + /// Client is not connected to the CLI server. + /// + /// This is the initial state before [`start()`](crate::CopilotClient::start) is called, + /// or after [`stop()`](crate::CopilotClient::stop) completes. + Disconnected, + + /// Client is in the process of connecting to the CLI server. + /// + /// This transient state occurs during [`start()`](crate::CopilotClient::start) + /// while the connection is being established. + Connecting, + + /// Client is connected and ready to use. + /// + /// The client can create sessions and send messages in this state. + Connected, + + /// An error occurred with the connection. + /// + /// This state indicates the connection failed or was lost unexpectedly. + /// Check logs for details and consider calling [`start()`](crate::CopilotClient::start) + /// again to reconnect. + Error, +} + +impl std::fmt::Display for ConnectionState { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ConnectionState::Disconnected => write!(f, "disconnected"), + ConnectionState::Connecting => write!(f, "connecting"), + ConnectionState::Connected => write!(f, "connected"), + ConnectionState::Error => write!(f, "error"), + } + } +} + +/// Options for configuring the CopilotClient. +#[derive(Debug, Clone, Default)] +pub struct ClientOptions { + /// Path to the Copilot CLI executable (default: "copilot"). + pub cli_path: Option, + + /// Working directory for the CLI process. + pub cwd: Option, + + /// Port for TCP transport (default: 0 = random port). + pub port: Option, + + /// Enable stdio transport instead of TCP (default: true). + pub use_stdio: Option, + + /// URL of an existing Copilot CLI server to connect to over TCP. + /// Format: "host:port", "http://host:port", or just "port" (defaults to localhost). + /// Mutually exclusive with cli_path and use_stdio. + pub cli_url: Option, + + /// Log level for the CLI server. + pub log_level: Option, + + /// Automatically start the CLI server on first use (default: true). + pub auto_start: Option, + + /// Automatically restart the CLI server if it crashes (default: true). + pub auto_restart: Option, + + /// Environment variables for the CLI process. + pub env: Option>, +} + +/// System message configuration for session creation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SystemMessageConfig { + /// Mode: "append" or "replace". + #[serde(skip_serializing_if = "Option::is_none")] + pub mode: Option, + + /// Content for the system message. + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, +} + +/// Configuration for a local/stdio MCP server. +/// +/// Local MCP servers are spawned as child processes and communicate via stdio. +/// This is the most common configuration for MCP servers running on the same machine. +/// +/// # Example +/// +/// ```ignore +/// use copilot_sdk::McpLocalServerConfig; +/// +/// let config = McpLocalServerConfig { +/// tools: vec!["read_file".to_string(), "write_file".to_string()], +/// server_type: None, // Defaults to "stdio" +/// timeout: Some(30000), +/// command: "npx".to_string(), +/// args: Some(vec!["-y".to_string(), "@modelcontextprotocol/server-filesystem".to_string()]), +/// env: None, +/// cwd: Some("/path/to/project".to_string()), +/// }; +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpLocalServerConfig { + /// List of tool names this server provides. + pub tools: Vec, + + /// Server type (typically "stdio" for local servers). + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub server_type: Option, + + /// Timeout in milliseconds for server operations. + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, + + /// Command to execute to start the server. + pub command: String, + + /// Arguments to pass to the command. + #[serde(skip_serializing_if = "Option::is_none")] + pub args: Option>, + + /// Environment variables to set for the server process. + #[serde(skip_serializing_if = "Option::is_none")] + pub env: Option>, + + /// Working directory for the server process. + #[serde(skip_serializing_if = "Option::is_none")] + pub cwd: Option, +} + +/// Configuration for a remote MCP server (HTTP or SSE). +/// +/// Remote MCP servers communicate over HTTP or Server-Sent Events (SSE). +/// Use this for MCP servers hosted on remote machines or as web services. +/// +/// # Example +/// +/// ```ignore +/// use copilot_sdk::McpRemoteServerConfig; +/// use std::collections::HashMap; +/// +/// let config = McpRemoteServerConfig { +/// tools: vec!["search".to_string()], +/// server_type: "sse".to_string(), +/// timeout: Some(60000), +/// url: "https://mcp.example.com/sse".to_string(), +/// headers: Some(HashMap::from([ +/// ("Authorization".to_string(), "Bearer token".to_string()), +/// ])), +/// }; +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpRemoteServerConfig { + /// List of tool names this server provides. + pub tools: Vec, + + /// Server type: "http" or "sse". + #[serde(rename = "type")] + pub server_type: String, + + /// Timeout in milliseconds for server operations. + #[serde(skip_serializing_if = "Option::is_none")] + pub timeout: Option, + + /// URL of the remote MCP server. + pub url: String, + + /// HTTP headers to include in requests to the server. + #[serde(skip_serializing_if = "Option::is_none")] + pub headers: Option>, +} + +/// MCP server configuration (can be local or remote). +/// +/// This enum represents different types of MCP server configurations. +/// Use [`McpLocalServerConfig`] for locally-spawned servers or +/// [`McpRemoteServerConfig`] for remote HTTP/SSE servers. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum McpServerConfig { + /// Local MCP server spawned as a child process. + Local(McpLocalServerConfig), + /// Remote MCP server accessed via HTTP or SSE. + Remote(McpRemoteServerConfig), + /// Raw JSON configuration for advanced use cases. + Raw(serde_json::Value), +} + +/// Configuration for a custom agent. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CustomAgentConfig { + /// Unique name of the custom agent. + pub name: String, + + /// Display name for UI purposes. + #[serde(rename = "displayName", skip_serializing_if = "Option::is_none")] + pub display_name: Option, + + /// Description of what the agent does. + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + + /// List of tool names the agent can use (None for all tools). + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + + /// Prompt content for the agent. + pub prompt: String, + + /// MCP servers specific to this agent. + #[serde(rename = "mcpServers", skip_serializing_if = "Option::is_none")] + pub mcp_servers: Option>, + + /// Whether the agent should be available for model inference. + #[serde(skip_serializing_if = "Option::is_none")] + pub infer: Option, +} + +/// Azure-specific provider configuration. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AzureProviderOptions { + /// Azure API version (default: "2024-10-21"). + #[serde(rename = "apiVersion", skip_serializing_if = "Option::is_none")] + pub api_version: Option, +} + +/// Configuration for a custom model provider. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ProviderConfig { + /// Provider type: "openai", "azure", or "anthropic" (default: "openai"). + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + pub provider_type: Option, + + /// API format (openai/azure only): "completions" or "responses" (default: "completions"). + #[serde(rename = "wireApi", skip_serializing_if = "Option::is_none")] + pub wire_api: Option, + + /// API endpoint URL. + #[serde(rename = "baseUrl")] + pub base_url: String, + + /// API key. Optional for local providers like Ollama. + #[serde(rename = "apiKey", skip_serializing_if = "Option::is_none")] + pub api_key: Option, + + /// Bearer token for authentication. + #[serde(rename = "bearerToken", skip_serializing_if = "Option::is_none")] + pub bearer_token: Option, + + /// Azure-specific options. + #[serde(skip_serializing_if = "Option::is_none")] + pub azure: Option, +} + +/// Configuration for creating a new session. +#[derive(Debug, Clone, Default)] +pub struct SessionConfig { + /// Optional custom session ID. + pub session_id: Option, + + /// Model to use for this session. + pub model: Option, + + /// Override the default configuration directory location. + pub config_dir: Option, + + /// Caller-implemented tools to expose to the CLI. + pub tools: Vec, + + /// System message customization. + pub system_message: Option, + + /// List of tool names to allow. When specified, only these tools will be available. + pub available_tools: Option>, + + /// List of tool names to disable. All other tools remain available. + pub excluded_tools: Option>, + + /// Enable streaming of assistant message and reasoning chunks. + pub streaming: Option, + + /// Custom model provider (BYOK). + pub provider: Option, + + /// MCP servers for the session. + pub mcp_servers: Option>, + + /// Custom agents for the session. + pub custom_agents: Option>, + + /// Directories to load skills from. + pub skill_directories: Option>, + + /// Skill names to disable. + pub disabled_skills: Option>, +} + +/// Configuration for resuming an existing session. +#[derive(Debug, Clone, Default)] +pub struct ResumeSessionConfig { + /// Caller-implemented tools to expose to the CLI. + pub tools: Vec, + + /// Custom model provider. + pub provider: Option, + + /// Enable streaming of assistant message and reasoning chunks. + pub streaming: Option, + + /// MCP servers for the session. + pub mcp_servers: Option>, + + /// Custom agents for the session. + pub custom_agents: Option>, + + /// Directories to load skills from. + pub skill_directories: Option>, + + /// Skill names to disable. + pub disabled_skills: Option>, +} + +/// Options for sending a message. +#[derive(Debug, Clone, Default)] +pub struct MessageOptions { + /// The message to send. + pub prompt: String, + + /// File or directory attachments. + pub attachments: Option>, + + /// Message delivery mode (default: "enqueue"). + pub mode: Option, +} + +/// File or directory attachment. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Attachment { + /// Display name for the attachment. + #[serde(rename = "displayName")] + pub display_name: String, + + /// Path to the file or directory. + pub path: String, + + /// Type: "file" or "directory". + #[serde(rename = "type")] + pub attachment_type: AttachmentType, +} + +/// Attachment type. +/// +/// Specifies whether an [`Attachment`] refers to a single file or a directory. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum AttachmentType { + /// A single file attachment. + File, + /// A directory attachment (contents may be recursively included). + Directory, +} + +/// Response from a ping request. +/// +/// Returned by [`CopilotClient::ping()`](crate::CopilotClient::ping) to verify +/// connectivity and protocol compatibility with the CLI server. +/// +/// # Example +/// +/// ```ignore +/// let response = client.ping(Some("hello")).await?; +/// println!("Server responded: {}", response.message); +/// println!("Protocol version: {:?}", response.protocol_version); +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PingResponse { + /// Echo of the ping message, or default response if none was sent. + pub message: String, + + /// Unix timestamp (milliseconds) when the server processed the ping. + pub timestamp: i64, + + /// Protocol version reported by the server. + /// + /// Used to verify SDK and server compatibility. If `None`, the server + /// may be an older version that doesn't report protocol versions. + #[serde(rename = "protocolVersion")] + pub protocol_version: Option, +} + +/// Permission request from the server. +/// +/// Sent by the CLI server when an operation requires user permission. +/// The SDK can be configured with a permission handler to respond to these requests. +#[derive(Debug, Clone)] +pub struct PermissionRequest { + /// Type of permission being requested (e.g., "file_write", "shell_execute"). + pub kind: String, + + /// ID of the tool call that triggered this permission request. + pub tool_call_id: Option, + + /// Additional context about the permission request as JSON. + pub extra: serde_json::Value, +} + +/// Result of a permission request. +/// +/// Returned by permission handlers to grant or deny permission requests. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct PermissionRequestResult { + /// Type of permission (should match the request's kind). + pub kind: String, + + /// Permission rules to apply. If `None`, permission is denied. + #[serde(skip_serializing_if = "Option::is_none")] + pub rules: Option>, +} + +/// Context for a permission request. +/// +/// Provides additional context to permission handlers about where the +/// permission request originated. +#[derive(Debug, Clone)] +pub struct PermissionInvocation { + /// ID of the session that triggered the permission request. + pub session_id: String, +} + +/// Session metadata returned by list_sessions. +/// +/// Contains basic information about an existing session. Use +/// [`CopilotClient::list_sessions()`](crate::CopilotClient::list_sessions) +/// to retrieve all available sessions. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SessionMetadata { + /// Unique identifier for the session. + #[serde(rename = "sessionId")] + pub session_id: String, +} + +/// SDK protocol version constant. +pub const SDK_PROTOCOL_VERSION: i32 = 1; + +/// Returns the SDK protocol version. +pub fn get_sdk_protocol_version() -> i32 { + SDK_PROTOCOL_VERSION +}