diff --git a/Cargo.lock b/Cargo.lock index f46addca..eff66fa1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,21 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "addr2line" +version = "0.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a76fd60b23679b7d19bd066031410fb7e458ccc5e958eb5c325888ce4baedc97" +dependencies = [ + "gimli", +] + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + [[package]] name = "ahash" version = "0.3.8" @@ -17,12 +32,67 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstream" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-wincon", + "concolor-override", + "concolor-query", + "is-terminal", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2" + +[[package]] +name = "anstyle-parse" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-wincon" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa" +dependencies = [ + "anstyle", + "windows-sys 0.45.0", +] + [[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "backtrace" +version = "0.3.67" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233d376d6d185f2a3093e58f283f60f880315b6c60075b01f36b3b85154564ca" +dependencies = [ + "addr2line", + "cc", + "cfg-if", + "libc", + "miniz_oxide", + "object", + "rustc-demangle", +] + [[package]] name = "bincode" version = "1.3.3" @@ -103,40 +173,45 @@ dependencies = [ [[package]] name = "clap" -version = "4.1.8" +version = "4.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5" +checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3" dependencies = [ - "bitflags", + "clap_builder", "clap_derive", - "clap_lex", - "is-terminal", "once_cell", +] + +[[package]] +name = "clap_builder" +version = "4.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f" +dependencies = [ + "anstream", + "anstyle", + "bitflags", + "clap_lex", "strsim", - "termcolor", ] [[package]] name = "clap_derive" -version = "4.1.8" +version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0" +checksum = "3f9644cd56d6b87dbe899ef8b053e331c0637664e9e21a33dfcdc36093f5c5c4" dependencies = [ "heck", - "proc-macro-error", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.13", ] [[package]] name = "clap_lex" -version = "0.3.2" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09" -dependencies = [ - "os_str_bytes", -] +checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" [[package]] name = "clipboard-win" @@ -149,6 +224,34 @@ dependencies = [ "winapi", ] +[[package]] +name = "color-eyre" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a667583cca8c4f8436db8de46ea8233c42a7d9ae424a82d338f2e4675229204" +dependencies = [ + "backtrace", + "eyre", + "indenter", + "once_cell", + "owo-colors", +] + +[[package]] +name = "concolor-override" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f" + +[[package]] +name = "concolor-query" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf" +dependencies = [ + "windows-sys 0.45.0", +] + [[package]] name = "crossbeam-channel" version = "0.5.7" @@ -261,13 +364,13 @@ dependencies = [ [[package]] name = "errno" -version = "0.2.8" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +checksum = "50d6a0976c999d473fe89ad888d5a284e55366d9dc9038b1ba2aa15128c4afa0" dependencies = [ "errno-dragonfly", "libc", - "winapi", + "windows-sys 0.45.0", ] [[package]] @@ -290,15 +393,25 @@ dependencies = [ "str-buf", ] +[[package]] +name = "eyre" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c2b6b5a29c02cdc822728b7d7b8ae1bab3e3b05d44522770ddd49722eeac7eb" +dependencies = [ + "indenter", + "once_cell", +] + [[package]] name = "fd-lock" -version = "3.0.10" +version = "3.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ef1a30ae415c3a691a4f41afddc2dbcd6d70baf338368d85ebc1e8ed92cedb9" +checksum = "39ae6b3d9530211fb3b12a95374b8b0823be812f53d09e18c5675c0146b09642" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -310,9 +423,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", "libc", @@ -326,6 +439,14 @@ dependencies = [ "ggml-sys", ] +[[package]] +name = "ggml-loader" +version = "0.1.0" +dependencies = [ + "ggml", + "thiserror", +] + [[package]] name = "ggml-sys" version = "0.1.0" @@ -333,6 +454,12 @@ dependencies = [ "cc", ] +[[package]] +name = "gimli" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad0a93d233ebf96623465aad4046a8d3aa4da22d4f4beba5388838c8a434bbb4" + [[package]] name = "glob" version = "0.3.1" @@ -376,26 +503,33 @@ version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "indenter" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" + [[package]] name = "io-lifetimes" -version = "1.0.6" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfa919a82ea574332e2de6e74b4c36e74d41982b335080fa59d4ef31be20fdf3" +checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" dependencies = [ + "hermit-abi 0.3.1", "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "is-terminal" -version = "0.4.4" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857" +checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" dependencies = [ "hermit-abi 0.3.1", "io-lifetimes", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -436,9 +570,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.140" +version = "0.2.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" +checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" [[package]] name = "libloading" @@ -452,9 +586,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.1.4" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" +checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" [[package]] name = "llama-cli" @@ -462,6 +596,7 @@ version = "0.1.0" dependencies = [ "bincode", "clap", + "color-eyre", "env_logger", "llama-rs", "log", @@ -479,6 +614,8 @@ version = "0.1.0" dependencies = [ "bytemuck", "ggml", + "ggml-loader", + "memmap2", "partial_sort", "protobuf", "rand", @@ -510,6 +647,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "memmap2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327" +dependencies = [ + "libc", +] + [[package]] name = "memoffset" version = "0.8.0" @@ -525,6 +671,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" +dependencies = [ + "adler", +] + [[package]] name = "nibble_vec" version = "0.1.0" @@ -566,6 +721,15 @@ dependencies = [ "libc", ] +[[package]] +name = "object" +version = "0.30.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea86265d3d3dcb6a27fc51bd29a4bf387fae9d2986b823079d4986af253eb439" +dependencies = [ + "memchr", +] + [[package]] name = "once_cell" version = "1.17.1" @@ -573,10 +737,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" [[package]] -name = "os_str_bytes" -version = "6.4.1" +name = "owo-colors" +version = "3.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" +checksum = "c1b04fb49957986fdce4d6ee7a65027d55d4b6d2265e5848bbb507b58ccfdb6f" [[package]] name = "partial_sort" @@ -602,35 +766,11 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro2" -version = "1.0.52" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d0e1ae9e836cc3beddd63db0df682593d7e2d3d891ae8c9083d2113e1744224" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" dependencies = [ "unicode-ident", ] @@ -643,9 +783,9 @@ checksum = "8e86d370532557ae7573551a1ec8235a0f8d6cb276c7c9e6aa490b511c447485" [[package]] name = "quote" -version = "1.0.25" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5308e8208729c3e1504a6cfad0d5daacc4614c9a2e65d1ea312a34b5cb00fe84" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -734,9 +874,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" +checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" dependencies = [ "aho-corasick", "memchr", @@ -745,9 +885,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.28" +version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "rust_tokenizers" @@ -768,6 +908,12 @@ dependencies = [ "unicode-normalization-alignments", ] +[[package]] +name = "rustc-demangle" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" + [[package]] name = "rustc-hash" version = "1.1.0" @@ -776,16 +922,16 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.36.9" +version = "0.37.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd5c6ff11fecd55b40746d1995a02f2eb375bf8c00d192d521ee09f42bef37bc" +checksum = "1aef160324be24d31a62147fae491c14d2204a3865c7ca8c3b0d7f7bcb3ea635" dependencies = [ "bitflags", "errno", "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -831,9 +977,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.158" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "771d4d9c4163ee138805e12c710dd365e4f44be8be0503cb1bb9eb989425d9c9" +checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" dependencies = [ "serde_derive", ] @@ -849,20 +995,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.158" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e801c1712f48475582b7696ac71e0ca34ebb30e09338425384269d9717c62cad" +checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" dependencies = [ "proc-macro2", "quote", - "syn 2.0.10", + "syn 2.0.13", ] [[package]] name = "serde_json" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" +checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" dependencies = [ "itoa", "ryu", @@ -945,9 +1091,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.10" +version = "2.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aad1363ed6d37b84299588d62d3a7d95b5a5c2d9aad5c85609fda12afaa1f40" +checksum = "4c9da457c5285ac1f936ebd076af6dac17a61cfe7826f2076b4d015cf47bc8ec" dependencies = [ "proc-macro2", "quote", @@ -965,22 +1111,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5ab016db510546d856297882807df8da66a16fb8c4101cb8b30054b0d5b2d9c" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5420d42e90af0c38c3290abcca25b9b3bdf379fc9f55c528f53a269d9c9a267e" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.13", ] [[package]] @@ -1040,12 +1186,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1100,7 +1240,16 @@ version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" dependencies = [ - "windows-targets", + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.0", ] [[package]] @@ -1109,13 +1258,28 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", ] [[package]] @@ -1124,42 +1288,84 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + [[package]] name = "windows_i686_gnu" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + [[package]] name = "windows_i686_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + [[package]] name = "zstd" version = "0.12.3+zstd.1.5.2" @@ -1171,9 +1377,9 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "6.0.4+zstd.1.5.4" +version = "6.0.5+zstd.1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7afb4b54b8910cf5447638cb54bf4e8a65cbedd783af98b98c62ffe91f185543" +checksum = "d56d9e60b4b1758206c238a10165fbcae3ca37b01744e394c463463f6529d23b" dependencies = [ "libc", "zstd-sys", @@ -1181,9 +1387,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.7+zstd.1.5.4" +version = "2.0.8+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5" +checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c" dependencies = [ "cc", "libc", diff --git a/Cargo.toml b/Cargo.toml index 8ea220d8..f579b1c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "ggml-sys", "ggml", + "ggml-loader", "llama-rs", "llama-cli", "generate-ggml-bindings" diff --git a/ggml-loader/Cargo.toml b/ggml-loader/Cargo.toml new file mode 100644 index 00000000..ab711363 --- /dev/null +++ b/ggml-loader/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "ggml-loader" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +ggml = { path = "../ggml" } +thiserror = "*" diff --git a/ggml-loader/src/lib.rs b/ggml-loader/src/lib.rs new file mode 100644 index 00000000..7d29d4b3 --- /dev/null +++ b/ggml-loader/src/lib.rs @@ -0,0 +1,236 @@ +//! standalone model loader +//! +//! Only the hyperparameter is llama-specific. Everything else can be reused for other LLM. +#![allow(clippy::nonminimal_bool)] + +pub mod util; + +use std::ops::ControlFlow; +use util::*; + +pub type ElementType = ggml::Type; + +/// file type containing the model +#[derive(Debug, PartialEq, Clone, Copy)] +#[allow(clippy::upper_case_acronyms)] +pub enum ContainerType { + /// legacy format, oldest ggml tensor file format + GGML, + /// also legacy format, newer than GGML, older than GGJT + GGMF, + /// mmap-able format + GGJT, +} + +impl ContainerType { + pub fn support_mmap(&self) -> bool { + match self { + ContainerType::GGML => false, + ContainerType::GGMF => false, + ContainerType::GGJT => true, + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum LoadError { + #[error("invalid file magic number: {0}")] + InvalidMagic(u32), + + #[error("invalid ggml format: version={0}")] + InvalidFormatVersion(u32), + + #[error("{0}")] + Io(#[from] std::io::Error), + + #[error("{0}")] + FailedCast(#[from] std::num::TryFromIntError), + + /// return `ControlFlow::Break` from any of the `cb_*` function to trigger this error + #[error("user requested interrupt: {0}")] + UserInterrupted(T), + + #[error("unsupported tensor dtype/f16_: {0}")] + UnsupportedElementType(i32), + + /// sanity check failed + #[error("invariant broken: {0}")] + InvariantBroken(String), +} + +#[derive(Debug, Clone)] +pub struct TensorInfo { + pub name: Vec, + pub n_dims: usize, + pub dims: [usize; 2], + pub n_elements: usize, + pub ftype: ElementType, + /// start of tensor - start of file + pub start_offset: u64, +} + +/// Info in hyperparameter used for later loading tasks. Used in callback. +/// see [`LoadHandler::load_hyper_parameters`] +#[derive(Debug, Clone)] +pub struct PartialHyperparameters { + pub n_vocab: usize, +} + +pub enum TensorDataTreatment<'a> { + CopyInto(&'a mut [u8]), + SeekPast { + /// should be `tensor.nbytes` + n_bytes: usize, + }, +} + +#[allow(unused_variables)] +pub trait LoadHandler { + fn got_container_type(&mut self, container_type: ContainerType) -> ControlFlow { + ControlFlow::Continue(()) + } + + fn got_vocab_token(&mut self, i: usize, token: Vec, score: f32) -> ControlFlow { + ControlFlow::Continue(()) + } + + fn load_hyper_parameters(&mut self, reader: &mut R) -> ControlFlow; + + /// callback to get tensor buffer to populate + /// + /// # Returns + /// + /// `None` to skip copying + /// `Some(buf)` to provide a buffer for copying weights into + fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow; +} + +#[test] +fn can_be_vtable() { + use std::mem::MaybeUninit; + let _a: MaybeUninit>> = MaybeUninit::uninit(); +} + +pub fn load_model_from_reader( + reader: &mut R, + handler: &mut impl LoadHandler, +) -> Result<(), LoadError> { + // Verify magic + let container_type: ContainerType = match read_u32(reader)? { + ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, + ggml::FILE_MAGIC_GGJT => ContainerType::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, + magic => return Err(LoadError::InvalidMagic(magic)), + }; + controlflow_to_result(handler.got_container_type(container_type))?; + + // Load format version + match container_type { + ContainerType::GGMF | ContainerType::GGJT => { + let _version: u32 = match read_u32(reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion(version)), + }; + } + ContainerType::GGML => {} + } + + // Load hyper params + let hparams = controlflow_to_result(handler.load_hyper_parameters(reader))?; + let n_vocab = hparams.n_vocab; + + // Load vocabulary + for i in 0..n_vocab { + let len = read_u32(reader)?.try_into()?; + let token = read_bytes_with_len(reader, len)?; + let token_score = match container_type { + ContainerType::GGMF | ContainerType::GGJT => read_f32(reader)?, + ContainerType::GGML => { + // Legacy model, set empty score + 0. + } + }; + controlflow_to_result(handler.got_vocab_token(i, token, token_score))?; + } + + // Load tensor data + match container_type { + ContainerType::GGMF | ContainerType::GGML => load_weights(reader, handler, false), + ContainerType::GGJT => load_weights(reader, handler, true), + } +} + +/// # Params +/// +/// `align` +/// align to 4 bytes before reading tensor weights +pub fn load_weights( + reader: &mut R, + handler: &mut impl LoadHandler, + align: bool, +) -> Result<(), LoadError> { + while has_data_left(reader)? { + // load tensor header + let n_dims: usize = read_i32(reader)?.try_into()?; + let name_len = read_i32(reader)?; + let ftype = decode_element_type_res(read_i32(reader)?)?; + + let mut n_elements: usize = 1; + let mut dims = [1usize, 1]; + let ne_len = dims.len(); + if !(n_dims <= ne_len) { + return Err(LoadError::InvariantBroken(format!("{n_dims} <= {ne_len}"))); + } + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + let dim: usize = read_i32(reader)?.try_into()?; + dims[i] = dim; + n_elements *= dim; + } + + // load tensor name + let name = read_bytes_with_len(reader, name_len.try_into()?)?; + + // sanity check + match ftype { + ElementType::Q4_0 | ElementType::Q4_1 => { + if !(dims[0] % 64 == 0) { + return Err(LoadError::InvariantBroken(format!("{dims:?}[0] % 64 == 0"))); + } + } + _ => {} + } + + // load tensor weights + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = if align { + (offset_curr + 31) & !31 + } else { + offset_curr + }; + + let tensor_info = TensorInfo { + name, + dims, + n_dims, + n_elements, + ftype, + start_offset: offset_aligned, + }; + + match controlflow_to_result(handler.tensor_buffer(tensor_info))? { + TensorDataTreatment::CopyInto(buf) => { + if align { + reader.seek(SeekFrom::Start(offset_aligned))?; + } + reader.read_exact(buf)?; + } + TensorDataTreatment::SeekPast { n_bytes } => { + // skip if no buffer is given + reader.seek(SeekFrom::Start(offset_aligned + n_bytes as u64))?; + } + } + } + + Ok(()) +} diff --git a/ggml-loader/src/util.rs b/ggml-loader/src/util.rs new file mode 100644 index 00000000..33374fd6 --- /dev/null +++ b/ggml-loader/src/util.rs @@ -0,0 +1,77 @@ +pub use std::io::{BufRead, Seek, SeekFrom}; +use std::ops::ControlFlow; + +use crate::{ElementType, LoadError}; + +pub fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], std::io::Error> { + let mut bytes = [0u8; N]; + reader.read_exact(&mut bytes)?; + Ok(bytes) +} + +pub fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub fn read_bytes_with_len( + reader: &mut impl BufRead, + len: usize, +) -> Result, std::io::Error> { + let mut bytes = vec![0u8; len]; + reader.read_exact(&mut bytes)?; + Ok(bytes) +} + +// NOTE: Implementation from #![feature(buf_read_has_data_left)] +pub fn has_data_left(reader: &mut impl BufRead) -> Result { + reader.fill_buf().map(|b| !b.is_empty()) +} + +pub fn decode_element_type(ftype: i32) -> Option { + match ftype { + 0 => Some(ggml::Type::F32), + 1 => Some(ggml::Type::F16), + 2 => Some(ggml::Type::Q4_0), + 3 => Some(ggml::Type::Q4_1), + _ => None, + } +} + +pub fn encode_element_type(element_type: ElementType) -> Option { + match element_type { + ggml::Type::F32 => Some(0), + ggml::Type::F16 => Some(1), + ggml::Type::Q4_0 => Some(2), + ggml::Type::Q4_1 => Some(3), + _ => None, + } +} + +pub fn decode_element_type_res(ftype: i32) -> Result> { + match decode_element_type(ftype) { + Some(x) => Ok(x), + None => Err(LoadError::UnsupportedElementType(ftype)), + } +} + +pub fn controlflow_to_result(x: ControlFlow) -> Result> { + match x { + ControlFlow::Continue(x) => Ok(x), + ControlFlow::Break(y) => Err(LoadError::UserInterrupted(y)), + } +} + +pub fn result_to_controlflow>(x: Result) -> ControlFlow { + match x { + Ok(x) => ControlFlow::Continue(x), + Err(y) => ControlFlow::Break(y.into()), + } +} diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 7625e5b1..11d4246b 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -14,18 +14,21 @@ use std::{ sync::{Arc, Weak}, }; -/// Magic constant for `ggml` files (versioned). -pub const FILE_MAGIC: u32 = 0x67676d66; +/// Magic constant for `ggml` files (versioned, ggmf). +pub const FILE_MAGIC_GGMF: u32 = 0x67676d66; +/// Magic constant for `ggml` files (versioned, ggjt). +pub const FILE_MAGIC_GGJT: u32 = 0x67676a74; /// Magic constant for `ggml` files (unversioned). pub const FILE_MAGIC_UNVERSIONED: u32 = 0x67676d6c; /// The currently-supported format version for `ggml` files. pub const FORMAT_VERSION: u32 = 1; -#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] /// The type of a value in `ggml`. pub enum Type { /// Quantized 4-bit (type 0). + #[default] Q4_0, /// Quantized 4-bit (type 1); used by GPTQ. Q4_1, @@ -83,14 +86,14 @@ pub struct Context { } impl Context { /// Creates a new [Context] with the specified `mem_size` as a working area. - pub fn init(mem_size: usize) -> Self { + pub fn init(mem_size: usize, alloc: bool) -> Self { let raw = unsafe { ggml_sys::ggml_init(ggml_sys::ggml_init_params { mem_size, // Null here means we want ggml to own this memory. We don't // support passing an owned buffer from the Rust side. mem_buffer: std::ptr::null_mut(), - no_alloc: false, + no_alloc: !alloc, }) }; Self { @@ -523,6 +526,14 @@ impl Tensor { } } + fn with_alive_ctx_mut(&self, mut f: impl FnMut() -> U) -> U { + if let Some(_ctx) = self.ctx.upgrade() { + f() + } else { + panic!("Using a tensor after the context was dropped") + } + } + /// Number of bytes used by this tensor. pub fn nbytes(&self) -> usize { self.with_alive_ctx(|| { @@ -535,14 +546,27 @@ impl Tensor { /// /// # Safety /// - /// The data must not be mutated while being read from. - pub unsafe fn data(&self) -> *mut c_void { + /// Only `std::slice::from_raw_parts_mut(tensor.data(), tensor.nbytes())` is safe to mutate. + pub unsafe fn data(&mut self) -> *mut c_void { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive unsafe { *self.ptr.as_ptr() }.data }) } + /// Set the tensor's data pointer (useful for mmap-ed data) + /// + /// # Safety + /// + /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. + pub unsafe fn set_data(&mut self, data_ptr: *mut c_void) { + let tensor = self.ptr.as_mut(); + self.with_alive_ctx_mut(|| { + // SAFETY: The with_alive_call guarantees the context is alive + tensor.data = data_ptr; + }) + } + /// Number of elements in this tensor. pub fn nelements(&self) -> usize { self.with_alive_ctx(|| { @@ -576,12 +600,12 @@ impl Tensor { /// # Safety /// /// This tensor must not be written to or read by from any other code. - pub unsafe fn write_data(&self, src: &[u8]) { + pub unsafe fn write_data(&mut self, src: &[u8]) { std::ptr::copy_nonoverlapping(src.as_ptr(), self.data() as *mut u8, src.len()) } /// Zeroes out this tensor. - pub fn zero_data(&self) { + pub fn zero_data(&mut self) { unsafe { std::ptr::write_bytes(self.data() as *mut u8, 0, self.nbytes()) } } diff --git a/llama-cli/Cargo.toml b/llama-cli/Cargo.toml index 31400c67..d4914b15 100644 --- a/llama-cli/Cargo.toml +++ b/llama-cli/Cargo.toml @@ -1,22 +1,22 @@ [package] -name = "llama-cli" -version = { workspace = true } edition = "2021" +name = "llama-cli" +version = {workspace = true} # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -llama-rs = { path = "../llama-rs", features = ["convert"] } +llama-rs = {path = "../llama-rs", features = ["convert"]} -rand = { workspace = true } +rand = {workspace = true} bincode = "1.3.3" -clap = { version = "4.1.8", features = ["derive"] } +clap = {version = "4.1.8", features = ["derive"]} +color-eyre = {version = "0.6.2", default-features = false} env_logger = "0.10.0" log = "0.4" num_cpus = "1.15.0" once_cell = "1.17.1" rustyline = "11.0.0" spinners = "4.1.0" -zstd = { version = "0.12", default-features = false } -color-eyre = { version = "0.6.2", default-features = false } +zstd = {version = "0.12", default-features = false} diff --git a/llama-cli/src/cli_args.rs b/llama-cli/src/cli_args.rs index e215d486..e31d4f48 100644 --- a/llama-cli/src/cli_args.rs +++ b/llama-cli/src/cli_args.rs @@ -259,59 +259,73 @@ pub struct ModelLoad { /// will likely not perform as well as a model with a larger context size. #[arg(long, default_value_t = 2048)] pub num_ctx_tokens: usize, + + /// Don't use mmap to load the model. + #[arg(long)] + pub no_mmap: bool, } impl ModelLoad { pub fn load(&self) -> Result { - let model = llama_rs::Model::load(&self.model_path, self.num_ctx_tokens, |progress| { - use llama_rs::LoadProgress; - match progress { - LoadProgress::HyperparametersLoaded(hparams) => { - log::debug!("Loaded hyperparameters {hparams:#?}") - } - LoadProgress::ContextSize { bytes } => log::info!( - "ggml ctx size = {:.2} MB\n", - bytes as f64 / (1024.0 * 1024.0) - ), - LoadProgress::PartLoading { - file, - current_part, - total_parts, - } => { - let current_part = current_part + 1; - log::info!( - "Loading model part {}/{} from '{}'\n", + let now = std::time::Instant::now(); + let model = llama_rs::Model::load( + &self.model_path, + !self.no_mmap, + self.num_ctx_tokens, + |progress| { + use llama_rs::LoadProgress; + match progress { + LoadProgress::HyperparametersLoaded(hparams) => { + log::debug!("Loaded hyperparameters {hparams:#?}") + } + LoadProgress::ContextSize { bytes } => log::info!( + "ggml ctx size = {:.2} MB\n", + bytes as f64 / (1024.0 * 1024.0) + ), + LoadProgress::PartLoading { + file, current_part, total_parts, - file.to_string_lossy(), - ) - } - LoadProgress::PartTensorLoaded { - current_tensor, - tensor_count, - .. - } => { - let current_tensor = current_tensor + 1; - if current_tensor % 8 == 0 { - log::info!("Loaded tensor {current_tensor}/{tensor_count}"); + } => { + let current_part = current_part + 1; + log::info!( + "Loading model part {}/{} from '{}' (mmap preferred: {})\n", + current_part, + total_parts, + file.to_string_lossy(), + !self.no_mmap + ) + } + LoadProgress::PartTensorLoaded { + current_tensor, + tensor_count, + .. + } => { + let current_tensor = current_tensor + 1; + if current_tensor % 8 == 0 { + log::info!("Loaded tensor {current_tensor}/{tensor_count}"); + } + } + LoadProgress::PartLoaded { + file, + byte_size, + tensor_count, + } => { + log::info!("Loading of '{}' complete", file.to_string_lossy()); + log::info!( + "Model size = {:.2} MB / num tensors = {}", + byte_size as f64 / 1024.0 / 1024.0, + tensor_count + ); } } - LoadProgress::PartLoaded { - file, - byte_size, - tensor_count, - } => { - log::info!("Loading of '{}' complete", file.to_string_lossy()); - log::info!( - "Model size = {:.2} MB / num tensors = {}", - byte_size as f64 / 1024.0 / 1024.0, - tensor_count - ); - } - } - }) - .wrap_err("Failed to load model")?; + }, + ) + .wrap_err("Could not load model")?; - log::info!("Model fully loaded!"); + log::info!( + "Model fully loaded! Elapsed: {}ms", + now.elapsed().as_millis() + ); Ok(model) } @@ -375,8 +389,8 @@ pub enum ElementType { F32, } impl From for llama_rs::ElementType { - fn from(model_type: ElementType) -> Self { - match model_type { + fn from(t: ElementType) -> Self { + match t { ElementType::Q4_0 => llama_rs::ElementType::Q4_0, ElementType::Q4_1 => llama_rs::ElementType::Q4_1, ElementType::F16 => llama_rs::ElementType::F16, diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index 076dd7bc..7ed254a4 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -8,6 +8,7 @@ rust-version = "1.65" [dependencies] ggml = { path = "../ggml" } +ggml-loader = { path = "../ggml-loader" } rand = { workspace = true } @@ -16,6 +17,7 @@ partial_sort = "0.2.0" thiserror = "1.0" serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11" +memmap2 = "0.5.10" # Used for the `convert` feature serde_json = { version = "1.0", optional = true } @@ -23,4 +25,4 @@ protobuf = { version = "= 2.14.0", optional = true } rust_tokenizers = { version = "3.1.2", optional = true } [features] -convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] \ No newline at end of file +convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index fc562d48..67557b8f 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -17,6 +17,7 @@ use std::{ }; use crate::{util, Hyperparameters, Vocabulary}; +use ggml_loader::util::encode_element_type; /// Converts a `pth` file to a `ggml` file. pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) { @@ -28,12 +29,12 @@ pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) { let model_files = util::find_all_model_files(model_directory).unwrap(); for (i, _file) in model_files.iter().enumerate() { - let fname_out = model_directory.join(format!("rust-model-{}.bin", element_type)); + let fname_out = model_directory.join(format!("rust-model-{element_type}.bin")); let mut file = File::create(fname_out).expect("Unable to create file"); write_header(file.borrow_mut(), &hparams).unwrap(); write_tokens(file.borrow_mut(), &vocab).unwrap(); - let _fname_model = model_directory.join(format!("consolidated.0{}.pth", i)); + let _fname_model = model_directory.join(format!("consolidated.0{i}.pth")); // Todo process and write variables } } @@ -82,13 +83,7 @@ fn load_hyperparameters( let json = read_to_string(path.join("params.json")).expect("Unable to read file"); let json: HyperParametersJson = serde_json::from_str(&json).expect("Unable to parse json"); Hyperparameters { - f16_: match element_type { - ggml::Type::F32 => 0, - ggml::Type::F16 => 1, - ggml::Type::Q4_0 => 2, - ggml::Type::Q4_1 => 3, - _ => panic!("unsupported element type"), - }, + element_type, n_ctx: 0, n_embd: json.dim, n_head: json.n_heads, @@ -112,7 +107,7 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String i32::try_from(hparams.n_head).unwrap(), i32::try_from(hparams.n_layer).unwrap(), i32::try_from(hparams.n_embd / hparams.n_head).unwrap(), - i32::try_from(hparams.f16_).unwrap(), + encode_element_type(hparams.element_type).unwrap(), ]; let mut packed_values: Vec = vec![]; diff --git a/llama-rs/src/inference_session.rs b/llama-rs/src/inference_session.rs index 3af27812..428b9a7b 100644 --- a/llama-rs/src/inference_session.rs +++ b/llama-rs/src/inference_session.rs @@ -365,7 +365,7 @@ impl InferenceSession { ctx_size }; - let session_ctx = ggml::Context::init(ctx_size); + let session_ctx = ggml::Context::init(ctx_size, true); // Initialize key + value memory tensors let n_mem = n_layer * n_ctx; @@ -397,7 +397,7 @@ impl InferenceSession { } impl Clone for InferenceSession { fn clone(&self) -> Self { - let context = ggml::Context::init(self.memory_size); + let context = ggml::Context::init(self.memory_size, true); let memory_k = context.new_tensor_1d(self.memory_k.get_type(), self.memory_k.nelements()); let memory_v = context.new_tensor_1d(self.memory_v.get_type(), self.memory_v.nelements()); @@ -473,7 +473,7 @@ impl InferenceSnapshotRef<'_> { } /// A serializable snapshot of the inference process. Can be restored by calling -/// [Model::session_from_snapshot]. +/// [InferenceSession::from_snapshot]. #[derive(serde::Deserialize, Clone, PartialEq)] // Keep in sync with [InferenceSession] and [InferenceSnapshotRef]. pub struct InferenceSnapshot { @@ -493,7 +493,7 @@ pub struct InferenceSnapshot { pub memory_v: Vec, } -#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] /// Parameters for an inference session. pub struct InferenceSessionParameters { /// The number of tokens to consider for the repetition penalty. @@ -550,7 +550,7 @@ impl Display for InferenceStats { } /// Allowed types for the model memory K/V tensors. -#[derive(Clone, Copy, Debug, PartialEq, serde::Serialize, serde::Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] pub enum ModelKVMemoryType { /// 16-bit float. Float16, diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index 51218cb8..3f0a6c69 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -8,6 +8,8 @@ pub mod convert; mod inference_session; mod loader; +mod loader2; +mod loader_common; mod model; mod util; mod vocabulary; @@ -17,7 +19,7 @@ pub use inference_session::{ InferenceSession, InferenceSessionParameters, InferenceSnapshot, ModelKVMemoryType, SnapshotError, }; -pub use loader::{LoadError, LoadProgress}; +pub use loader_common::{LoadError, LoadProgress}; pub use model::{Hyperparameters, Model}; pub use util::TokenUtf8Buffer; pub use vocabulary::{TokenBias, TokenId, Vocabulary}; diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs index f99344dd..9ef545e5 100644 --- a/llama-rs/src/loader.rs +++ b/llama-rs/src/loader.rs @@ -1,140 +1,23 @@ +#![allow(dead_code)] + use std::{ collections::HashMap, io::{BufRead, Read, Seek, SeekFrom}, - path::{Path, PathBuf}, + path::Path, }; -use thiserror::Error; - use crate::{ util::{self, mulf}, - vocabulary::TokenId, - Hyperparameters, Model, Vocabulary, + LoadError, LoadProgress, Model, TokenId, Vocabulary, }; +use crate::{ElementType, Hyperparameters}; +use ggml_loader::util::*; +use ggml_loader::ContainerType; +use memmap2::Mmap; -/// Each variant represents a step within the process of loading the model. -/// These can be used to report progress to the user. -#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] -pub enum LoadProgress<'a> { - /// The hyperparameters have been loaded from the model. - HyperparametersLoaded(&'a Hyperparameters), - /// The context has been created. - ContextSize { - /// The size of the context. - bytes: usize, - }, - /// A part of the model is being loaded. - PartLoading { - /// The path to the model part. - file: &'a Path, - /// The current part (0-indexed). - current_part: usize, - /// The number of total parts. - total_parts: usize, - }, - /// A tensor from the current part has been loaded. - PartTensorLoaded { - /// The path to the model part. - file: &'a Path, - /// The current tensor (0-indexed). - current_tensor: usize, - /// The number of total tensors. - tensor_count: usize, - }, - /// A model part has finished fully loading. - PartLoaded { - /// The path to the model part. - file: &'a Path, - /// The number of bytes in the part. - byte_size: usize, - /// The number of tensors in the part. - tensor_count: usize, - }, -} - -#[derive(Error, Debug)] -/// Errors encountered during the loading process. -pub enum LoadError { - #[error("could not open file {path:?}")] - /// A file failed to open. - OpenFileFailed { - /// The original error. - source: std::io::Error, - /// The path that failed. - path: PathBuf, - }, - #[error("no parent path for {path:?}")] - /// There is no parent path for a given path. - NoParentPath { - /// The path without a parent. - path: PathBuf, - }, - #[error("unable to read exactly {bytes} bytes")] - /// Reading exactly `bytes` from a file failed. - ReadExactFailed { - /// The original error. - source: std::io::Error, - /// The number of bytes that were attempted to be read. - bytes: usize, - }, - #[error("non-specific I/O error")] - /// A non-specific IO error. - IO(#[from] std::io::Error), - #[error("could not convert bytes to a UTF-8 string")] - /// One of the strings encountered was not valid UTF-8. - InvalidUtf8(#[from] std::string::FromUtf8Error), - #[error("invalid integer conversion")] - /// One of the integers encountered could not be converted to a more appropriate type. - InvalidIntegerConversion(#[from] std::num::TryFromIntError), - #[error("invalid magic number for {path:?}")] - /// An invalid magic number was encountered during the loading process. - InvalidMagic { - /// The path that failed. - path: PathBuf, - }, - #[error("invalid file format version {value}")] - /// The version of the format is not supported by this version of `llama-rs`. - InvalidFormatVersion { - /// The version that was encountered. - value: u32, - }, - #[error("invalid value {ftype} for `f16` in hyperparameters")] - /// The `f16` hyperparameter had an invalid value. - HyperparametersF16Invalid { - /// The format type that was encountered. - ftype: u32, - }, - #[error("unknown tensor `{tensor_name}` in {path:?}")] - /// The tensor `tensor_name` was encountered during the loading of `path`, but was not seen during - /// the model prelude. - UnknownTensor { - /// The name of the tensor. - tensor_name: String, - /// The path that failed. - path: PathBuf, - }, - #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] - /// The tensor `tensor_name` did not match its expected size. - TensorWrongSize { - /// The name of the tensor. - tensor_name: String, - /// The path that failed. - path: PathBuf, - }, - /// The tensor `tensor_name` did not have the expected format type. - #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")] - InvalidFtype { - /// The name of the tensor. - tensor_name: String, - /// The format type that was encountered. - ftype: u32, - /// The path that failed. - path: PathBuf, - }, -} - -pub fn load( +pub(crate) fn load( path: impl AsRef, + prefer_mmap: bool, n_context_tokens: usize, mut load_progress_callback: impl FnMut(LoadProgress), ) -> Result { @@ -143,18 +26,17 @@ pub fn load( let main_path = path.as_ref(); - let mut reader = - BufReader::new( - File::open(main_path).map_err(|e| LoadError::OpenFileFailed { - source: e, - path: main_path.to_owned(), - })?, - ); + let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: main_path.to_owned(), + })?; + let mut reader = BufReader::new(&file); // Verify magic - let is_legacy_model: bool = match read_u32(&mut reader)? { - ggml::FILE_MAGIC => false, - ggml::FILE_MAGIC_UNVERSIONED => true, + let model_type: ContainerType = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ContainerType::GGMF, + ggml::FILE_MAGIC_GGJT => ContainerType::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ContainerType::GGML, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), @@ -163,12 +45,14 @@ pub fn load( }; // Load format version - if !is_legacy_model { - #[allow(unused_variables)] - let version: u32 = match read_u32(&mut reader)? { - ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, - version => return Err(LoadError::InvalidFormatVersion { value: version }), - }; + match model_type { + ContainerType::GGMF | ContainerType::GGJT => { + let _version: u32 = match read_u32(&mut reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion { version }), + }; + } + ContainerType::GGML => {} } // ================= @@ -185,7 +69,10 @@ pub fn load( n_head: read_i32(&mut reader)?.try_into()?, n_layer: read_i32(&mut reader)?.try_into()?, n_rot: read_i32(&mut reader)?.try_into()?, - f16_: read_i32(&mut reader)?.try_into()?, + element_type: { + let ftype = read_i32(&mut reader)?; + decode_element_type(ftype).ok_or_else(|| LoadError::UnsupportedElementType(ftype)) + }?, }; let n_ff = @@ -197,76 +84,64 @@ pub fn load( // Load vocabulary // =============== let vocabulary = { - let mut id_to_token = vec![]; - let mut id_to_token_score = vec![]; - let mut token_to_id = HashMap::new(); - let mut max_token_length = 0; + let mut vocab = Vocabulary::default(); for i in 0..hparams.n_vocab { let len = read_i32(&mut reader)?; - let token = read_bytes_with_len(&mut reader, len as usize)?; - max_token_length = max_token_length.max(token.len()); - id_to_token.push(token.clone()); - token_to_id.insert(token, TokenId::try_from(i)?); - - // Token score, currently unused - if !is_legacy_model { - if let Ok(score) = read_f32(&mut reader) { - id_to_token_score.push(score); + let id = i as TokenId; + let token = read_bytes_with_len(&mut reader, len.try_into()?)?; + + let score = match model_type { + ContainerType::GGMF | ContainerType::GGJT => read_f32(&mut reader)?, + ContainerType::GGML => { + // Legacy model, set empty score + 0. } - } else { - // Legacy model, set empty score - id_to_token_score.push(0.); - } - } + }; - Vocabulary { - id_to_token, - id_to_token_score, - token_to_id, - max_token_length, + vocab.push_token(id, token, score); } + + vocab }; // for the big tensors, we have the option to store the data in 16-bit // floats or quantized in order to save memory and also to speed up the // computation - let wtype = match hparams.f16_ { - 0 => ggml::Type::F32, - 1 => ggml::Type::F16, - 2 => ggml::Type::Q4_0, - 3 => ggml::Type::Q4_1, - invalid => return Err(LoadError::HyperparametersF16Invalid { ftype: invalid }), - }; + let wtype = hparams.element_type; let n_embd = hparams.n_embd; let n_layer = hparams.n_layer; let n_vocab = hparams.n_vocab; + let alloc = !(prefer_mmap && model_type.support_mmap()); + let ctx_size = { // Use 64-bit math to prevent overflow. - let mut ctx_size: usize = 0; + let mut ctx_size: usize = (5 + 10 * n_layer) * 256; // object overhead - ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings + if alloc { + let mut model_size: usize = 0; - ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::Type::F32)); // norm + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings + ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::Type::F32)); // norm + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output - ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output + model_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // attention_norm - ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // attention_norm + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wq + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wk + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wv + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wq - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wk - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wv - ctx_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo + model_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // ffn_norm - ctx_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // ffn_norm + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w3 - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 - ctx_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w3 - - ctx_size += (5 + 10 * n_layer) * 256; // object overhead + ctx_size += model_size; + } load_progress_callback(LoadProgress::ContextSize { bytes: ctx_size }); @@ -274,18 +149,66 @@ pub fn load( }; // Initialize the context - let context = ggml::Context::init(ctx_size); + let context = ggml::Context::init(ctx_size, alloc); + + let (mmap, mmap_ptr) = if prefer_mmap && model_type.support_mmap() { + let mmap = unsafe { Mmap::map(&file)? }; + let ptr = mmap.as_ptr(); + (Some(mmap), Some(ptr)) + } else { + (None, None) + }; - let model = Model::new(context, hparams, vocabulary, n_ff, wtype); + let mut model = Model::new(context, hparams, vocabulary, n_ff, wtype, model_type, mmap); + match model_type { + ContainerType::GGMF | ContainerType::GGML => { + let file_offset = reader.stream_position()?; + drop(reader); + load_weights_ggmf_or_unversioned( + file_offset, + main_path, + load_progress_callback, + model.tensors_mut(), + )? + } + ContainerType::GGJT => { + load_weights_ggjt( + &mut reader, + mmap_ptr, + main_path, + load_progress_callback, + model.tensors_mut(), + )?; + } + } + + Ok(model) +} - // Close the file, but keep its offset. That way we know how to skip the - // metadata when loading the parts. - let file_offset = reader.stream_position()?; - drop(reader); +/// Helper function. Reads a string from the buffer and returns it. +pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result { + let mut buf = vec![0; len]; + reader + .read_exact(&mut buf) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: buf.len(), + })?; + let s = String::from_utf8(buf)?; + Ok(s) +} + +fn load_weights_ggmf_or_unversioned( + file_offset: u64, + main_path: &Path, + mut load_progress_callback: impl FnMut(LoadProgress), + tensors: &mut HashMap, +) -> Result<(), LoadError> { + use std::{fs::File, io::BufReader}; let paths = util::find_all_model_files(main_path)?; - let n_parts = paths.len(); + let n_parts = paths.len(); for (i, part_path) in paths.into_iter().enumerate() { let part_id = i; @@ -305,125 +228,23 @@ pub fn load( // Load weights loop { - // NOTE: Implementation from #![feature(buf_read_has_data_left)] - let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; - - if is_eof { + if !has_data_left(&mut part_reader)? { break; } let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; let length = read_i32(&mut part_reader)?; - let ftype = read_u32(&mut part_reader)?; - - let mut nelements = 1; - let mut ne = [1i64, 1i64]; - - #[allow(clippy::needless_range_loop)] - for i in 0..n_dims { - ne[i] = read_i32(&mut part_reader)? as i64; - nelements *= usize::try_from(ne[i])?; - } - - let tensor_name = read_string(&mut part_reader, length as usize)?; - - let Some(tensor) = model.tensors().get(&tensor_name) - else { - return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); - }; - - // split_type = 0: split by columns - // split_type = 1: split by rows - // - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - #[allow(clippy::if_same_then_else)] - let split_type = if tensor_name.contains("tok_embeddings") { - 0 - } else if tensor_name.contains("layers") { - if tensor_name.contains("attention.wo.weight") { - 0 - } else if tensor_name.contains("feed_forward.w2.weight") { - 0 - } else { - 1 - } - } else if tensor_name.contains("output") { - 1 - } else { - 0 - }; - - if n_dims == 1 { - if tensor.nelements() != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.nelements() / n_parts != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if n_dims == 1 { - if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if split_type == 0 { - if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] - || tensor.get_ne()[1] != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.get_ne()[0] != ne[0] - || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - let bpe = match ftype { - 0 => ggml::type_size(ggml::Type::F32), - 1 => ggml::type_size(ggml::Type::F16), - 2 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_0) - } - 3 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_1) - } - _ => { - return Err(LoadError::InvalidFtype { - tensor_name, - ftype, - path: part_path, - }) - } - }; + let ftype = read_i32(&mut part_reader)?; + + let (nelements, ne, tensor_name, tensor, split_type, bpe) = load_tensor_header_ggmf( + n_dims, + &mut part_reader, + length, + tensors, + &part_path, + n_parts, + ftype, + )?; if n_dims == 1 || n_parts == 1 { if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { @@ -500,7 +321,7 @@ pub fn load( load_progress_callback(LoadProgress::PartTensorLoaded { file: &part_path, current_tensor: n_tensors.try_into()?, - tensor_count: model.tensors().len(), + tensor_count: tensors.len(), }); } @@ -510,45 +331,224 @@ pub fn load( tensor_count: n_tensors.try_into()?, }); } - - Ok(model) + Ok(()) } -pub fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) +#[allow(clippy::type_complexity)] +fn load_tensor_header_ggmf<'a>( + n_dims: usize, + reader: &mut impl BufRead, + length: i32, + tensors: &'a mut HashMap, + path: &Path, + n_parts: usize, + ftype: i32, +) -> Result<(usize, [i64; 2], String, &'a mut ggml::Tensor, i32, usize), LoadError> { + let mut nelements = 1; + let mut ne = [1i64, 1i64]; + assert!(n_dims <= ne.len()); + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + ne[i] = read_i32(reader)? as i64; + nelements *= usize::try_from(ne[i])?; + } + let tensor_name = read_string(reader, length as usize)?; + let Some(tensor) = tensors.get_mut(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); + }; + #[allow(clippy::if_same_then_else)] + let split_type = if tensor_name.contains("tok_embeddings") { + 0 + } else if tensor_name.contains("layers") { + if tensor_name.contains("attention.wo.weight") { + 0 + } else if tensor_name.contains("feed_forward.w2.weight") { + 0 + } else { + 1 + } + } else if tensor_name.contains("output") { + 1 + } else { + 0 + }; + if n_dims == 1 { + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if tensor.nelements() / n_parts != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + if n_dims == 1 { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if split_type == 0 { + if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + let bpe = tensor_type_size(ftype, ne); + let bpe = match bpe { + Some(x) => x, + None => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: path.to_owned(), + }); + } + }; + Ok((nelements, ne, tensor_name, tensor, split_type, bpe)) } -pub fn read_bytes_with_len(reader: &mut impl BufRead, len: usize) -> Result, LoadError> { - let mut bytes = vec![0u8; len]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: len, - })?; - Ok(bytes) +fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { + let ftype = decode_element_type(ftype)?; + match ftype { + ElementType::Q4_0 | ElementType::Q4_1 => { + assert_eq!(ne[0] % 64, 0); + } + _ => {} + } + Some(ggml::type_size(ftype)) } -pub fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) -} +fn load_weights_ggjt( + reader: &mut (impl BufRead + Seek), + mmap_base: Option<*const u8>, + path: &Path, + mut load_progress_callback: impl FnMut(LoadProgress), + tensors: &mut HashMap, +) -> Result<(), LoadError> +// where R: std::io::Read +{ + let mut loop_i = 0; + let mut total_loaded_bytes = 0; + load_progress_callback(LoadProgress::PartLoading { + file: path, + current_part: 0, + total_parts: 1, + }); + + loop { + if !has_data_left(reader)? { + break; + } -pub fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) + let n_dims = read_i32(reader)? as usize; + let length = read_i32(reader)?; + let ftype = read_i32(reader)?; + + let mut nelements: usize = 1; + let mut ne = [1i64, 1]; + assert!(n_dims <= ne.len()); + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + let dim = read_i32(reader)? as usize; + ne[i] = dim as i64; + nelements *= dim; + } + let tensor_name = read_string(reader, length as usize)?; + let Some(tensor) = tensors.get_mut(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); + }; + + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + let tensor_ne = tensor.get_ne(); + if tensor_ne[0] != ne[0] || tensor_ne[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + + match tensor_type_size(ftype, ne) { + Some(_) => {} + None => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: path.to_owned(), + }); + } + }; + + if let Some(mmap_base) = mmap_base { + load_tensor_ggjt_mmap(reader, mmap_base, tensor)?; + } else { + load_tensor_ggjt_copy(reader, tensor)?; + } + + total_loaded_bytes += tensor.nbytes() as u64; + + load_progress_callback(LoadProgress::PartTensorLoaded { + file: path, + current_tensor: loop_i, + tensor_count: tensors.len(), + }); + + loop_i += 1; + } + + load_progress_callback(LoadProgress::PartLoaded { + file: path, + byte_size: total_loaded_bytes as usize, + tensor_count: loop_i, + }); + + Ok(()) } -pub fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +fn load_tensor_ggjt_mmap( + reader: &mut (impl BufRead + Seek), + mmap_base: *const u8, + tensor: &mut ggml::Tensor, +) -> Result<(), LoadError> { + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & !31; + unsafe { + let ptr = mmap_base.offset(offset_aligned as isize); + tensor.set_data(ptr as *mut std::ffi::c_void); + } + reader.seek(SeekFrom::Start(offset_aligned + tensor.nbytes() as u64))?; + Ok(()) } -/// Helper function. Reads a string from the buffer and returns it. -pub fn read_string(reader: &mut impl BufRead, len: usize) -> Result { - Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?) +fn load_tensor_ggjt_copy<'a>( + reader: &mut (impl BufRead + Seek), + tensor: &'a mut ggml::Tensor, +) -> Result<(), LoadError> { + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & !31; + reader.seek(SeekFrom::Start(offset_aligned))?; + + let buf: &'a mut [u8] = + unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; + reader.read_exact(buf)?; + + Ok(()) } diff --git a/llama-rs/src/loader2.rs b/llama-rs/src/loader2.rs new file mode 100644 index 00000000..ff84e3b1 --- /dev/null +++ b/llama-rs/src/loader2.rs @@ -0,0 +1,280 @@ +use ggml_loader::util::*; +use ggml_loader::*; +use memmap2::Mmap; + +use std::{ + fs::File, + io::{BufRead, BufReader, Seek}, + ops::ControlFlow, + path::{Path, PathBuf}, +}; + +use crate::{ + util::{self, mulf}, + Hyperparameters, LoadError, LoadProgress, Model, TokenId, Vocabulary, +}; + +impl LoadError { + fn from_ggml_loader_error(value: ggml_loader::LoadError, path: PathBuf) -> Self { + match value { + ggml_loader::LoadError::InvalidMagic(_magic) => LoadError::InvalidMagic { path }, + ggml_loader::LoadError::InvalidFormatVersion(version) => { + LoadError::InvalidFormatVersion { version } + } + ggml_loader::LoadError::Io(err) => LoadError::Io(err), + ggml_loader::LoadError::FailedCast(err) => LoadError::InvalidIntegerConversion(err), + ggml_loader::LoadError::UserInterrupted(err) => err, + ggml_loader::LoadError::UnsupportedElementType(ty) => { + LoadError::HyperparametersF16Invalid { ftype: ty } + } + ggml_loader::LoadError::InvariantBroken(invariant) => { + LoadError::InvariantBroken { path, invariant } + } + } + } +} + +pub(crate) fn load( + path: impl AsRef, + prefer_mmap: bool, + n_context_tokens: usize, + load_progress_callback: impl FnMut(LoadProgress), +) -> Result { + let main_path = path.as_ref(); + + let paths = util::find_all_model_files(main_path)?; + if paths.len() != 1 { + return Err(LoadError::MultipartNotSupported { paths }); + } + + let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: main_path.to_owned(), + })?; + let mut reader = BufReader::new(&file); + + let path = path.as_ref().to_owned(); + let mut loader = Loader { + path: path.clone(), + vocab: Default::default(), + model: None, + n_ctx: n_context_tokens, + load_progress_callback, + prefer_mmap, + + tensor_accumulator: 0, + hyperparameters: Hyperparameters::default(), + container_type: ContainerType::GGJT, + }; + + ggml_loader::load_model_from_reader(&mut reader, &mut loader) + .map_err(|err| LoadError::from_ggml_loader_error(err, path.clone()))?; + + loader.model.ok_or(LoadError::ModelNotCreated { path }) +} + +struct Loader { + // input data and options + path: PathBuf, + n_ctx: usize, + prefer_mmap: bool, + + // Internal state + tensor_accumulator: usize, + container_type: ContainerType, + hyperparameters: Hyperparameters, + model: Option, + vocab: Vocabulary, + load_progress_callback: F, +} + +impl ggml_loader::LoadHandler> for Loader { + fn load_hyper_parameters( + &mut self, + reader: &mut BufReader<&File>, + ) -> ControlFlow { + let (hyperparameters, partial) = match load_hyperparameters(reader, self.n_ctx) { + Ok(t) => t, + Err(err) => { + return ControlFlow::Break(LoadError::from_ggml_loader_error( + err, + self.path.clone(), + )) + } + }; + self.hyperparameters = hyperparameters; + (self.load_progress_callback)(LoadProgress::HyperparametersLoaded(&self.hyperparameters)); + + ControlFlow::Continue(partial) + } + + fn got_container_type(&mut self, t: ContainerType) -> ControlFlow { + self.container_type = t; + ControlFlow::Continue(()) + } + + fn got_vocab_token(&mut self, i: usize, token: Vec, score: f32) -> ControlFlow { + let id = match TokenId::try_from(i) { + Ok(id) => id, + Err(err) => return ControlFlow::Break(LoadError::InvalidIntegerConversion(err)), + }; + self.vocab.push_token(id, token, score); + + ControlFlow::Continue(()) + } + + fn tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow { + let model = match &mut self.model { + Some(model) => model, + None => { + let model = result_to_controlflow(self.create_model(self.vocab.clone()))?; + self.model.insert(model) + } + }; + + let tensor_name = match String::from_utf8(info.name) { + Ok(n) => n, + Err(err) => return ControlFlow::Break(LoadError::InvalidUtf8(err)), + }; + + let tensor_count = model.tensors_mut().len(); + + // to satisfy borrow checker + macro_rules! get_tensor { + () => { + match model.tensors_mut().get_mut(&tensor_name) { + Some(tensor) => tensor, + None => { + return ControlFlow::Break(LoadError::UnknownTensor { + path: self.path.clone(), + tensor_name, + }) + } + } + }; + } + + let ret = match &model.mmap { + Some(map) => unsafe { + let ptr = map.as_ptr().offset(info.start_offset as isize); + let tensor = get_tensor!(); + tensor.set_data(ptr as *mut std::ffi::c_void); + TensorDataTreatment::SeekPast { + n_bytes: tensor.nbytes(), + } + }, + None => { + let tensor = get_tensor!(); + let buf: &mut [u8] = unsafe { + std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) + }; + TensorDataTreatment::CopyInto(buf) + } + }; + (self.load_progress_callback)(LoadProgress::PartTensorLoaded { + file: &self.path, + current_tensor: self.tensor_accumulator, + tensor_count, + }); + self.tensor_accumulator += 1; + + ControlFlow::Continue(ret) + } +} + +impl Loader { + fn create_model(&mut self, vocabulary: Vocabulary) -> Result { + (self.load_progress_callback)(LoadProgress::PartLoading { + file: &self.path, + current_part: 0, + total_parts: 1, + }); + let alloc = !(self.use_mmap()); + let Hyperparameters { + n_vocab, + n_embd, + n_mult, + n_layer, + element_type, + .. + } = self.hyperparameters; + let n_ff = ((2 * (4 * n_embd) / 3 + n_mult - 1) / n_mult) * n_mult; + let wtype = element_type; + let ctx_size = { + // Use 64-bit math to prevent overflow. + let mut ctx_size: usize = (5 + 10 * n_layer) * 256; // object overhead + + if alloc { + let mut model_size: usize = 0; + + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // tok_embeddings + ctx_size += mulf!(n_embd, ggml::type_sizef(ggml::Type::F32)); // norm + ctx_size += mulf!(n_embd, n_vocab, ggml::type_sizef(wtype)); // output + + model_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // attention_norm + + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wq + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wk + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wv + model_size += mulf!(n_layer, n_embd, n_embd, ggml::type_sizef(wtype)); // wo + + model_size += mulf!(n_layer, n_embd, ggml::type_sizef(ggml::Type::F32)); // ffn_norm + + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w1 + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w2 + model_size += mulf!(n_layer, n_ff, n_embd, ggml::type_sizef(wtype)); // w3 + + ctx_size += model_size; + } + + (self.load_progress_callback)(LoadProgress::ContextSize { bytes: ctx_size }); + + ctx_size + }; + // Initialize the context + let context = ggml::Context::init(ctx_size, alloc); + + let mmap = if self.use_mmap() { + let file = File::open(&self.path)?; + Some(unsafe { Mmap::map(&file)? }) + } else { + None + }; + + Ok(Model::new( + context, + self.hyperparameters, + vocabulary, + n_ff, + wtype, + self.container_type, + mmap, + )) + } + + fn use_mmap(&mut self) -> bool { + self.prefer_mmap && self.container_type.support_mmap() + } +} + +/// use this to load params for llama model inside [`LoadHandler::load_hyper_parameters`] +fn load_hyperparameters( + reader: &mut R, + n_ctx: usize, +) -> Result<(Hyperparameters, PartialHyperparameters), ggml_loader::LoadError> { + // NOTE: Field order matters! Data is laid out in the file exactly in this order. + let hparams = Hyperparameters { + n_vocab: read_i32(reader)?.try_into()?, + n_embd: read_i32(reader)?.try_into()?, + n_mult: read_i32(reader)?.try_into()?, + n_head: read_i32(reader)?.try_into()?, + n_layer: read_i32(reader)?.try_into()?, + n_rot: read_i32(reader)?.try_into()?, + element_type: decode_element_type_res(read_i32(reader)?)?, + n_ctx, + }; + let partial = PartialHyperparameters { + n_vocab: hparams.n_vocab, + }; + Ok((hparams, partial)) +} diff --git a/llama-rs/src/loader_common.rs b/llama-rs/src/loader_common.rs new file mode 100644 index 00000000..4a219642 --- /dev/null +++ b/llama-rs/src/loader_common.rs @@ -0,0 +1,165 @@ +use std::path::{Path, PathBuf}; + +use thiserror::Error; + +use crate::{util::FindAllModelFilesError, Hyperparameters}; + +/// Each variant represents a step within the process of loading the model. +/// These can be used to report progress to the user. +#[derive(Clone, PartialEq, Eq, Debug)] +pub enum LoadProgress<'a> { + /// The hyperparameters have been loaded from the model. + HyperparametersLoaded(&'a Hyperparameters), + /// The context has been created. + ContextSize { + /// The size of the context. + bytes: usize, + }, + /// A part of the model is being loaded. + PartLoading { + /// The path to the model part. + file: &'a Path, + /// The current part (0-indexed). + current_part: usize, + /// The number of total parts. + total_parts: usize, + }, + /// A tensor from the current part has been loaded. + PartTensorLoaded { + /// The path to the model part. + file: &'a Path, + /// The current tensor (0-indexed). + current_tensor: usize, + /// The number of total tensors. + tensor_count: usize, + }, + /// A model part has finished fully loading. + PartLoaded { + /// The path to the model part. + file: &'a Path, + /// The number of bytes in the part. + byte_size: usize, + /// The number of tensors in the part. + tensor_count: usize, + }, +} + +#[derive(Error, Debug)] +/// Errors encountered during the loading process. +pub enum LoadError { + #[error("could not open file {path:?}")] + /// A file failed to open. + OpenFileFailed { + /// The original error. + source: std::io::Error, + /// The path that failed. + path: PathBuf, + }, + #[error("no parent path for {path:?}")] + /// There is no parent path for a given path. + NoParentPath { + /// The path without a parent. + path: PathBuf, + }, + #[error("unable to read exactly {bytes} bytes")] + /// Reading exactly `bytes` from a file failed. + ReadExactFailed { + /// The original error. + source: std::io::Error, + /// The number of bytes that were attempted to be read. + bytes: usize, + }, + #[error("non-specific I/O error")] + /// A non-specific IO error. + Io(#[from] std::io::Error), + #[error("could not convert bytes to a UTF-8 string")] + /// One of the strings encountered was not valid UTF-8. + InvalidUtf8(#[from] std::string::FromUtf8Error), + #[error("invalid integer conversion")] + /// One of the integers encountered could not be converted to a more appropriate type. + InvalidIntegerConversion(#[from] std::num::TryFromIntError), + #[error("unsupported f16_: {0}")] + /// One of the integers encountered could not be converted to a more appropriate type. + UnsupportedElementType(i32), + #[error("invalid magic number for {path:?}")] + /// An invalid magic number was encountered during the loading process. + InvalidMagic { + /// The path that failed. + path: PathBuf, + }, + #[error("invalid file format version {version}")] + /// The version of the format is not supported by this version of `llama-rs`. + InvalidFormatVersion { + /// The version that was encountered. + version: u32, + }, + #[error("invalid value {ftype} for `f16` in hyperparameters")] + /// The `f16` hyperparameter had an invalid value. + HyperparametersF16Invalid { + /// The format type that was encountered. + ftype: i32, + }, + #[error("unknown tensor `{tensor_name}` in {path:?}")] + /// The tensor `tensor_name` was encountered during the loading of `path`, but was not seen during + /// the model prelude. + UnknownTensor { + /// The name of the tensor. + tensor_name: String, + /// The path that failed. + path: PathBuf, + }, + #[error("the tensor `{tensor_name}` has the wrong size in {path:?}")] + /// The tensor `tensor_name` did not match its expected size. + TensorWrongSize { + /// The name of the tensor. + tensor_name: String, + /// The path that failed. + path: PathBuf, + }, + /// The tensor `tensor_name` did not have the expected format type. + #[error("invalid ftype {ftype} for tensor `{tensor_name}` in {path:?}")] + InvalidFtype { + /// The name of the tensor. + tensor_name: String, + /// The format type that was encountered. + ftype: i32, + /// The path that failed. + path: PathBuf, + }, + /// An invariant was broken. + /// + /// This error is not relevant unless `loader2` is being used. + #[error("invariant broken: {invariant} in {path:?}")] + InvariantBroken { + /// The path that failed. + path: PathBuf, + /// The invariant that was broken. + invariant: String, + }, + /// The model could not be created. + /// + /// This implies that there were no tensors in the model to be loaded. + /// + /// This error is not relevant unless `loader2` is being used. + #[error("could not create model from {path:?}")] + ModelNotCreated { + /// The path that failed. + path: PathBuf, + }, + /// Multiple parts of the model were found. + /// + /// Multi-part models are not supported. Please convert the model to a single part. + #[error("multipart models are not supported")] + MultipartNotSupported { + /// The paths that were found. + paths: Vec, + }, +} +impl From for LoadError { + fn from(value: FindAllModelFilesError) -> Self { + match value { + FindAllModelFilesError::NoParentPath { path } => LoadError::NoParentPath { path }, + FindAllModelFilesError::IO(err) => LoadError::Io(err), + } + } +} diff --git a/llama-rs/src/model.rs b/llama-rs/src/model.rs index 370e62df..6cd64dc1 100644 --- a/llama-rs/src/model.rs +++ b/llama-rs/src/model.rs @@ -1,11 +1,12 @@ use std::{collections::HashMap, path::Path}; -use serde::Deserialize; - use crate::{ - loader, vocabulary::TokenId, EvaluateOutputRequest, InferenceParameters, InferenceSession, - InferenceSessionParameters, LoadError, LoadProgress, Vocabulary, + loader, loader2, vocabulary::TokenId, EvaluateOutputRequest, InferenceParameters, + InferenceSession, InferenceSessionParameters, LoadError, LoadProgress, Vocabulary, }; +use memmap2::Mmap; + +use ggml_loader::ContainerType; /// The weights for the LLaMA model. All the mutable state is split into a /// separate struct `InferenceSession`. @@ -23,6 +24,11 @@ pub struct Model { tensors: HashMap, + /// Needs to kept alive while the model is alive + pub(crate) mmap: Option, + + _version: ContainerType, + // Must be kept alive for the model _context: ggml::Context, } @@ -33,6 +39,8 @@ impl Model { vocabulary: Vocabulary, n_ff: usize, wtype: ggml::Type, + container_type: ContainerType, + mmap: Option, ) -> Model { let n_embd = hparams.n_embd; let n_layer = hparams.n_layer; @@ -102,6 +110,8 @@ impl Model { layers, tensors, _context: context, + mmap, + _version: container_type, } } @@ -110,10 +120,25 @@ impl Model { /// The status of the loading process will be reported through `load_progress_callback`. pub fn load( path: impl AsRef, + prefer_mmap: bool, n_context_tokens: usize, load_progress_callback: impl FnMut(LoadProgress), ) -> Result { - loader::load(path, n_context_tokens, load_progress_callback) + // Loader2 is the default. It can support GGML, GGMF and GGJT, but does not support multipart models. + // + // Loader1 is the old loader. It can support multipart models, but will be deprecated. + let use_loader_2: bool = match std::env::var("GGML_LOADER").as_deref() { + Ok("2") => true, + Ok("1") => false, + Ok(_) => panic!("Please use GGML_LOADER=1 or GGML_LOADER=2"), + Err(_) => true, + }; + + if use_loader_2 { + loader2::load(path, prefer_mmap, n_context_tokens, load_progress_callback) + } else { + loader::load(path, prefer_mmap, n_context_tokens, load_progress_callback) + } } /// Starts a new `InferenceSession` for this model. @@ -155,7 +180,7 @@ impl Model { n_head, n_layer, n_rot, - f16_: _, + element_type: _, } = self.hparams; // For the first run, we need to guess a maximum buffer size so we can measure @@ -176,11 +201,11 @@ impl Model { // add 10% to account for ggml object overhead buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize; }; - let ctx0 = ggml::Context::init(buf_size); + let ctx0 = ggml::Context::init(buf_size, true); let mut gf = ggml::ComputationGraph::new(n_threads); - let embd = ctx0.new_tensor_1d(ggml::Type::I32, n); + let mut embd = ctx0.new_tensor_1d(ggml::Type::I32, n); unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) }; let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd); @@ -425,13 +450,13 @@ impl Model { &self.vocabulary } - pub(crate) fn tensors(&self) -> &HashMap { - &self.tensors + pub(crate) fn tensors_mut(&mut self) -> &mut HashMap { + &mut self.tensors } } /// The hyperparameters of the model. -#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Deserialize)] +#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)] pub struct Hyperparameters { /// n_vocab pub n_vocab: usize, @@ -447,8 +472,8 @@ pub struct Hyperparameters { pub n_layer: usize, /// n_rot pub n_rot: usize, - /// f16_ - pub f16_: u32, + /// element_type + pub element_type: crate::ElementType, } struct Layer { diff --git a/llama-rs/src/util.rs b/llama-rs/src/util.rs index 3eb8f06d..4ada4d22 100644 --- a/llama-rs/src/util.rs +++ b/llama-rs/src/util.rs @@ -1,7 +1,5 @@ use std::path::{Path, PathBuf}; -use crate::LoadError; - /// NOTE: The original code relies in promotion rules and automatic cast between /// int to float. What we do instead is use this macro to convert every term of /// the multiplication to f64, which should have enough precision bits to hold @@ -16,12 +14,13 @@ macro_rules! mulf { } pub(crate) use mulf; +use thiserror::Error; /// Used to buffer incoming tokens until they produce a valid string of UTF-8 text. /// /// Tokens are *not* valid UTF-8 by themselves. However, the LLM will produce valid UTF-8 /// from multiple tokens. This helps alleviate that issue. -#[derive(Clone, PartialEq, Default)] +#[derive(Clone, PartialEq, Eq, Default)] pub struct TokenUtf8Buffer(Vec); impl TokenUtf8Buffer { /// Create a new buffer. @@ -69,11 +68,29 @@ impl TokenUtf8Buffer { } } -pub(crate) fn find_all_model_files(main_path: &Path) -> Result, LoadError> { +#[derive(Error, Debug)] +/// Errors encountered during the loading process. +pub enum FindAllModelFilesError { + #[error("no parent path for {path:?}")] + /// There is no parent path for a given path. + NoParentPath { + /// The path without a parent. + path: PathBuf, + }, + #[error("non-specific I/O error")] + /// A non-specific IO error. + IO(#[from] std::io::Error), +} + +pub(crate) fn find_all_model_files( + main_path: &Path, +) -> Result, FindAllModelFilesError> { Ok(collect_related_paths( main_path, - std::fs::read_dir(main_path.parent().ok_or_else(|| LoadError::NoParentPath { - path: main_path.to_owned(), + std::fs::read_dir(main_path.parent().ok_or_else(|| { + FindAllModelFilesError::NoParentPath { + path: main_path.to_owned(), + } })?)? .filter_map(Result::ok) .map(|de| de.path()), diff --git a/llama-rs/src/vocabulary.rs b/llama-rs/src/vocabulary.rs index 80e619c7..32bdd07f 100644 --- a/llama-rs/src/vocabulary.rs +++ b/llama-rs/src/vocabulary.rs @@ -8,7 +8,7 @@ pub(crate) type Token = Vec; pub(crate) type TokenScore = f32; /// The vocabulary used by a model. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] pub struct Vocabulary { /// Maps every integer (index) token id to its corresponding token pub(crate) id_to_token: Vec, @@ -16,13 +16,37 @@ pub struct Vocabulary { /// Maps every integer (index) token id to corresponding score pub(crate) id_to_token_score: Vec, + // todo: use a radix tree /// Maps a token to a token id pub(crate) token_to_id: HashMap, /// The longest token in this vocabulary pub(crate) max_token_length: usize, } + impl Vocabulary { + /// Add a token to the vocabulary. + /// + /// The token added must have `id` directly after the last token in the vocabulary. + /// + /// # Panics + /// - This function can panic if `id` does not correspond to the next token in the vocabulary. + /// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`. + pub(crate) fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { + // These are loader invariants. If this is broken, then the loader is broken and this is a bug, + // not an issue with the model itself. + assert_eq!(self.id_to_token.len(), self.id_to_token_score.len()); + if self.id_to_token.len() != id as usize || self.id_to_token_score.len() != id as usize { + let expected_id = self.id_to_token.len() as TokenId; + panic!("the id of token added should be {expected_id}; is {id}"); + } + + self.max_token_length = self.max_token_length.max(content.len()); + self.id_to_token.push(content.clone()); + self.id_to_token_score.push(score); + self.token_to_id.insert(content, id); + } + pub(crate) fn token(&self, idx: usize) -> &[u8] { &self.id_to_token[idx] }