diff --git a/Cargo.lock b/Cargo.lock index 9ac1053c..38bd0498 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,6 +17,46 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstream" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "342258dd14006105c2b75ab1bd7543a03bdf0cfc94383303ac212a04939dff6f" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-wincon", + "concolor-override", + "concolor-query", + "is-terminal", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23ea9e81bd02e310c216d080f6223c179012256e5151c41db88d12c88a1684d2" + +[[package]] +name = "anstyle-parse" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7d1bb534e9efed14f3e5f44e7dd1a4f709384023a4165199a4241e18dff0116" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-wincon" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3127af6145b149f3287bb9a0d10ad9c5692dba8c53ad48285e5bec4063834fa" +dependencies = [ + "anstyle", + "windows-sys 0.45.0", +] + [[package]] name = "autocfg" version = "1.1.0" @@ -103,40 +143,45 @@ dependencies = [ [[package]] name = "clap" -version = "4.1.8" +version = "4.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3d7ae14b20b94cb02149ed21a86c423859cbe18dc7ed69845cace50e52b40a5" +checksum = "046ae530c528f252094e4a77886ee1374437744b2bff1497aa898bbddbbb29b3" dependencies = [ - "bitflags", + "clap_builder", "clap_derive", - "clap_lex", - "is-terminal", "once_cell", +] + +[[package]] +name = "clap_builder" +version = "4.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "223163f58c9a40c3b0a43e1c4b50a9ce09f007ea2cb1ec258a687945b4b7929f" +dependencies = [ + "anstream", + "anstyle", + "bitflags", + "clap_lex", "strsim", - "termcolor", ] [[package]] name = "clap_derive" -version = "4.1.8" +version = "4.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44bec8e5c9d09e439c4335b1af0abaab56dcf3b94999a936e1bb47b9134288f0" +checksum = "3f9644cd56d6b87dbe899ef8b053e331c0637664e9e21a33dfcdc36093f5c5c4" dependencies = [ "heck", - "proc-macro-error", "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.13", ] [[package]] name = "clap_lex" -version = "0.3.2" +version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "350b9cf31731f9957399229e9b2adc51eeabdfbe9d71d9a0552275fd12710d09" -dependencies = [ - "os_str_bytes", -] +checksum = "8a2dd5a6fe8c6e3502f568a6353e5273bbb15193ad9a89e457b9970798efbea1" [[package]] name = "clipboard-win" @@ -149,6 +194,21 @@ dependencies = [ "winapi", ] +[[package]] +name = "concolor-override" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a855d4a1978dc52fb0536a04d384c2c0c1aa273597f08b77c8c4d3b2eec6037f" + +[[package]] +name = "concolor-query" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d11d52c3d7ca2e6d0040212be9e4dbbcd78b6447f535b6b561f449427944cf" +dependencies = [ + "windows-sys 0.45.0", +] + [[package]] name = "crossbeam-channel" version = "0.5.7" @@ -261,13 +321,13 @@ dependencies = [ [[package]] name = "errno" -version = "0.2.8" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1" +checksum = "50d6a0976c999d473fe89ad888d5a284e55366d9dc9038b1ba2aa15128c4afa0" dependencies = [ "errno-dragonfly", "libc", - "winapi", + "windows-sys 0.45.0", ] [[package]] @@ -292,13 +352,13 @@ dependencies = [ [[package]] name = "fd-lock" -version = "3.0.10" +version = "3.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ef1a30ae415c3a691a4f41afddc2dbcd6d70baf338368d85ebc1e8ed92cedb9" +checksum = "39ae6b3d9530211fb3b12a95374b8b0823be812f53d09e18c5675c0146b09642" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -310,9 +370,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.8" +version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c05aeb6a22b8f62540c194aac980f2115af067bfe15a0734d7277a768d396b31" +checksum = "c85e1d9ab2eadba7e5040d4e09cbd6d072b76a557ad64e797c2cb9d4da21d7e4" dependencies = [ "cfg-if", "libc", @@ -378,24 +438,25 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "io-lifetimes" -version = "1.0.6" +version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfa919a82ea574332e2de6e74b4c36e74d41982b335080fa59d4ef31be20fdf3" +checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" dependencies = [ + "hermit-abi 0.3.1", "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] name = "is-terminal" -version = "0.4.4" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b6b32576413a8e69b90e952e4a026476040d81017b80445deda5f2d3921857" +checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" dependencies = [ "hermit-abi 0.3.1", "io-lifetimes", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -436,9 +497,9 @@ checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" [[package]] name = "libc" -version = "0.2.140" +version = "0.2.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99227334921fae1a979cf0bfdfcc6b3e5ce376ef57e16fb6fb3ea2ed6095f80c" +checksum = "3304a64d199bb964be99741b7a14d26972741915b3649639149b2479bb46f4b5" [[package]] name = "libloading" @@ -452,9 +513,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.1.4" +version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" +checksum = "d59d8c75012853d2e872fb56bc8a2e53718e2cafe1a4c823143141c6d90c322f" [[package]] name = "llama-cli" @@ -479,6 +540,7 @@ dependencies = [ "bincode", "bytemuck", "ggml", + "memmap2", "partial_sort", "protobuf", "rand", @@ -510,6 +572,15 @@ version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +[[package]] +name = "memmap2" +version = "0.5.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83faa42c0a078c393f6b29d5db232d8be22776a891f8f56e5284faee4a20b327" +dependencies = [ + "libc", +] + [[package]] name = "memoffset" version = "0.8.0" @@ -572,12 +643,6 @@ version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" -[[package]] -name = "os_str_bytes" -version = "6.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee" - [[package]] name = "partial_sort" version = "0.2.0" @@ -602,35 +667,11 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" -[[package]] -name = "proc-macro-error" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" -dependencies = [ - "proc-macro-error-attr", - "proc-macro2", - "quote", - "syn 1.0.109", - "version_check", -] - -[[package]] -name = "proc-macro-error-attr" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" -dependencies = [ - "proc-macro2", - "quote", - "version_check", -] - [[package]] name = "proc-macro2" -version = "1.0.52" +version = "1.0.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d0e1ae9e836cc3beddd63db0df682593d7e2d3d891ae8c9083d2113e1744224" +checksum = "2b63bdb0cd06f1f4dedf69b254734f9b45af66e4a031e42a7480257d9898b435" dependencies = [ "unicode-ident", ] @@ -643,9 +684,9 @@ checksum = "8e86d370532557ae7573551a1ec8235a0f8d6cb276c7c9e6aa490b511c447485" [[package]] name = "quote" -version = "1.0.25" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5308e8208729c3e1504a6cfad0d5daacc4614c9a2e65d1ea312a34b5cb00fe84" +checksum = "4424af4bf778aae2051a77b60283332f386554255d722233d09fbfc7e30da2fc" dependencies = [ "proc-macro2", ] @@ -734,9 +775,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.7.1" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733" +checksum = "8b1f693b24f6ac912f4893ef08244d70b6067480d2f1a46e950c9691e6749d1d" dependencies = [ "aho-corasick", "memchr", @@ -745,9 +786,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.6.28" +version = "0.6.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" [[package]] name = "rust_tokenizers" @@ -776,16 +817,16 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.36.9" +version = "0.37.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fd5c6ff11fecd55b40746d1995a02f2eb375bf8c00d192d521ee09f42bef37bc" +checksum = "1aef160324be24d31a62147fae491c14d2204a3865c7ca8c3b0d7f7bcb3ea635" dependencies = [ "bitflags", "errno", "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -831,9 +872,9 @@ checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" [[package]] name = "serde" -version = "1.0.158" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "771d4d9c4163ee138805e12c710dd365e4f44be8be0503cb1bb9eb989425d9c9" +checksum = "3c04e8343c3daeec41f58990b9d77068df31209f2af111e059e9fe9646693065" dependencies = [ "serde_derive", ] @@ -849,20 +890,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.158" +version = "1.0.159" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e801c1712f48475582b7696ac71e0ca34ebb30e09338425384269d9717c62cad" +checksum = "4c614d17805b093df4b147b51339e7e44bf05ef59fba1e45d83500bcfb4d8585" dependencies = [ "proc-macro2", "quote", - "syn 2.0.10", + "syn 2.0.13", ] [[package]] name = "serde_json" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c533a59c9d8a93a09c6ab31f0fd5e5f4dd1b8fc9434804029839884765d04ea" +checksum = "d721eca97ac802aa7777b701877c8004d950fc142651367300d21c1cc0194744" dependencies = [ "itoa", "ryu", @@ -945,9 +986,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.10" +version = "2.0.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5aad1363ed6d37b84299588d62d3a7d95b5a5c2d9aad5c85609fda12afaa1f40" +checksum = "4c9da457c5285ac1f936ebd076af6dac17a61cfe7826f2076b4d015cf47bc8ec" dependencies = [ "proc-macro2", "quote", @@ -965,22 +1006,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5ab016db510546d856297882807df8da66a16fb8c4101cb8b30054b0d5b2d9c" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.39" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5420d42e90af0c38c3290abcca25b9b3bdf379fc9f55c528f53a269d9c9a267e" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" dependencies = [ "proc-macro2", "quote", - "syn 1.0.109", + "syn 2.0.13", ] [[package]] @@ -1040,12 +1081,6 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" @@ -1100,7 +1135,16 @@ version = "0.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" dependencies = [ - "windows-targets", + "windows-targets 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.0", ] [[package]] @@ -1109,13 +1153,28 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", ] [[package]] @@ -1124,42 +1183,84 @@ version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + [[package]] name = "windows_i686_gnu" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + [[package]] name = "windows_i686_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + [[package]] name = "zstd" version = "0.12.3+zstd.1.5.2" @@ -1171,9 +1272,9 @@ dependencies = [ [[package]] name = "zstd-safe" -version = "6.0.4+zstd.1.5.4" +version = "6.0.5+zstd.1.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7afb4b54b8910cf5447638cb54bf4e8a65cbedd783af98b98c62ffe91f185543" +checksum = "d56d9e60b4b1758206c238a10165fbcae3ca37b01744e394c463463f6529d23b" dependencies = [ "libc", "zstd-sys", @@ -1181,9 +1282,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.7+zstd.1.5.4" +version = "2.0.8+zstd.1.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94509c3ba2fe55294d752b79842c530ccfab760192521df74a081a78d2b3c7f5" +checksum = "5556e6ee25d32df2586c098bbfa278803692a20d0ab9565e049480d52707ec8c" dependencies = [ "cc", "libc", diff --git a/ggml/src/lib.rs b/ggml/src/lib.rs index 22c9eee8..89a6a32f 100644 --- a/ggml/src/lib.rs +++ b/ggml/src/lib.rs @@ -14,8 +14,10 @@ use std::{ sync::{Arc, Weak}, }; -/// Magic constant for `ggml` files (versioned). -pub const FILE_MAGIC: u32 = 0x67676d66; +/// Magic constant for `ggml` files (versioned, ggmf). +pub const FILE_MAGIC_GGMF: u32 = 0x67676d66; +/// Magic constant for `ggml` files (versioned, ggjt). +pub const FILE_MAGIC_GGJT: u32 = 0x67676a74; /// Magic constant for `ggml` files (unversioned). pub const FILE_MAGIC_UNVERSIONED: u32 = 0x67676d6c; @@ -351,7 +353,7 @@ impl Tensor { /// /// # Safety /// - /// The data must not be mutated while being read from. + /// Only `std::slice::from_raw_parts_mut(tensor.data(), tensor.nbytes())` is safe to mutate. pub unsafe fn data(&self) -> *mut c_void { self.with_alive_ctx(|| { // SAFETY: The with_alive_call guarantees the context is alive @@ -359,6 +361,18 @@ impl Tensor { }) } + // /// Set the tensor's data pointer (useful for mmap-ed data) + // /// + // /// # Safety + // /// + // /// The memory region from `data_ptr` to `data_ptr.offset(tensor.nbytes())` will be read from. + // pub unsafe fn set_data(&self, data_ptr: *mut c_void) { + // self.with_alive_ctx(|| { + // // SAFETY: The with_alive_call guarantees the context is alive + // unsafe { *self.ptr.as_ptr() }.data = data_ptr; + // }) + // } + /// Number of elements in this tensor. pub fn nelements(&self) -> usize { self.with_alive_ctx(|| { diff --git a/llama-rs/Cargo.toml b/llama-rs/Cargo.toml index 0e48cd58..302a1389 100644 --- a/llama-rs/Cargo.toml +++ b/llama-rs/Cargo.toml @@ -16,6 +16,7 @@ rand = { workspace = true } serde = { version = "1.0.156", features = ["derive"] } serde_bytes = "0.11" bincode = "1.3.3" +memmap2 = { version = "0.5.10", optional = true } # Used for the `convert` feature serde_json = { version = "1.0.94", optional = true } @@ -23,4 +24,7 @@ protobuf = { version = "= 2.14.0", optional = true } rust_tokenizers = { version = "3.1.2", optional = true } [features] -convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] \ No newline at end of file +convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"] + +# broken atm, see https://github.com/rustformers/llama-rs/pull/114#issuecomment-1500337463 +mmap = ["dep:memmap2"] diff --git a/llama-rs/src/convert.rs b/llama-rs/src/convert.rs index 285cf3c0..f4f55996 100644 --- a/llama-rs/src/convert.rs +++ b/llama-rs/src/convert.rs @@ -28,12 +28,12 @@ pub fn convert_pth_to_ggml(model_directory: &Path, element_type: ggml::Type) { let model_files = util::find_all_model_files(model_directory).unwrap(); for (i, _file) in model_files.iter().enumerate() { - let fname_out = model_directory.join(format!("rust-model-{}.bin", element_type)); + let fname_out = model_directory.join(format!("rust-model-{element_type}.bin")); let mut file = File::create(fname_out).expect("Unable to create file"); write_header(file.borrow_mut(), &hparams).unwrap(); write_tokens(file.borrow_mut(), &vocab).unwrap(); - let _fname_model = model_directory.join(format!("consolidated.0{}.pth", i)); + let _fname_model = model_directory.join(format!("consolidated.0{i}.pth")); // Todo process and write variables } } diff --git a/llama-rs/src/lib.rs b/llama-rs/src/lib.rs index d5ef2a23..ddffce02 100644 --- a/llama-rs/src/lib.rs +++ b/llama-rs/src/lib.rs @@ -1,6 +1,11 @@ #![deny(missing_docs)] //! LLaMA-rs is a Rust port of the llama.cpp project. This allows running inference for Facebook's LLaMA model on a CPU with good performance using full precision, f16 or 4-bit quantized versions of the model. +#[cfg(feature = "convert")] +pub mod convert; +mod loader; +mod util; + use core::slice; use std::{ collections::HashMap, @@ -11,18 +16,30 @@ use std::{ time, }; -use serde::Deserialize; -use thiserror::Error; - +pub use ggml::Type as ElementType; use partial_sort::PartialSort; use rand::{distributions::WeightedIndex, prelude::Distribution}; +use serde::Deserialize; +use thiserror::Error; -pub use ggml::Type as ElementType; +#[cfg(feature = "mmap")] +use memmap2::Mmap; -#[cfg(feature = "convert")] -pub mod convert; +/// dummy struct +#[cfg(not(feature = "mmap"))] +pub(crate) struct Mmap; -mod util; +/// dummy impl +#[cfg(not(feature = "mmap"))] +impl Mmap { + pub(crate) unsafe fn map(_: &std::fs::File) -> Result { + Ok(Mmap) + } + pub(crate) fn as_ptr(&self) -> *const u8 { + std::ptr::null() + } +} +// map /// The end of text token. pub const EOT_TOKEN_ID: TokenId = 2; // Hardcoded (for now?) @@ -57,6 +74,15 @@ struct Layer { w3: ggml::Tensor, } +/// Model Version +#[derive(Debug, PartialEq, Clone, Copy)] +#[allow(clippy::upper_case_acronyms)] +pub(crate) enum ModelVersion { + GGMF, + GGJT, + Unversioned, +} + /// The weights for the LLaMA model. All the mutable state is split into a /// separate struct `InferenceSession`. pub struct Model { @@ -71,6 +97,10 @@ pub struct Model { tensors: HashMap, + mmap: Option, + + _version: ModelVersion, + // Must be kept alive for the model _context: ggml::Context, } @@ -509,7 +539,7 @@ pub enum LoadError { /// The name of the tensor. tensor_name: String, /// The format type that was encountered. - ftype: u32, + ftype: i32, /// The path that failed. path: PathBuf, }, @@ -586,59 +616,23 @@ impl Model { n_context_tokens: usize, load_progress_callback: impl Fn(LoadProgress), ) -> Result<(Model, Vocabulary), LoadError> { + use loader::*; use std::fs::File; use std::io::BufReader; let main_path = path.as_ref(); - let mut reader = - BufReader::new( - File::open(main_path).map_err(|e| LoadError::OpenFileFailed { - source: e, - path: main_path.to_owned(), - })?, - ); - - fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { - let mut bytes = [0u8; N]; - reader - .read_exact(&mut bytes) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: N, - })?; - Ok(bytes) - } - - fn read_i32(reader: &mut impl BufRead) -> Result { - Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_u32(reader: &mut impl BufRead) -> Result { - Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - fn read_f32(reader: &mut impl BufRead) -> Result { - Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) - } - - /// Helper function. Reads a string from the buffer and returns it. - fn read_string(reader: &mut BufReader, len: usize) -> Result { - let mut buf = vec![0; len]; - reader - .read_exact(&mut buf) - .map_err(|e| LoadError::ReadExactFailed { - source: e, - bytes: buf.len(), - })?; - let s = String::from_utf8(buf)?; - Ok(s) - } + let file = File::open(main_path).map_err(|e| LoadError::OpenFileFailed { + source: e, + path: main_path.to_owned(), + })?; + let mut reader = BufReader::new(&file); // Verify magic - let is_legacy_model: bool = match read_u32(&mut reader)? { - ggml::FILE_MAGIC => false, - ggml::FILE_MAGIC_UNVERSIONED => true, + let model_type: ModelVersion = match read_u32(&mut reader)? { + ggml::FILE_MAGIC_GGMF => ModelVersion::GGMF, + ggml::FILE_MAGIC_GGJT => ModelVersion::GGJT, + ggml::FILE_MAGIC_UNVERSIONED => ModelVersion::Unversioned, _ => { return Err(LoadError::InvalidMagic { path: main_path.to_owned(), @@ -647,12 +641,14 @@ impl Model { }; // Load format version - if !is_legacy_model { - #[allow(unused_variables)] - let version: u32 = match read_u32(&mut reader)? { - ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, - version => return Err(LoadError::InvalidFormatVersion { value: version }), - }; + match model_type { + ModelVersion::GGMF | ModelVersion::GGJT => { + let _version: u32 = match read_u32(&mut reader)? { + ggml::FORMAT_VERSION => ggml::FORMAT_VERSION, + version => return Err(LoadError::InvalidFormatVersion { value: version }), + }; + } + ModelVersion::Unversioned => {} } // ================= @@ -687,24 +683,41 @@ impl Model { let mut max_token_length = 0; for i in 0..hparams.n_vocab { - let len = read_i32(&mut reader)?; - if let Ok(word) = read_string(&mut reader, len as usize) { - max_token_length = max_token_length.max(word.len()); - id_to_token.push(word.clone()); - token_to_id.insert(word, TokenId::try_from(i)?); + let len = match model_type { + // `read_i32` maybe a typo + ModelVersion::GGMF | ModelVersion::Unversioned => { + read_i32(&mut reader)? as usize + } + ModelVersion::GGJT => read_u32(&mut reader)? as usize, + }; + let maybe_word = if len > 0 { + read_string(&mut reader, len) } else { - load_progress_callback(LoadProgress::BadToken { index: i }); - id_to_token.push("�".to_string()); + Ok("".into()) + }; + match maybe_word { + Ok(word) => { + max_token_length = max_token_length.max(word.len()); + id_to_token.push(word.clone()); + token_to_id.insert(word, TokenId::try_from(i)?); + } + Err(_e) => { + load_progress_callback(LoadProgress::BadToken { index: i }); + id_to_token.push("�".to_string()); + } } // Token score, currently unused - if !is_legacy_model { - if let Ok(score) = read_f32(&mut reader) { - id_to_token_score.push(score); + match model_type { + ModelVersion::GGMF | ModelVersion::GGJT => { + if let Ok(score) = read_f32(&mut reader) { + id_to_token_score.push(score); + } + } + ModelVersion::Unversioned => { + // Legacy model, set empty score + id_to_token_score.push(0.); } - } else { - // Legacy model, set empty score - id_to_token_score.push(0.); } } @@ -764,7 +777,7 @@ impl Model { // Initialize the context let context = ggml::Context::init(ctx_size); - let model = { + let mut model = { let mut tensors = HashMap::new(); let tok_embeddings = context.new_tensor_2d(wtype, n_embd, n_vocab); @@ -828,243 +841,28 @@ impl Model { layers, tensors, _context: context, + mmap: None, + _version: model_type, } }; - // Close the file, but keep its offset. That way we know how to skip the - // metadata when loading the parts. - let file_offset = reader.stream_position()?; - drop(reader); - - let paths = util::find_all_model_files(main_path)?; - let n_parts = paths.len(); - - for (i, part_path) in paths.into_iter().enumerate() { - let part_id = i; - - load_progress_callback(LoadProgress::PartLoading { - file: &part_path, - current_part: i, - total_parts: n_parts, - }); - - let mut part_reader = BufReader::new(File::open(&part_path)?); - - // Skip metadata - part_reader.seek(SeekFrom::Start(file_offset))?; - - let mut total_size = 0; - let mut n_tensors = 0; - - // Load weights - loop { - // NOTE: Implementation from #![feature(buf_read_has_data_left)] - let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?; - - if is_eof { - break; - } - - let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; - let length = read_i32(&mut part_reader)?; - let ftype = read_u32(&mut part_reader)?; - - let mut nelements = 1; - let mut ne = [1i64, 1i64]; - - #[allow(clippy::needless_range_loop)] - for i in 0..n_dims { - ne[i] = read_i32(&mut part_reader)? as i64; - nelements *= usize::try_from(ne[i])?; - } - - let tensor_name = read_string(&mut part_reader, length as usize)?; - - let Some(tensor) = model.tensors.get(&tensor_name) - else { - return Err(LoadError::UnknownTensor { tensor_name, path: part_path }); - }; - - // split_type = 0: split by columns - // split_type = 1: split by rows - // - // split_type = 0: - // regex: - // - tok_embeddings.* - // - layers.*.attention.wo.weight - // - layers.*.feed_forward.w2.weight - - // split_type = 1: - // regex: - // - output.* - // - layers.*.attention.wq.weight - // - layers.*.attention.wk.weight - // - layers.*.attention.wv.weight - // - layers.*.feed_forward.w1.weight - // - layers.*.feed_forward.w3.weight - #[allow(clippy::if_same_then_else)] - let split_type = if tensor_name.contains("tok_embeddings") { - 0 - } else if tensor_name.contains("layers") { - if tensor_name.contains("attention.wo.weight") { - 0 - } else if tensor_name.contains("feed_forward.w2.weight") { - 0 - } else { - 1 - } - } else if tensor_name.contains("output") { - 1 - } else { - 0 - }; - - if n_dims == 1 { - if tensor.nelements() != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.nelements() / n_parts != nelements { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if n_dims == 1 { - if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if split_type == 0 { - if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] - || tensor.get_ne()[1] != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - } else if tensor.get_ne()[0] != ne[0] - || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - let bpe = match ftype { - 0 => ggml::type_size(ggml::Type::F32), - 1 => ggml::type_size(ggml::Type::F16), - 2 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_0) - } - 3 => { - assert_eq!(ne[0] % 64, 0); - ggml::type_size(ggml::Type::Q4_1) - } - _ => { - return Err(LoadError::InvalidFtype { - tensor_name, - ftype, - path: part_path, - }) - } - }; - - if n_dims == 1 || n_parts == 1 { - if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if part_id == 0 { - // SAFETY: yolo, same as original code - let slice = unsafe { - let data = tensor.data(); - std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) - }; - part_reader.read_exact(slice)?; - } else { - part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; - } - - total_size += tensor.nbytes(); - } else { - if (nelements * bpe) / ggml::blck_size(tensor.get_type()) - != tensor.nbytes() / n_parts - { - return Err(LoadError::TensorWrongSize { - tensor_name, - path: part_path, - }); - } - - if split_type == 0 { - let np0 = ne[0]; - let row_size = (usize::try_from(tensor.get_ne()[0])? - / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - - assert_eq!(row_size, tensor.get_nb()[1]); - - for i1 in 0..ne[1] { - let offset_row = i1 as usize * row_size; - let offset = offset_row - + ((part_id * np0 as usize) / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset); - let slice = std::slice::from_raw_parts_mut( - ptr as *mut u8, - row_size / n_parts, - ); - part_reader.read_exact(slice)?; - } - } - } else { - let np1 = ne[1]; - let row_size = (usize::try_from(tensor.get_ne()[0])? - / ggml::blck_size(tensor.get_type())) - * ggml::type_size(tensor.get_type()); - - for i1 in 0..ne[1] { - let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; - // SAFETY: yolo, same as original code - unsafe { - let ptr = tensor.data().add(offset_row); - let slice = - std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); - part_reader.read_exact(slice)?; - } - } - } - - total_size += tensor.nbytes() / n_parts; - } - - n_tensors += 1; - load_progress_callback(LoadProgress::PartTensorLoaded { - file: &part_path, - current_tensor: n_tensors.try_into()?, - tensor_count: model.tensors.len(), - }); + match model_type { + ModelVersion::GGMF | ModelVersion::Unversioned => { + let file_offset = reader.stream_position()?; + drop(reader); + load_weights_ggmf_or_unversioned( + file_offset, + main_path, + load_progress_callback, + &model, + )? + } + ModelVersion::GGJT => { + let mmap = unsafe { Mmap::map(&file)? }; + let ptr = mmap.as_ptr(); + model.mmap = Some(mmap); + load_weights_ggjt(&mut reader, ptr, main_path, load_progress_callback, &model)?; } - - load_progress_callback(LoadProgress::PartLoaded { - file: &part_path, - byte_size: total_size, - tensor_count: n_tensors.try_into()?, - }); } Ok((model, vocab)) diff --git a/llama-rs/src/loader.rs b/llama-rs/src/loader.rs new file mode 100644 index 00000000..e05483c8 --- /dev/null +++ b/llama-rs/src/loader.rs @@ -0,0 +1,402 @@ +use crate::*; + +pub(crate) fn read_bytes(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> { + let mut bytes = [0u8; N]; + reader + .read_exact(&mut bytes) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: N, + })?; + Ok(bytes) +} + +pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result { + Ok(i32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result { + Ok(u32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result { + Ok(f32::from_le_bytes(read_bytes::<4>(reader)?)) +} + +/// Helper function. Reads a string from the buffer and returns it. +pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result { + let mut buf = vec![0; len]; + reader + .read_exact(&mut buf) + .map_err(|e| LoadError::ReadExactFailed { + source: e, + bytes: buf.len(), + })?; + let s = String::from_utf8(buf)?; + Ok(s) +} + +// NOTE: Implementation from #![feature(buf_read_has_data_left)] +fn has_data_left(reader: &mut impl BufRead) -> Result { + reader.fill_buf().map(|b| !b.is_empty()) +} + +pub(crate) fn load_weights_ggmf_or_unversioned( + file_offset: u64, + main_path: &Path, + load_progress_callback: impl Fn(LoadProgress), + model: &Model, +) -> Result<(), LoadError> { + use std::{fs::File, io::BufReader}; + + let paths = util::find_all_model_files(main_path)?; + + let n_parts = paths.len(); + for (i, part_path) in paths.into_iter().enumerate() { + let part_id = i; + + load_progress_callback(LoadProgress::PartLoading { + file: &part_path, + current_part: i, + total_parts: n_parts, + }); + + let mut part_reader = BufReader::new(File::open(&part_path)?); + + // Skip metadata + part_reader.seek(SeekFrom::Start(file_offset))?; + + let mut total_size = 0; + let mut n_tensors = 0; + + // Load weights + loop { + if !has_data_left(&mut part_reader)? { + break; + } + + let n_dims = usize::try_from(read_i32(&mut part_reader)?)?; + let length = read_i32(&mut part_reader)?; + let ftype = read_i32(&mut part_reader)?; + + let (nelements, ne, tensor_name, tensor, split_type, bpe) = load_tensor_header_ggmf( + n_dims, + &mut part_reader, + length, + model, + &part_path, + n_parts, + ftype, + )?; + + if n_dims == 1 || n_parts == 1 { + if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if part_id == 0 { + // SAFETY: yolo, same as original code + let slice = unsafe { + let data = tensor.data(); + std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes()) + }; + part_reader.read_exact(slice)?; + } else { + part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?; + } + + total_size += tensor.nbytes(); + } else { + if (nelements * bpe) / ggml::blck_size(tensor.get_type()) + != tensor.nbytes() / n_parts + { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: part_path, + }); + } + + if split_type == 0 { + let np0 = ne[0]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + assert_eq!(row_size, tensor.get_nb()[1]); + + for i1 in 0..ne[1] { + let offset_row = i1 as usize * row_size; + let offset = offset_row + + ((part_id * np0 as usize) / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset); + let slice = + std::slice::from_raw_parts_mut(ptr as *mut u8, row_size / n_parts); + part_reader.read_exact(slice)?; + } + } + } else { + let np1 = ne[1]; + let row_size = (usize::try_from(tensor.get_ne()[0])? + / ggml::blck_size(tensor.get_type())) + * ggml::type_size(tensor.get_type()); + + for i1 in 0..ne[1] { + let offset_row = (i1 as usize + part_id * np1 as usize) * row_size; + // SAFETY: yolo, same as original code + unsafe { + let ptr = tensor.data().add(offset_row); + let slice = std::slice::from_raw_parts_mut(ptr as *mut u8, row_size); + part_reader.read_exact(slice)?; + } + } + } + + total_size += tensor.nbytes() / n_parts; + } + + n_tensors += 1; + load_progress_callback(LoadProgress::PartTensorLoaded { + file: &part_path, + current_tensor: n_tensors.try_into()?, + tensor_count: model.tensors.len(), + }); + } + + load_progress_callback(LoadProgress::PartLoaded { + file: &part_path, + byte_size: total_size, + tensor_count: n_tensors.try_into()?, + }); + } + Ok(()) +} + +#[allow(clippy::type_complexity)] +fn load_tensor_header_ggmf<'a>( + n_dims: usize, + reader: &mut impl BufRead, + length: i32, + model: &'a Model, + path: &Path, + n_parts: usize, + ftype: i32, +) -> Result<(usize, [i64; 2], String, &'a ggml::Tensor, i32, usize), LoadError> { + let mut nelements = 1; + let mut ne = [1i64, 1i64]; + assert!(n_dims <= ne.len()); + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + ne[i] = read_i32(reader)? as i64; + nelements *= usize::try_from(ne[i])?; + } + let tensor_name = read_string(reader, length as usize)?; + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); + }; + #[allow(clippy::if_same_then_else)] + let split_type = if tensor_name.contains("tok_embeddings") { + 0 + } else if tensor_name.contains("layers") { + if tensor_name.contains("attention.wo.weight") { + 0 + } else if tensor_name.contains("feed_forward.w2.weight") { + 0 + } else { + 1 + } + } else if tensor_name.contains("output") { + 1 + } else { + 0 + }; + if n_dims == 1 { + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if tensor.nelements() / n_parts != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + if n_dims == 1 { + if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if split_type == 0 { + if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0] || tensor.get_ne()[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + } else if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + let bpe = tensor_type_size(ftype, ne); + let bpe = match bpe { + Some(x) => x, + None => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: path.to_owned(), + }); + } + }; + Ok((nelements, ne, tensor_name, tensor, split_type, bpe)) +} + +fn tensor_type_size(ftype: i32, ne: [i64; 2]) -> Option { + match ftype { + 0 => Some(ggml::type_size(ggml::Type::F32)), + 1 => Some(ggml::type_size(ggml::Type::F16)), + 2 => { + assert_eq!(ne[0] % 64, 0); + Some(ggml::type_size(ggml::Type::Q4_0)) + } + 3 => { + assert_eq!(ne[0] % 64, 0); + Some(ggml::type_size(ggml::Type::Q4_1)) + } + _ => None, + } +} + +pub(crate) fn load_weights_ggjt( + reader: &mut (impl BufRead + Seek), + mmap_base: *const u8, + path: &Path, + load_progress_callback: impl Fn(LoadProgress), + model: &Model, +) -> Result<(), LoadError> +// where R: std::io::Read +{ + let mut loop_i = 0; + let mut total_loaded_bytes = 0; + load_progress_callback(LoadProgress::PartLoading { + file: path, + current_part: 0, + total_parts: 1, + }); + + loop { + if !has_data_left(reader)? { + break; + } + + let n_dims = read_i32(reader)? as usize; + let length = read_i32(reader)?; + let ftype = read_i32(reader)?; + + let mut nelements: usize = 1; + let mut ne = [1i64, 1]; + assert!(n_dims <= ne.len()); + #[allow(clippy::needless_range_loop)] + for i in 0..n_dims { + let dim = read_i32(reader)? as usize; + ne[i] = dim as i64; + nelements *= dim; + } + let tensor_name = read_string(reader, length as usize)?; + let Some(tensor) = model.tensors.get(&tensor_name) + else { + return Err(LoadError::UnknownTensor { tensor_name, path: path.to_owned() }); + }; + + if tensor.nelements() != nelements { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + let tensor_ne = tensor.get_ne(); + if tensor_ne[0] != ne[0] || tensor_ne[1] != ne[1] { + return Err(LoadError::TensorWrongSize { + tensor_name, + path: path.to_owned(), + }); + } + + match tensor_type_size(ftype, ne) { + Some(_) => {} + None => { + return Err(LoadError::InvalidFtype { + tensor_name, + ftype, + path: path.to_owned(), + }); + } + }; + + load_tensor_ggjt(reader, mmap_base, tensor)?; + + total_loaded_bytes += tensor.nbytes() as u64; + + load_progress_callback(LoadProgress::PartTensorLoaded { + file: path, + current_tensor: loop_i, + tensor_count: model.tensors.len(), + }); + + loop_i += 1; + } + + load_progress_callback(LoadProgress::PartLoaded { + file: path, + byte_size: total_loaded_bytes as usize, + tensor_count: loop_i, + }); + + Ok(()) +} + +#[cfg(feature = "mmap")] +fn load_tensor_ggjt( + reader: &mut (impl BufRead + Seek), + mmap_base: *const u8, + tensor: &ggml::Tensor, +) -> Result<(), LoadError> { + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & !31; + unsafe { + let ptr = mmap_base.offset(offset_aligned as isize); + tensor.set_data(ptr as *mut std::ffi::c_void); + } + reader.seek(SeekFrom::Start(offset_aligned + tensor.nbytes() as u8))?; + Ok(()) +} + +#[cfg(not(feature = "mmap"))] +fn load_tensor_ggjt<'a>( + reader: &mut (impl BufRead + Seek), + mmap_base: *const u8, + tensor: &'a ggml::Tensor, +) -> Result<(), LoadError> { + _ = mmap_base; + let offset_curr = reader.stream_position()?; + let offset_aligned: u64 = (offset_curr + 31) & !31; + reader.seek(SeekFrom::Start(offset_aligned))?; + + let buf: &'a mut [u8] = + unsafe { std::slice::from_raw_parts_mut(tensor.data() as *mut u8, tensor.nbytes()) }; + reader.read_exact(buf)?; + + Ok(()) +}