From 43ab2bc4b77e91455880b6e467ffb3bd951e359a Mon Sep 17 00:00:00 2001 From: hude Date: Fri, 17 Apr 2026 09:02:07 +0900 Subject: [PATCH 1/4] Add configurable embedding backends --- Cargo.lock | 544 +++++++++++++++++- Cargo.toml | 1 + README.md | 22 +- rust/cli_defs.rs | 4 +- rust/clients/launcher.rs | 12 +- rust/commands/ask_ai.rs | 30 +- rust/commands/prefs.rs | 11 + rust/commands/search.rs | 3 +- rust/commands/search_raw.rs | 7 +- rust/embedding.rs | 91 ++- rust/embedding_config.rs | 317 ++++++++++ rust/insert_service.rs | 25 +- rust/lib.rs | 3 + rust/local_chunking.rs | 131 +++++ rust/local_embedding.rs | 125 ++++ rust/preferences.rs | 58 ++ rust/tools/service.rs | 5 +- rust/tui/bridge.rs | 3 +- rust/tui/provider/mod.rs | 33 ++ rust/tui/settings.rs | 26 +- rust/tui/settings_tests.rs | 31 +- tests/kinic_cli_prefs.rs | 2 + tests/kinic_cli_tui.rs | 4 +- .../src/ui/app/screens/settings/mod.rs | 4 + tui/crates/tui-kit-runtime/src/lib.rs | 39 ++ 25 files changed, 1476 insertions(+), 55 deletions(-) create mode 100644 rust/embedding_config.rs create mode 100644 rust/local_chunking.rs create mode 100644 rust/local_embedding.rs diff --git a/Cargo.lock b/Cargo.lock index ef2bee9..b590495 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,11 +24,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", + "getrandom 0.3.4", "once_cell", + "serde", "version_check", "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -528,6 +539,7 @@ dependencies = [ "itoa", "rustversion", "ryu", + "serde", "static_assertions", ] @@ -540,6 +552,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width 0.2.0", + "windows-sys 0.59.0", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -599,6 +624,25 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -712,6 +756,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + [[package]] name = "darling" version = "0.21.3" @@ -732,6 +786,20 @@ dependencies = [ "darling_macro 0.23.0", ] +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.109", +] + [[package]] name = "darling_core" version = "0.21.3" @@ -759,6 +827,17 @@ dependencies = [ "syn 2.0.109", ] +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core 0.20.11", + "quote", + "syn 2.0.109", +] + [[package]] name = "darling_macro" version = "0.21.3" @@ -781,6 +860,15 @@ dependencies = [ "syn 2.0.109", ] +[[package]] +name = "dary_heap" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06d2e3287df1c007e74221c49ca10a95d557349e54b3a75dc2fb14712c751f04" +dependencies = [ + "serde", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -829,6 +917,37 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling 0.20.11", + "proc-macro2", + "quote", + "syn 2.0.109", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.109", +] + [[package]] name = "derive_more" version = "2.1.1" @@ -1007,6 +1126,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -1032,6 +1157,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" + [[package]] name = "euclid" version = "0.20.14" @@ -1068,6 +1199,22 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fastembed" +version = "5.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f54fc1188b7f7eac8f47be2ab7b3a79ffd842cc8ff2e38316dd59ba4858890e" +dependencies = [ + "anyhow", + "hf-hub", + "ndarray", + "ort", + "safetensors", + "serde", + "serde_json", + "tokenizers", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -1352,6 +1499,8 @@ dependencies = [ "allocator-api2", "equivalent", "foldhash 0.2.0", + "serde", + "serde_core", ] [[package]] @@ -1369,6 +1518,26 @@ dependencies = [ "serde", ] +[[package]] +name = "hf-hub" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97" +dependencies = [ + "dirs", + "http", + "indicatif", + "libc", + "log", + "rand 0.9.2", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.17", + "ureq 2.12.1", + "windows-sys 0.60.2", +] + [[package]] name = "hkdf" version = "0.12.4" @@ -1387,6 +1556,12 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "hmac-sha256" +version = "1.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec9d92d097f4749b64e8cc33d924d9f40a2d4eb91402b458014b781f5733d60f" + [[package]] name = "home" version = "0.5.12" @@ -1478,7 +1653,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.4", ] [[package]] @@ -1829,6 +2004,19 @@ dependencies = [ "hashbrown 0.16.1", ] +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width 0.2.0", + "web-time", +] + [[package]] name = "indoc" version = "2.0.7" @@ -2007,6 +2195,7 @@ dependencies = [ "clap", "der", "dotenvy", + "fastembed", "gag", "hex", "ic-agent", @@ -2187,6 +2376,28 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "lzma-rust2" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1670343e58806300d87950e3401e820b519b9384281bbabfb15e3636689ffd69" + +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + [[package]] name = "malloc_buf" version = "0.0.6" @@ -2202,6 +2413,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "md-5" version = "0.10.6" @@ -2293,6 +2514,43 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.109", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk-context" version = "0.1.1" @@ -2329,6 +2587,15 @@ dependencies = [ "serde", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -2354,6 +2621,12 @@ dependencies = [ "libm", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "objc" version = "0.2.7" @@ -2384,6 +2657,28 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "onig" +version = "6.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0" +dependencies = [ + "bitflags", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f86c6eef3d6df15f23bcfb6af487cbd2fed4e5581d58d5bf1f5f8b7f6727dc" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "opaque-debug" version = "0.3.1" @@ -2444,6 +2739,30 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ort" +version = "2.0.0-rc.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5df903c0d2c07b56950f1058104ab0c8557159f2741782223704de9be73c3c" +dependencies = [ + "ndarray", + "ort-sys", + "smallvec", + "tracing", + "ureq 3.3.0", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06503bb33f294c5f1ba484011e053bfa6ae227074bdb841e9863492dc5960d4b" +dependencies = [ + "hmac-sha256", + "lzma-rust2", + "ureq 3.3.0", +] + [[package]] name = "p256" version = "0.13.2" @@ -2590,6 +2909,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +dependencies = [ + "portable-atomic", +] + [[package]] name = "postscript" version = "0.14.1" @@ -2945,6 +3273,43 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f" +dependencies = [ + "either", + "itertools 0.14.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -2985,6 +3350,35 @@ dependencies = [ "syn 2.0.109", ] +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + [[package]] name = "reqwest" version = "0.12.24" @@ -3024,7 +3418,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots 1.0.4", ] [[package]] @@ -3132,6 +3526,7 @@ version = "0.23.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", @@ -3173,6 +3568,17 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "safetensors" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5" +dependencies = [ + "hashbrown 0.16.1", + "serde", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -3499,6 +3905,17 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "spki" version = "0.7.3" @@ -3509,6 +3926,18 @@ dependencies = [ "der", ] +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -3773,6 +4202,39 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b238e22d44a15349529690fb07bd645cf58149a1b1e44d6cb5bd1641ff1a6223" +dependencies = [ + "ahash", + "aho-corasick", + "compact_str 0.9.0", + "dary_heap", + "derive_builder", + "esaxx-rs", + "getrandom 0.3.4", + "itertools 0.14.0", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.9.2", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 2.0.17", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.48.0" @@ -4021,6 +4483,15 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -4061,6 +4532,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "unindent" version = "0.2.4" @@ -4079,6 +4556,54 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "socks", + "url", + "webpki-roots 0.26.11", +] + +[[package]] +name = "ureq" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0" +dependencies = [ + "base64 0.22.1", + "log", + "percent-encoding", + "rustls", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf8-zero", + "webpki-roots 1.0.4", +] + +[[package]] +name = "ureq-proto" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c" +dependencies = [ + "base64 0.22.1", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.7" @@ -4091,6 +4616,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf8-zero" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -4263,6 +4794,15 @@ dependencies = [ "web-sys", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.4", +] + [[package]] name = "webpki-roots" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index 086a82f..09290d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ ring = "0.17.14" der = "0.7.10" pkcs8 = "0.10.2" ic-ed25519 = "0.2.0" +fastembed = { version = "5.13.2", default-features = false, features = ["ort-download-binaries-rustls-tls", "hf-hub-rustls-tls"] } kinic-core = { path = "crates/kinic-core" } tui-kit-host = { path = "tui/crates/tui-kit-host" } tui-kit-runtime = { path = "tui/crates/tui-kit-runtime" } diff --git a/README.md b/README.md index 60de5b0..d275bba 100644 --- a/README.md +++ b/README.md @@ -384,7 +384,7 @@ See `python/examples/insert_pdf_file.py` for a runnable script. ## Ask AI -Runs a search and prepares context for an AI answer. The CLI calls `/chat` at `EMBEDDING_API_ENDPOINT` (default `https://api.kinic.io`) and prints only the `` text. +Runs a search with the configured embedding backend, prepares context for an AI answer, then calls `/chat` at `EMBEDDING_API_ENDPOINT` (default `https://api.kinic.io`) and prints only the `` text. ```python prompt, answer = km.ask_ai(memory_id, "What did we say about quarterly goals?", top_k=3, language="en") @@ -395,6 +395,26 @@ print("Answer:\n", answer) - `km.ask_ai` returns `(prompt, answer)` where `answer` is the `` section from the chat response. - CLI usage: `cargo run -- --identity ask-ai --memory-id --query "" --top-k 3` +Embedding backend initialization: + +- initial saved backend: `api` +- API dimension: `1024` +- local option example: `Snowflake/snowflake-arctic-embed-s` +- local cache dir: `$HOME/.cache/kinic-cli/embeddings` + +Shared settings behavior: + +- if the shared config directory is available and `tui.yaml` is missing, Kinic initializes with the saved default `api` +- if the shared config directory is unavailable, embedding-backed commands fail explicitly instead of falling back + +Local runtime overrides: + +- `KINIC_LOCAL_EMBEDDING_CACHE_DIR` +- `KINIC_LOCAL_EMBEDDING_MAX_LENGTH` +- `KINIC_LOCAL_EMBEDDING_CHUNK_SOFT_LIMIT` +- `KINIC_LOCAL_EMBEDDING_CHUNK_HARD_LIMIT` +- `KINIC_LOCAL_EMBEDDING_CHUNK_OVERLAP` + --- ## Configure memory visibility diff --git a/rust/cli_defs.rs b/rust/cli_defs.rs index 4a9c80d..755745f 100644 --- a/rust/cli_defs.rs +++ b/rust/cli_defs.rs @@ -116,7 +116,7 @@ pub enum Command { Capabilities(CapabilitiesArgs), #[command( about = "Manage local Kinic preferences shared with the TUI. All prefs commands return JSON.", - after_help = "Examples:\n kinic-cli prefs show\n kinic-cli prefs set-default-memory --memory-id MEMORY_CANISTER_ID\n kinic-cli prefs set-chat-overall-top-k --value 10\n\nReturns:\n show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer}\n mutations -> {\"resource\": string, \"action\": string, \"status\": \"updated\"|\"unchanged\", \"value\": string|integer|null}" + after_help = "Examples:\n kinic-cli prefs show\n kinic-cli prefs set-default-memory --memory-id MEMORY_CANISTER_ID\n kinic-cli prefs set-chat-overall-top-k --value 10\n\nReturns:\n show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}\n mutations -> {\"resource\": string, \"action\": string, \"status\": \"updated\"|\"unchanged\", \"value\": string|integer|null}" )] Prefs(PrefsArgs), #[command( @@ -411,7 +411,7 @@ pub struct PrefsArgs { pub enum PrefsCommand { #[command( about = "Show local preferences shared with the TUI. Returns JSON.", - after_help = "Returns:\n {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer}\n\nExample:\n kinic-cli prefs show" + after_help = "Returns:\n {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}\n\nExample:\n kinic-cli prefs show" )] Show, #[command( diff --git a/rust/clients/launcher.rs b/rust/clients/launcher.rs index 6b3a1ac..e260ed3 100644 --- a/rust/clients/launcher.rs +++ b/rust/clients/launcher.rs @@ -13,9 +13,10 @@ use icrc_ledger_types::{ use serde_json::json; use thiserror::Error; -use crate::clients::{LAUNCHER_CANISTER, LEDGER_CANISTER}; - -const DEFAULT_VECTOR_DIM: u64 = 1024; +use crate::{ + clients::{LAUNCHER_CANISTER, LEDGER_CANISTER}, + embedding::configured_embedding_dimension_u64, +}; const APPROVAL_TTL_NS: u64 = 10 * 60 * 1_000_000_000; pub struct LauncherClient { @@ -131,7 +132,10 @@ fn encode_deploy_args(name: &str, description: &str) -> Result> { "name": name, "description": description}) .to_string(); - Ok(candid::encode_args((payload, DEFAULT_VECTOR_DIM))?) + Ok(candid::encode_args(( + payload, + configured_embedding_dimension_u64()?, + ))?) } fn encode_update_instance_args(instance_pid_str: &str) -> Result> { diff --git a/rust/commands/ask_ai.rs b/rust/commands/ask_ai.rs index 263e3df..ad2ab62 100644 --- a/rust/commands/ask_ai.rs +++ b/rust/commands/ask_ai.rs @@ -2,14 +2,13 @@ use std::cmp::Ordering; use anyhow::{Context, Result}; use ic_agent::export::Principal; -use reqwest::Client; use tracing::info; use crate::{ agent::AgentFactory, cli::AskAiArgs, clients::memory::MemoryClient, - embedding::{embedding_base_url, fetch_embedding}, + embedding::{call_chat_http, ensure_memory_dim_matches, fetch_embedding}, prompt_utils::{escape_xml, prompt_language_instruction}, }; @@ -20,8 +19,6 @@ const MAX_RESULTS: usize = 5; const MAX_HITS_PER_DOC: usize = 6; const MAX_HIT_LEN: usize = 600; const MAX_FULL_LEN: usize = 4096; -const CHAT_PATH: &str = "/chat"; - pub struct AskAiResult { pub prompt: String, pub response: String, @@ -67,6 +64,7 @@ pub async fn ask_ai_flow( let client = MemoryClient::new(agent, *memory_id); let embedding = fetch_embedding(query).await?; + ensure_memory_dim_matches(&client, &memory_id.to_text(), embedding.len()).await?; let mut results = client.search(embedding).await?; results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal)); @@ -84,24 +82,7 @@ pub async fn ask_ai_flow( } pub(crate) async fn call_chat_endpoint(prompt: &str) -> Result { - let url = format!("{}{}", embedding_base_url(), CHAT_PATH); - let response = Client::new() - .post(url) - .json(&ChatRequest { message: prompt }) - .send() - .await - .context("Failed to call chat endpoint")?; - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - anyhow::bail!("chat endpoint returned {status}: {body}"); - } - - let body = response - .text() - .await - .context("Failed to read chat response")?; + let body = call_chat_http(prompt).await?; let mut acc = String::new(); for line in body.lines() { @@ -141,11 +122,6 @@ struct SearchResult { hits: Vec, } -#[derive(serde::Serialize)] -struct ChatRequest<'a> { - message: &'a str, -} - #[derive(serde::Deserialize)] struct ChatChunk { content: Option, diff --git a/rust/commands/prefs.rs b/rust/commands/prefs.rs index ebc8bcb..1a40b2d 100644 --- a/rust/commands/prefs.rs +++ b/rust/commands/prefs.rs @@ -291,6 +291,7 @@ struct ShowPreferences { chat_overall_top_k: usize, chat_per_memory_cap: usize, chat_mmr_lambda: u8, + embedding_model_id: String, } impl From for ShowPreferences { @@ -302,6 +303,7 @@ impl From for ShowPreferences { chat_overall_top_k: value.chat_overall_top_k, chat_per_memory_cap: value.chat_per_memory_cap, chat_mmr_lambda: value.chat_mmr_lambda, + embedding_model_id: value.embedding_model_id, } } } @@ -407,6 +409,7 @@ mod tests { chat_overall_top_k: DEFAULT_CHAT_OVERALL_TOP_K, chat_per_memory_cap: DEFAULT_CHAT_PER_MEMORY_CAP, chat_mmr_lambda: DEFAULT_CHAT_MMR_LAMBDA, + embedding_model_id: "invalid".to_string(), }; let normalized = preferences::normalize_user_preferences(preferences); @@ -423,6 +426,10 @@ mod tests { assert_eq!(normalized.chat_overall_top_k, DEFAULT_CHAT_OVERALL_TOP_K); assert_eq!(normalized.chat_per_memory_cap, DEFAULT_CHAT_PER_MEMORY_CAP); assert_eq!(normalized.chat_mmr_lambda, DEFAULT_CHAT_MMR_LAMBDA); + assert_eq!( + normalized.embedding_model_id, + preferences::default_embedding_model_id() + ); } #[test] @@ -439,6 +446,10 @@ mod tests { DEFAULT_CHAT_PER_MEMORY_CAP ); assert_eq!(serialized["chat_mmr_lambda"], DEFAULT_CHAT_MMR_LAMBDA); + assert_eq!( + serialized["embedding_model_id"], + preferences::default_embedding_model_id() + ); } #[test] diff --git a/rust/commands/search.rs b/rust/commands/search.rs index e1bc029..d4b3dd7 100644 --- a/rust/commands/search.rs +++ b/rust/commands/search.rs @@ -13,7 +13,7 @@ use tracing::info; use crate::{ cli::SearchArgs, clients::{launcher::LauncherClient, memory::MemoryClient}, - embedding::fetch_embedding, + embedding::{ensure_memory_dim_matches, fetch_embedding}, shared::cross_memory_search::{ SearchHit, collect_searchable_memory_ids, fold_search_batches, searchable_memory_id_from_state, sort_search_hits, @@ -142,6 +142,7 @@ pub(crate) async fn search_single_memory_items( embedding: Vec, ) -> Result> { let client = build_memory_client(agent, &memory_id)?; + ensure_memory_dim_matches(&client, &memory_id, embedding.len()).await?; let mut rows = client .search(embedding) .await diff --git a/rust/commands/search_raw.rs b/rust/commands/search_raw.rs index d98a38c..2f21a54 100644 --- a/rust/commands/search_raw.rs +++ b/rust/commands/search_raw.rs @@ -1,13 +1,18 @@ use anyhow::{Context, Result, bail}; use tracing::info; -use crate::{cli::SearchRawArgs, memory_client_builder::build_memory_client}; +use crate::{ + cli::SearchRawArgs, embedding::ensure_vector_dim_matches, + memory_client_builder::build_memory_client, +}; use super::CommandContext; pub async fn handle(args: SearchRawArgs, ctx: &CommandContext) -> Result<()> { let client = build_memory_client(&ctx.agent_factory, &args.memory_id).await?; let embedding = parse_embedding(&args.embedding)?; + let expected_dim = client.get_dim().await?; + ensure_vector_dim_matches(&args.memory_id, embedding.len(), expected_dim)?; let mut results = client.search(embedding).await?; results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); diff --git a/rust/embedding.rs b/rust/embedding.rs index 4331903..a5baade 100644 --- a/rust/embedding.rs +++ b/rust/embedding.rs @@ -1,10 +1,21 @@ +//! Where: shared embedding facade used by CLI, TUI, MCP, and Python bindings. +//! What: routes query embeddings and late chunking to either the remote API or a local model. +//! Why: old API-backed memories must remain usable while local models stay opt-in. + use std::env; use anyhow::{Context, Result, bail}; use reqwest::Client; use serde::{Deserialize, Serialize}; -use crate::operation_timeout::embedding_request_timeout; +use crate::{ + clients::memory::MemoryClient, + embedding_config::{ + API_EMBEDDING_BACKEND_ID, configured_embedding_dimension, selected_embedding_backend_id, + }, + local_embedding, + operation_timeout::embedding_request_timeout, +}; pub(crate) const EMBEDDING_API_ENV_VAR: &str = "EMBEDDING_API_ENDPOINT"; pub(crate) const DEFAULT_EMBEDDING_API_ENDPOINT: &str = "https://api.kinic.io"; @@ -12,6 +23,73 @@ const LATE_CHUNKING_PATH: &str = "/late-chunking"; const EMBEDDING_PATH: &str = "/embedding"; pub async fn late_chunking(text: &str) -> Result> { + if selected_embedding_backend_id()? == API_EMBEDDING_BACKEND_ID { + return late_chunking_remote(text).await; + } + local_embedding::late_chunk_and_embed(text).await +} + +pub async fn fetch_embedding(text: &str) -> Result> { + if selected_embedding_backend_id()? == API_EMBEDDING_BACKEND_ID { + return fetch_embedding_remote(text).await; + } + local_embedding::embed_query(text).await +} + +pub(crate) fn configured_embedding_dimension_u64() -> Result { + configured_embedding_dimension() +} + +pub(crate) async fn ensure_memory_dim_matches( + client: &MemoryClient, + memory_id: &str, + provided_dim: usize, +) -> Result { + let actual_dim = client + .get_dim() + .await + .context("Failed to load memory embedding dimension")?; + ensure_vector_dim_matches(memory_id, provided_dim, actual_dim)?; + Ok(actual_dim) +} + +pub(crate) fn ensure_vector_dim_matches( + memory_id: &str, + provided_dim: usize, + expected_dim: u64, +) -> Result<()> { + if provided_dim == expected_dim as usize { + return Ok(()); + } + bail!( + "Embedding dimension mismatch for memory {memory_id}. Provided {provided_dim}, expected {expected_dim}. Reindex or reset the memory before searching or inserting." + ); +} + +pub(crate) fn embedding_base_url() -> String { + env::var(EMBEDDING_API_ENV_VAR).unwrap_or_else(|_| DEFAULT_EMBEDDING_API_ENDPOINT.to_string()) +} + +pub(crate) async fn call_chat_http(prompt: &str) -> Result { + let url = format!("{}/chat", embedding_base_url()); + let response = Client::new() + .post(url) + .json(&ChatRequest { message: prompt }) + .send() + .await + .context("Failed to call chat endpoint")?; + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + bail!("chat endpoint returned {status}: {body}"); + } + response + .text() + .await + .context("Failed to read chat response") +} + +async fn late_chunking_remote(text: &str) -> Result> { let url = format!("{}{}", embedding_base_url(), LATE_CHUNKING_PATH); let timeout = embedding_request_timeout(text.len()); let response = Client::new() @@ -30,7 +108,7 @@ pub async fn late_chunking(text: &str) -> Result> { Ok(payload.chunks) } -pub async fn fetch_embedding(text: &str) -> Result> { +async fn fetch_embedding_remote(text: &str) -> Result> { let url = format!("{}{}", embedding_base_url(), EMBEDDING_PATH); let timeout = embedding_request_timeout(text.len()); let response = Client::new() @@ -59,10 +137,6 @@ async fn ensure_success(response: reqwest::Response) -> Result String { - env::var(EMBEDDING_API_ENV_VAR).unwrap_or_else(|_| DEFAULT_EMBEDDING_API_ENDPOINT.to_string()) -} - #[derive(Serialize)] struct LateChunkingRequest<'a> { markdown: &'a str, @@ -79,6 +153,11 @@ pub struct LateChunk { pub sentence: String, } +#[derive(Serialize)] +struct ChatRequest<'a> { + message: &'a str, +} + #[derive(Serialize)] struct EmbeddingRequest<'a> { content: &'a str, diff --git a/rust/embedding_config.rs b/rust/embedding_config.rs new file mode 100644 index 0000000..fd84cfb --- /dev/null +++ b/rust/embedding_config.rs @@ -0,0 +1,317 @@ +//! Where: shared by embedding routing, local inference, memory creation defaults, and dimension guards. +//! What: centralizes the selected embedding backend plus local model cache and chunking parameters. +//! Why: create/search/insert must resolve the same backend and dimension without drift. + +use std::{env, path::PathBuf}; + +use anyhow::{Context, Result, anyhow, bail}; +use fastembed::{EmbeddingModel, TextInitOptions}; +use tui_kit_host::settings::SettingsError; + +use crate::preferences; + +const CACHE_DIR_ENV_VAR: &str = "KINIC_LOCAL_EMBEDDING_CACHE_DIR"; +const MAX_LENGTH_ENV_VAR: &str = "KINIC_LOCAL_EMBEDDING_MAX_LENGTH"; +const CHUNK_SOFT_LIMIT_ENV_VAR: &str = "KINIC_LOCAL_EMBEDDING_CHUNK_SOFT_LIMIT"; +const CHUNK_HARD_LIMIT_ENV_VAR: &str = "KINIC_LOCAL_EMBEDDING_CHUNK_HARD_LIMIT"; +const CHUNK_OVERLAP_ENV_VAR: &str = "KINIC_LOCAL_EMBEDDING_CHUNK_OVERLAP"; +const DEFAULT_CACHE_DIR: &str = ".cache/kinic-cli/embeddings"; +const DEFAULT_MAX_LENGTH: usize = 512; +const DEFAULT_CHUNK_SOFT_LIMIT: usize = 800; +const DEFAULT_CHUNK_HARD_LIMIT: usize = 1200; +const DEFAULT_CHUNK_OVERLAP: usize = 120; +pub(crate) const API_EMBEDDING_BACKEND_ID: &str = "api"; +const API_EMBEDDING_BACKEND_LABEL: &str = "API (remote default)"; +const API_EMBEDDING_DIMENSION: usize = 1024; + +#[derive(Debug, Clone, PartialEq, Eq)] +struct ModelSpec { + id: &'static str, + label: &'static str, + model: EmbeddingModel, + dimension: usize, +} + +const SNOWFLAKE_MODEL: ModelSpec = ModelSpec { + id: "Snowflake/snowflake-arctic-embed-s", + label: "Snowflake Arctic Embed S", + model: EmbeddingModel::SnowflakeArcticEmbedS, + dimension: 384, +}; + +const SUPPORTED_MODELS: [ModelSpec; 1] = [SNOWFLAKE_MODEL]; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct SupportedEmbeddingBackend { + pub id: &'static str, + pub label: &'static str, + pub dimension: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ChunkingConfig { + pub soft_limit: usize, + pub hard_limit: usize, + pub overlap: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct LocalEmbeddingConfig { + pub model_id: &'static str, + pub dimension: usize, + pub cache_dir: PathBuf, + pub max_length: usize, + pub chunking: ChunkingConfig, + model: EmbeddingModel, +} + +impl LocalEmbeddingConfig { + pub(crate) fn for_model_id(model_id: &str) -> Result { + let spec = parse_model_spec(model_id)?; + let max_length = env_usize(MAX_LENGTH_ENV_VAR, DEFAULT_MAX_LENGTH)?; + let soft_limit = env_usize(CHUNK_SOFT_LIMIT_ENV_VAR, DEFAULT_CHUNK_SOFT_LIMIT)?; + let hard_limit = env_usize(CHUNK_HARD_LIMIT_ENV_VAR, DEFAULT_CHUNK_HARD_LIMIT)?; + let overlap = env_usize(CHUNK_OVERLAP_ENV_VAR, DEFAULT_CHUNK_OVERLAP)?; + if soft_limit == 0 || hard_limit == 0 { + bail!("Chunk limits must be positive."); + } + if soft_limit > hard_limit { + bail!("Chunk soft limit cannot exceed hard limit."); + } + if overlap >= hard_limit { + bail!("Chunk overlap must be smaller than hard limit."); + } + + Ok(Self { + model_id: spec.id, + dimension: spec.dimension, + cache_dir: cache_dir()?, + max_length, + chunking: ChunkingConfig { + soft_limit, + hard_limit, + overlap, + }, + model: spec.model, + }) + } + + pub(crate) fn text_init_options(&self) -> TextInitOptions { + TextInitOptions::new(self.model.clone()) + .with_cache_dir(self.cache_dir.clone()) + .with_max_length(self.max_length) + .with_show_download_progress(false) + } +} + +pub(crate) fn selected_embedding_backend_id() -> Result<&'static str> { + resolve_embedding_backend_id() +} + +pub(crate) fn configured_embedding_dimension() -> Result { + let backend_id = resolve_embedding_backend_id()?; + if backend_id == API_EMBEDDING_BACKEND_ID { + return Ok(API_EMBEDDING_DIMENSION as u64); + } + Ok(LocalEmbeddingConfig::for_model_id(backend_id)?.dimension as u64) +} + +pub(crate) fn selected_local_embedding_config() -> Result> { + let backend_id = resolve_embedding_backend_id()?; + if backend_id == API_EMBEDDING_BACKEND_ID { + return Ok(None); + } + LocalEmbeddingConfig::for_model_id(backend_id).map(Some) +} + +fn resolve_embedding_backend_id() -> Result<&'static str> { + let preferences = load_embedding_preferences()?; + Ok(normalize_supported_embedding_backend_id( + &preferences.embedding_model_id, + )) +} + +fn load_embedding_preferences() -> Result { + preferences::load_user_preferences().map_err(|error| match error { + SettingsError::NoConfigDir => anyhow!( + "Embedding backend could not be resolved because the shared settings directory is unavailable. Kinic requires a writable config directory for shared preferences." + ), + other => anyhow!(other) + .context("Failed to load shared embedding backend from tui.yaml"), + }) +} + +pub(crate) fn supported_embedding_backends() -> Vec { + std::iter::once(SupportedEmbeddingBackend { + id: API_EMBEDDING_BACKEND_ID, + label: API_EMBEDDING_BACKEND_LABEL, + dimension: API_EMBEDDING_DIMENSION, + }) + .chain( + SUPPORTED_MODELS + .iter() + .map(|spec| SupportedEmbeddingBackend { + id: spec.id, + label: spec.label, + dimension: spec.dimension, + }), + ) + .collect() +} + +pub(crate) fn normalize_supported_embedding_backend_id(raw: &str) -> &'static str { + let trimmed = raw.trim(); + if trimmed == API_EMBEDDING_BACKEND_ID { + return API_EMBEDDING_BACKEND_ID; + } + parse_model_spec(trimmed) + .map(|spec| spec.id) + .unwrap_or(API_EMBEDDING_BACKEND_ID) +} + +fn parse_model_spec(raw: &str) -> Result { + let trimmed = raw.trim(); + SUPPORTED_MODELS + .iter() + .find(|spec| spec.id == trimmed) + .cloned() + .ok_or_else(|| { + anyhow!( + "Embedding backend must be one of: {}", + supported_embedding_backends() + .iter() + .map(|spec| spec.id) + .collect::>() + .join(", ") + ) + }) +} + +fn cache_dir() -> Result { + if let Ok(value) = env::var(CACHE_DIR_ENV_VAR) { + let trimmed = value.trim(); + if trimmed.is_empty() { + bail!("{CACHE_DIR_ENV_VAR} cannot be blank."); + } + return Ok(PathBuf::from(trimmed)); + } + + let home = env::var("HOME").context("HOME is not set")?; + Ok(PathBuf::from(home).join(DEFAULT_CACHE_DIR)) +} + +fn env_usize(name: &str, default: usize) -> Result { + match env::var(name) { + Ok(raw) => raw + .trim() + .parse::() + .with_context(|| format!("{name} must be a positive integer")), + Err(_) => Ok(default), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::preferences; + use std::sync::{Mutex, OnceLock}; + + fn env_guard() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())).lock().unwrap() + } + + fn reset_test_preference_error() { + preferences::set_load_user_preferences_error_for_tests(false); + } + + #[test] + fn configured_dimension_defaults_to_api() { + let _guard = env_guard(); + reset_test_preference_error(); + unsafe { + env::remove_var(CACHE_DIR_ENV_VAR); + env::remove_var(MAX_LENGTH_ENV_VAR); + env::remove_var(CHUNK_SOFT_LIMIT_ENV_VAR); + env::remove_var(CHUNK_HARD_LIMIT_ENV_VAR); + env::remove_var(CHUNK_OVERLAP_ENV_VAR); + } + let config = configured_embedding_dimension().expect("api dimension should load"); + + assert_eq!(config, 1024); + } + + #[test] + fn local_model_config_uses_snowflake() { + let _guard = env_guard(); + reset_test_preference_error(); + unsafe { + env::remove_var(CACHE_DIR_ENV_VAR); + env::remove_var(MAX_LENGTH_ENV_VAR); + env::remove_var(CHUNK_SOFT_LIMIT_ENV_VAR); + env::remove_var(CHUNK_HARD_LIMIT_ENV_VAR); + env::remove_var(CHUNK_OVERLAP_ENV_VAR); + } + let config = LocalEmbeddingConfig::for_model_id("Snowflake/snowflake-arctic-embed-s") + .expect("local config should load"); + + assert_eq!(config.model_id, "Snowflake/snowflake-arctic-embed-s"); + assert_eq!(config.dimension, 384); + assert_eq!(config.max_length, 512); + assert_eq!(config.chunking.soft_limit, 800); + } + + #[test] + fn invalid_chunk_bounds_are_rejected() { + let _guard = env_guard(); + reset_test_preference_error(); + unsafe { + env::set_var(CHUNK_SOFT_LIMIT_ENV_VAR, "1300"); + env::set_var(CHUNK_HARD_LIMIT_ENV_VAR, "1200"); + } + + let error = LocalEmbeddingConfig::for_model_id("Snowflake/snowflake-arctic-embed-s") + .expect_err("invalid bounds should fail"); + assert!(error.to_string().contains("soft limit")); + + unsafe { + env::remove_var(CHUNK_SOFT_LIMIT_ENV_VAR); + env::remove_var(CHUNK_HARD_LIMIT_ENV_VAR); + } + } + + #[test] + fn unsupported_model_is_rejected() { + reset_test_preference_error(); + let error = parse_model_spec("bad-model").expect_err("bad model should fail"); + assert!( + error + .to_string() + .contains("Embedding backend must be one of") + ); + } + + #[test] + fn unsupported_backend_normalizes_to_api() { + reset_test_preference_error(); + assert_eq!( + normalize_supported_embedding_backend_id("bad-model"), + API_EMBEDDING_BACKEND_ID + ); + } + + #[test] + fn configured_dimension_errors_when_preferences_fail_to_load() { + let _guard = env_guard(); + preferences::set_load_user_preferences_error_for_tests(true); + + let error = configured_embedding_dimension().expect_err("load failure should reach caller"); + + assert!( + error + .to_string() + .contains("shared settings directory is unavailable") + ); + reset_test_preference_error(); + } +} diff --git a/rust/insert_service.rs b/rust/insert_service.rs index f111449..7dcde0e 100644 --- a/rust/insert_service.rs +++ b/rust/insert_service.rs @@ -13,7 +13,9 @@ use ic_agent::export::Principal; use serde_json::json; use crate::{ - clients::memory::MemoryClient, commands::convert_pdf::pdf_to_markdown, embedding::late_chunking, + clients::memory::MemoryClient, + commands::convert_pdf::pdf_to_markdown, + embedding::{ensure_vector_dim_matches, late_chunking}, }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -87,6 +89,7 @@ pub async fn execute_insert_request( validate_insert_request_fields(request)?; let validated = validate_and_transform_insert_request(request)?; let prepared = prepare_insert_request(&validated).await?; + ensure_prepared_items_match_memory(client, request.memory_id(), &prepared).await?; let inserted_count = prepared.len(); let source_name = validated.source_name(); @@ -149,6 +152,21 @@ pub fn validate_insert_request_fields(request: &InsertRequest) -> Result<()> { Ok(()) } +async fn ensure_prepared_items_match_memory( + client: &MemoryClient, + memory_id: &str, + items: &[PreparedInsertItem], +) -> Result<()> { + let Some(first) = items.first() else { + bail!("Insert content did not produce any chunks."); + }; + let expected_dim = client + .get_dim() + .await + .context("Failed to load memory embedding dimension")?; + ensure_vector_dim_matches(memory_id, first.embedding.len(), expected_dim) +} + pub fn validate_insert_request_for_submit(request: &InsertRequest) -> Result<()> { let _ = validate_and_transform_insert_request(request)?; Ok(()) @@ -458,6 +476,11 @@ mod tests { assert_eq!(payload, "{\"sentence\":\"hello\",\"tag\":\"docs\"}"); } + #[test] + fn prepared_items_match_expected_dimension() { + ensure_vector_dim_matches("aaaaa-aa", 2, 2).unwrap(); + } + #[test] fn validated_insert_request_source_name_uses_file_name() { let request = ValidatedInsertRequest::Normal { diff --git a/rust/lib.rs b/rust/lib.rs index da9fecd..32a9232 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -6,9 +6,12 @@ pub(crate) mod clients; mod commands; pub(crate) mod create_domain; mod embedding; +mod embedding_config; pub(crate) mod identity_store; pub(crate) mod insert_service; mod ledger; +mod local_chunking; +mod local_embedding; pub(crate) mod memory_client_builder; mod operation_timeout; pub(crate) mod preferences; diff --git a/rust/local_chunking.rs b/rust/local_chunking.rs new file mode 100644 index 0000000..6b6154a --- /dev/null +++ b/rust/local_chunking.rs @@ -0,0 +1,131 @@ +//! Where: local chunk preparation for insert and PDF ingestion. +//! What: splits markdown/text into bounded chunks before passage embedding. +//! Why: local backends need an in-process late-chunking path without changing insert callers. + +use crate::embedding_config::ChunkingConfig; + +pub(crate) fn chunk_markdown(markdown: &str, config: &ChunkingConfig) -> Vec { + let normalized = markdown.replace("\r\n", "\n"); + let blocks = normalized + .split("\n\n") + .map(str::trim) + .filter(|block| !block.is_empty()) + .flat_map(|block| chunk_block(block, config)) + .collect::>(); + if blocks.is_empty() { + return Vec::new(); + } + blocks +} + +fn chunk_block(block: &str, config: &ChunkingConfig) -> Vec { + if block.chars().count() <= config.hard_limit { + return vec![block.to_string()]; + } + + let sentences = split_sentences(block); + if sentences.len() <= 1 { + return split_long_text(block, config); + } + + let mut chunks = Vec::new(); + let mut current = String::new(); + for sentence in sentences { + if sentence.chars().count() > config.hard_limit { + if !current.is_empty() { + chunks.push(current.trim().to_string()); + current.clear(); + } + chunks.extend(split_long_text(&sentence, config)); + continue; + } + + let separator = if current.is_empty() { "" } else { " " }; + if current.chars().count() + separator.len() + sentence.chars().count() > config.soft_limit + && !current.is_empty() + { + chunks.push(current.trim().to_string()); + current.clear(); + } + + if !current.is_empty() { + current.push(' '); + } + current.push_str(sentence.trim()); + } + + if !current.is_empty() { + chunks.push(current.trim().to_string()); + } + chunks +} + +fn split_sentences(block: &str) -> Vec { + let mut pieces = Vec::new(); + let mut current = String::new(); + for ch in block.chars() { + current.push(ch); + if matches!(ch, '.' | '!' | '?' | '\n') { + let trimmed = current.trim(); + if !trimmed.is_empty() { + pieces.push(trimmed.to_string()); + } + current.clear(); + } + } + if !current.trim().is_empty() { + pieces.push(current.trim().to_string()); + } + pieces +} + +fn split_long_text(text: &str, config: &ChunkingConfig) -> Vec { + let chars = text.chars().collect::>(); + let mut chunks = Vec::new(); + let mut start = 0usize; + while start < chars.len() { + let end = usize::min(start + config.hard_limit, chars.len()); + let slice = chars[start..end] + .iter() + .collect::() + .trim() + .to_string(); + if !slice.is_empty() { + chunks.push(slice); + } + if end == chars.len() { + break; + } + start = end.saturating_sub(config.overlap); + } + chunks +} + +#[cfg(test)] +mod tests { + use super::*; + + fn config() -> ChunkingConfig { + ChunkingConfig { + soft_limit: 16, + hard_limit: 24, + overlap: 4, + } + } + + #[test] + fn chunk_markdown_splits_paragraphs_and_long_blocks() { + let chunks = chunk_markdown( + "First sentence. Second sentence.\n\nThis block is definitely long enough to split.", + &config(), + ); + + assert!(chunks.len() >= 3); + assert!(chunks.iter().all(|chunk| !chunk.trim().is_empty())); + } + + #[test] + fn chunk_markdown_ignores_empty_input() { + assert!(chunk_markdown(" \n\n ", &config()).is_empty()); + } +} diff --git a/rust/local_embedding.rs b/rust/local_embedding.rs new file mode 100644 index 0000000..1d0fe07 --- /dev/null +++ b/rust/local_embedding.rs @@ -0,0 +1,125 @@ +//! Where: local embedding inference shared by CLI, TUI, tools, and Python bindings. +//! What: lazily initializes the selected local text embedding model and embeds queries/passages. +//! Why: local backends are opt-in and the selected model can change at runtime via preferences. + +use std::sync::{Mutex, OnceLock}; + +use anyhow::{Context, Result, bail}; +use fastembed::TextEmbedding; + +use crate::{ + embedding::LateChunk, + embedding_config::{LocalEmbeddingConfig, selected_local_embedding_config}, + local_chunking::chunk_markdown, +}; + +const SNOWFLAKE_QUERY_PREFIX: &str = "Represent this sentence for searching relevant passages: "; + +struct CachedModel { + model_id: String, + model: TextEmbedding, +} + +static MODEL: OnceLock>> = OnceLock::new(); + +pub(crate) async fn embed_query(text: &str) -> Result> { + embed_texts(vec![snowflake_query_text(text)]) + .await + .map(|mut rows| { + rows.pop() + .expect("one query input should always produce one embedding") + }) +} + +pub(crate) async fn late_chunk_and_embed(markdown: &str) -> Result> { + let config = load_selected_local_config()?; + let chunks = chunk_markdown(markdown, &config.chunking); + if chunks.is_empty() { + bail!("Insert content is empty after normalization."); + } + let embeddings = embed_texts(chunks.iter().map(|chunk| chunk.to_string()).collect()).await?; + Ok(chunks + .into_iter() + .zip(embeddings) + .map(|(sentence, embedding)| LateChunk { + embedding, + sentence, + }) + .collect()) +} + +async fn embed_texts(inputs: Vec) -> Result>> { + tokio::task::spawn_blocking(move || { + let config = load_selected_local_config()?; + let mut guard = model_cache() + .lock() + .map_err(|_| anyhow::anyhow!("Embedding model lock poisoned"))?; + let cached = ensure_cached_model(&mut guard, &config)?; + cached + .model + .embed(inputs, None) + .context("Failed to generate local embeddings") + }) + .await + .context("Local embedding worker crashed")? +} + +fn load_selected_local_config() -> Result { + selected_local_embedding_config()? + .ok_or_else(|| anyhow::anyhow!("Local embedding backend is not selected")) +} + +fn snowflake_query_text(text: &str) -> String { + format!("{SNOWFLAKE_QUERY_PREFIX}{}", text.trim()) +} + +fn model_cache() -> &'static Mutex> { + MODEL.get_or_init(|| Mutex::new(None)) +} + +fn ensure_cached_model<'a>( + cache: &'a mut Option, + config: &LocalEmbeddingConfig, +) -> Result<&'a mut CachedModel> { + let needs_reload = cache + .as_ref() + .is_none_or(|cached| cached.model_id != config.model_id); + if needs_reload { + std::fs::create_dir_all(&config.cache_dir).with_context(|| { + format!( + "Failed to create embedding cache dir {}", + config.cache_dir.display() + ) + })?; + let model = TextEmbedding::try_new(config.text_init_options())?; + *cache = Some(CachedModel { + model_id: config.model_id.to_string(), + model, + }); + } + Ok(cache + .as_mut() + .expect("embedding cache should exist after initialization")) +} + +#[cfg(test)] +mod tests { + use super::snowflake_query_text; + + #[test] + fn local_embedding_supports_snowflake_model_id() { + let config = crate::embedding_config::LocalEmbeddingConfig::for_model_id( + "Snowflake/snowflake-arctic-embed-s", + ) + .expect("snowflake config should load"); + assert_eq!(config.dimension, 384); + } + + #[test] + fn snowflake_query_text_uses_retrieval_prefix() { + assert_eq!( + snowflake_query_text("hello"), + "Represent this sentence for searching relevant passages: hello" + ); + } +} diff --git a/rust/preferences.rs b/rust/preferences.rs index 27a0f7b..e2ee675 100644 --- a/rust/preferences.rs +++ b/rust/preferences.rs @@ -5,6 +5,8 @@ use kinic_core::{prefs_policy, principal::normalize_memory_id_text, tag}; use serde::{Deserialize, Serialize}; +#[cfg(test)] +use std::cell::Cell; use tui_kit_host::settings::SettingsError; #[cfg(not(test))] use tui_kit_host::settings::{load_yaml_or_default, save_yaml}; @@ -17,6 +19,9 @@ pub use kinic_core::prefs_policy::{ DEFAULT_CHAT_MMR_LAMBDA, DEFAULT_CHAT_OVERALL_TOP_K, DEFAULT_CHAT_PER_MEMORY_CAP, }; +const DEFAULT_EMBEDDING_MODEL_ID: &str = "api"; +const SUPPORTED_EMBEDDING_MODEL_IDS: &[&str] = &["api", "Snowflake/snowflake-arctic-embed-s"]; + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] // Missing legacy fields are backfilled with defaults and unknown fields are ignored. // Type mismatches and malformed YAML still fail to decode so broken settings stay explicit. @@ -32,6 +37,8 @@ pub struct UserPreferences { pub chat_per_memory_cap: usize, #[serde(default = "default_chat_mmr_lambda")] pub chat_mmr_lambda: u8, + #[serde(default = "default_embedding_model_id")] + pub embedding_model_id: String, } impl Default for UserPreferences { @@ -43,6 +50,7 @@ impl Default for UserPreferences { chat_overall_top_k: DEFAULT_CHAT_OVERALL_TOP_K, chat_per_memory_cap: DEFAULT_CHAT_PER_MEMORY_CAP, chat_mmr_lambda: DEFAULT_CHAT_MMR_LAMBDA, + embedding_model_id: default_embedding_model_id(), } } } @@ -59,8 +67,15 @@ pub fn default_chat_mmr_lambda() -> u8 { DEFAULT_CHAT_MMR_LAMBDA } +pub fn default_embedding_model_id() -> String { + DEFAULT_EMBEDDING_MODEL_ID.to_string() +} + #[cfg(test)] pub fn load_user_preferences() -> Result { + if let Some(error) = test_load_error() { + return Err(error); + } Ok(normalize_user_preferences(UserPreferences::default())) } @@ -75,6 +90,25 @@ pub fn save_user_preferences(_preferences: &UserPreferences) -> Result<(), Setti Ok(()) } +#[cfg(test)] +pub fn set_load_user_preferences_error_for_tests(enabled: bool) { + TEST_LOAD_ERROR.with(|slot| slot.set(enabled)); +} + +#[cfg(test)] +fn test_load_error() -> Option { + if TEST_LOAD_ERROR.with(Cell::get) { + Some(SettingsError::NoConfigDir) + } else { + None + } +} + +#[cfg(test)] +thread_local! { + static TEST_LOAD_ERROR: Cell = const { Cell::new(false) }; +} + #[cfg(not(test))] pub fn save_user_preferences(preferences: &UserPreferences) -> Result<(), SettingsError> { save_yaml( @@ -97,6 +131,7 @@ pub fn normalize_user_preferences(mut preferences: UserPreferences) -> UserPrefe preferences.chat_per_memory_cap = normalize_chat_per_memory_cap(preferences.chat_per_memory_cap); preferences.chat_mmr_lambda = normalize_chat_mmr_lambda(preferences.chat_mmr_lambda); + preferences.embedding_model_id = normalize_embedding_model_id(preferences.embedding_model_id); preferences } @@ -124,6 +159,14 @@ pub fn chat_diversity_display(value: u8) -> String { format!("{:.2}", f32::from(value) / 100.0) } +pub fn normalize_embedding_model_id(value: String) -> String { + let trimmed = value.trim(); + if trimmed.is_empty() || !SUPPORTED_EMBEDDING_MODEL_IDS.contains(&trimmed) { + return default_embedding_model_id(); + } + trimmed.to_string() +} + fn normalize_default_memory_id(memory_id: Option) -> Option { memory_id.and_then(|value| normalize_memory_id_text(&value).ok()) } @@ -176,6 +219,7 @@ future_setting: true assert_eq!(normalized.chat_overall_top_k, DEFAULT_CHAT_OVERALL_TOP_K); assert_eq!(normalized.chat_per_memory_cap, DEFAULT_CHAT_PER_MEMORY_CAP); assert_eq!(normalized.chat_mmr_lambda, DEFAULT_CHAT_MMR_LAMBDA); + assert_eq!(normalized.embedding_model_id, DEFAULT_EMBEDDING_MODEL_ID); } #[test] @@ -227,6 +271,20 @@ manual_memory_ids: assert_eq!(normalized.chat_overall_top_k, DEFAULT_CHAT_OVERALL_TOP_K); assert_eq!(normalized.chat_per_memory_cap, DEFAULT_CHAT_PER_MEMORY_CAP); assert_eq!(normalized.chat_mmr_lambda, DEFAULT_CHAT_MMR_LAMBDA); + assert_eq!(normalized.embedding_model_id, DEFAULT_EMBEDDING_MODEL_ID); + } + + #[test] + fn user_preferences_normalizes_blank_embedding_model_id() { + let preferences: UserPreferences = serde_yaml::from_str( + r#" +embedding_model_id: " " +"#, + ) + .expect("blank embedding model should deserialize"); + + let normalized = normalize_user_preferences(preferences); + assert_eq!(normalized.embedding_model_id, DEFAULT_EMBEDDING_MODEL_ID); } #[test] diff --git a/rust/tools/service.rs b/rust/tools/service.rs index aee1a8b..a543b64 100644 --- a/rust/tools/service.rs +++ b/rust/tools/service.rs @@ -16,7 +16,7 @@ use crate::{ search::{search_across_memories, searchable_memory_ids}, show::{ShowOutput, load_show_output}, }, - embedding::fetch_embedding, + embedding::{ensure_memory_dim_matches, fetch_embedding}, insert_service::{InsertRequest, execute_insert_request}, memory_client_builder::build_memory_client, }; @@ -198,9 +198,10 @@ impl ToolService { async fn search_memory( client: &MemoryClient, - _memory_id: &str, + memory_id: &str, embedding: Vec, ) -> anyhow::Result> { + ensure_memory_dim_matches(client, memory_id, embedding.len()).await?; let rows = client .search(embedding) .await diff --git a/rust/tui/bridge.rs b/rust/tui/bridge.rs index 156b48b..4bed515 100644 --- a/rust/tui/bridge.rs +++ b/rust/tui/bridge.rs @@ -7,7 +7,7 @@ use crate::{ memory::MemoryClient, }, create_domain::{BalanceDelta, balance_delta, required_balance}, - embedding::embedding_base_url, + embedding::{embedding_base_url, ensure_memory_dim_matches}, insert_service::{InsertRequest, execute_insert_request}, ledger::{fetch_balance, fetch_fee, transfer}, shared::{ @@ -355,6 +355,7 @@ pub async fn search_memory_with_agent( ) -> Result> { let memory = Principal::from_text(&memory_id).context("Failed to parse memory canister id")?; let client = MemoryClient::new(agent, memory); + ensure_memory_dim_matches(&client, &memory_id, embedding.len()).await?; let mut results = client.search(embedding).await?; results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal)); diff --git a/rust/tui/provider/mod.rs b/rust/tui/provider/mod.rs index 457778b..6b42a74 100644 --- a/rust/tui/provider/mod.rs +++ b/rust/tui/provider/mod.rs @@ -15,6 +15,7 @@ use crate::{ agent::{KeychainErrorCode, extract_keychain_error_code}, create_domain::derive_create_cost, embedding::fetch_embedding, + embedding_config::{normalize_supported_embedding_backend_id, supported_embedding_backends}, insert_service::{ InsertRequest, parse_embedding_json, validate_insert_request_fields, validate_insert_request_for_submit, @@ -368,6 +369,7 @@ fn add_action_label_for_context(context: PickerContext) -> Option<&'static str> PickerContext::DefaultMemory | PickerContext::InsertTarget | PickerContext::AddTag + | PickerContext::EmbeddingModel | PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity => None, @@ -388,6 +390,7 @@ fn picker_selected_id_for_context( (!insert_tag.is_empty()).then(|| insert_tag.to_string()) } PickerContext::TagManagement | PickerContext::AddTag => None, + PickerContext::EmbeddingModel => Some(user_preferences.embedding_model_id.clone()), PickerContext::ChatResultLimit => Some(user_preferences.chat_overall_top_k.to_string()), PickerContext::ChatPerMemoryLimit => Some(user_preferences.chat_per_memory_cap.to_string()), PickerContext::ChatDiversity => Some(user_preferences.chat_mmr_lambda.to_string()), @@ -442,6 +445,16 @@ fn picker_items_for_context( .unwrap_or_default(), _ => Vec::new(), }, + PickerContext::EmbeddingModel => supported_embedding_backends() + .iter() + .map(|backend| { + PickerItem::option( + backend.id.to_string(), + format!("{} ({})", backend.label, backend.dimension), + user_preferences.embedding_model_id == backend.id, + ) + }) + .collect(), PickerContext::ChatResultLimit => prefs_policy::chat_result_limit_options() .iter() .map(|value| { @@ -513,6 +526,7 @@ impl<'a> DefaultMemoryController<'a> { chat_overall_top_k: self.user_preferences.chat_overall_top_k, chat_per_memory_cap: self.user_preferences.chat_per_memory_cap, chat_mmr_lambda: self.user_preferences.chat_mmr_lambda, + embedding_model_id: self.user_preferences.embedding_model_id.clone(), }; #[cfg(test)] let _settings_io_lock = settings_io_lock() @@ -2523,6 +2537,9 @@ impl KinicProvider { CoreEffect::Notify(format!("Selected tag {} for insert", item.id)), ], PickerContext::AddTag => Vec::new(), + PickerContext::EmbeddingModel => { + vec![self.set_embedding_model_preference(item.id.as_str())] + } PickerContext::ChatResultLimit => item .id .parse::() @@ -2790,6 +2807,21 @@ impl KinicProvider { } } + fn set_embedding_model_preference(&mut self, model_id: &str) -> CoreEffect { + let normalized = normalize_supported_embedding_backend_id(model_id); + if self.user_preferences.embedding_model_id == normalized { + return CoreEffect::Notify(format!("Embedding backend already set to {normalized}")); + } + match self.update_user_preferences(|preferences| { + preferences.embedding_model_id = normalized.to_string(); + }) { + Ok(()) => CoreEffect::Notify(format!( + "Embedding backend set to {normalized}. API preserves old memories. Local backends may need reindex." + )), + Err(error) => CoreEffect::Notify(format!("Embedding backend save failed: {error}")), + } + } + fn update_user_preferences(&mut self, update: F) -> Result<(), String> where F: FnOnce(&mut UserPreferences), @@ -4998,6 +5030,7 @@ impl DataProvider for KinicProvider { | PickerContext::InsertTarget | PickerContext::InsertTag | PickerContext::TagManagement + | PickerContext::EmbeddingModel | PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity diff --git a/rust/tui/settings.rs b/rust/tui/settings.rs index ad1fb94..18af5b2 100644 --- a/rust/tui/settings.rs +++ b/rust/tui/settings.rs @@ -12,10 +12,12 @@ use tui_kit_host::settings::SettingsError; use tui_kit_host::settings::{load_yaml_or_default, save_yaml}; use tui_kit_runtime::{ SETTINGS_ENTRY_CHAT_DIVERSITY_ID, SETTINGS_ENTRY_CHAT_PER_MEMORY_LIMIT_ID, - SETTINGS_ENTRY_CHAT_RESULT_LIMIT_ID, SETTINGS_ENTRY_DEFAULT_MEMORY_ID, SessionAccountOverview, - SessionSettingsSnapshot, SettingsEntry, SettingsSection, SettingsSnapshot, + SETTINGS_ENTRY_CHAT_RESULT_LIMIT_ID, SETTINGS_ENTRY_DEFAULT_MEMORY_ID, + SETTINGS_ENTRY_EMBEDDING_MODEL_ID, SessionAccountOverview, SessionSettingsSnapshot, + SettingsEntry, SettingsSection, SettingsSnapshot, }; +use crate::embedding_config::supported_embedding_backends; use crate::preferences::{ UserPreferences, chat_diversity_display, chat_per_memory_limit_display, chat_result_limit_display, @@ -28,6 +30,7 @@ const APP_NAMESPACE: &str = "kinic"; const CHAT_HISTORY_FILE_NAME: &str = "chat-threads.yaml"; const UNAVAILABLE: &str = "unavailable"; const NOT_SET: &str = "not set"; +const EMBEDDING_MODEL_NOTE: &str = "Affects create/search/insert. API keeps existing memories usable. Local backends may require reindex. Same-dimension model mismatches are not detectable."; const CHAT_HISTORY_MAX_MESSAGES: usize = 40; const CHAT_MESSAGE_MAX_CONTENT_LEN: usize = 4096; @@ -457,6 +460,7 @@ pub fn build_settings_snapshot( let default_memory_display = default_memory_display(preferences, selector_items, selector_labels); let saved_tags_display = saved_tags_display(preferences); + let embedding_model_display = embedding_model_display(preferences); let chat_result_limit_display = chat_result_limit_display(preferences.chat_overall_top_k); let chat_per_memory_limit_display = chat_per_memory_limit_display(preferences.chat_per_memory_cap); @@ -490,7 +494,7 @@ pub fn build_settings_snapshot( }, SettingsEntry { id: "embedding_api_endpoint".to_string(), - label: "Embedding".to_string(), + label: "Chat API".to_string(), value: session.embedding_api_endpoint.clone(), note: None, }, @@ -511,6 +515,12 @@ pub fn build_settings_snapshot( value: saved_tags_display, note: None, }, + SettingsEntry { + id: SETTINGS_ENTRY_EMBEDDING_MODEL_ID.to_string(), + label: "Embedding backend".to_string(), + value: embedding_model_display, + note: Some(EMBEDDING_MODEL_NOTE.to_string()), + }, SettingsEntry { id: "preferences_status".to_string(), label: "Preferences status".to_string(), @@ -567,7 +577,7 @@ pub fn build_settings_snapshot( }, SettingsEntry { id: "embedding_api_endpoint".to_string(), - label: "Embedding".to_string(), + label: "Chat API".to_string(), value: session.embedding_api_endpoint.clone(), note: None, }, @@ -719,6 +729,14 @@ fn saved_tags_display(preferences: &UserPreferences) -> String { preferences.saved_tags.join(", ") } +fn embedding_model_display(preferences: &UserPreferences) -> String { + supported_embedding_backends() + .into_iter() + .find(|model| model.id == preferences.embedding_model_id) + .map(|model| format!("{} ({})", model.label, model.dimension)) + .unwrap_or_else(|| preferences.embedding_model_id.clone()) +} + #[cfg(test)] #[path = "settings_tests.rs"] mod tests; diff --git a/rust/tui/settings_tests.rs b/rust/tui/settings_tests.rs index 090a8fa..dc3ca26 100644 --- a/rust/tui/settings_tests.rs +++ b/rust/tui/settings_tests.rs @@ -1,7 +1,9 @@ use super::*; use crate::preferences::UserPreferences; use candid::Nat; -use tui_kit_runtime::{SETTINGS_ENTRY_DEFAULT_MEMORY_ID, SessionAccountOverview}; +use tui_kit_runtime::{ + SETTINGS_ENTRY_DEFAULT_MEMORY_ID, SETTINGS_ENTRY_EMBEDDING_MODEL_ID, SessionAccountOverview, +}; fn deferred_session() -> SessionSettingsSnapshot { session_settings_snapshot( @@ -198,6 +200,24 @@ fn settings_snapshot_projects_default_memory_and_preferences_status() { status, "{name}" ); + assert_eq!( + section_entry_value( + &snapshot, + "Saved preferences", + SETTINGS_ENTRY_EMBEDDING_MODEL_ID + ), + "API (remote default) (1024)", + "{name}" + ); + assert!( + section_entry_note( + &snapshot, + "Saved preferences", + SETTINGS_ENTRY_EMBEDDING_MODEL_ID + ) + .is_some(), + "{name}" + ); assert_eq!( section_titles, vec![ @@ -271,6 +291,7 @@ fn settings_snapshot_projects_chat_retrieval_section() { chat_overall_top_k: 10, chat_per_memory_cap: 4, chat_mmr_lambda: 80, + embedding_model_id: "Snowflake/snowflake-arctic-embed-s".to_string(), ..UserPreferences::default() }, &Vec::new(), @@ -290,6 +311,14 @@ fn settings_snapshot_projects_chat_retrieval_section() { section_entry_value(&snapshot, "Chat retrieval", "chat_diversity"), "0.80" ); + assert_eq!( + section_entry_value( + &snapshot, + "Saved preferences", + SETTINGS_ENTRY_EMBEDDING_MODEL_ID + ), + "Snowflake Arctic Embed S (384)" + ); } #[test] diff --git a/tests/kinic_cli_prefs.rs b/tests/kinic_cli_prefs.rs index 192b629..0986789 100644 --- a/tests/kinic_cli_prefs.rs +++ b/tests/kinic_cli_prefs.rs @@ -58,6 +58,7 @@ fn prefs_show_runs_without_identity() { "chat_overall_top_k": 8, "chat_per_memory_cap": 3, "chat_mmr_lambda": 70, + "embedding_model_id": "api", }) ); } @@ -166,6 +167,7 @@ fn prefs_mutations_update_shared_yaml_and_preserve_chat_fields() { "chat_overall_top_k": 10, "chat_per_memory_cap": 4, "chat_mmr_lambda": 80, + "embedding_model_id": "api", }) ); } diff --git a/tests/kinic_cli_tui.rs b/tests/kinic_cli_tui.rs index 86a38d4..e484ee4 100644 --- a/tests/kinic_cli_tui.rs +++ b/tests/kinic_cli_tui.rs @@ -30,7 +30,7 @@ fn prefs_help_mentions_json_contract_and_examples() { assert!(output.status.success()); let stdout = String::from_utf8(output.stdout).unwrap(); assert!(stdout.contains("All prefs commands return JSON.")); - assert!(stdout.contains("show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer}")); + assert!(stdout.contains("show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}")); assert!(stdout.contains("kinic-cli prefs set-default-memory --memory-id")); } @@ -44,7 +44,7 @@ fn prefs_show_help_mentions_return_shape() { assert!(output.status.success()); let stdout = String::from_utf8(output.stdout).unwrap(); assert!(stdout.contains("Returns JSON.")); - assert!(stdout.contains("{\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer}")); + assert!(stdout.contains("{\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}")); } #[test] diff --git a/tui/crates/tui-kit-render/src/ui/app/screens/settings/mod.rs b/tui/crates/tui-kit-render/src/ui/app/screens/settings/mod.rs index e721936..fe9b2fe 100644 --- a/tui/crates/tui-kit-render/src/ui/app/screens/settings/mod.rs +++ b/tui/crates/tui-kit-render/src/ui/app/screens/settings/mod.rs @@ -184,6 +184,7 @@ pub(crate) fn picker_hint(context: PickerContext) -> &'static str { PickerContext::InsertTag => " Enter: choose ↑/↓: move Esc: close", PickerContext::TagManagement => " Enter: use d: delete ↑/↓: browse Esc: close", PickerContext::AddTag => " Enter: save tag Esc: close", + PickerContext::EmbeddingModel => " Enter: save ↑/↓: move Esc: close", PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity => " Enter: save ↑/↓: move Esc: close", @@ -197,6 +198,7 @@ pub(crate) fn picker_input_placeholder(context: PickerContext) -> &'static str { | PickerContext::InsertTarget | PickerContext::InsertTag | PickerContext::TagManagement + | PickerContext::EmbeddingModel | PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity => "", @@ -217,6 +219,7 @@ fn picker_context_title(context: PickerContext) -> &'static str { PickerContext::InsertTag => "Select insert tag", PickerContext::TagManagement => "Saved tags", PickerContext::AddTag => "Add tag", + PickerContext::EmbeddingModel => "Embedding backend", PickerContext::ChatResultLimit => "Chat result limit", PickerContext::ChatPerMemoryLimit => "Per-memory limit", PickerContext::ChatDiversity => "Chat diversity", @@ -252,6 +255,7 @@ fn picker_empty_message(context: PickerContext) -> &'static str { PickerContext::InsertTag | PickerContext::TagManagement | PickerContext::AddTag => { " No saved tags yet." } + PickerContext::EmbeddingModel => " No embedding backends available.", PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity => " No options available yet.", diff --git a/tui/crates/tui-kit-runtime/src/lib.rs b/tui/crates/tui-kit-runtime/src/lib.rs index 1d80a05..c942565 100644 --- a/tui/crates/tui-kit-runtime/src/lib.rs +++ b/tui/crates/tui-kit-runtime/src/lib.rs @@ -26,6 +26,7 @@ use tui_kit_model::{UiContextNode, UiItemContent, UiItemKind, UiItemSummary}; pub const SETTINGS_ENTRY_DEFAULT_MEMORY_ID: &str = "default_memory"; pub const SETTINGS_ENTRY_KINIC_BALANCE_ID: &str = "kinic_balance"; pub const SETTINGS_ENTRY_SAVED_TAGS_ID: &str = "saved_tags"; +pub const SETTINGS_ENTRY_EMBEDDING_MODEL_ID: &str = "embedding_model"; pub const SETTINGS_ENTRY_CHAT_RESULT_LIMIT_ID: &str = "chat_result_limit"; pub const SETTINGS_ENTRY_CHAT_PER_MEMORY_LIMIT_ID: &str = "chat_per_memory_limit"; pub const SETTINGS_ENTRY_CHAT_DIVERSITY_ID: &str = "chat_diversity"; @@ -179,6 +180,7 @@ pub enum PickerContext { InsertTag, TagManagement, AddTag, + EmbeddingModel, ChatResultLimit, ChatPerMemoryLimit, ChatDiversity, @@ -2629,6 +2631,10 @@ pub fn settings_row_behavior_for_index( Some(CoreAction::OpenPicker(PickerContext::TagManagement)), " manage saved tags ", ), + SETTINGS_ENTRY_EMBEDDING_MODEL_ID => SettingsRowBehavior::new( + Some(CoreAction::OpenPicker(PickerContext::EmbeddingModel)), + " choose embedding backend ", + ), SETTINGS_ENTRY_CHAT_RESULT_LIMIT_ID => SettingsRowBehavior::new( Some(CoreAction::OpenPicker(PickerContext::ChatResultLimit)), " adjust chat limit ", @@ -2780,6 +2786,7 @@ fn picker_selected_index( } PickerContext::TagManagement | PickerContext::AddTag + | PickerContext::EmbeddingModel | PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity => None, @@ -4017,6 +4024,38 @@ mod tests { ); } + #[test] + fn selected_settings_row_behavior_matches_embedding_model_row() { + let settings = SettingsSnapshot { + quick_entries: vec![], + sections: vec![SettingsSection { + title: "Saved preferences".to_string(), + entries: vec![SettingsEntry { + id: SETTINGS_ENTRY_EMBEDDING_MODEL_ID.to_string(), + label: "Embedding backend".to_string(), + value: "API (remote default) (1024)".to_string(), + note: None, + }], + footer: None, + }], + }; + let state = CoreState { + current_tab_id: kinic_tabs::KINIC_SETTINGS_TAB_ID.to_string(), + focus: PaneFocus::Content, + selected_index: Some(0), + settings, + ..CoreState::default() + }; + + assert_eq!( + selected_settings_row_behavior(&state), + Some(SettingsRowBehavior { + enter_action: Some(CoreAction::OpenPicker(PickerContext::EmbeddingModel)), + status_hint: " choose embedding backend ", + }) + ); + } + #[test] fn selected_settings_row_behavior_matches_chat_result_limit_row() { let settings = SettingsSnapshot { From a62e4a28f9ed33d7c0495b2f0f87dcfcb2b2e5aa Mon Sep 17 00:00:00 2001 From: hude Date: Wed, 22 Apr 2026 15:31:31 +0900 Subject: [PATCH 2/4] Ignore local app artifacts and update embedding docs --- .gitignore | 9 +++ README.md | 4 +- rust/clients/launcher.rs | 4 +- rust/commands/prefs.rs | 14 ++++ rust/embedding.rs | 8 +-- rust/embedding_config.rs | 141 +++++++++++-------------------------- rust/local_embedding.rs | 38 ++++------ rust/preferences.rs | 13 ++-- rust/tui/settings.rs | 2 +- rust/tui/settings_tests.rs | 4 +- 10 files changed, 96 insertions(+), 141 deletions(-) diff --git a/.gitignore b/.gitignore index 3c38649..15e55f8 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,12 @@ pylate-index/ #web .next/ node_modules/ + +# local app artifacts +apps/kinic-portal/.env.local +apps/kinic-portal/.dev.vars +apps/kinic-portal/.cache/ +apps/kinic-portal/.wrangler/ +apps/kinic-portal/tsconfig.tsbuildinfo +apps/kinic-portal/workers/public-api/.dev.vars +apps/kinic-portal/workers/public-api/.wrangler/ diff --git a/README.md b/README.md index d275bba..3e0636e 100644 --- a/README.md +++ b/README.md @@ -399,13 +399,15 @@ Embedding backend initialization: - initial saved backend: `api` - API dimension: `1024` -- local option example: `Snowflake/snowflake-arctic-embed-s` +- local option example: `mixedbread-ai/mxbai-embed-large-v1` - local cache dir: `$HOME/.cache/kinic-cli/embeddings` +- new memories are created with dimension `1024` Shared settings behavior: - if the shared config directory is available and `tui.yaml` is missing, Kinic initializes with the saved default `api` - if the shared config directory is unavailable, embedding-backed commands fail explicitly instead of falling back +- local search/insert against older non-`1024` memories fails explicitly until the memory is reset or reindexed Local runtime overrides: diff --git a/rust/clients/launcher.rs b/rust/clients/launcher.rs index e260ed3..4491033 100644 --- a/rust/clients/launcher.rs +++ b/rust/clients/launcher.rs @@ -15,7 +15,7 @@ use thiserror::Error; use crate::{ clients::{LAUNCHER_CANISTER, LEDGER_CANISTER}, - embedding::configured_embedding_dimension_u64, + embedding_config::create_memory_dimension_u64, }; const APPROVAL_TTL_NS: u64 = 10 * 60 * 1_000_000_000; @@ -134,7 +134,7 @@ fn encode_deploy_args(name: &str, description: &str) -> Result> { .to_string(); Ok(candid::encode_args(( payload, - configured_embedding_dimension_u64()?, + create_memory_dimension_u64(), ))?) } diff --git a/rust/commands/prefs.rs b/rust/commands/prefs.rs index 1a40b2d..991f7c3 100644 --- a/rust/commands/prefs.rs +++ b/rust/commands/prefs.rs @@ -432,6 +432,20 @@ mod tests { ); } + #[test] + fn show_preferences_preserves_mxbai_embedding_model_id() { + let serialized = serde_json::to_value(ShowPreferences::from(UserPreferences { + embedding_model_id: "mixedbread-ai/mxbai-embed-large-v1".to_string(), + ..UserPreferences::default() + })) + .expect("show preferences should serialize"); + + assert_eq!( + serialized["embedding_model_id"], + "mixedbread-ai/mxbai-embed-large-v1" + ); + } + #[test] fn show_preferences_from_user_preferences_omits_chat_fields() { let serialized = serde_json::to_value(ShowPreferences::from(UserPreferences::default())) diff --git a/rust/embedding.rs b/rust/embedding.rs index a5baade..4d1e2ef 100644 --- a/rust/embedding.rs +++ b/rust/embedding.rs @@ -10,9 +10,7 @@ use serde::{Deserialize, Serialize}; use crate::{ clients::memory::MemoryClient, - embedding_config::{ - API_EMBEDDING_BACKEND_ID, configured_embedding_dimension, selected_embedding_backend_id, - }, + embedding_config::{API_EMBEDDING_BACKEND_ID, selected_embedding_backend_id}, local_embedding, operation_timeout::embedding_request_timeout, }; @@ -36,10 +34,6 @@ pub async fn fetch_embedding(text: &str) -> Result> { local_embedding::embed_query(text).await } -pub(crate) fn configured_embedding_dimension_u64() -> Result { - configured_embedding_dimension() -} - pub(crate) async fn ensure_memory_dim_matches( client: &MemoryClient, memory_id: &str, diff --git a/rust/embedding_config.rs b/rust/embedding_config.rs index fd84cfb..71f990b 100644 --- a/rust/embedding_config.rs +++ b/rust/embedding_config.rs @@ -23,23 +23,11 @@ const DEFAULT_CHUNK_OVERLAP: usize = 120; pub(crate) const API_EMBEDDING_BACKEND_ID: &str = "api"; const API_EMBEDDING_BACKEND_LABEL: &str = "API (remote default)"; const API_EMBEDDING_DIMENSION: usize = 1024; - -#[derive(Debug, Clone, PartialEq, Eq)] -struct ModelSpec { - id: &'static str, - label: &'static str, - model: EmbeddingModel, - dimension: usize, -} - -const SNOWFLAKE_MODEL: ModelSpec = ModelSpec { - id: "Snowflake/snowflake-arctic-embed-s", - label: "Snowflake Arctic Embed S", - model: EmbeddingModel::SnowflakeArcticEmbedS, - dimension: 384, -}; - -const SUPPORTED_MODELS: [ModelSpec; 1] = [SNOWFLAKE_MODEL]; +pub(crate) const MXBAI_EMBEDDING_BACKEND_ID: &str = "mixedbread-ai/mxbai-embed-large-v1"; +const MXBAI_EMBEDDING_BACKEND_LABEL: &str = "Mixedbread MXBAI Embed Large V1"; +const MXBAI_EMBEDDING_DIMENSION: usize = 1024; +pub(crate) const MXBAI_QUERY_PREFIX: &str = + "Represent this sentence for searching relevant passages: "; #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct SupportedEmbeddingBackend { @@ -57,17 +45,13 @@ pub(crate) struct ChunkingConfig { #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct LocalEmbeddingConfig { - pub model_id: &'static str, - pub dimension: usize, pub cache_dir: PathBuf, pub max_length: usize, pub chunking: ChunkingConfig, - model: EmbeddingModel, } impl LocalEmbeddingConfig { - pub(crate) fn for_model_id(model_id: &str) -> Result { - let spec = parse_model_spec(model_id)?; + pub(crate) fn mxbai() -> Result { let max_length = env_usize(MAX_LENGTH_ENV_VAR, DEFAULT_MAX_LENGTH)?; let soft_limit = env_usize(CHUNK_SOFT_LIMIT_ENV_VAR, DEFAULT_CHUNK_SOFT_LIMIT)?; let hard_limit = env_usize(CHUNK_HARD_LIMIT_ENV_VAR, DEFAULT_CHUNK_HARD_LIMIT)?; @@ -83,8 +67,6 @@ impl LocalEmbeddingConfig { } Ok(Self { - model_id: spec.id, - dimension: spec.dimension, cache_dir: cache_dir()?, max_length, chunking: ChunkingConfig { @@ -92,12 +74,11 @@ impl LocalEmbeddingConfig { hard_limit, overlap, }, - model: spec.model, }) } pub(crate) fn text_init_options(&self) -> TextInitOptions { - TextInitOptions::new(self.model.clone()) + TextInitOptions::new(EmbeddingModel::MxbaiEmbedLargeV1) .with_cache_dir(self.cache_dir.clone()) .with_max_length(self.max_length) .with_show_download_progress(false) @@ -108,20 +89,16 @@ pub(crate) fn selected_embedding_backend_id() -> Result<&'static str> { resolve_embedding_backend_id() } -pub(crate) fn configured_embedding_dimension() -> Result { - let backend_id = resolve_embedding_backend_id()?; - if backend_id == API_EMBEDDING_BACKEND_ID { - return Ok(API_EMBEDDING_DIMENSION as u64); - } - Ok(LocalEmbeddingConfig::for_model_id(backend_id)?.dimension as u64) +pub(crate) fn create_memory_dimension_u64() -> u64 { + API_EMBEDDING_DIMENSION as u64 } pub(crate) fn selected_local_embedding_config() -> Result> { - let backend_id = resolve_embedding_backend_id()?; - if backend_id == API_EMBEDDING_BACKEND_ID { - return Ok(None); + match resolve_embedding_backend_id()? { + API_EMBEDDING_BACKEND_ID => Ok(None), + MXBAI_EMBEDDING_BACKEND_ID => LocalEmbeddingConfig::mxbai().map(Some), + _ => unreachable!("embedding backend should already be normalized"), } - LocalEmbeddingConfig::for_model_id(backend_id).map(Some) } fn resolve_embedding_backend_id() -> Result<&'static str> { @@ -142,49 +119,26 @@ fn load_embedding_preferences() -> Result { } pub(crate) fn supported_embedding_backends() -> Vec { - std::iter::once(SupportedEmbeddingBackend { - id: API_EMBEDDING_BACKEND_ID, - label: API_EMBEDDING_BACKEND_LABEL, - dimension: API_EMBEDDING_DIMENSION, - }) - .chain( - SUPPORTED_MODELS - .iter() - .map(|spec| SupportedEmbeddingBackend { - id: spec.id, - label: spec.label, - dimension: spec.dimension, - }), - ) - .collect() + vec![ + SupportedEmbeddingBackend { + id: API_EMBEDDING_BACKEND_ID, + label: API_EMBEDDING_BACKEND_LABEL, + dimension: API_EMBEDDING_DIMENSION, + }, + SupportedEmbeddingBackend { + id: MXBAI_EMBEDDING_BACKEND_ID, + label: MXBAI_EMBEDDING_BACKEND_LABEL, + dimension: MXBAI_EMBEDDING_DIMENSION, + }, + ] } pub(crate) fn normalize_supported_embedding_backend_id(raw: &str) -> &'static str { - let trimmed = raw.trim(); - if trimmed == API_EMBEDDING_BACKEND_ID { - return API_EMBEDDING_BACKEND_ID; + match raw.trim() { + API_EMBEDDING_BACKEND_ID => API_EMBEDDING_BACKEND_ID, + MXBAI_EMBEDDING_BACKEND_ID => MXBAI_EMBEDDING_BACKEND_ID, + _ => API_EMBEDDING_BACKEND_ID, } - parse_model_spec(trimmed) - .map(|spec| spec.id) - .unwrap_or(API_EMBEDDING_BACKEND_ID) -} - -fn parse_model_spec(raw: &str) -> Result { - let trimmed = raw.trim(); - SUPPORTED_MODELS - .iter() - .find(|spec| spec.id == trimmed) - .cloned() - .ok_or_else(|| { - anyhow!( - "Embedding backend must be one of: {}", - supported_embedding_backends() - .iter() - .map(|spec| spec.id) - .collect::>() - .join(", ") - ) - }) } fn cache_dir() -> Result { @@ -226,7 +180,7 @@ mod tests { } #[test] - fn configured_dimension_defaults_to_api() { + fn selected_local_embedding_config_defaults_to_api_backend() { let _guard = env_guard(); reset_test_preference_error(); unsafe { @@ -236,13 +190,13 @@ mod tests { env::remove_var(CHUNK_HARD_LIMIT_ENV_VAR); env::remove_var(CHUNK_OVERLAP_ENV_VAR); } - let config = configured_embedding_dimension().expect("api dimension should load"); + let config = selected_local_embedding_config().expect("api backend should load"); - assert_eq!(config, 1024); + assert_eq!(config, None); } #[test] - fn local_model_config_uses_snowflake() { + fn local_model_config_uses_mxbai() { let _guard = env_guard(); reset_test_preference_error(); unsafe { @@ -252,11 +206,8 @@ mod tests { env::remove_var(CHUNK_HARD_LIMIT_ENV_VAR); env::remove_var(CHUNK_OVERLAP_ENV_VAR); } - let config = LocalEmbeddingConfig::for_model_id("Snowflake/snowflake-arctic-embed-s") - .expect("local config should load"); + let config = LocalEmbeddingConfig::mxbai().expect("local config should load"); - assert_eq!(config.model_id, "Snowflake/snowflake-arctic-embed-s"); - assert_eq!(config.dimension, 384); assert_eq!(config.max_length, 512); assert_eq!(config.chunking.soft_limit, 800); } @@ -270,8 +221,7 @@ mod tests { env::set_var(CHUNK_HARD_LIMIT_ENV_VAR, "1200"); } - let error = LocalEmbeddingConfig::for_model_id("Snowflake/snowflake-arctic-embed-s") - .expect_err("invalid bounds should fail"); + let error = LocalEmbeddingConfig::mxbai().expect_err("invalid bounds should fail"); assert!(error.to_string().contains("soft limit")); unsafe { @@ -280,17 +230,6 @@ mod tests { } } - #[test] - fn unsupported_model_is_rejected() { - reset_test_preference_error(); - let error = parse_model_spec("bad-model").expect_err("bad model should fail"); - assert!( - error - .to_string() - .contains("Embedding backend must be one of") - ); - } - #[test] fn unsupported_backend_normalizes_to_api() { reset_test_preference_error(); @@ -301,11 +240,12 @@ mod tests { } #[test] - fn configured_dimension_errors_when_preferences_fail_to_load() { + fn selected_local_embedding_config_errors_when_preferences_fail_to_load() { let _guard = env_guard(); preferences::set_load_user_preferences_error_for_tests(true); - let error = configured_embedding_dimension().expect_err("load failure should reach caller"); + let error = + selected_local_embedding_config().expect_err("load failure should reach caller"); assert!( error @@ -314,4 +254,9 @@ mod tests { ); reset_test_preference_error(); } + + #[test] + fn create_memory_dimension_is_fixed_to_1024() { + assert_eq!(create_memory_dimension_u64(), 1024); + } } diff --git a/rust/local_embedding.rs b/rust/local_embedding.rs index 1d0fe07..dff2e09 100644 --- a/rust/local_embedding.rs +++ b/rust/local_embedding.rs @@ -9,21 +9,19 @@ use fastembed::TextEmbedding; use crate::{ embedding::LateChunk, - embedding_config::{LocalEmbeddingConfig, selected_local_embedding_config}, + embedding_config::{LocalEmbeddingConfig, MXBAI_QUERY_PREFIX, selected_local_embedding_config}, local_chunking::chunk_markdown, }; -const SNOWFLAKE_QUERY_PREFIX: &str = "Represent this sentence for searching relevant passages: "; - struct CachedModel { - model_id: String, model: TextEmbedding, } static MODEL: OnceLock>> = OnceLock::new(); pub(crate) async fn embed_query(text: &str) -> Result> { - embed_texts(vec![snowflake_query_text(text)]) + load_selected_local_config()?; + embed_texts(vec![mxbai_query_text(text)]) .await .map(|mut rows| { rows.pop() @@ -69,8 +67,8 @@ fn load_selected_local_config() -> Result { .ok_or_else(|| anyhow::anyhow!("Local embedding backend is not selected")) } -fn snowflake_query_text(text: &str) -> String { - format!("{SNOWFLAKE_QUERY_PREFIX}{}", text.trim()) +fn mxbai_query_text(text: &str) -> String { + format!("{}{}", MXBAI_QUERY_PREFIX, text.trim()) } fn model_cache() -> &'static Mutex> { @@ -81,10 +79,7 @@ fn ensure_cached_model<'a>( cache: &'a mut Option, config: &LocalEmbeddingConfig, ) -> Result<&'a mut CachedModel> { - let needs_reload = cache - .as_ref() - .is_none_or(|cached| cached.model_id != config.model_id); - if needs_reload { + if cache.is_none() { std::fs::create_dir_all(&config.cache_dir).with_context(|| { format!( "Failed to create embedding cache dir {}", @@ -92,10 +87,7 @@ fn ensure_cached_model<'a>( ) })?; let model = TextEmbedding::try_new(config.text_init_options())?; - *cache = Some(CachedModel { - model_id: config.model_id.to_string(), - model, - }); + *cache = Some(CachedModel { model }); } Ok(cache .as_mut() @@ -104,21 +96,19 @@ fn ensure_cached_model<'a>( #[cfg(test)] mod tests { - use super::snowflake_query_text; + use super::mxbai_query_text; #[test] - fn local_embedding_supports_snowflake_model_id() { - let config = crate::embedding_config::LocalEmbeddingConfig::for_model_id( - "Snowflake/snowflake-arctic-embed-s", - ) - .expect("snowflake config should load"); - assert_eq!(config.dimension, 384); + fn local_embedding_supports_mxbai_model_id() { + let config = crate::embedding_config::LocalEmbeddingConfig::mxbai() + .expect("mxbai config should load"); + assert_eq!(config.max_length, 512); } #[test] - fn snowflake_query_text_uses_retrieval_prefix() { + fn mxbai_query_text_uses_retrieval_prefix() { assert_eq!( - snowflake_query_text("hello"), + mxbai_query_text("hello"), "Represent this sentence for searching relevant passages: hello" ); } diff --git a/rust/preferences.rs b/rust/preferences.rs index e2ee675..832c0f2 100644 --- a/rust/preferences.rs +++ b/rust/preferences.rs @@ -11,6 +11,8 @@ use tui_kit_host::settings::SettingsError; #[cfg(not(test))] use tui_kit_host::settings::{load_yaml_or_default, save_yaml}; +use crate::embedding_config::{API_EMBEDDING_BACKEND_ID, MXBAI_EMBEDDING_BACKEND_ID}; + #[cfg(not(test))] const APP_NAMESPACE: &str = "kinic"; #[cfg(not(test))] @@ -19,8 +21,7 @@ pub use kinic_core::prefs_policy::{ DEFAULT_CHAT_MMR_LAMBDA, DEFAULT_CHAT_OVERALL_TOP_K, DEFAULT_CHAT_PER_MEMORY_CAP, }; -const DEFAULT_EMBEDDING_MODEL_ID: &str = "api"; -const SUPPORTED_EMBEDDING_MODEL_IDS: &[&str] = &["api", "Snowflake/snowflake-arctic-embed-s"]; +const DEFAULT_EMBEDDING_MODEL_ID: &str = API_EMBEDDING_BACKEND_ID; #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] // Missing legacy fields are backfilled with defaults and unknown fields are ignored. @@ -160,11 +161,11 @@ pub fn chat_diversity_display(value: u8) -> String { } pub fn normalize_embedding_model_id(value: String) -> String { - let trimmed = value.trim(); - if trimmed.is_empty() || !SUPPORTED_EMBEDDING_MODEL_IDS.contains(&trimmed) { - return default_embedding_model_id(); + match value.trim() { + API_EMBEDDING_BACKEND_ID => API_EMBEDDING_BACKEND_ID.to_string(), + MXBAI_EMBEDDING_BACKEND_ID => MXBAI_EMBEDDING_BACKEND_ID.to_string(), + _ => default_embedding_model_id(), } - trimmed.to_string() } fn normalize_default_memory_id(memory_id: Option) -> Option { diff --git a/rust/tui/settings.rs b/rust/tui/settings.rs index 18af5b2..750378c 100644 --- a/rust/tui/settings.rs +++ b/rust/tui/settings.rs @@ -30,7 +30,7 @@ const APP_NAMESPACE: &str = "kinic"; const CHAT_HISTORY_FILE_NAME: &str = "chat-threads.yaml"; const UNAVAILABLE: &str = "unavailable"; const NOT_SET: &str = "not set"; -const EMBEDDING_MODEL_NOTE: &str = "Affects create/search/insert. API keeps existing memories usable. Local backends may require reindex. Same-dimension model mismatches are not detectable."; +const EMBEDDING_MODEL_NOTE: &str = "Affects search/insert. New memories are created at 1024 dims. API keeps existing memories usable. Reindex or reset older non-1024 memories before local search/insert."; const CHAT_HISTORY_MAX_MESSAGES: usize = 40; const CHAT_MESSAGE_MAX_CONTENT_LEN: usize = 4096; diff --git a/rust/tui/settings_tests.rs b/rust/tui/settings_tests.rs index dc3ca26..16611d1 100644 --- a/rust/tui/settings_tests.rs +++ b/rust/tui/settings_tests.rs @@ -291,7 +291,7 @@ fn settings_snapshot_projects_chat_retrieval_section() { chat_overall_top_k: 10, chat_per_memory_cap: 4, chat_mmr_lambda: 80, - embedding_model_id: "Snowflake/snowflake-arctic-embed-s".to_string(), + embedding_model_id: "mixedbread-ai/mxbai-embed-large-v1".to_string(), ..UserPreferences::default() }, &Vec::new(), @@ -317,7 +317,7 @@ fn settings_snapshot_projects_chat_retrieval_section() { "Saved preferences", SETTINGS_ENTRY_EMBEDDING_MODEL_ID ), - "Snowflake Arctic Embed S (384)" + "Mixedbread MXBAI Embed Large V1 (1024)" ); } From 6d91ba5ed0f680df98b552c8bafc260678af5a2a Mon Sep 17 00:00:00 2001 From: hude Date: Thu, 23 Apr 2026 10:54:58 +0900 Subject: [PATCH 3/4] Add embedding backend prefs and no-config fallback --- rust/cli_defs.rs | 17 ++++- rust/commands/prefs.rs | 31 ++++++++- rust/embedding_config.rs | 86 +++++++++++++++++++++---- rust/insert_service.rs | 85 +++++++++++++++++++++--- rust/preferences.rs | 32 +++++++-- tests/fixtures/capabilities_golden.json | 26 ++++++++ tests/kinic_cli_capabilities.rs | 7 ++ tests/kinic_cli_prefs.rs | 80 +++++++++++++++++++++++ tests/kinic_cli_tui.rs | 1 + 9 files changed, 335 insertions(+), 30 deletions(-) diff --git a/rust/cli_defs.rs b/rust/cli_defs.rs index 755745f..34953cc 100644 --- a/rust/cli_defs.rs +++ b/rust/cli_defs.rs @@ -116,7 +116,7 @@ pub enum Command { Capabilities(CapabilitiesArgs), #[command( about = "Manage local Kinic preferences shared with the TUI. All prefs commands return JSON.", - after_help = "Examples:\n kinic-cli prefs show\n kinic-cli prefs set-default-memory --memory-id MEMORY_CANISTER_ID\n kinic-cli prefs set-chat-overall-top-k --value 10\n\nReturns:\n show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}\n mutations -> {\"resource\": string, \"action\": string, \"status\": \"updated\"|\"unchanged\", \"value\": string|integer|null}" + after_help = "Examples:\n kinic-cli prefs show\n kinic-cli prefs set-default-memory --memory-id MEMORY_CANISTER_ID\n kinic-cli prefs set-embedding-backend --model-id api\n kinic-cli prefs set-chat-overall-top-k --value 10\n\nReturns:\n show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}\n mutations -> {\"resource\": string, \"action\": string, \"status\": \"updated\"|\"unchanged\", \"value\": string|integer|null}" )] Prefs(PrefsArgs), #[command( @@ -459,6 +459,11 @@ pub enum PrefsCommand { after_help = "Returns:\n {\"resource\": \"chat_mmr_lambda\", \"action\": \"set\", \"status\": \"updated\"|\"unchanged\", \"value\": integer}\n\nExample:\n kinic-cli prefs set-chat-mmr-lambda --value 80" )] SetChatMmrLambda(ChatMmrLambdaArgs), + #[command( + about = "Set the embedding backend shared with the TUI. Returns JSON.", + after_help = "Returns:\n {\"resource\": \"embedding_model_id\", \"action\": \"set\", \"status\": \"updated\"|\"unchanged\", \"value\": string}\n\nExample:\n kinic-cli prefs set-embedding-backend --model-id api" + )] + SetEmbeddingBackend(EmbeddingBackendArgs), } #[derive(Args, Debug)] @@ -533,6 +538,16 @@ pub struct ChatMmrLambdaArgs { pub value: u8, } +#[derive(Args, Debug)] +pub struct EmbeddingBackendArgs { + #[arg( + long, + required = true, + help = "Embedding backend id shared with the TUI, e.g. api or mixedbread-ai/mxbai-embed-large-v1" + )] + pub model_id: String, +} + #[derive(Args, Debug)] pub struct UpdateArgs { #[arg( diff --git a/rust/commands/prefs.rs b/rust/commands/prefs.rs index 991f7c3..efbd457 100644 --- a/rust/commands/prefs.rs +++ b/rust/commands/prefs.rs @@ -14,10 +14,12 @@ use serde::Serialize; use crate::{ cli::{ - AddMemoryArgs, ChatMmrLambdaArgs, ChatOverallTopKArgs, ChatPerMemoryCapArgs, GlobalOpts, - MemoryIdArgs, PrefsArgs, PrefsCommand, SetDefaultMemoryArgs, TagArgs, + AddMemoryArgs, ChatMmrLambdaArgs, ChatOverallTopKArgs, ChatPerMemoryCapArgs, + EmbeddingBackendArgs, GlobalOpts, MemoryIdArgs, PrefsArgs, PrefsCommand, + SetDefaultMemoryArgs, TagArgs, }, clients::memory::MemoryClient, + embedding_config::normalize_supported_embedding_backend_id, preferences::{self, UserPreferences}, }; @@ -33,6 +35,7 @@ pub async fn handle(args: PrefsArgs, global: &GlobalOpts) -> Result<()> { PrefsCommand::SetChatOverallTopK(args) => set_chat_overall_top_k(args), PrefsCommand::SetChatPerMemoryCap(args) => set_chat_per_memory_cap(args), PrefsCommand::SetChatMmrLambda(args) => set_chat_mmr_lambda(args), + PrefsCommand::SetEmbeddingBackend(args) => set_embedding_backend(args), } } @@ -200,6 +203,18 @@ fn set_chat_mmr_lambda(args: ChatMmrLambdaArgs) -> Result<()> { print_json_response(PrefsResponse::updated("chat_mmr_lambda", "set", value)) } +fn set_embedding_backend(args: EmbeddingBackendArgs) -> Result<()> { + let value = validate_embedding_backend(args.model_id.as_str()); + let mut preferences = load_preferences()?; + if preferences.embedding_model_id == value { + return print_json_response(PrefsResponse::unchanged("embedding_model_id", "set", value)); + } + + preferences.embedding_model_id = value.clone(); + save_preferences(&preferences)?; + print_json_response(PrefsResponse::updated("embedding_model_id", "set", value)) +} + fn load_preferences() -> Result { preferences::load_user_preferences().context("Failed to load shared TUI preferences") } @@ -267,6 +282,10 @@ fn validate_chat_mmr_lambda(value: u8) -> Result { } } +fn validate_embedding_backend(value: &str) -> String { + normalize_supported_embedding_backend_id(value).to_string() +} + fn display_options_usize(values: &[usize]) -> String { values .iter() @@ -508,6 +527,14 @@ mod tests { ); } + #[test] + fn validate_embedding_backend_normalizes_unknown_values_to_api() { + assert_eq!( + validate_embedding_backend("unsupported"), + preferences::default_embedding_model_id() + ); + } + #[test] fn prefs_response_skips_memory_name_by_default() { let json = serde_json::to_value(PrefsResponse::updated( diff --git a/rust/embedding_config.rs b/rust/embedding_config.rs index 71f990b..27ba501 100644 --- a/rust/embedding_config.rs +++ b/rust/embedding_config.rs @@ -56,6 +56,9 @@ impl LocalEmbeddingConfig { let soft_limit = env_usize(CHUNK_SOFT_LIMIT_ENV_VAR, DEFAULT_CHUNK_SOFT_LIMIT)?; let hard_limit = env_usize(CHUNK_HARD_LIMIT_ENV_VAR, DEFAULT_CHUNK_HARD_LIMIT)?; let overlap = env_usize(CHUNK_OVERLAP_ENV_VAR, DEFAULT_CHUNK_OVERLAP)?; + if max_length == 0 { + bail!("{MAX_LENGTH_ENV_VAR} must be a positive integer"); + } if soft_limit == 0 || hard_limit == 0 { bail!("Chunk limits must be positive."); } @@ -101,6 +104,12 @@ pub(crate) fn selected_local_embedding_config() -> Result Result { + Ok(embedding_dimension_for_backend( + resolve_embedding_backend_id()?, + )) +} + fn resolve_embedding_backend_id() -> Result<&'static str> { let preferences = load_embedding_preferences()?; Ok(normalize_supported_embedding_backend_id( @@ -109,13 +118,13 @@ fn resolve_embedding_backend_id() -> Result<&'static str> { } fn load_embedding_preferences() -> Result { - preferences::load_user_preferences().map_err(|error| match error { - SettingsError::NoConfigDir => anyhow!( - "Embedding backend could not be resolved because the shared settings directory is unavailable. Kinic requires a writable config directory for shared preferences." - ), - other => anyhow!(other) - .context("Failed to load shared embedding backend from tui.yaml"), - }) + match preferences::load_user_preferences() { + Ok(preferences) => Ok(preferences), + Err(SettingsError::NoConfigDir) => Ok(preferences::UserPreferences::default()), + Err(other) => { + Err(anyhow!(other).context("Failed to load shared embedding backend from tui.yaml")) + } + } } pub(crate) fn supported_embedding_backends() -> Vec { @@ -141,6 +150,14 @@ pub(crate) fn normalize_supported_embedding_backend_id(raw: &str) -> &'static st } } +pub(crate) fn embedding_dimension_for_backend(backend_id: &str) -> usize { + match normalize_supported_embedding_backend_id(backend_id) { + API_EMBEDDING_BACKEND_ID => API_EMBEDDING_DIMENSION, + MXBAI_EMBEDDING_BACKEND_ID => MXBAI_EMBEDDING_DIMENSION, + _ => unreachable!("embedding backend should already be normalized"), + } +} + fn cache_dir() -> Result { if let Ok(value) = env::var(CACHE_DIR_ENV_VAR) { let trimmed = value.trim(); @@ -176,7 +193,7 @@ mod tests { } fn reset_test_preference_error() { - preferences::set_load_user_preferences_error_for_tests(false); + preferences::set_load_user_preferences_error_for_tests(None); } #[test] @@ -240,21 +257,66 @@ mod tests { } #[test] - fn selected_local_embedding_config_errors_when_preferences_fail_to_load() { + fn selected_local_embedding_config_defaults_to_api_when_no_config_dir_is_available() { let _guard = env_guard(); - preferences::set_load_user_preferences_error_for_tests(true); + preferences::set_load_user_preferences_error_for_tests(Some( + preferences::TestLoadPreferencesError::NoConfigDir, + )); + + let config = selected_local_embedding_config().expect("no config dir should fall back"); + + assert_eq!(config, None); + reset_test_preference_error(); + } + + #[test] + fn selected_embedding_backend_defaults_to_api_when_no_config_dir_is_available() { + let _guard = env_guard(); + preferences::set_load_user_preferences_error_for_tests(Some( + preferences::TestLoadPreferencesError::NoConfigDir, + )); + + let backend = + selected_embedding_backend_id().expect("no config dir should select api backend"); + + assert_eq!(backend, API_EMBEDDING_BACKEND_ID); + reset_test_preference_error(); + } + + #[test] + fn selected_local_embedding_config_still_errors_on_yaml_failure() { + let _guard = env_guard(); + preferences::set_load_user_preferences_error_for_tests(Some( + preferences::TestLoadPreferencesError::Yaml, + )); let error = - selected_local_embedding_config().expect_err("load failure should reach caller"); + selected_local_embedding_config().expect_err("yaml failure should reach caller"); assert!( error .to_string() - .contains("shared settings directory is unavailable") + .contains("Failed to load shared embedding backend from tui.yaml") ); reset_test_preference_error(); } + #[test] + fn local_model_config_rejects_zero_max_length() { + let _guard = env_guard(); + reset_test_preference_error(); + unsafe { + env::set_var(MAX_LENGTH_ENV_VAR, "0"); + } + + let error = LocalEmbeddingConfig::mxbai().expect_err("zero max length should fail"); + + assert!(error.to_string().contains(MAX_LENGTH_ENV_VAR)); + unsafe { + env::remove_var(MAX_LENGTH_ENV_VAR); + } + } + #[test] fn create_memory_dimension_is_fixed_to_1024() { assert_eq!(create_memory_dimension_u64(), 1024); diff --git a/rust/insert_service.rs b/rust/insert_service.rs index 7dcde0e..a40db15 100644 --- a/rust/insert_service.rs +++ b/rust/insert_service.rs @@ -16,6 +16,7 @@ use crate::{ clients::memory::MemoryClient, commands::convert_pdf::pdf_to_markdown, embedding::{ensure_vector_dim_matches, late_chunking}, + embedding_config::selected_embedding_dimension, }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -88,8 +89,13 @@ pub async fn execute_insert_request( ) -> Result { validate_insert_request_fields(request)?; let validated = validate_and_transform_insert_request(request)?; + let expected_dim = client + .get_dim() + .await + .context("Failed to load memory embedding dimension")?; + ensure_request_matches_memory_dim(&validated, request.memory_id(), expected_dim)?; let prepared = prepare_insert_request(&validated).await?; - ensure_prepared_items_match_memory(client, request.memory_id(), &prepared).await?; + ensure_prepared_items_match_memory(request.memory_id(), &prepared, expected_dim)?; let inserted_count = prepared.len(); let source_name = validated.source_name(); @@ -152,18 +158,30 @@ pub fn validate_insert_request_fields(request: &InsertRequest) -> Result<()> { Ok(()) } -async fn ensure_prepared_items_match_memory( - client: &MemoryClient, +fn ensure_request_matches_memory_dim( + request: &ValidatedInsertRequest, + memory_id: &str, + expected_dim: u64, +) -> Result<()> { + match request { + ValidatedInsertRequest::Raw { embedding, .. } => { + ensure_vector_dim_matches(memory_id, embedding.len(), expected_dim) + } + ValidatedInsertRequest::Normal { .. } | ValidatedInsertRequest::Pdf { .. } => { + let selected_dim = selected_embedding_dimension()?; + ensure_vector_dim_matches(memory_id, selected_dim, expected_dim) + } + } +} + +fn ensure_prepared_items_match_memory( memory_id: &str, items: &[PreparedInsertItem], + expected_dim: u64, ) -> Result<()> { let Some(first) = items.first() else { bail!("Insert content did not produce any chunks."); }; - let expected_dim = client - .get_dim() - .await - .context("Failed to load memory embedding dimension")?; ensure_vector_dim_matches(memory_id, first.embedding.len(), expected_dim) } @@ -478,7 +496,58 @@ mod tests { #[test] fn prepared_items_match_expected_dimension() { - ensure_vector_dim_matches("aaaaa-aa", 2, 2).unwrap(); + ensure_prepared_items_match_memory( + "aaaaa-aa", + &[PreparedInsertItem { + embedding: vec![0.1, 0.2], + payload: "{}".to_string(), + }], + 2, + ) + .unwrap(); + } + + #[test] + fn prepared_items_reject_dimension_mismatch() { + let error = ensure_prepared_items_match_memory( + "aaaaa-aa", + &[PreparedInsertItem { + embedding: vec![0.1, 0.2], + payload: "{}".to_string(), + }], + 3, + ) + .unwrap_err(); + + assert!(error.to_string().contains("Embedding dimension mismatch")); + } + + #[test] + fn raw_insert_request_fails_fast_on_dimension_mismatch() { + let request = ValidatedInsertRequest::Raw { + memory_id: "aaaaa-aa".to_string(), + tag: "docs".to_string(), + text: "payload".to_string(), + embedding: vec![0.1, 0.2], + }; + + let error = ensure_request_matches_memory_dim(&request, "aaaaa-aa", 3).unwrap_err(); + + assert!(error.to_string().contains("Embedding dimension mismatch")); + } + + #[test] + fn normal_insert_request_uses_selected_backend_dimension_for_fail_fast_checks() { + let request = ValidatedInsertRequest::Normal { + memory_id: "aaaaa-aa".to_string(), + tag: "docs".to_string(), + text: Some("payload".to_string()), + file_path: None, + }; + + let error = ensure_request_matches_memory_dim(&request, "aaaaa-aa", 3).unwrap_err(); + + assert!(error.to_string().contains("Embedding dimension mismatch")); } #[test] diff --git a/rust/preferences.rs b/rust/preferences.rs index 832c0f2..f84fc6f 100644 --- a/rust/preferences.rs +++ b/rust/preferences.rs @@ -72,6 +72,13 @@ pub fn default_embedding_model_id() -> String { DEFAULT_EMBEDDING_MODEL_ID.to_string() } +#[cfg(test)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TestLoadPreferencesError { + NoConfigDir, + Yaml, +} + #[cfg(test)] pub fn load_user_preferences() -> Result { if let Some(error) = test_load_error() { @@ -92,22 +99,33 @@ pub fn save_user_preferences(_preferences: &UserPreferences) -> Result<(), Setti } #[cfg(test)] -pub fn set_load_user_preferences_error_for_tests(enabled: bool) { - TEST_LOAD_ERROR.with(|slot| slot.set(enabled)); +pub fn set_load_user_preferences_error_for_tests(error: Option) { + TEST_LOAD_ERROR.with(|slot| slot.set(test_load_error_code(error))); } #[cfg(test)] fn test_load_error() -> Option { - if TEST_LOAD_ERROR.with(Cell::get) { - Some(SettingsError::NoConfigDir) - } else { - None + match TEST_LOAD_ERROR.with(Cell::get) { + 1 => Some(SettingsError::NoConfigDir), + 2 => Some(SettingsError::Yaml( + serde_yaml::from_str::(":").unwrap_err(), + )), + _ => None, } } #[cfg(test)] thread_local! { - static TEST_LOAD_ERROR: Cell = const { Cell::new(false) }; + static TEST_LOAD_ERROR: Cell = const { Cell::new(0) }; +} + +#[cfg(test)] +const fn test_load_error_code(error: Option) -> u8 { + match error { + Some(TestLoadPreferencesError::NoConfigDir) => 1, + Some(TestLoadPreferencesError::Yaml) => 2, + None => 0, + } } #[cfg(not(test))] diff --git a/tests/fixtures/capabilities_golden.json b/tests/fixtures/capabilities_golden.json index 046a212..49b5b2b 100644 --- a/tests/fixtures/capabilities_golden.json +++ b/tests/fixtures/capabilities_golden.json @@ -1044,6 +1044,32 @@ "value_kind": "integer" } ] + }, + { + "name": "set-embedding-backend", + "summary": "Set the embedding backend shared with the TUI. Returns JSON.", + "auth": { + "required": false, + "sources": [] + }, + "output": { + "default": "json", + "supported": [ + "json" + ], + "interactive": false + }, + "global_flags_supported": [ + "verbose" + ], + "arguments": [ + { + "name": "model_id", + "required": true, + "input_shape": "single_value", + "value_kind": "string" + } + ] } ] }, diff --git a/tests/kinic_cli_capabilities.rs b/tests/kinic_cli_capabilities.rs index fec22de..b1db9e2 100644 --- a/tests/kinic_cli_capabilities.rs +++ b/tests/kinic_cli_capabilities.rs @@ -112,6 +112,13 @@ fn capabilities_describes_prefs_and_tui_contracts() { .iter() .any(|entry| entry["name"] == "set-default-memory") ); + assert!( + prefs["subcommands"] + .as_array() + .unwrap() + .iter() + .any(|entry| entry["name"] == "set-embedding-backend") + ); let prefs_add_memory = prefs["subcommands"] .as_array() .unwrap() diff --git a/tests/kinic_cli_prefs.rs b/tests/kinic_cli_prefs.rs index 0986789..79eae49 100644 --- a/tests/kinic_cli_prefs.rs +++ b/tests/kinic_cli_prefs.rs @@ -301,6 +301,86 @@ fn prefs_chat_retrieval_mutations_update_yaml_and_show_output() { assert_eq!(parsed["chat_mmr_lambda"], json!(80)); } +#[test] +fn prefs_set_embedding_backend_updates_yaml_and_show_output() { + let config_dir = temp_config_dir("embedding-backend"); + let kinic_dir = app_config_root(&config_dir).join("kinic"); + fs::create_dir_all(&kinic_dir).unwrap(); + fs::write( + kinic_dir.join("tui.yaml"), + "embedding_model_id: mixedbread-ai/mxbai-embed-large-v1\n", + ) + .unwrap(); + + let output = prefs_command(&config_dir) + .args(["prefs", "set-embedding-backend", "--model-id", "api"]) + .output() + .unwrap(); + assert!(output.status.success()); + let parsed: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("json response should parse"); + assert_eq!( + parsed, + json!({ + "resource": "embedding_model_id", + "action": "set", + "status": "updated", + "value": "api" + }) + ); + + let output = prefs_command(&config_dir) + .args(["prefs", "set-embedding-backend", "--model-id", "api"]) + .output() + .unwrap(); + assert!(output.status.success()); + let parsed: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("json response should parse"); + assert_eq!( + parsed, + json!({ + "resource": "embedding_model_id", + "action": "set", + "status": "unchanged", + "value": "api" + }) + ); + + let yaml = read_prefs_yaml(&config_dir); + assert!(yaml.contains("embedding_model_id: api")); + + let output = prefs_command(&config_dir) + .args(["prefs", "show"]) + .output() + .unwrap(); + assert!(output.status.success()); + let parsed: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("show output should parse"); + assert_eq!(parsed["embedding_model_id"], json!("api")); +} + +#[test] +fn prefs_set_embedding_backend_normalizes_invalid_values_to_api() { + let config_dir = temp_config_dir("embedding-backend-normalize"); + + let output = prefs_command(&config_dir) + .args(["prefs", "set-embedding-backend", "--model-id", "bad-model"]) + .output() + .unwrap(); + assert!(output.status.success()); + let parsed: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("json response should parse"); + assert_eq!( + parsed, + json!({ + "resource": "embedding_model_id", + "action": "set", + "status": "unchanged", + "value": "api" + }) + ); +} + #[test] fn prefs_add_memory_validate_requires_identity_or_ii() { let config_dir = temp_config_dir("add-memory-validate"); diff --git a/tests/kinic_cli_tui.rs b/tests/kinic_cli_tui.rs index e484ee4..e979039 100644 --- a/tests/kinic_cli_tui.rs +++ b/tests/kinic_cli_tui.rs @@ -32,6 +32,7 @@ fn prefs_help_mentions_json_contract_and_examples() { assert!(stdout.contains("All prefs commands return JSON.")); assert!(stdout.contains("show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}")); assert!(stdout.contains("kinic-cli prefs set-default-memory --memory-id")); + assert!(stdout.contains("kinic-cli prefs set-embedding-backend --model-id api")); } #[test] From 2fd4ad64be16aa9eed4316c29d9939d5dfa9d019 Mon Sep 17 00:00:00 2001 From: hude Date: Thu, 23 Apr 2026 17:29:43 +0900 Subject: [PATCH 4/4] {"findings":[],"overall_correctness":"patch is correct","overall_explant --- README.md | 2 +- rust/cli_defs.rs | 6 +++--- rust/commands/prefs.rs | 9 +++------ rust/embedding_config.rs | 41 ++++++++++++++++++++++---------------- rust/local_embedding.rs | 24 ++++++++-------------- rust/preferences.rs | 17 ++++++++++++++-- rust/tui/settings_tests.rs | 4 ++-- tests/kinic_cli_prefs.rs | 29 ++++++++++++++++++++++++++- tests/kinic_cli_tui.rs | 2 +- 9 files changed, 85 insertions(+), 49 deletions(-) diff --git a/README.md b/README.md index 3e0636e..09d3fa3 100644 --- a/README.md +++ b/README.md @@ -399,7 +399,7 @@ Embedding backend initialization: - initial saved backend: `api` - API dimension: `1024` -- local option example: `mixedbread-ai/mxbai-embed-large-v1` +- local option example: `BAAI/bge-m3` - local cache dir: `$HOME/.cache/kinic-cli/embeddings` - new memories are created with dimension `1024` diff --git a/rust/cli_defs.rs b/rust/cli_defs.rs index 34953cc..963e1bd 100644 --- a/rust/cli_defs.rs +++ b/rust/cli_defs.rs @@ -116,7 +116,7 @@ pub enum Command { Capabilities(CapabilitiesArgs), #[command( about = "Manage local Kinic preferences shared with the TUI. All prefs commands return JSON.", - after_help = "Examples:\n kinic-cli prefs show\n kinic-cli prefs set-default-memory --memory-id MEMORY_CANISTER_ID\n kinic-cli prefs set-embedding-backend --model-id api\n kinic-cli prefs set-chat-overall-top-k --value 10\n\nReturns:\n show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}\n mutations -> {\"resource\": string, \"action\": string, \"status\": \"updated\"|\"unchanged\", \"value\": string|integer|null}" + after_help = "Examples:\n kinic-cli prefs show\n kinic-cli prefs set-default-memory --memory-id MEMORY_CANISTER_ID\n kinic-cli prefs set-embedding-backend --model-id BAAI/bge-m3\n kinic-cli prefs set-chat-overall-top-k --value 10\n\nReturns:\n show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}\n mutations -> {\"resource\": string, \"action\": string, \"status\": \"updated\"|\"unchanged\", \"value\": string|integer|null}" )] Prefs(PrefsArgs), #[command( @@ -461,7 +461,7 @@ pub enum PrefsCommand { SetChatMmrLambda(ChatMmrLambdaArgs), #[command( about = "Set the embedding backend shared with the TUI. Returns JSON.", - after_help = "Returns:\n {\"resource\": \"embedding_model_id\", \"action\": \"set\", \"status\": \"updated\"|\"unchanged\", \"value\": string}\n\nExample:\n kinic-cli prefs set-embedding-backend --model-id api" + after_help = "Returns:\n {\"resource\": \"embedding_model_id\", \"action\": \"set\", \"status\": \"updated\"|\"unchanged\", \"value\": string}\n\nExample:\n kinic-cli prefs set-embedding-backend --model-id BAAI/bge-m3" )] SetEmbeddingBackend(EmbeddingBackendArgs), } @@ -543,7 +543,7 @@ pub struct EmbeddingBackendArgs { #[arg( long, required = true, - help = "Embedding backend id shared with the TUI, e.g. api or mixedbread-ai/mxbai-embed-large-v1" + help = "Embedding backend id shared with the TUI, e.g. api or BAAI/bge-m3" )] pub model_id: String, } diff --git a/rust/commands/prefs.rs b/rust/commands/prefs.rs index efbd457..1d169cd 100644 --- a/rust/commands/prefs.rs +++ b/rust/commands/prefs.rs @@ -452,17 +452,14 @@ mod tests { } #[test] - fn show_preferences_preserves_mxbai_embedding_model_id() { + fn show_preferences_preserves_bgem3_embedding_model_id() { let serialized = serde_json::to_value(ShowPreferences::from(UserPreferences { - embedding_model_id: "mixedbread-ai/mxbai-embed-large-v1".to_string(), + embedding_model_id: "BAAI/bge-m3".to_string(), ..UserPreferences::default() })) .expect("show preferences should serialize"); - assert_eq!( - serialized["embedding_model_id"], - "mixedbread-ai/mxbai-embed-large-v1" - ); + assert_eq!(serialized["embedding_model_id"], "BAAI/bge-m3"); } #[test] diff --git a/rust/embedding_config.rs b/rust/embedding_config.rs index 27ba501..bd458b6 100644 --- a/rust/embedding_config.rs +++ b/rust/embedding_config.rs @@ -23,11 +23,9 @@ const DEFAULT_CHUNK_OVERLAP: usize = 120; pub(crate) const API_EMBEDDING_BACKEND_ID: &str = "api"; const API_EMBEDDING_BACKEND_LABEL: &str = "API (remote default)"; const API_EMBEDDING_DIMENSION: usize = 1024; -pub(crate) const MXBAI_EMBEDDING_BACKEND_ID: &str = "mixedbread-ai/mxbai-embed-large-v1"; -const MXBAI_EMBEDDING_BACKEND_LABEL: &str = "Mixedbread MXBAI Embed Large V1"; -const MXBAI_EMBEDDING_DIMENSION: usize = 1024; -pub(crate) const MXBAI_QUERY_PREFIX: &str = - "Represent this sentence for searching relevant passages: "; +pub(crate) const BGEM3_EMBEDDING_BACKEND_ID: &str = "BAAI/bge-m3"; +const BGEM3_EMBEDDING_BACKEND_LABEL: &str = "BAAI BGE-M3"; +const BGEM3_EMBEDDING_DIMENSION: usize = 1024; #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct SupportedEmbeddingBackend { @@ -51,7 +49,7 @@ pub(crate) struct LocalEmbeddingConfig { } impl LocalEmbeddingConfig { - pub(crate) fn mxbai() -> Result { + pub(crate) fn bgem3() -> Result { let max_length = env_usize(MAX_LENGTH_ENV_VAR, DEFAULT_MAX_LENGTH)?; let soft_limit = env_usize(CHUNK_SOFT_LIMIT_ENV_VAR, DEFAULT_CHUNK_SOFT_LIMIT)?; let hard_limit = env_usize(CHUNK_HARD_LIMIT_ENV_VAR, DEFAULT_CHUNK_HARD_LIMIT)?; @@ -81,7 +79,7 @@ impl LocalEmbeddingConfig { } pub(crate) fn text_init_options(&self) -> TextInitOptions { - TextInitOptions::new(EmbeddingModel::MxbaiEmbedLargeV1) + TextInitOptions::new(EmbeddingModel::BGEM3) .with_cache_dir(self.cache_dir.clone()) .with_max_length(self.max_length) .with_show_download_progress(false) @@ -99,7 +97,7 @@ pub(crate) fn create_memory_dimension_u64() -> u64 { pub(crate) fn selected_local_embedding_config() -> Result> { match resolve_embedding_backend_id()? { API_EMBEDDING_BACKEND_ID => Ok(None), - MXBAI_EMBEDDING_BACKEND_ID => LocalEmbeddingConfig::mxbai().map(Some), + BGEM3_EMBEDDING_BACKEND_ID => LocalEmbeddingConfig::bgem3().map(Some), _ => unreachable!("embedding backend should already be normalized"), } } @@ -135,9 +133,9 @@ pub(crate) fn supported_embedding_backends() -> Vec { dimension: API_EMBEDDING_DIMENSION, }, SupportedEmbeddingBackend { - id: MXBAI_EMBEDDING_BACKEND_ID, - label: MXBAI_EMBEDDING_BACKEND_LABEL, - dimension: MXBAI_EMBEDDING_DIMENSION, + id: BGEM3_EMBEDDING_BACKEND_ID, + label: BGEM3_EMBEDDING_BACKEND_LABEL, + dimension: BGEM3_EMBEDDING_DIMENSION, }, ] } @@ -145,7 +143,7 @@ pub(crate) fn supported_embedding_backends() -> Vec { pub(crate) fn normalize_supported_embedding_backend_id(raw: &str) -> &'static str { match raw.trim() { API_EMBEDDING_BACKEND_ID => API_EMBEDDING_BACKEND_ID, - MXBAI_EMBEDDING_BACKEND_ID => MXBAI_EMBEDDING_BACKEND_ID, + BGEM3_EMBEDDING_BACKEND_ID => BGEM3_EMBEDDING_BACKEND_ID, _ => API_EMBEDDING_BACKEND_ID, } } @@ -153,7 +151,7 @@ pub(crate) fn normalize_supported_embedding_backend_id(raw: &str) -> &'static st pub(crate) fn embedding_dimension_for_backend(backend_id: &str) -> usize { match normalize_supported_embedding_backend_id(backend_id) { API_EMBEDDING_BACKEND_ID => API_EMBEDDING_DIMENSION, - MXBAI_EMBEDDING_BACKEND_ID => MXBAI_EMBEDDING_DIMENSION, + BGEM3_EMBEDDING_BACKEND_ID => BGEM3_EMBEDDING_DIMENSION, _ => unreachable!("embedding backend should already be normalized"), } } @@ -213,7 +211,7 @@ mod tests { } #[test] - fn local_model_config_uses_mxbai() { + fn local_model_config_uses_bgem3() { let _guard = env_guard(); reset_test_preference_error(); unsafe { @@ -223,7 +221,7 @@ mod tests { env::remove_var(CHUNK_HARD_LIMIT_ENV_VAR); env::remove_var(CHUNK_OVERLAP_ENV_VAR); } - let config = LocalEmbeddingConfig::mxbai().expect("local config should load"); + let config = LocalEmbeddingConfig::bgem3().expect("local config should load"); assert_eq!(config.max_length, 512); assert_eq!(config.chunking.soft_limit, 800); @@ -238,7 +236,7 @@ mod tests { env::set_var(CHUNK_HARD_LIMIT_ENV_VAR, "1200"); } - let error = LocalEmbeddingConfig::mxbai().expect_err("invalid bounds should fail"); + let error = LocalEmbeddingConfig::bgem3().expect_err("invalid bounds should fail"); assert!(error.to_string().contains("soft limit")); unsafe { @@ -256,6 +254,15 @@ mod tests { ); } + #[test] + fn legacy_mxbai_backend_normalizes_to_api() { + reset_test_preference_error(); + assert_eq!( + normalize_supported_embedding_backend_id("mixedbread-ai/mxbai-embed-large-v1"), + API_EMBEDDING_BACKEND_ID + ); + } + #[test] fn selected_local_embedding_config_defaults_to_api_when_no_config_dir_is_available() { let _guard = env_guard(); @@ -309,7 +316,7 @@ mod tests { env::set_var(MAX_LENGTH_ENV_VAR, "0"); } - let error = LocalEmbeddingConfig::mxbai().expect_err("zero max length should fail"); + let error = LocalEmbeddingConfig::bgem3().expect_err("zero max length should fail"); assert!(error.to_string().contains(MAX_LENGTH_ENV_VAR)); unsafe { diff --git a/rust/local_embedding.rs b/rust/local_embedding.rs index dff2e09..a883437 100644 --- a/rust/local_embedding.rs +++ b/rust/local_embedding.rs @@ -9,7 +9,7 @@ use fastembed::TextEmbedding; use crate::{ embedding::LateChunk, - embedding_config::{LocalEmbeddingConfig, MXBAI_QUERY_PREFIX, selected_local_embedding_config}, + embedding_config::{LocalEmbeddingConfig, selected_local_embedding_config}, local_chunking::chunk_markdown, }; @@ -21,7 +21,7 @@ static MODEL: OnceLock>> = OnceLock::new(); pub(crate) async fn embed_query(text: &str) -> Result> { load_selected_local_config()?; - embed_texts(vec![mxbai_query_text(text)]) + embed_texts(vec![text.trim().to_string()]) .await .map(|mut rows| { rows.pop() @@ -67,10 +67,6 @@ fn load_selected_local_config() -> Result { .ok_or_else(|| anyhow::anyhow!("Local embedding backend is not selected")) } -fn mxbai_query_text(text: &str) -> String { - format!("{}{}", MXBAI_QUERY_PREFIX, text.trim()) -} - fn model_cache() -> &'static Mutex> { MODEL.get_or_init(|| Mutex::new(None)) } @@ -96,20 +92,16 @@ fn ensure_cached_model<'a>( #[cfg(test)] mod tests { - use super::mxbai_query_text; - #[test] - fn local_embedding_supports_mxbai_model_id() { - let config = crate::embedding_config::LocalEmbeddingConfig::mxbai() - .expect("mxbai config should load"); + fn local_embedding_supports_bgem3_model_id() { + let config = crate::embedding_config::LocalEmbeddingConfig::bgem3() + .expect("bgem3 config should load"); assert_eq!(config.max_length, 512); } #[test] - fn mxbai_query_text_uses_retrieval_prefix() { - assert_eq!( - mxbai_query_text("hello"), - "Represent this sentence for searching relevant passages: hello" - ); + fn embed_query_trims_without_instruction_prefix() { + let trimmed = " hello ".trim().to_string(); + assert_eq!(trimmed, "hello"); } } diff --git a/rust/preferences.rs b/rust/preferences.rs index f84fc6f..4b0245d 100644 --- a/rust/preferences.rs +++ b/rust/preferences.rs @@ -11,7 +11,7 @@ use tui_kit_host::settings::SettingsError; #[cfg(not(test))] use tui_kit_host::settings::{load_yaml_or_default, save_yaml}; -use crate::embedding_config::{API_EMBEDDING_BACKEND_ID, MXBAI_EMBEDDING_BACKEND_ID}; +use crate::embedding_config::{API_EMBEDDING_BACKEND_ID, BGEM3_EMBEDDING_BACKEND_ID}; #[cfg(not(test))] const APP_NAMESPACE: &str = "kinic"; @@ -181,7 +181,7 @@ pub fn chat_diversity_display(value: u8) -> String { pub fn normalize_embedding_model_id(value: String) -> String { match value.trim() { API_EMBEDDING_BACKEND_ID => API_EMBEDDING_BACKEND_ID.to_string(), - MXBAI_EMBEDDING_BACKEND_ID => MXBAI_EMBEDDING_BACKEND_ID.to_string(), + BGEM3_EMBEDDING_BACKEND_ID => BGEM3_EMBEDDING_BACKEND_ID.to_string(), _ => default_embedding_model_id(), } } @@ -306,6 +306,19 @@ embedding_model_id: " " assert_eq!(normalized.embedding_model_id, DEFAULT_EMBEDDING_MODEL_ID); } + #[test] + fn user_preferences_normalizes_legacy_mxbai_embedding_model_id_to_default() { + let preferences: UserPreferences = serde_yaml::from_str( + r#" +embedding_model_id: "mixedbread-ai/mxbai-embed-large-v1" +"#, + ) + .expect("legacy embedding model should deserialize"); + + let normalized = normalize_user_preferences(preferences); + assert_eq!(normalized.embedding_model_id, DEFAULT_EMBEDDING_MODEL_ID); + } + #[test] fn user_preferences_rejects_invalid_saved_tags_type() { let result = serde_yaml::from_str::( diff --git a/rust/tui/settings_tests.rs b/rust/tui/settings_tests.rs index 16611d1..0a35c82 100644 --- a/rust/tui/settings_tests.rs +++ b/rust/tui/settings_tests.rs @@ -291,7 +291,7 @@ fn settings_snapshot_projects_chat_retrieval_section() { chat_overall_top_k: 10, chat_per_memory_cap: 4, chat_mmr_lambda: 80, - embedding_model_id: "mixedbread-ai/mxbai-embed-large-v1".to_string(), + embedding_model_id: "BAAI/bge-m3".to_string(), ..UserPreferences::default() }, &Vec::new(), @@ -317,7 +317,7 @@ fn settings_snapshot_projects_chat_retrieval_section() { "Saved preferences", SETTINGS_ENTRY_EMBEDDING_MODEL_ID ), - "Mixedbread MXBAI Embed Large V1 (1024)" + "BAAI BGE-M3 (1024)" ); } diff --git a/tests/kinic_cli_prefs.rs b/tests/kinic_cli_prefs.rs index 79eae49..603f572 100644 --- a/tests/kinic_cli_prefs.rs +++ b/tests/kinic_cli_prefs.rs @@ -308,7 +308,7 @@ fn prefs_set_embedding_backend_updates_yaml_and_show_output() { fs::create_dir_all(&kinic_dir).unwrap(); fs::write( kinic_dir.join("tui.yaml"), - "embedding_model_id: mixedbread-ai/mxbai-embed-large-v1\n", + "embedding_model_id: BAAI/bge-m3\n", ) .unwrap(); @@ -359,6 +359,33 @@ fn prefs_set_embedding_backend_updates_yaml_and_show_output() { assert_eq!(parsed["embedding_model_id"], json!("api")); } +#[test] +fn prefs_set_embedding_backend_accepts_bgem3() { + let config_dir = temp_config_dir("embedding-backend-bgem3"); + + let output = prefs_command(&config_dir) + .args([ + "prefs", + "set-embedding-backend", + "--model-id", + "BAAI/bge-m3", + ]) + .output() + .unwrap(); + assert!(output.status.success()); + let parsed: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("json response should parse"); + assert_eq!( + parsed, + json!({ + "resource": "embedding_model_id", + "action": "set", + "status": "updated", + "value": "BAAI/bge-m3" + }) + ); +} + #[test] fn prefs_set_embedding_backend_normalizes_invalid_values_to_api() { let config_dir = temp_config_dir("embedding-backend-normalize"); diff --git a/tests/kinic_cli_tui.rs b/tests/kinic_cli_tui.rs index e979039..c46f708 100644 --- a/tests/kinic_cli_tui.rs +++ b/tests/kinic_cli_tui.rs @@ -32,7 +32,7 @@ fn prefs_help_mentions_json_contract_and_examples() { assert!(stdout.contains("All prefs commands return JSON.")); assert!(stdout.contains("show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}")); assert!(stdout.contains("kinic-cli prefs set-default-memory --memory-id")); - assert!(stdout.contains("kinic-cli prefs set-embedding-backend --model-id api")); + assert!(stdout.contains("kinic-cli prefs set-embedding-backend --model-id BAAI/bge-m3")); } #[test]