diff --git a/.gitignore b/.gitignore index 3c38649..15e55f8 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,12 @@ pylate-index/ #web .next/ node_modules/ + +# local app artifacts +apps/kinic-portal/.env.local +apps/kinic-portal/.dev.vars +apps/kinic-portal/.cache/ +apps/kinic-portal/.wrangler/ +apps/kinic-portal/tsconfig.tsbuildinfo +apps/kinic-portal/workers/public-api/.dev.vars +apps/kinic-portal/workers/public-api/.wrangler/ diff --git a/Cargo.lock b/Cargo.lock index ef2bee9..b590495 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,11 +24,22 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", + "getrandom 0.3.4", "once_cell", + "serde", "version_check", "zerocopy", ] +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -528,6 +539,7 @@ dependencies = [ "itoa", "rustversion", "ryu", + "serde", "static_assertions", ] @@ -540,6 +552,19 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "unicode-width 0.2.0", + "windows-sys 0.59.0", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -599,6 +624,25 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -712,6 +756,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + [[package]] name = "darling" version = "0.21.3" @@ -732,6 +786,20 @@ dependencies = [ "darling_macro 0.23.0", ] +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 2.0.109", +] + [[package]] name = "darling_core" version = "0.21.3" @@ -759,6 +827,17 @@ dependencies = [ "syn 2.0.109", ] +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core 0.20.11", + "quote", + "syn 2.0.109", +] + [[package]] name = "darling_macro" version = "0.21.3" @@ -781,6 +860,15 @@ dependencies = [ "syn 2.0.109", ] +[[package]] +name = "dary_heap" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06d2e3287df1c007e74221c49ca10a95d557349e54b3a75dc2fb14712c751f04" +dependencies = [ + "serde", +] + [[package]] name = "data-encoding" version = "2.9.0" @@ -829,6 +917,37 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling 0.20.11", + "proc-macro2", + "quote", + "syn 2.0.109", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn 2.0.109", +] + [[package]] name = "derive_more" version = "2.1.1" @@ -1007,6 +1126,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "encoding_rs" version = "0.8.35" @@ -1032,6 +1157,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" + [[package]] name = "euclid" version = "0.20.14" @@ -1068,6 +1199,22 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fastembed" +version = "5.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f54fc1188b7f7eac8f47be2ab7b3a79ffd842cc8ff2e38316dd59ba4858890e" +dependencies = [ + "anyhow", + "hf-hub", + "ndarray", + "ort", + "safetensors", + "serde", + "serde_json", + "tokenizers", +] + [[package]] name = "fastrand" version = "2.3.0" @@ -1352,6 +1499,8 @@ dependencies = [ "allocator-api2", "equivalent", "foldhash 0.2.0", + "serde", + "serde_core", ] [[package]] @@ -1369,6 +1518,26 @@ dependencies = [ "serde", ] +[[package]] +name = "hf-hub" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "629d8f3bbeda9d148036d6b0de0a3ab947abd08ce90626327fc3547a49d59d97" +dependencies = [ + "dirs", + "http", + "indicatif", + "libc", + "log", + "rand 0.9.2", + "reqwest", + "serde", + "serde_json", + "thiserror 2.0.17", + "ureq 2.12.1", + "windows-sys 0.60.2", +] + [[package]] name = "hkdf" version = "0.12.4" @@ -1387,6 +1556,12 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "hmac-sha256" +version = "1.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec9d92d097f4749b64e8cc33d924d9f40a2d4eb91402b458014b781f5733d60f" + [[package]] name = "home" version = "0.5.12" @@ -1478,7 +1653,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", - "webpki-roots", + "webpki-roots 1.0.4", ] [[package]] @@ -1829,6 +2004,19 @@ dependencies = [ "hashbrown 0.16.1", ] +[[package]] +name = "indicatif" +version = "0.17.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" +dependencies = [ + "console", + "number_prefix", + "portable-atomic", + "unicode-width 0.2.0", + "web-time", +] + [[package]] name = "indoc" version = "2.0.7" @@ -2007,6 +2195,7 @@ dependencies = [ "clap", "der", "dotenvy", + "fastembed", "gag", "hex", "ic-agent", @@ -2187,6 +2376,28 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" +[[package]] +name = "lzma-rust2" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1670343e58806300d87950e3401e820b519b9384281bbabfb15e3636689ffd69" + +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + [[package]] name = "malloc_buf" version = "0.0.6" @@ -2202,6 +2413,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "md-5" version = "0.10.6" @@ -2293,6 +2514,43 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.109", +] + +[[package]] +name = "ndarray" +version = "0.17.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520080814a7a6b4a6e9070823bb24b4531daac8c4627e08ba5de8c5ef2f2752d" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "ndk-context" version = "0.1.1" @@ -2329,6 +2587,15 @@ dependencies = [ "serde", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" @@ -2354,6 +2621,12 @@ dependencies = [ "libm", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "objc" version = "0.2.7" @@ -2384,6 +2657,28 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "onig" +version = "6.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0" +dependencies = [ + "bitflags", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f86c6eef3d6df15f23bcfb6af487cbd2fed4e5581d58d5bf1f5f8b7f6727dc" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "opaque-debug" version = "0.3.1" @@ -2444,6 +2739,30 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "ort" +version = "2.0.0-rc.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5df903c0d2c07b56950f1058104ab0c8557159f2741782223704de9be73c3c" +dependencies = [ + "ndarray", + "ort-sys", + "smallvec", + "tracing", + "ureq 3.3.0", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06503bb33f294c5f1ba484011e053bfa6ae227074bdb841e9863492dc5960d4b" +dependencies = [ + "hmac-sha256", + "lzma-rust2", + "ureq 3.3.0", +] + [[package]] name = "p256" version = "0.13.2" @@ -2590,6 +2909,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "091397be61a01d4be58e7841595bd4bfedb15f1cd54977d79b8271e94ed799a3" +dependencies = [ + "portable-atomic", +] + [[package]] name = "postscript" version = "0.14.1" @@ -2945,6 +3273,43 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f" +dependencies = [ + "either", + "itertools 0.14.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -2985,6 +3350,35 @@ dependencies = [ "syn 2.0.109", ] +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc897dd8d9e8bd1ed8cdad82b5966c3e0ecae09fb1907d58efaa013543185d0a" + [[package]] name = "reqwest" version = "0.12.24" @@ -3024,7 +3418,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", - "webpki-roots", + "webpki-roots 1.0.4", ] [[package]] @@ -3132,6 +3526,7 @@ version = "0.23.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" dependencies = [ + "log", "once_cell", "ring", "rustls-pki-types", @@ -3173,6 +3568,17 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "safetensors" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "675656c1eabb620b921efea4f9199f97fc86e36dd6ffd1fbbe48d0f59a4987f5" +dependencies = [ + "hashbrown 0.16.1", + "serde", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -3499,6 +3905,17 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "spki" version = "0.7.3" @@ -3509,6 +3926,18 @@ dependencies = [ "der", ] +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -3773,6 +4202,39 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b238e22d44a15349529690fb07bd645cf58149a1b1e44d6cb5bd1641ff1a6223" +dependencies = [ + "ahash", + "aho-corasick", + "compact_str 0.9.0", + "dary_heap", + "derive_builder", + "esaxx-rs", + "getrandom 0.3.4", + "itertools 0.14.0", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.9.2", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 2.0.17", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.48.0" @@ -4021,6 +4483,15 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + [[package]] name = "unicode-segmentation" version = "1.12.0" @@ -4061,6 +4532,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fc81956842c57dac11422a97c3b8195a1ff727f06e85c84ed2e8aa277c9a0fd" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "unindent" version = "0.2.4" @@ -4079,6 +4556,54 @@ version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "2.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02d1a66277ed75f640d608235660df48c8e3c19f3b4edb6a263315626cc3c01d" +dependencies = [ + "base64 0.22.1", + "flate2", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "socks", + "url", + "webpki-roots 0.26.11", +] + +[[package]] +name = "ureq" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dea7109cdcd5864d4eeb1b58a1648dc9bf520360d7af16ec26d0a9354bafcfc0" +dependencies = [ + "base64 0.22.1", + "log", + "percent-encoding", + "rustls", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf8-zero", + "webpki-roots 1.0.4", +] + +[[package]] +name = "ureq-proto" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e994ba84b0bd1b1b0cf92878b7ef898a5c1760108fe7b6010327e274917a808c" +dependencies = [ + "base64 0.22.1", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.7" @@ -4091,6 +4616,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf8-zero" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8c0a043c9540bae7c578c88f91dda8bd82e59ae27c21baca69c8b191aaf5a6e" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -4263,6 +4794,15 @@ dependencies = [ "web-sys", ] +[[package]] +name = "webpki-roots" +version = "0.26.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "521bc38abb08001b01866da9f51eb7c5d647a19260e00054a8c7fd5f9e57f7a9" +dependencies = [ + "webpki-roots 1.0.4", +] + [[package]] name = "webpki-roots" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index 086a82f..09290d0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,6 +39,7 @@ ring = "0.17.14" der = "0.7.10" pkcs8 = "0.10.2" ic-ed25519 = "0.2.0" +fastembed = { version = "5.13.2", default-features = false, features = ["ort-download-binaries-rustls-tls", "hf-hub-rustls-tls"] } kinic-core = { path = "crates/kinic-core" } tui-kit-host = { path = "tui/crates/tui-kit-host" } tui-kit-runtime = { path = "tui/crates/tui-kit-runtime" } diff --git a/README.md b/README.md index 60de5b0..09d3fa3 100644 --- a/README.md +++ b/README.md @@ -384,7 +384,7 @@ See `python/examples/insert_pdf_file.py` for a runnable script. ## Ask AI -Runs a search and prepares context for an AI answer. The CLI calls `/chat` at `EMBEDDING_API_ENDPOINT` (default `https://api.kinic.io`) and prints only the `` text. +Runs a search with the configured embedding backend, prepares context for an AI answer, then calls `/chat` at `EMBEDDING_API_ENDPOINT` (default `https://api.kinic.io`) and prints only the `` text. ```python prompt, answer = km.ask_ai(memory_id, "What did we say about quarterly goals?", top_k=3, language="en") @@ -395,6 +395,28 @@ print("Answer:\n", answer) - `km.ask_ai` returns `(prompt, answer)` where `answer` is the `` section from the chat response. - CLI usage: `cargo run -- --identity ask-ai --memory-id --query "" --top-k 3` +Embedding backend initialization: + +- initial saved backend: `api` +- API dimension: `1024` +- local option example: `BAAI/bge-m3` +- local cache dir: `$HOME/.cache/kinic-cli/embeddings` +- new memories are created with dimension `1024` + +Shared settings behavior: + +- if the shared config directory is available and `tui.yaml` is missing, Kinic initializes with the saved default `api` +- if the shared config directory is unavailable, embedding-backed commands fail explicitly instead of falling back +- local search/insert against older non-`1024` memories fails explicitly until the memory is reset or reindexed + +Local runtime overrides: + +- `KINIC_LOCAL_EMBEDDING_CACHE_DIR` +- `KINIC_LOCAL_EMBEDDING_MAX_LENGTH` +- `KINIC_LOCAL_EMBEDDING_CHUNK_SOFT_LIMIT` +- `KINIC_LOCAL_EMBEDDING_CHUNK_HARD_LIMIT` +- `KINIC_LOCAL_EMBEDDING_CHUNK_OVERLAP` + --- ## Configure memory visibility diff --git a/rust/cli_defs.rs b/rust/cli_defs.rs index 4a9c80d..963e1bd 100644 --- a/rust/cli_defs.rs +++ b/rust/cli_defs.rs @@ -116,7 +116,7 @@ pub enum Command { Capabilities(CapabilitiesArgs), #[command( about = "Manage local Kinic preferences shared with the TUI. All prefs commands return JSON.", - after_help = "Examples:\n kinic-cli prefs show\n kinic-cli prefs set-default-memory --memory-id MEMORY_CANISTER_ID\n kinic-cli prefs set-chat-overall-top-k --value 10\n\nReturns:\n show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer}\n mutations -> {\"resource\": string, \"action\": string, \"status\": \"updated\"|\"unchanged\", \"value\": string|integer|null}" + after_help = "Examples:\n kinic-cli prefs show\n kinic-cli prefs set-default-memory --memory-id MEMORY_CANISTER_ID\n kinic-cli prefs set-embedding-backend --model-id BAAI/bge-m3\n kinic-cli prefs set-chat-overall-top-k --value 10\n\nReturns:\n show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}\n mutations -> {\"resource\": string, \"action\": string, \"status\": \"updated\"|\"unchanged\", \"value\": string|integer|null}" )] Prefs(PrefsArgs), #[command( @@ -411,7 +411,7 @@ pub struct PrefsArgs { pub enum PrefsCommand { #[command( about = "Show local preferences shared with the TUI. Returns JSON.", - after_help = "Returns:\n {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer}\n\nExample:\n kinic-cli prefs show" + after_help = "Returns:\n {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}\n\nExample:\n kinic-cli prefs show" )] Show, #[command( @@ -459,6 +459,11 @@ pub enum PrefsCommand { after_help = "Returns:\n {\"resource\": \"chat_mmr_lambda\", \"action\": \"set\", \"status\": \"updated\"|\"unchanged\", \"value\": integer}\n\nExample:\n kinic-cli prefs set-chat-mmr-lambda --value 80" )] SetChatMmrLambda(ChatMmrLambdaArgs), + #[command( + about = "Set the embedding backend shared with the TUI. Returns JSON.", + after_help = "Returns:\n {\"resource\": \"embedding_model_id\", \"action\": \"set\", \"status\": \"updated\"|\"unchanged\", \"value\": string}\n\nExample:\n kinic-cli prefs set-embedding-backend --model-id BAAI/bge-m3" + )] + SetEmbeddingBackend(EmbeddingBackendArgs), } #[derive(Args, Debug)] @@ -533,6 +538,16 @@ pub struct ChatMmrLambdaArgs { pub value: u8, } +#[derive(Args, Debug)] +pub struct EmbeddingBackendArgs { + #[arg( + long, + required = true, + help = "Embedding backend id shared with the TUI, e.g. api or BAAI/bge-m3" + )] + pub model_id: String, +} + #[derive(Args, Debug)] pub struct UpdateArgs { #[arg( diff --git a/rust/clients/launcher.rs b/rust/clients/launcher.rs index 6b3a1ac..4491033 100644 --- a/rust/clients/launcher.rs +++ b/rust/clients/launcher.rs @@ -13,9 +13,10 @@ use icrc_ledger_types::{ use serde_json::json; use thiserror::Error; -use crate::clients::{LAUNCHER_CANISTER, LEDGER_CANISTER}; - -const DEFAULT_VECTOR_DIM: u64 = 1024; +use crate::{ + clients::{LAUNCHER_CANISTER, LEDGER_CANISTER}, + embedding_config::create_memory_dimension_u64, +}; const APPROVAL_TTL_NS: u64 = 10 * 60 * 1_000_000_000; pub struct LauncherClient { @@ -131,7 +132,10 @@ fn encode_deploy_args(name: &str, description: &str) -> Result> { "name": name, "description": description}) .to_string(); - Ok(candid::encode_args((payload, DEFAULT_VECTOR_DIM))?) + Ok(candid::encode_args(( + payload, + create_memory_dimension_u64(), + ))?) } fn encode_update_instance_args(instance_pid_str: &str) -> Result> { diff --git a/rust/commands/ask_ai.rs b/rust/commands/ask_ai.rs index 263e3df..ad2ab62 100644 --- a/rust/commands/ask_ai.rs +++ b/rust/commands/ask_ai.rs @@ -2,14 +2,13 @@ use std::cmp::Ordering; use anyhow::{Context, Result}; use ic_agent::export::Principal; -use reqwest::Client; use tracing::info; use crate::{ agent::AgentFactory, cli::AskAiArgs, clients::memory::MemoryClient, - embedding::{embedding_base_url, fetch_embedding}, + embedding::{call_chat_http, ensure_memory_dim_matches, fetch_embedding}, prompt_utils::{escape_xml, prompt_language_instruction}, }; @@ -20,8 +19,6 @@ const MAX_RESULTS: usize = 5; const MAX_HITS_PER_DOC: usize = 6; const MAX_HIT_LEN: usize = 600; const MAX_FULL_LEN: usize = 4096; -const CHAT_PATH: &str = "/chat"; - pub struct AskAiResult { pub prompt: String, pub response: String, @@ -67,6 +64,7 @@ pub async fn ask_ai_flow( let client = MemoryClient::new(agent, *memory_id); let embedding = fetch_embedding(query).await?; + ensure_memory_dim_matches(&client, &memory_id.to_text(), embedding.len()).await?; let mut results = client.search(embedding).await?; results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal)); @@ -84,24 +82,7 @@ pub async fn ask_ai_flow( } pub(crate) async fn call_chat_endpoint(prompt: &str) -> Result { - let url = format!("{}{}", embedding_base_url(), CHAT_PATH); - let response = Client::new() - .post(url) - .json(&ChatRequest { message: prompt }) - .send() - .await - .context("Failed to call chat endpoint")?; - - if !response.status().is_success() { - let status = response.status(); - let body = response.text().await.unwrap_or_default(); - anyhow::bail!("chat endpoint returned {status}: {body}"); - } - - let body = response - .text() - .await - .context("Failed to read chat response")?; + let body = call_chat_http(prompt).await?; let mut acc = String::new(); for line in body.lines() { @@ -141,11 +122,6 @@ struct SearchResult { hits: Vec, } -#[derive(serde::Serialize)] -struct ChatRequest<'a> { - message: &'a str, -} - #[derive(serde::Deserialize)] struct ChatChunk { content: Option, diff --git a/rust/commands/prefs.rs b/rust/commands/prefs.rs index ebc8bcb..1d169cd 100644 --- a/rust/commands/prefs.rs +++ b/rust/commands/prefs.rs @@ -14,10 +14,12 @@ use serde::Serialize; use crate::{ cli::{ - AddMemoryArgs, ChatMmrLambdaArgs, ChatOverallTopKArgs, ChatPerMemoryCapArgs, GlobalOpts, - MemoryIdArgs, PrefsArgs, PrefsCommand, SetDefaultMemoryArgs, TagArgs, + AddMemoryArgs, ChatMmrLambdaArgs, ChatOverallTopKArgs, ChatPerMemoryCapArgs, + EmbeddingBackendArgs, GlobalOpts, MemoryIdArgs, PrefsArgs, PrefsCommand, + SetDefaultMemoryArgs, TagArgs, }, clients::memory::MemoryClient, + embedding_config::normalize_supported_embedding_backend_id, preferences::{self, UserPreferences}, }; @@ -33,6 +35,7 @@ pub async fn handle(args: PrefsArgs, global: &GlobalOpts) -> Result<()> { PrefsCommand::SetChatOverallTopK(args) => set_chat_overall_top_k(args), PrefsCommand::SetChatPerMemoryCap(args) => set_chat_per_memory_cap(args), PrefsCommand::SetChatMmrLambda(args) => set_chat_mmr_lambda(args), + PrefsCommand::SetEmbeddingBackend(args) => set_embedding_backend(args), } } @@ -200,6 +203,18 @@ fn set_chat_mmr_lambda(args: ChatMmrLambdaArgs) -> Result<()> { print_json_response(PrefsResponse::updated("chat_mmr_lambda", "set", value)) } +fn set_embedding_backend(args: EmbeddingBackendArgs) -> Result<()> { + let value = validate_embedding_backend(args.model_id.as_str()); + let mut preferences = load_preferences()?; + if preferences.embedding_model_id == value { + return print_json_response(PrefsResponse::unchanged("embedding_model_id", "set", value)); + } + + preferences.embedding_model_id = value.clone(); + save_preferences(&preferences)?; + print_json_response(PrefsResponse::updated("embedding_model_id", "set", value)) +} + fn load_preferences() -> Result { preferences::load_user_preferences().context("Failed to load shared TUI preferences") } @@ -267,6 +282,10 @@ fn validate_chat_mmr_lambda(value: u8) -> Result { } } +fn validate_embedding_backend(value: &str) -> String { + normalize_supported_embedding_backend_id(value).to_string() +} + fn display_options_usize(values: &[usize]) -> String { values .iter() @@ -291,6 +310,7 @@ struct ShowPreferences { chat_overall_top_k: usize, chat_per_memory_cap: usize, chat_mmr_lambda: u8, + embedding_model_id: String, } impl From for ShowPreferences { @@ -302,6 +322,7 @@ impl From for ShowPreferences { chat_overall_top_k: value.chat_overall_top_k, chat_per_memory_cap: value.chat_per_memory_cap, chat_mmr_lambda: value.chat_mmr_lambda, + embedding_model_id: value.embedding_model_id, } } } @@ -407,6 +428,7 @@ mod tests { chat_overall_top_k: DEFAULT_CHAT_OVERALL_TOP_K, chat_per_memory_cap: DEFAULT_CHAT_PER_MEMORY_CAP, chat_mmr_lambda: DEFAULT_CHAT_MMR_LAMBDA, + embedding_model_id: "invalid".to_string(), }; let normalized = preferences::normalize_user_preferences(preferences); @@ -423,6 +445,21 @@ mod tests { assert_eq!(normalized.chat_overall_top_k, DEFAULT_CHAT_OVERALL_TOP_K); assert_eq!(normalized.chat_per_memory_cap, DEFAULT_CHAT_PER_MEMORY_CAP); assert_eq!(normalized.chat_mmr_lambda, DEFAULT_CHAT_MMR_LAMBDA); + assert_eq!( + normalized.embedding_model_id, + preferences::default_embedding_model_id() + ); + } + + #[test] + fn show_preferences_preserves_bgem3_embedding_model_id() { + let serialized = serde_json::to_value(ShowPreferences::from(UserPreferences { + embedding_model_id: "BAAI/bge-m3".to_string(), + ..UserPreferences::default() + })) + .expect("show preferences should serialize"); + + assert_eq!(serialized["embedding_model_id"], "BAAI/bge-m3"); } #[test] @@ -439,6 +476,10 @@ mod tests { DEFAULT_CHAT_PER_MEMORY_CAP ); assert_eq!(serialized["chat_mmr_lambda"], DEFAULT_CHAT_MMR_LAMBDA); + assert_eq!( + serialized["embedding_model_id"], + preferences::default_embedding_model_id() + ); } #[test] @@ -483,6 +524,14 @@ mod tests { ); } + #[test] + fn validate_embedding_backend_normalizes_unknown_values_to_api() { + assert_eq!( + validate_embedding_backend("unsupported"), + preferences::default_embedding_model_id() + ); + } + #[test] fn prefs_response_skips_memory_name_by_default() { let json = serde_json::to_value(PrefsResponse::updated( diff --git a/rust/commands/search.rs b/rust/commands/search.rs index e1bc029..d4b3dd7 100644 --- a/rust/commands/search.rs +++ b/rust/commands/search.rs @@ -13,7 +13,7 @@ use tracing::info; use crate::{ cli::SearchArgs, clients::{launcher::LauncherClient, memory::MemoryClient}, - embedding::fetch_embedding, + embedding::{ensure_memory_dim_matches, fetch_embedding}, shared::cross_memory_search::{ SearchHit, collect_searchable_memory_ids, fold_search_batches, searchable_memory_id_from_state, sort_search_hits, @@ -142,6 +142,7 @@ pub(crate) async fn search_single_memory_items( embedding: Vec, ) -> Result> { let client = build_memory_client(agent, &memory_id)?; + ensure_memory_dim_matches(&client, &memory_id, embedding.len()).await?; let mut rows = client .search(embedding) .await diff --git a/rust/commands/search_raw.rs b/rust/commands/search_raw.rs index d98a38c..2f21a54 100644 --- a/rust/commands/search_raw.rs +++ b/rust/commands/search_raw.rs @@ -1,13 +1,18 @@ use anyhow::{Context, Result, bail}; use tracing::info; -use crate::{cli::SearchRawArgs, memory_client_builder::build_memory_client}; +use crate::{ + cli::SearchRawArgs, embedding::ensure_vector_dim_matches, + memory_client_builder::build_memory_client, +}; use super::CommandContext; pub async fn handle(args: SearchRawArgs, ctx: &CommandContext) -> Result<()> { let client = build_memory_client(&ctx.agent_factory, &args.memory_id).await?; let embedding = parse_embedding(&args.embedding)?; + let expected_dim = client.get_dim().await?; + ensure_vector_dim_matches(&args.memory_id, embedding.len(), expected_dim)?; let mut results = client.search(embedding).await?; results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal)); diff --git a/rust/embedding.rs b/rust/embedding.rs index 4331903..4d1e2ef 100644 --- a/rust/embedding.rs +++ b/rust/embedding.rs @@ -1,10 +1,19 @@ +//! Where: shared embedding facade used by CLI, TUI, MCP, and Python bindings. +//! What: routes query embeddings and late chunking to either the remote API or a local model. +//! Why: old API-backed memories must remain usable while local models stay opt-in. + use std::env; use anyhow::{Context, Result, bail}; use reqwest::Client; use serde::{Deserialize, Serialize}; -use crate::operation_timeout::embedding_request_timeout; +use crate::{ + clients::memory::MemoryClient, + embedding_config::{API_EMBEDDING_BACKEND_ID, selected_embedding_backend_id}, + local_embedding, + operation_timeout::embedding_request_timeout, +}; pub(crate) const EMBEDDING_API_ENV_VAR: &str = "EMBEDDING_API_ENDPOINT"; pub(crate) const DEFAULT_EMBEDDING_API_ENDPOINT: &str = "https://api.kinic.io"; @@ -12,6 +21,69 @@ const LATE_CHUNKING_PATH: &str = "/late-chunking"; const EMBEDDING_PATH: &str = "/embedding"; pub async fn late_chunking(text: &str) -> Result> { + if selected_embedding_backend_id()? == API_EMBEDDING_BACKEND_ID { + return late_chunking_remote(text).await; + } + local_embedding::late_chunk_and_embed(text).await +} + +pub async fn fetch_embedding(text: &str) -> Result> { + if selected_embedding_backend_id()? == API_EMBEDDING_BACKEND_ID { + return fetch_embedding_remote(text).await; + } + local_embedding::embed_query(text).await +} + +pub(crate) async fn ensure_memory_dim_matches( + client: &MemoryClient, + memory_id: &str, + provided_dim: usize, +) -> Result { + let actual_dim = client + .get_dim() + .await + .context("Failed to load memory embedding dimension")?; + ensure_vector_dim_matches(memory_id, provided_dim, actual_dim)?; + Ok(actual_dim) +} + +pub(crate) fn ensure_vector_dim_matches( + memory_id: &str, + provided_dim: usize, + expected_dim: u64, +) -> Result<()> { + if provided_dim == expected_dim as usize { + return Ok(()); + } + bail!( + "Embedding dimension mismatch for memory {memory_id}. Provided {provided_dim}, expected {expected_dim}. Reindex or reset the memory before searching or inserting." + ); +} + +pub(crate) fn embedding_base_url() -> String { + env::var(EMBEDDING_API_ENV_VAR).unwrap_or_else(|_| DEFAULT_EMBEDDING_API_ENDPOINT.to_string()) +} + +pub(crate) async fn call_chat_http(prompt: &str) -> Result { + let url = format!("{}/chat", embedding_base_url()); + let response = Client::new() + .post(url) + .json(&ChatRequest { message: prompt }) + .send() + .await + .context("Failed to call chat endpoint")?; + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + bail!("chat endpoint returned {status}: {body}"); + } + response + .text() + .await + .context("Failed to read chat response") +} + +async fn late_chunking_remote(text: &str) -> Result> { let url = format!("{}{}", embedding_base_url(), LATE_CHUNKING_PATH); let timeout = embedding_request_timeout(text.len()); let response = Client::new() @@ -30,7 +102,7 @@ pub async fn late_chunking(text: &str) -> Result> { Ok(payload.chunks) } -pub async fn fetch_embedding(text: &str) -> Result> { +async fn fetch_embedding_remote(text: &str) -> Result> { let url = format!("{}{}", embedding_base_url(), EMBEDDING_PATH); let timeout = embedding_request_timeout(text.len()); let response = Client::new() @@ -59,10 +131,6 @@ async fn ensure_success(response: reqwest::Response) -> Result String { - env::var(EMBEDDING_API_ENV_VAR).unwrap_or_else(|_| DEFAULT_EMBEDDING_API_ENDPOINT.to_string()) -} - #[derive(Serialize)] struct LateChunkingRequest<'a> { markdown: &'a str, @@ -79,6 +147,11 @@ pub struct LateChunk { pub sentence: String, } +#[derive(Serialize)] +struct ChatRequest<'a> { + message: &'a str, +} + #[derive(Serialize)] struct EmbeddingRequest<'a> { content: &'a str, diff --git a/rust/embedding_config.rs b/rust/embedding_config.rs new file mode 100644 index 0000000..bd458b6 --- /dev/null +++ b/rust/embedding_config.rs @@ -0,0 +1,331 @@ +//! Where: shared by embedding routing, local inference, memory creation defaults, and dimension guards. +//! What: centralizes the selected embedding backend plus local model cache and chunking parameters. +//! Why: create/search/insert must resolve the same backend and dimension without drift. + +use std::{env, path::PathBuf}; + +use anyhow::{Context, Result, anyhow, bail}; +use fastembed::{EmbeddingModel, TextInitOptions}; +use tui_kit_host::settings::SettingsError; + +use crate::preferences; + +const CACHE_DIR_ENV_VAR: &str = "KINIC_LOCAL_EMBEDDING_CACHE_DIR"; +const MAX_LENGTH_ENV_VAR: &str = "KINIC_LOCAL_EMBEDDING_MAX_LENGTH"; +const CHUNK_SOFT_LIMIT_ENV_VAR: &str = "KINIC_LOCAL_EMBEDDING_CHUNK_SOFT_LIMIT"; +const CHUNK_HARD_LIMIT_ENV_VAR: &str = "KINIC_LOCAL_EMBEDDING_CHUNK_HARD_LIMIT"; +const CHUNK_OVERLAP_ENV_VAR: &str = "KINIC_LOCAL_EMBEDDING_CHUNK_OVERLAP"; +const DEFAULT_CACHE_DIR: &str = ".cache/kinic-cli/embeddings"; +const DEFAULT_MAX_LENGTH: usize = 512; +const DEFAULT_CHUNK_SOFT_LIMIT: usize = 800; +const DEFAULT_CHUNK_HARD_LIMIT: usize = 1200; +const DEFAULT_CHUNK_OVERLAP: usize = 120; +pub(crate) const API_EMBEDDING_BACKEND_ID: &str = "api"; +const API_EMBEDDING_BACKEND_LABEL: &str = "API (remote default)"; +const API_EMBEDDING_DIMENSION: usize = 1024; +pub(crate) const BGEM3_EMBEDDING_BACKEND_ID: &str = "BAAI/bge-m3"; +const BGEM3_EMBEDDING_BACKEND_LABEL: &str = "BAAI BGE-M3"; +const BGEM3_EMBEDDING_DIMENSION: usize = 1024; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct SupportedEmbeddingBackend { + pub id: &'static str, + pub label: &'static str, + pub dimension: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ChunkingConfig { + pub soft_limit: usize, + pub hard_limit: usize, + pub overlap: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct LocalEmbeddingConfig { + pub cache_dir: PathBuf, + pub max_length: usize, + pub chunking: ChunkingConfig, +} + +impl LocalEmbeddingConfig { + pub(crate) fn bgem3() -> Result { + let max_length = env_usize(MAX_LENGTH_ENV_VAR, DEFAULT_MAX_LENGTH)?; + let soft_limit = env_usize(CHUNK_SOFT_LIMIT_ENV_VAR, DEFAULT_CHUNK_SOFT_LIMIT)?; + let hard_limit = env_usize(CHUNK_HARD_LIMIT_ENV_VAR, DEFAULT_CHUNK_HARD_LIMIT)?; + let overlap = env_usize(CHUNK_OVERLAP_ENV_VAR, DEFAULT_CHUNK_OVERLAP)?; + if max_length == 0 { + bail!("{MAX_LENGTH_ENV_VAR} must be a positive integer"); + } + if soft_limit == 0 || hard_limit == 0 { + bail!("Chunk limits must be positive."); + } + if soft_limit > hard_limit { + bail!("Chunk soft limit cannot exceed hard limit."); + } + if overlap >= hard_limit { + bail!("Chunk overlap must be smaller than hard limit."); + } + + Ok(Self { + cache_dir: cache_dir()?, + max_length, + chunking: ChunkingConfig { + soft_limit, + hard_limit, + overlap, + }, + }) + } + + pub(crate) fn text_init_options(&self) -> TextInitOptions { + TextInitOptions::new(EmbeddingModel::BGEM3) + .with_cache_dir(self.cache_dir.clone()) + .with_max_length(self.max_length) + .with_show_download_progress(false) + } +} + +pub(crate) fn selected_embedding_backend_id() -> Result<&'static str> { + resolve_embedding_backend_id() +} + +pub(crate) fn create_memory_dimension_u64() -> u64 { + API_EMBEDDING_DIMENSION as u64 +} + +pub(crate) fn selected_local_embedding_config() -> Result> { + match resolve_embedding_backend_id()? { + API_EMBEDDING_BACKEND_ID => Ok(None), + BGEM3_EMBEDDING_BACKEND_ID => LocalEmbeddingConfig::bgem3().map(Some), + _ => unreachable!("embedding backend should already be normalized"), + } +} + +pub(crate) fn selected_embedding_dimension() -> Result { + Ok(embedding_dimension_for_backend( + resolve_embedding_backend_id()?, + )) +} + +fn resolve_embedding_backend_id() -> Result<&'static str> { + let preferences = load_embedding_preferences()?; + Ok(normalize_supported_embedding_backend_id( + &preferences.embedding_model_id, + )) +} + +fn load_embedding_preferences() -> Result { + match preferences::load_user_preferences() { + Ok(preferences) => Ok(preferences), + Err(SettingsError::NoConfigDir) => Ok(preferences::UserPreferences::default()), + Err(other) => { + Err(anyhow!(other).context("Failed to load shared embedding backend from tui.yaml")) + } + } +} + +pub(crate) fn supported_embedding_backends() -> Vec { + vec![ + SupportedEmbeddingBackend { + id: API_EMBEDDING_BACKEND_ID, + label: API_EMBEDDING_BACKEND_LABEL, + dimension: API_EMBEDDING_DIMENSION, + }, + SupportedEmbeddingBackend { + id: BGEM3_EMBEDDING_BACKEND_ID, + label: BGEM3_EMBEDDING_BACKEND_LABEL, + dimension: BGEM3_EMBEDDING_DIMENSION, + }, + ] +} + +pub(crate) fn normalize_supported_embedding_backend_id(raw: &str) -> &'static str { + match raw.trim() { + API_EMBEDDING_BACKEND_ID => API_EMBEDDING_BACKEND_ID, + BGEM3_EMBEDDING_BACKEND_ID => BGEM3_EMBEDDING_BACKEND_ID, + _ => API_EMBEDDING_BACKEND_ID, + } +} + +pub(crate) fn embedding_dimension_for_backend(backend_id: &str) -> usize { + match normalize_supported_embedding_backend_id(backend_id) { + API_EMBEDDING_BACKEND_ID => API_EMBEDDING_DIMENSION, + BGEM3_EMBEDDING_BACKEND_ID => BGEM3_EMBEDDING_DIMENSION, + _ => unreachable!("embedding backend should already be normalized"), + } +} + +fn cache_dir() -> Result { + if let Ok(value) = env::var(CACHE_DIR_ENV_VAR) { + let trimmed = value.trim(); + if trimmed.is_empty() { + bail!("{CACHE_DIR_ENV_VAR} cannot be blank."); + } + return Ok(PathBuf::from(trimmed)); + } + + let home = env::var("HOME").context("HOME is not set")?; + Ok(PathBuf::from(home).join(DEFAULT_CACHE_DIR)) +} + +fn env_usize(name: &str, default: usize) -> Result { + match env::var(name) { + Ok(raw) => raw + .trim() + .parse::() + .with_context(|| format!("{name} must be a positive integer")), + Err(_) => Ok(default), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::preferences; + use std::sync::{Mutex, OnceLock}; + + fn env_guard() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())).lock().unwrap() + } + + fn reset_test_preference_error() { + preferences::set_load_user_preferences_error_for_tests(None); + } + + #[test] + fn selected_local_embedding_config_defaults_to_api_backend() { + let _guard = env_guard(); + reset_test_preference_error(); + unsafe { + env::remove_var(CACHE_DIR_ENV_VAR); + env::remove_var(MAX_LENGTH_ENV_VAR); + env::remove_var(CHUNK_SOFT_LIMIT_ENV_VAR); + env::remove_var(CHUNK_HARD_LIMIT_ENV_VAR); + env::remove_var(CHUNK_OVERLAP_ENV_VAR); + } + let config = selected_local_embedding_config().expect("api backend should load"); + + assert_eq!(config, None); + } + + #[test] + fn local_model_config_uses_bgem3() { + let _guard = env_guard(); + reset_test_preference_error(); + unsafe { + env::remove_var(CACHE_DIR_ENV_VAR); + env::remove_var(MAX_LENGTH_ENV_VAR); + env::remove_var(CHUNK_SOFT_LIMIT_ENV_VAR); + env::remove_var(CHUNK_HARD_LIMIT_ENV_VAR); + env::remove_var(CHUNK_OVERLAP_ENV_VAR); + } + let config = LocalEmbeddingConfig::bgem3().expect("local config should load"); + + assert_eq!(config.max_length, 512); + assert_eq!(config.chunking.soft_limit, 800); + } + + #[test] + fn invalid_chunk_bounds_are_rejected() { + let _guard = env_guard(); + reset_test_preference_error(); + unsafe { + env::set_var(CHUNK_SOFT_LIMIT_ENV_VAR, "1300"); + env::set_var(CHUNK_HARD_LIMIT_ENV_VAR, "1200"); + } + + let error = LocalEmbeddingConfig::bgem3().expect_err("invalid bounds should fail"); + assert!(error.to_string().contains("soft limit")); + + unsafe { + env::remove_var(CHUNK_SOFT_LIMIT_ENV_VAR); + env::remove_var(CHUNK_HARD_LIMIT_ENV_VAR); + } + } + + #[test] + fn unsupported_backend_normalizes_to_api() { + reset_test_preference_error(); + assert_eq!( + normalize_supported_embedding_backend_id("bad-model"), + API_EMBEDDING_BACKEND_ID + ); + } + + #[test] + fn legacy_mxbai_backend_normalizes_to_api() { + reset_test_preference_error(); + assert_eq!( + normalize_supported_embedding_backend_id("mixedbread-ai/mxbai-embed-large-v1"), + API_EMBEDDING_BACKEND_ID + ); + } + + #[test] + fn selected_local_embedding_config_defaults_to_api_when_no_config_dir_is_available() { + let _guard = env_guard(); + preferences::set_load_user_preferences_error_for_tests(Some( + preferences::TestLoadPreferencesError::NoConfigDir, + )); + + let config = selected_local_embedding_config().expect("no config dir should fall back"); + + assert_eq!(config, None); + reset_test_preference_error(); + } + + #[test] + fn selected_embedding_backend_defaults_to_api_when_no_config_dir_is_available() { + let _guard = env_guard(); + preferences::set_load_user_preferences_error_for_tests(Some( + preferences::TestLoadPreferencesError::NoConfigDir, + )); + + let backend = + selected_embedding_backend_id().expect("no config dir should select api backend"); + + assert_eq!(backend, API_EMBEDDING_BACKEND_ID); + reset_test_preference_error(); + } + + #[test] + fn selected_local_embedding_config_still_errors_on_yaml_failure() { + let _guard = env_guard(); + preferences::set_load_user_preferences_error_for_tests(Some( + preferences::TestLoadPreferencesError::Yaml, + )); + + let error = + selected_local_embedding_config().expect_err("yaml failure should reach caller"); + + assert!( + error + .to_string() + .contains("Failed to load shared embedding backend from tui.yaml") + ); + reset_test_preference_error(); + } + + #[test] + fn local_model_config_rejects_zero_max_length() { + let _guard = env_guard(); + reset_test_preference_error(); + unsafe { + env::set_var(MAX_LENGTH_ENV_VAR, "0"); + } + + let error = LocalEmbeddingConfig::bgem3().expect_err("zero max length should fail"); + + assert!(error.to_string().contains(MAX_LENGTH_ENV_VAR)); + unsafe { + env::remove_var(MAX_LENGTH_ENV_VAR); + } + } + + #[test] + fn create_memory_dimension_is_fixed_to_1024() { + assert_eq!(create_memory_dimension_u64(), 1024); + } +} diff --git a/rust/insert_service.rs b/rust/insert_service.rs index f111449..a40db15 100644 --- a/rust/insert_service.rs +++ b/rust/insert_service.rs @@ -13,7 +13,10 @@ use ic_agent::export::Principal; use serde_json::json; use crate::{ - clients::memory::MemoryClient, commands::convert_pdf::pdf_to_markdown, embedding::late_chunking, + clients::memory::MemoryClient, + commands::convert_pdf::pdf_to_markdown, + embedding::{ensure_vector_dim_matches, late_chunking}, + embedding_config::selected_embedding_dimension, }; #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -86,7 +89,13 @@ pub async fn execute_insert_request( ) -> Result { validate_insert_request_fields(request)?; let validated = validate_and_transform_insert_request(request)?; + let expected_dim = client + .get_dim() + .await + .context("Failed to load memory embedding dimension")?; + ensure_request_matches_memory_dim(&validated, request.memory_id(), expected_dim)?; let prepared = prepare_insert_request(&validated).await?; + ensure_prepared_items_match_memory(request.memory_id(), &prepared, expected_dim)?; let inserted_count = prepared.len(); let source_name = validated.source_name(); @@ -149,6 +158,33 @@ pub fn validate_insert_request_fields(request: &InsertRequest) -> Result<()> { Ok(()) } +fn ensure_request_matches_memory_dim( + request: &ValidatedInsertRequest, + memory_id: &str, + expected_dim: u64, +) -> Result<()> { + match request { + ValidatedInsertRequest::Raw { embedding, .. } => { + ensure_vector_dim_matches(memory_id, embedding.len(), expected_dim) + } + ValidatedInsertRequest::Normal { .. } | ValidatedInsertRequest::Pdf { .. } => { + let selected_dim = selected_embedding_dimension()?; + ensure_vector_dim_matches(memory_id, selected_dim, expected_dim) + } + } +} + +fn ensure_prepared_items_match_memory( + memory_id: &str, + items: &[PreparedInsertItem], + expected_dim: u64, +) -> Result<()> { + let Some(first) = items.first() else { + bail!("Insert content did not produce any chunks."); + }; + ensure_vector_dim_matches(memory_id, first.embedding.len(), expected_dim) +} + pub fn validate_insert_request_for_submit(request: &InsertRequest) -> Result<()> { let _ = validate_and_transform_insert_request(request)?; Ok(()) @@ -458,6 +494,62 @@ mod tests { assert_eq!(payload, "{\"sentence\":\"hello\",\"tag\":\"docs\"}"); } + #[test] + fn prepared_items_match_expected_dimension() { + ensure_prepared_items_match_memory( + "aaaaa-aa", + &[PreparedInsertItem { + embedding: vec![0.1, 0.2], + payload: "{}".to_string(), + }], + 2, + ) + .unwrap(); + } + + #[test] + fn prepared_items_reject_dimension_mismatch() { + let error = ensure_prepared_items_match_memory( + "aaaaa-aa", + &[PreparedInsertItem { + embedding: vec![0.1, 0.2], + payload: "{}".to_string(), + }], + 3, + ) + .unwrap_err(); + + assert!(error.to_string().contains("Embedding dimension mismatch")); + } + + #[test] + fn raw_insert_request_fails_fast_on_dimension_mismatch() { + let request = ValidatedInsertRequest::Raw { + memory_id: "aaaaa-aa".to_string(), + tag: "docs".to_string(), + text: "payload".to_string(), + embedding: vec![0.1, 0.2], + }; + + let error = ensure_request_matches_memory_dim(&request, "aaaaa-aa", 3).unwrap_err(); + + assert!(error.to_string().contains("Embedding dimension mismatch")); + } + + #[test] + fn normal_insert_request_uses_selected_backend_dimension_for_fail_fast_checks() { + let request = ValidatedInsertRequest::Normal { + memory_id: "aaaaa-aa".to_string(), + tag: "docs".to_string(), + text: Some("payload".to_string()), + file_path: None, + }; + + let error = ensure_request_matches_memory_dim(&request, "aaaaa-aa", 3).unwrap_err(); + + assert!(error.to_string().contains("Embedding dimension mismatch")); + } + #[test] fn validated_insert_request_source_name_uses_file_name() { let request = ValidatedInsertRequest::Normal { diff --git a/rust/lib.rs b/rust/lib.rs index da9fecd..32a9232 100644 --- a/rust/lib.rs +++ b/rust/lib.rs @@ -6,9 +6,12 @@ pub(crate) mod clients; mod commands; pub(crate) mod create_domain; mod embedding; +mod embedding_config; pub(crate) mod identity_store; pub(crate) mod insert_service; mod ledger; +mod local_chunking; +mod local_embedding; pub(crate) mod memory_client_builder; mod operation_timeout; pub(crate) mod preferences; diff --git a/rust/local_chunking.rs b/rust/local_chunking.rs new file mode 100644 index 0000000..6b6154a --- /dev/null +++ b/rust/local_chunking.rs @@ -0,0 +1,131 @@ +//! Where: local chunk preparation for insert and PDF ingestion. +//! What: splits markdown/text into bounded chunks before passage embedding. +//! Why: local backends need an in-process late-chunking path without changing insert callers. + +use crate::embedding_config::ChunkingConfig; + +pub(crate) fn chunk_markdown(markdown: &str, config: &ChunkingConfig) -> Vec { + let normalized = markdown.replace("\r\n", "\n"); + let blocks = normalized + .split("\n\n") + .map(str::trim) + .filter(|block| !block.is_empty()) + .flat_map(|block| chunk_block(block, config)) + .collect::>(); + if blocks.is_empty() { + return Vec::new(); + } + blocks +} + +fn chunk_block(block: &str, config: &ChunkingConfig) -> Vec { + if block.chars().count() <= config.hard_limit { + return vec![block.to_string()]; + } + + let sentences = split_sentences(block); + if sentences.len() <= 1 { + return split_long_text(block, config); + } + + let mut chunks = Vec::new(); + let mut current = String::new(); + for sentence in sentences { + if sentence.chars().count() > config.hard_limit { + if !current.is_empty() { + chunks.push(current.trim().to_string()); + current.clear(); + } + chunks.extend(split_long_text(&sentence, config)); + continue; + } + + let separator = if current.is_empty() { "" } else { " " }; + if current.chars().count() + separator.len() + sentence.chars().count() > config.soft_limit + && !current.is_empty() + { + chunks.push(current.trim().to_string()); + current.clear(); + } + + if !current.is_empty() { + current.push(' '); + } + current.push_str(sentence.trim()); + } + + if !current.is_empty() { + chunks.push(current.trim().to_string()); + } + chunks +} + +fn split_sentences(block: &str) -> Vec { + let mut pieces = Vec::new(); + let mut current = String::new(); + for ch in block.chars() { + current.push(ch); + if matches!(ch, '.' | '!' | '?' | '\n') { + let trimmed = current.trim(); + if !trimmed.is_empty() { + pieces.push(trimmed.to_string()); + } + current.clear(); + } + } + if !current.trim().is_empty() { + pieces.push(current.trim().to_string()); + } + pieces +} + +fn split_long_text(text: &str, config: &ChunkingConfig) -> Vec { + let chars = text.chars().collect::>(); + let mut chunks = Vec::new(); + let mut start = 0usize; + while start < chars.len() { + let end = usize::min(start + config.hard_limit, chars.len()); + let slice = chars[start..end] + .iter() + .collect::() + .trim() + .to_string(); + if !slice.is_empty() { + chunks.push(slice); + } + if end == chars.len() { + break; + } + start = end.saturating_sub(config.overlap); + } + chunks +} + +#[cfg(test)] +mod tests { + use super::*; + + fn config() -> ChunkingConfig { + ChunkingConfig { + soft_limit: 16, + hard_limit: 24, + overlap: 4, + } + } + + #[test] + fn chunk_markdown_splits_paragraphs_and_long_blocks() { + let chunks = chunk_markdown( + "First sentence. Second sentence.\n\nThis block is definitely long enough to split.", + &config(), + ); + + assert!(chunks.len() >= 3); + assert!(chunks.iter().all(|chunk| !chunk.trim().is_empty())); + } + + #[test] + fn chunk_markdown_ignores_empty_input() { + assert!(chunk_markdown(" \n\n ", &config()).is_empty()); + } +} diff --git a/rust/local_embedding.rs b/rust/local_embedding.rs new file mode 100644 index 0000000..a883437 --- /dev/null +++ b/rust/local_embedding.rs @@ -0,0 +1,107 @@ +//! Where: local embedding inference shared by CLI, TUI, tools, and Python bindings. +//! What: lazily initializes the selected local text embedding model and embeds queries/passages. +//! Why: local backends are opt-in and the selected model can change at runtime via preferences. + +use std::sync::{Mutex, OnceLock}; + +use anyhow::{Context, Result, bail}; +use fastembed::TextEmbedding; + +use crate::{ + embedding::LateChunk, + embedding_config::{LocalEmbeddingConfig, selected_local_embedding_config}, + local_chunking::chunk_markdown, +}; + +struct CachedModel { + model: TextEmbedding, +} + +static MODEL: OnceLock>> = OnceLock::new(); + +pub(crate) async fn embed_query(text: &str) -> Result> { + load_selected_local_config()?; + embed_texts(vec![text.trim().to_string()]) + .await + .map(|mut rows| { + rows.pop() + .expect("one query input should always produce one embedding") + }) +} + +pub(crate) async fn late_chunk_and_embed(markdown: &str) -> Result> { + let config = load_selected_local_config()?; + let chunks = chunk_markdown(markdown, &config.chunking); + if chunks.is_empty() { + bail!("Insert content is empty after normalization."); + } + let embeddings = embed_texts(chunks.iter().map(|chunk| chunk.to_string()).collect()).await?; + Ok(chunks + .into_iter() + .zip(embeddings) + .map(|(sentence, embedding)| LateChunk { + embedding, + sentence, + }) + .collect()) +} + +async fn embed_texts(inputs: Vec) -> Result>> { + tokio::task::spawn_blocking(move || { + let config = load_selected_local_config()?; + let mut guard = model_cache() + .lock() + .map_err(|_| anyhow::anyhow!("Embedding model lock poisoned"))?; + let cached = ensure_cached_model(&mut guard, &config)?; + cached + .model + .embed(inputs, None) + .context("Failed to generate local embeddings") + }) + .await + .context("Local embedding worker crashed")? +} + +fn load_selected_local_config() -> Result { + selected_local_embedding_config()? + .ok_or_else(|| anyhow::anyhow!("Local embedding backend is not selected")) +} + +fn model_cache() -> &'static Mutex> { + MODEL.get_or_init(|| Mutex::new(None)) +} + +fn ensure_cached_model<'a>( + cache: &'a mut Option, + config: &LocalEmbeddingConfig, +) -> Result<&'a mut CachedModel> { + if cache.is_none() { + std::fs::create_dir_all(&config.cache_dir).with_context(|| { + format!( + "Failed to create embedding cache dir {}", + config.cache_dir.display() + ) + })?; + let model = TextEmbedding::try_new(config.text_init_options())?; + *cache = Some(CachedModel { model }); + } + Ok(cache + .as_mut() + .expect("embedding cache should exist after initialization")) +} + +#[cfg(test)] +mod tests { + #[test] + fn local_embedding_supports_bgem3_model_id() { + let config = crate::embedding_config::LocalEmbeddingConfig::bgem3() + .expect("bgem3 config should load"); + assert_eq!(config.max_length, 512); + } + + #[test] + fn embed_query_trims_without_instruction_prefix() { + let trimmed = " hello ".trim().to_string(); + assert_eq!(trimmed, "hello"); + } +} diff --git a/rust/preferences.rs b/rust/preferences.rs index 27a0f7b..4b0245d 100644 --- a/rust/preferences.rs +++ b/rust/preferences.rs @@ -5,10 +5,14 @@ use kinic_core::{prefs_policy, principal::normalize_memory_id_text, tag}; use serde::{Deserialize, Serialize}; +#[cfg(test)] +use std::cell::Cell; use tui_kit_host::settings::SettingsError; #[cfg(not(test))] use tui_kit_host::settings::{load_yaml_or_default, save_yaml}; +use crate::embedding_config::{API_EMBEDDING_BACKEND_ID, BGEM3_EMBEDDING_BACKEND_ID}; + #[cfg(not(test))] const APP_NAMESPACE: &str = "kinic"; #[cfg(not(test))] @@ -17,6 +21,8 @@ pub use kinic_core::prefs_policy::{ DEFAULT_CHAT_MMR_LAMBDA, DEFAULT_CHAT_OVERALL_TOP_K, DEFAULT_CHAT_PER_MEMORY_CAP, }; +const DEFAULT_EMBEDDING_MODEL_ID: &str = API_EMBEDDING_BACKEND_ID; + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] // Missing legacy fields are backfilled with defaults and unknown fields are ignored. // Type mismatches and malformed YAML still fail to decode so broken settings stay explicit. @@ -32,6 +38,8 @@ pub struct UserPreferences { pub chat_per_memory_cap: usize, #[serde(default = "default_chat_mmr_lambda")] pub chat_mmr_lambda: u8, + #[serde(default = "default_embedding_model_id")] + pub embedding_model_id: String, } impl Default for UserPreferences { @@ -43,6 +51,7 @@ impl Default for UserPreferences { chat_overall_top_k: DEFAULT_CHAT_OVERALL_TOP_K, chat_per_memory_cap: DEFAULT_CHAT_PER_MEMORY_CAP, chat_mmr_lambda: DEFAULT_CHAT_MMR_LAMBDA, + embedding_model_id: default_embedding_model_id(), } } } @@ -59,8 +68,22 @@ pub fn default_chat_mmr_lambda() -> u8 { DEFAULT_CHAT_MMR_LAMBDA } +pub fn default_embedding_model_id() -> String { + DEFAULT_EMBEDDING_MODEL_ID.to_string() +} + +#[cfg(test)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum TestLoadPreferencesError { + NoConfigDir, + Yaml, +} + #[cfg(test)] pub fn load_user_preferences() -> Result { + if let Some(error) = test_load_error() { + return Err(error); + } Ok(normalize_user_preferences(UserPreferences::default())) } @@ -75,6 +98,36 @@ pub fn save_user_preferences(_preferences: &UserPreferences) -> Result<(), Setti Ok(()) } +#[cfg(test)] +pub fn set_load_user_preferences_error_for_tests(error: Option) { + TEST_LOAD_ERROR.with(|slot| slot.set(test_load_error_code(error))); +} + +#[cfg(test)] +fn test_load_error() -> Option { + match TEST_LOAD_ERROR.with(Cell::get) { + 1 => Some(SettingsError::NoConfigDir), + 2 => Some(SettingsError::Yaml( + serde_yaml::from_str::(":").unwrap_err(), + )), + _ => None, + } +} + +#[cfg(test)] +thread_local! { + static TEST_LOAD_ERROR: Cell = const { Cell::new(0) }; +} + +#[cfg(test)] +const fn test_load_error_code(error: Option) -> u8 { + match error { + Some(TestLoadPreferencesError::NoConfigDir) => 1, + Some(TestLoadPreferencesError::Yaml) => 2, + None => 0, + } +} + #[cfg(not(test))] pub fn save_user_preferences(preferences: &UserPreferences) -> Result<(), SettingsError> { save_yaml( @@ -97,6 +150,7 @@ pub fn normalize_user_preferences(mut preferences: UserPreferences) -> UserPrefe preferences.chat_per_memory_cap = normalize_chat_per_memory_cap(preferences.chat_per_memory_cap); preferences.chat_mmr_lambda = normalize_chat_mmr_lambda(preferences.chat_mmr_lambda); + preferences.embedding_model_id = normalize_embedding_model_id(preferences.embedding_model_id); preferences } @@ -124,6 +178,14 @@ pub fn chat_diversity_display(value: u8) -> String { format!("{:.2}", f32::from(value) / 100.0) } +pub fn normalize_embedding_model_id(value: String) -> String { + match value.trim() { + API_EMBEDDING_BACKEND_ID => API_EMBEDDING_BACKEND_ID.to_string(), + BGEM3_EMBEDDING_BACKEND_ID => BGEM3_EMBEDDING_BACKEND_ID.to_string(), + _ => default_embedding_model_id(), + } +} + fn normalize_default_memory_id(memory_id: Option) -> Option { memory_id.and_then(|value| normalize_memory_id_text(&value).ok()) } @@ -176,6 +238,7 @@ future_setting: true assert_eq!(normalized.chat_overall_top_k, DEFAULT_CHAT_OVERALL_TOP_K); assert_eq!(normalized.chat_per_memory_cap, DEFAULT_CHAT_PER_MEMORY_CAP); assert_eq!(normalized.chat_mmr_lambda, DEFAULT_CHAT_MMR_LAMBDA); + assert_eq!(normalized.embedding_model_id, DEFAULT_EMBEDDING_MODEL_ID); } #[test] @@ -227,6 +290,33 @@ manual_memory_ids: assert_eq!(normalized.chat_overall_top_k, DEFAULT_CHAT_OVERALL_TOP_K); assert_eq!(normalized.chat_per_memory_cap, DEFAULT_CHAT_PER_MEMORY_CAP); assert_eq!(normalized.chat_mmr_lambda, DEFAULT_CHAT_MMR_LAMBDA); + assert_eq!(normalized.embedding_model_id, DEFAULT_EMBEDDING_MODEL_ID); + } + + #[test] + fn user_preferences_normalizes_blank_embedding_model_id() { + let preferences: UserPreferences = serde_yaml::from_str( + r#" +embedding_model_id: " " +"#, + ) + .expect("blank embedding model should deserialize"); + + let normalized = normalize_user_preferences(preferences); + assert_eq!(normalized.embedding_model_id, DEFAULT_EMBEDDING_MODEL_ID); + } + + #[test] + fn user_preferences_normalizes_legacy_mxbai_embedding_model_id_to_default() { + let preferences: UserPreferences = serde_yaml::from_str( + r#" +embedding_model_id: "mixedbread-ai/mxbai-embed-large-v1" +"#, + ) + .expect("legacy embedding model should deserialize"); + + let normalized = normalize_user_preferences(preferences); + assert_eq!(normalized.embedding_model_id, DEFAULT_EMBEDDING_MODEL_ID); } #[test] diff --git a/rust/tools/service.rs b/rust/tools/service.rs index aee1a8b..a543b64 100644 --- a/rust/tools/service.rs +++ b/rust/tools/service.rs @@ -16,7 +16,7 @@ use crate::{ search::{search_across_memories, searchable_memory_ids}, show::{ShowOutput, load_show_output}, }, - embedding::fetch_embedding, + embedding::{ensure_memory_dim_matches, fetch_embedding}, insert_service::{InsertRequest, execute_insert_request}, memory_client_builder::build_memory_client, }; @@ -198,9 +198,10 @@ impl ToolService { async fn search_memory( client: &MemoryClient, - _memory_id: &str, + memory_id: &str, embedding: Vec, ) -> anyhow::Result> { + ensure_memory_dim_matches(client, memory_id, embedding.len()).await?; let rows = client .search(embedding) .await diff --git a/rust/tui/bridge.rs b/rust/tui/bridge.rs index 156b48b..4bed515 100644 --- a/rust/tui/bridge.rs +++ b/rust/tui/bridge.rs @@ -7,7 +7,7 @@ use crate::{ memory::MemoryClient, }, create_domain::{BalanceDelta, balance_delta, required_balance}, - embedding::embedding_base_url, + embedding::{embedding_base_url, ensure_memory_dim_matches}, insert_service::{InsertRequest, execute_insert_request}, ledger::{fetch_balance, fetch_fee, transfer}, shared::{ @@ -355,6 +355,7 @@ pub async fn search_memory_with_agent( ) -> Result> { let memory = Principal::from_text(&memory_id).context("Failed to parse memory canister id")?; let client = MemoryClient::new(agent, memory); + ensure_memory_dim_matches(&client, &memory_id, embedding.len()).await?; let mut results = client.search(embedding).await?; results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal)); diff --git a/rust/tui/provider/mod.rs b/rust/tui/provider/mod.rs index 457778b..6b42a74 100644 --- a/rust/tui/provider/mod.rs +++ b/rust/tui/provider/mod.rs @@ -15,6 +15,7 @@ use crate::{ agent::{KeychainErrorCode, extract_keychain_error_code}, create_domain::derive_create_cost, embedding::fetch_embedding, + embedding_config::{normalize_supported_embedding_backend_id, supported_embedding_backends}, insert_service::{ InsertRequest, parse_embedding_json, validate_insert_request_fields, validate_insert_request_for_submit, @@ -368,6 +369,7 @@ fn add_action_label_for_context(context: PickerContext) -> Option<&'static str> PickerContext::DefaultMemory | PickerContext::InsertTarget | PickerContext::AddTag + | PickerContext::EmbeddingModel | PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity => None, @@ -388,6 +390,7 @@ fn picker_selected_id_for_context( (!insert_tag.is_empty()).then(|| insert_tag.to_string()) } PickerContext::TagManagement | PickerContext::AddTag => None, + PickerContext::EmbeddingModel => Some(user_preferences.embedding_model_id.clone()), PickerContext::ChatResultLimit => Some(user_preferences.chat_overall_top_k.to_string()), PickerContext::ChatPerMemoryLimit => Some(user_preferences.chat_per_memory_cap.to_string()), PickerContext::ChatDiversity => Some(user_preferences.chat_mmr_lambda.to_string()), @@ -442,6 +445,16 @@ fn picker_items_for_context( .unwrap_or_default(), _ => Vec::new(), }, + PickerContext::EmbeddingModel => supported_embedding_backends() + .iter() + .map(|backend| { + PickerItem::option( + backend.id.to_string(), + format!("{} ({})", backend.label, backend.dimension), + user_preferences.embedding_model_id == backend.id, + ) + }) + .collect(), PickerContext::ChatResultLimit => prefs_policy::chat_result_limit_options() .iter() .map(|value| { @@ -513,6 +526,7 @@ impl<'a> DefaultMemoryController<'a> { chat_overall_top_k: self.user_preferences.chat_overall_top_k, chat_per_memory_cap: self.user_preferences.chat_per_memory_cap, chat_mmr_lambda: self.user_preferences.chat_mmr_lambda, + embedding_model_id: self.user_preferences.embedding_model_id.clone(), }; #[cfg(test)] let _settings_io_lock = settings_io_lock() @@ -2523,6 +2537,9 @@ impl KinicProvider { CoreEffect::Notify(format!("Selected tag {} for insert", item.id)), ], PickerContext::AddTag => Vec::new(), + PickerContext::EmbeddingModel => { + vec![self.set_embedding_model_preference(item.id.as_str())] + } PickerContext::ChatResultLimit => item .id .parse::() @@ -2790,6 +2807,21 @@ impl KinicProvider { } } + fn set_embedding_model_preference(&mut self, model_id: &str) -> CoreEffect { + let normalized = normalize_supported_embedding_backend_id(model_id); + if self.user_preferences.embedding_model_id == normalized { + return CoreEffect::Notify(format!("Embedding backend already set to {normalized}")); + } + match self.update_user_preferences(|preferences| { + preferences.embedding_model_id = normalized.to_string(); + }) { + Ok(()) => CoreEffect::Notify(format!( + "Embedding backend set to {normalized}. API preserves old memories. Local backends may need reindex." + )), + Err(error) => CoreEffect::Notify(format!("Embedding backend save failed: {error}")), + } + } + fn update_user_preferences(&mut self, update: F) -> Result<(), String> where F: FnOnce(&mut UserPreferences), @@ -4998,6 +5030,7 @@ impl DataProvider for KinicProvider { | PickerContext::InsertTarget | PickerContext::InsertTag | PickerContext::TagManagement + | PickerContext::EmbeddingModel | PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity diff --git a/rust/tui/settings.rs b/rust/tui/settings.rs index ad1fb94..750378c 100644 --- a/rust/tui/settings.rs +++ b/rust/tui/settings.rs @@ -12,10 +12,12 @@ use tui_kit_host::settings::SettingsError; use tui_kit_host::settings::{load_yaml_or_default, save_yaml}; use tui_kit_runtime::{ SETTINGS_ENTRY_CHAT_DIVERSITY_ID, SETTINGS_ENTRY_CHAT_PER_MEMORY_LIMIT_ID, - SETTINGS_ENTRY_CHAT_RESULT_LIMIT_ID, SETTINGS_ENTRY_DEFAULT_MEMORY_ID, SessionAccountOverview, - SessionSettingsSnapshot, SettingsEntry, SettingsSection, SettingsSnapshot, + SETTINGS_ENTRY_CHAT_RESULT_LIMIT_ID, SETTINGS_ENTRY_DEFAULT_MEMORY_ID, + SETTINGS_ENTRY_EMBEDDING_MODEL_ID, SessionAccountOverview, SessionSettingsSnapshot, + SettingsEntry, SettingsSection, SettingsSnapshot, }; +use crate::embedding_config::supported_embedding_backends; use crate::preferences::{ UserPreferences, chat_diversity_display, chat_per_memory_limit_display, chat_result_limit_display, @@ -28,6 +30,7 @@ const APP_NAMESPACE: &str = "kinic"; const CHAT_HISTORY_FILE_NAME: &str = "chat-threads.yaml"; const UNAVAILABLE: &str = "unavailable"; const NOT_SET: &str = "not set"; +const EMBEDDING_MODEL_NOTE: &str = "Affects search/insert. New memories are created at 1024 dims. API keeps existing memories usable. Reindex or reset older non-1024 memories before local search/insert."; const CHAT_HISTORY_MAX_MESSAGES: usize = 40; const CHAT_MESSAGE_MAX_CONTENT_LEN: usize = 4096; @@ -457,6 +460,7 @@ pub fn build_settings_snapshot( let default_memory_display = default_memory_display(preferences, selector_items, selector_labels); let saved_tags_display = saved_tags_display(preferences); + let embedding_model_display = embedding_model_display(preferences); let chat_result_limit_display = chat_result_limit_display(preferences.chat_overall_top_k); let chat_per_memory_limit_display = chat_per_memory_limit_display(preferences.chat_per_memory_cap); @@ -490,7 +494,7 @@ pub fn build_settings_snapshot( }, SettingsEntry { id: "embedding_api_endpoint".to_string(), - label: "Embedding".to_string(), + label: "Chat API".to_string(), value: session.embedding_api_endpoint.clone(), note: None, }, @@ -511,6 +515,12 @@ pub fn build_settings_snapshot( value: saved_tags_display, note: None, }, + SettingsEntry { + id: SETTINGS_ENTRY_EMBEDDING_MODEL_ID.to_string(), + label: "Embedding backend".to_string(), + value: embedding_model_display, + note: Some(EMBEDDING_MODEL_NOTE.to_string()), + }, SettingsEntry { id: "preferences_status".to_string(), label: "Preferences status".to_string(), @@ -567,7 +577,7 @@ pub fn build_settings_snapshot( }, SettingsEntry { id: "embedding_api_endpoint".to_string(), - label: "Embedding".to_string(), + label: "Chat API".to_string(), value: session.embedding_api_endpoint.clone(), note: None, }, @@ -719,6 +729,14 @@ fn saved_tags_display(preferences: &UserPreferences) -> String { preferences.saved_tags.join(", ") } +fn embedding_model_display(preferences: &UserPreferences) -> String { + supported_embedding_backends() + .into_iter() + .find(|model| model.id == preferences.embedding_model_id) + .map(|model| format!("{} ({})", model.label, model.dimension)) + .unwrap_or_else(|| preferences.embedding_model_id.clone()) +} + #[cfg(test)] #[path = "settings_tests.rs"] mod tests; diff --git a/rust/tui/settings_tests.rs b/rust/tui/settings_tests.rs index 090a8fa..0a35c82 100644 --- a/rust/tui/settings_tests.rs +++ b/rust/tui/settings_tests.rs @@ -1,7 +1,9 @@ use super::*; use crate::preferences::UserPreferences; use candid::Nat; -use tui_kit_runtime::{SETTINGS_ENTRY_DEFAULT_MEMORY_ID, SessionAccountOverview}; +use tui_kit_runtime::{ + SETTINGS_ENTRY_DEFAULT_MEMORY_ID, SETTINGS_ENTRY_EMBEDDING_MODEL_ID, SessionAccountOverview, +}; fn deferred_session() -> SessionSettingsSnapshot { session_settings_snapshot( @@ -198,6 +200,24 @@ fn settings_snapshot_projects_default_memory_and_preferences_status() { status, "{name}" ); + assert_eq!( + section_entry_value( + &snapshot, + "Saved preferences", + SETTINGS_ENTRY_EMBEDDING_MODEL_ID + ), + "API (remote default) (1024)", + "{name}" + ); + assert!( + section_entry_note( + &snapshot, + "Saved preferences", + SETTINGS_ENTRY_EMBEDDING_MODEL_ID + ) + .is_some(), + "{name}" + ); assert_eq!( section_titles, vec![ @@ -271,6 +291,7 @@ fn settings_snapshot_projects_chat_retrieval_section() { chat_overall_top_k: 10, chat_per_memory_cap: 4, chat_mmr_lambda: 80, + embedding_model_id: "BAAI/bge-m3".to_string(), ..UserPreferences::default() }, &Vec::new(), @@ -290,6 +311,14 @@ fn settings_snapshot_projects_chat_retrieval_section() { section_entry_value(&snapshot, "Chat retrieval", "chat_diversity"), "0.80" ); + assert_eq!( + section_entry_value( + &snapshot, + "Saved preferences", + SETTINGS_ENTRY_EMBEDDING_MODEL_ID + ), + "BAAI BGE-M3 (1024)" + ); } #[test] diff --git a/tests/fixtures/capabilities_golden.json b/tests/fixtures/capabilities_golden.json index 046a212..49b5b2b 100644 --- a/tests/fixtures/capabilities_golden.json +++ b/tests/fixtures/capabilities_golden.json @@ -1044,6 +1044,32 @@ "value_kind": "integer" } ] + }, + { + "name": "set-embedding-backend", + "summary": "Set the embedding backend shared with the TUI. Returns JSON.", + "auth": { + "required": false, + "sources": [] + }, + "output": { + "default": "json", + "supported": [ + "json" + ], + "interactive": false + }, + "global_flags_supported": [ + "verbose" + ], + "arguments": [ + { + "name": "model_id", + "required": true, + "input_shape": "single_value", + "value_kind": "string" + } + ] } ] }, diff --git a/tests/kinic_cli_capabilities.rs b/tests/kinic_cli_capabilities.rs index fec22de..b1db9e2 100644 --- a/tests/kinic_cli_capabilities.rs +++ b/tests/kinic_cli_capabilities.rs @@ -112,6 +112,13 @@ fn capabilities_describes_prefs_and_tui_contracts() { .iter() .any(|entry| entry["name"] == "set-default-memory") ); + assert!( + prefs["subcommands"] + .as_array() + .unwrap() + .iter() + .any(|entry| entry["name"] == "set-embedding-backend") + ); let prefs_add_memory = prefs["subcommands"] .as_array() .unwrap() diff --git a/tests/kinic_cli_prefs.rs b/tests/kinic_cli_prefs.rs index 192b629..603f572 100644 --- a/tests/kinic_cli_prefs.rs +++ b/tests/kinic_cli_prefs.rs @@ -58,6 +58,7 @@ fn prefs_show_runs_without_identity() { "chat_overall_top_k": 8, "chat_per_memory_cap": 3, "chat_mmr_lambda": 70, + "embedding_model_id": "api", }) ); } @@ -166,6 +167,7 @@ fn prefs_mutations_update_shared_yaml_and_preserve_chat_fields() { "chat_overall_top_k": 10, "chat_per_memory_cap": 4, "chat_mmr_lambda": 80, + "embedding_model_id": "api", }) ); } @@ -299,6 +301,113 @@ fn prefs_chat_retrieval_mutations_update_yaml_and_show_output() { assert_eq!(parsed["chat_mmr_lambda"], json!(80)); } +#[test] +fn prefs_set_embedding_backend_updates_yaml_and_show_output() { + let config_dir = temp_config_dir("embedding-backend"); + let kinic_dir = app_config_root(&config_dir).join("kinic"); + fs::create_dir_all(&kinic_dir).unwrap(); + fs::write( + kinic_dir.join("tui.yaml"), + "embedding_model_id: BAAI/bge-m3\n", + ) + .unwrap(); + + let output = prefs_command(&config_dir) + .args(["prefs", "set-embedding-backend", "--model-id", "api"]) + .output() + .unwrap(); + assert!(output.status.success()); + let parsed: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("json response should parse"); + assert_eq!( + parsed, + json!({ + "resource": "embedding_model_id", + "action": "set", + "status": "updated", + "value": "api" + }) + ); + + let output = prefs_command(&config_dir) + .args(["prefs", "set-embedding-backend", "--model-id", "api"]) + .output() + .unwrap(); + assert!(output.status.success()); + let parsed: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("json response should parse"); + assert_eq!( + parsed, + json!({ + "resource": "embedding_model_id", + "action": "set", + "status": "unchanged", + "value": "api" + }) + ); + + let yaml = read_prefs_yaml(&config_dir); + assert!(yaml.contains("embedding_model_id: api")); + + let output = prefs_command(&config_dir) + .args(["prefs", "show"]) + .output() + .unwrap(); + assert!(output.status.success()); + let parsed: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("show output should parse"); + assert_eq!(parsed["embedding_model_id"], json!("api")); +} + +#[test] +fn prefs_set_embedding_backend_accepts_bgem3() { + let config_dir = temp_config_dir("embedding-backend-bgem3"); + + let output = prefs_command(&config_dir) + .args([ + "prefs", + "set-embedding-backend", + "--model-id", + "BAAI/bge-m3", + ]) + .output() + .unwrap(); + assert!(output.status.success()); + let parsed: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("json response should parse"); + assert_eq!( + parsed, + json!({ + "resource": "embedding_model_id", + "action": "set", + "status": "updated", + "value": "BAAI/bge-m3" + }) + ); +} + +#[test] +fn prefs_set_embedding_backend_normalizes_invalid_values_to_api() { + let config_dir = temp_config_dir("embedding-backend-normalize"); + + let output = prefs_command(&config_dir) + .args(["prefs", "set-embedding-backend", "--model-id", "bad-model"]) + .output() + .unwrap(); + assert!(output.status.success()); + let parsed: serde_json::Value = + serde_json::from_slice(&output.stdout).expect("json response should parse"); + assert_eq!( + parsed, + json!({ + "resource": "embedding_model_id", + "action": "set", + "status": "unchanged", + "value": "api" + }) + ); +} + #[test] fn prefs_add_memory_validate_requires_identity_or_ii() { let config_dir = temp_config_dir("add-memory-validate"); diff --git a/tests/kinic_cli_tui.rs b/tests/kinic_cli_tui.rs index 86a38d4..c46f708 100644 --- a/tests/kinic_cli_tui.rs +++ b/tests/kinic_cli_tui.rs @@ -30,8 +30,9 @@ fn prefs_help_mentions_json_contract_and_examples() { assert!(output.status.success()); let stdout = String::from_utf8(output.stdout).unwrap(); assert!(stdout.contains("All prefs commands return JSON.")); - assert!(stdout.contains("show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer}")); + assert!(stdout.contains("show -> {\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}")); assert!(stdout.contains("kinic-cli prefs set-default-memory --memory-id")); + assert!(stdout.contains("kinic-cli prefs set-embedding-backend --model-id BAAI/bge-m3")); } #[test] @@ -44,7 +45,7 @@ fn prefs_show_help_mentions_return_shape() { assert!(output.status.success()); let stdout = String::from_utf8(output.stdout).unwrap(); assert!(stdout.contains("Returns JSON.")); - assert!(stdout.contains("{\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer}")); + assert!(stdout.contains("{\"default_memory_id\": string|null, \"saved_tags\": string[], \"manual_memory_ids\": string[], \"chat_overall_top_k\": integer, \"chat_per_memory_cap\": integer, \"chat_mmr_lambda\": integer, \"embedding_model_id\": string}")); } #[test] diff --git a/tui/crates/tui-kit-render/src/ui/app/screens/settings/mod.rs b/tui/crates/tui-kit-render/src/ui/app/screens/settings/mod.rs index e721936..fe9b2fe 100644 --- a/tui/crates/tui-kit-render/src/ui/app/screens/settings/mod.rs +++ b/tui/crates/tui-kit-render/src/ui/app/screens/settings/mod.rs @@ -184,6 +184,7 @@ pub(crate) fn picker_hint(context: PickerContext) -> &'static str { PickerContext::InsertTag => " Enter: choose ↑/↓: move Esc: close", PickerContext::TagManagement => " Enter: use d: delete ↑/↓: browse Esc: close", PickerContext::AddTag => " Enter: save tag Esc: close", + PickerContext::EmbeddingModel => " Enter: save ↑/↓: move Esc: close", PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity => " Enter: save ↑/↓: move Esc: close", @@ -197,6 +198,7 @@ pub(crate) fn picker_input_placeholder(context: PickerContext) -> &'static str { | PickerContext::InsertTarget | PickerContext::InsertTag | PickerContext::TagManagement + | PickerContext::EmbeddingModel | PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity => "", @@ -217,6 +219,7 @@ fn picker_context_title(context: PickerContext) -> &'static str { PickerContext::InsertTag => "Select insert tag", PickerContext::TagManagement => "Saved tags", PickerContext::AddTag => "Add tag", + PickerContext::EmbeddingModel => "Embedding backend", PickerContext::ChatResultLimit => "Chat result limit", PickerContext::ChatPerMemoryLimit => "Per-memory limit", PickerContext::ChatDiversity => "Chat diversity", @@ -252,6 +255,7 @@ fn picker_empty_message(context: PickerContext) -> &'static str { PickerContext::InsertTag | PickerContext::TagManagement | PickerContext::AddTag => { " No saved tags yet." } + PickerContext::EmbeddingModel => " No embedding backends available.", PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity => " No options available yet.", diff --git a/tui/crates/tui-kit-runtime/src/lib.rs b/tui/crates/tui-kit-runtime/src/lib.rs index 1d80a05..c942565 100644 --- a/tui/crates/tui-kit-runtime/src/lib.rs +++ b/tui/crates/tui-kit-runtime/src/lib.rs @@ -26,6 +26,7 @@ use tui_kit_model::{UiContextNode, UiItemContent, UiItemKind, UiItemSummary}; pub const SETTINGS_ENTRY_DEFAULT_MEMORY_ID: &str = "default_memory"; pub const SETTINGS_ENTRY_KINIC_BALANCE_ID: &str = "kinic_balance"; pub const SETTINGS_ENTRY_SAVED_TAGS_ID: &str = "saved_tags"; +pub const SETTINGS_ENTRY_EMBEDDING_MODEL_ID: &str = "embedding_model"; pub const SETTINGS_ENTRY_CHAT_RESULT_LIMIT_ID: &str = "chat_result_limit"; pub const SETTINGS_ENTRY_CHAT_PER_MEMORY_LIMIT_ID: &str = "chat_per_memory_limit"; pub const SETTINGS_ENTRY_CHAT_DIVERSITY_ID: &str = "chat_diversity"; @@ -179,6 +180,7 @@ pub enum PickerContext { InsertTag, TagManagement, AddTag, + EmbeddingModel, ChatResultLimit, ChatPerMemoryLimit, ChatDiversity, @@ -2629,6 +2631,10 @@ pub fn settings_row_behavior_for_index( Some(CoreAction::OpenPicker(PickerContext::TagManagement)), " manage saved tags ", ), + SETTINGS_ENTRY_EMBEDDING_MODEL_ID => SettingsRowBehavior::new( + Some(CoreAction::OpenPicker(PickerContext::EmbeddingModel)), + " choose embedding backend ", + ), SETTINGS_ENTRY_CHAT_RESULT_LIMIT_ID => SettingsRowBehavior::new( Some(CoreAction::OpenPicker(PickerContext::ChatResultLimit)), " adjust chat limit ", @@ -2780,6 +2786,7 @@ fn picker_selected_index( } PickerContext::TagManagement | PickerContext::AddTag + | PickerContext::EmbeddingModel | PickerContext::ChatResultLimit | PickerContext::ChatPerMemoryLimit | PickerContext::ChatDiversity => None, @@ -4017,6 +4024,38 @@ mod tests { ); } + #[test] + fn selected_settings_row_behavior_matches_embedding_model_row() { + let settings = SettingsSnapshot { + quick_entries: vec![], + sections: vec![SettingsSection { + title: "Saved preferences".to_string(), + entries: vec![SettingsEntry { + id: SETTINGS_ENTRY_EMBEDDING_MODEL_ID.to_string(), + label: "Embedding backend".to_string(), + value: "API (remote default) (1024)".to_string(), + note: None, + }], + footer: None, + }], + }; + let state = CoreState { + current_tab_id: kinic_tabs::KINIC_SETTINGS_TAB_ID.to_string(), + focus: PaneFocus::Content, + selected_index: Some(0), + settings, + ..CoreState::default() + }; + + assert_eq!( + selected_settings_row_behavior(&state), + Some(SettingsRowBehavior { + enter_action: Some(CoreAction::OpenPicker(PickerContext::EmbeddingModel)), + status_hint: " choose embedding backend ", + }) + ); + } + #[test] fn selected_settings_row_behavior_matches_chat_result_limit_row() { let settings = SettingsSnapshot {