From 5a423b1010dec771fa7da78c2f9a0ec399d1bcce Mon Sep 17 00:00:00 2001 From: RedBoxing Date: Tue, 23 May 2023 19:12:39 +0200 Subject: [PATCH 01/20] Use HF tokenizer when vocab file is provided --- Cargo.lock | 1461 +++++++++++++++++++++- binaries/llm-cli/src/cli_args.rs | 5 + binaries/llm-cli/src/main.rs | 2 +- crates/llm-base/Cargo.toml | 1 + crates/llm-base/src/inference_session.rs | 8 +- crates/llm-base/src/loader.rs | 42 +- crates/llm-base/src/model/mod.rs | 7 +- crates/llm-base/src/quantize.rs | 2 +- crates/llm-base/src/vocabulary.rs | 130 +- crates/llm/examples/inference.rs | 1 + crates/llm/examples/vicuna-chat.rs | 1 + crates/llm/src/lib.rs | 39 +- crates/models/llama/src/convert.rs | 3 +- 13 files changed, 1614 insertions(+), 88 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4c9d1ec4..35f336d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,12 +17,32 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "aes" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "433cfd6710c9986c576a25ca913c39d66a6474107b406f34f91d4a8923395241" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "ahash" version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8fd72866655d1904d6b0997d0b07ba561047d070fbe29de039031c641b61217" +[[package]] +name = "aho-corasick" +version = "0.7.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc936419f96fa211c1b9166887b38e5e40b19958e5b895be7c1f93adec7071ac" +dependencies = [ + "memchr", +] + [[package]] name = "aho-corasick" version = "1.0.1" @@ -68,7 +88,7 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" dependencies = [ - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -78,7 +98,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" dependencies = [ "anstyle", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -114,11 +134,29 @@ dependencies = [ "cc", "cfg-if", "libc", - "miniz_oxide", + "miniz_oxide 0.6.2", "object", "rustc-demangle", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + +[[package]] +name = "base64" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f1e31e207a6b8fb791a38ea3105e6cb541f55e4d029902d3039a4ad07cc4105" + +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bincode" version = "1.3.3" @@ -157,18 +195,88 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "block-buffer" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +dependencies = [ + "generic-array", +] + +[[package]] +name = "bumpalo" +version = "3.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" + [[package]] name = "bytemuck" version = "1.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + +[[package]] +name = "bytes" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" + [[package]] name = "bytesize" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38fcc2979eff34a4b84e1cf9a1e3da42a7d44b3b690a40cdcb23e3d556cfb2e5" +[[package]] +name = "bzip2" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] + +[[package]] +name = "cached-path" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "097968e38f1319207f057d0f4d76452e4f4f847a5de61c5215379f297fa034f3" +dependencies = [ + "flate2", + "fs2", + "glob", + "indicatif 0.16.2", + "log", + "rand", + "reqwest", + "serde", + "serde_json", + "sha2", + "tar", + "tempfile", + "thiserror", + "zip", +] + [[package]] name = "cc" version = "1.0.79" @@ -193,6 +301,16 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clang-sys" version = "1.6.1" @@ -287,6 +405,59 @@ dependencies = [ "winapi", ] +[[package]] +name = "console" +version = "0.15.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c926e00cc70edefdc64d3a5ff31cc65bb97a3460097762bd23afb4d8145fccf8" +dependencies = [ + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys 0.45.0", +] + +[[package]] +name = "constant_time_eq" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" + +[[package]] +name = "core-foundation" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" + +[[package]] +name = "cpufeatures" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e4c1eaa2012c47becbbad2ab175484c2a84d1185b566fb2cc5b8707343dfe58" +dependencies = [ + "libc", +] + +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam-channel" version = "0.5.8" @@ -336,6 +507,16 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "csv" version = "1.2.1" @@ -357,6 +538,92 @@ dependencies = [ "memchr", ] +[[package]] +name = "darling" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b750cb3417fd1b327431a470f388520309479ab0bf5e323505daf0290cd3850" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "109c1ca6e6b7f82cc233a97004ea8ed7ca123a9af07a8230878fcfda9b158bf0" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn 1.0.109", +] + +[[package]] +name = "darling_macro" +version = "0.14.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4aab4dbc9f7611d8b55048a3a16d2d010c2c8334e46304b40ac1cc14bf3b48e" +dependencies = [ + "darling_core", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8d67778784b508018359cbc8696edb3db78160bab2c2a28ba7f56ef6932997f8" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c11bdc11a0c47bc7d37d582b5285da6849c96681023680b906673c5707af7b0f" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "derive_builder_macro" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebcda35c7a396850a55ffeac740804b40ffec779b98fffbb1738f4033f0ee79e" +dependencies = [ + "derive_builder_core", + "syn 1.0.109", +] + +[[package]] +name = "digest" +version = "0.10.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +dependencies = [ + "block-buffer", + "crypto-common", + "subtle", +] + +[[package]] +name = "dirs" +version = "4.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca3aa72a6f96ea37bbc5aa912f6788242832f75369bdfdadcb0e38423f100059" +dependencies = [ + "dirs-sys", +] + [[package]] name = "dirs-next" version = "2.0.0" @@ -367,6 +634,17 @@ dependencies = [ "dirs-sys-next", ] +[[package]] +name = "dirs-sys" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + [[package]] name = "dirs-sys-next" version = "0.1.2" @@ -384,6 +662,21 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +[[package]] +name = "encode_unicode" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" + +[[package]] +name = "encoding_rs" +version = "0.8.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394" +dependencies = [ + "cfg-if", +] + [[package]] name = "endian-type" version = "0.1.2" @@ -411,7 +704,7 @@ checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" dependencies = [ "errno-dragonfly", "libc", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -434,6 +727,15 @@ dependencies = [ "str-buf", ] +[[package]] +name = "esaxx-rs" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f748b253ceca9fed5f42f8b5ceb3851e93102199bc25b64b65369f76e5c0a35" +dependencies = [ + "cc", +] + [[package]] name = "eyre" version = "0.6.8" @@ -444,6 +746,15 @@ dependencies = [ "once_cell", ] +[[package]] +name = "fastrand" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" +dependencies = [ + "instant", +] + [[package]] name = "fd-lock" version = "3.0.12" @@ -452,7 +763,117 @@ checksum = "39ae6b3d9530211fb3b12a95374b8b0823be812f53d09e18c5675c0146b09642" dependencies = [ "cfg-if", "rustix", - "windows-sys", + "windows-sys 0.48.0", +] + +[[package]] +name = "filetime" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.2.16", + "windows-sys 0.48.0", +] + +[[package]] +name = "flate2" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" +dependencies = [ + "crc32fast", + "miniz_oxide 0.7.1", +] + +[[package]] +name = "fnv" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9c384f161156f5260c24a097c56119f9be8c798586aecc13afbcbe7b7e26bf8" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "fs2" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9564fc758e15025b46aa6643b1b77d047d1a56a1aea6e01002ac0c7026876213" +dependencies = [ + "libc", + "winapi", +] + +[[package]] +name = "futures-channel" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +dependencies = [ + "futures-core", +] + +[[package]] +name = "futures-core" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" + +[[package]] +name = "futures-io" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" + +[[package]] +name = "futures-sink" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" + +[[package]] +name = "futures-task" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" + +[[package]] +name = "futures-util" +version = "0.3.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +dependencies = [ + "futures-core", + "futures-io", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", ] [[package]] @@ -462,6 +883,16 @@ dependencies = [ "bindgen", ] +[[package]] +name = "generic-array" +version = "0.14.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.9" @@ -503,7 +934,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" [[package]] -name = "half" +name = "h2" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d357c7ae988e7d2182f7d7871d0b963962420b0678b0997ce7de72001aeab782" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "half" version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" @@ -521,6 +971,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "heck" version = "0.4.1" @@ -551,18 +1007,166 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" +[[package]] +name = "hmac" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +dependencies = [ + "digest", +] + +[[package]] +name = "http" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd6effc99afb63425aff9b05836f029929e345a6148a14b7ecd5ab67af944482" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + +[[package]] +name = "http-body" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" +dependencies = [ + "bytes", + "http", + "pin-project-lite", +] + +[[package]] +name = "httparse" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" + +[[package]] +name = "httpdate" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421" + [[package]] name = "humantime" version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" +[[package]] +name = "hyper" +version = "0.14.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab302d72a6f11a3b910431ff93aae7e773078c769f0a3ef15fb9ec692ed147d4" +dependencies = [ + "bytes", + "futures-channel", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "socket2", + "tokio", + "tower-service", + "tracing", + "want", +] + +[[package]] +name = "hyper-tls" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6183ddfa99b85da61a140bea0efc93fdf56ceaa041b37d553518030827f9905" +dependencies = [ + "bytes", + "hyper", + "native-tls", + "tokio", + "tokio-native-tls", +] + +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + +[[package]] +name = "idna" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e14ddfc70884202db2244c223200c204c2bda1bc6e0998d11b5e024d657209e6" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + [[package]] name = "indenter" version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + +[[package]] +name = "indicatif" +version = "0.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7baab56125e25686df467fe470785512329883aab42696d661247aca2a2896e4" +dependencies = [ + "console", + "lazy_static", + "number_prefix 0.3.0", + "regex", +] + +[[package]] +name = "indicatif" +version = "0.16.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d207dc617c7a380ab07ff572a6e52fa202a2a8f355860ac9c38e23f8196be1b" +dependencies = [ + "console", + "lazy_static", + "number_prefix 0.4.0", + "regex", +] + +[[package]] +name = "inout" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" +dependencies = [ + "generic-array", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + [[package]] name = "io-lifetimes" version = "1.0.10" @@ -571,9 +1175,15 @@ checksum = "9c66c74d2ae7e79a5a8f7ac924adbe38ee42a859c6539ad869eb51f0b52dc220" dependencies = [ "hermit-abi 0.3.1", "libc", - "windows-sys", + "windows-sys 0.48.0", ] +[[package]] +name = "ipnet" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12b6ee2129af8d4fb011108c73d99a1b83a85977f23b82460c0ae2e25bb4b57f" + [[package]] name = "is-terminal" version = "0.4.7" @@ -583,7 +1193,7 @@ dependencies = [ "hermit-abi 0.3.1", "io-lifetimes", "rustix", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -595,6 +1205,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "284f18f85651fe11e8a991b2adb42cb078325c996ed026d994719efcfca1d54b" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.6" @@ -610,6 +1229,15 @@ dependencies = [ "libc", ] +[[package]] +name = "js-sys" +version = "0.3.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f37a4a5928311ac501dee68b3c7613a1037d0edb30c8e5427bd832d55d1b790" +dependencies = [ + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -677,6 +1305,7 @@ dependencies = [ "serde", "serde_bytes", "thiserror", + "tokenizers", ] [[package]] @@ -702,7 +1331,7 @@ dependencies = [ "rand", "rustyline", "spinoff", - "zstd", + "zstd 0.12.3+zstd.1.5.2", ] [[package]] @@ -761,6 +1390,22 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "macro_rules_attribute" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d" + [[package]] name = "memchr" version = "2.5.0" @@ -785,6 +1430,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -800,6 +1451,66 @@ dependencies = [ "adler", ] +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + +[[package]] +name = "mio" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b9d9a46eff5b4ff64b45a9e316a6d1e0bc719ef429cbec4dc630684212bfdf9" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys 0.45.0", +] + +[[package]] +name = "monostate" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0230b703f1ac35df1e24f6d0d2255472bcccaf657ecdfa4f1fcbcad1ad5bb98a" +dependencies = [ + "monostate-impl", + "serde", +] + +[[package]] +name = "monostate-impl" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8795add3e14028f11f8e848bd3294898a8294767b3776b6f733560d33bd2530b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", +] + +[[package]] +name = "native-tls" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" +dependencies = [ + "lazy_static", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nibble_vec" version = "0.1.0" @@ -841,6 +1552,18 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17b02fc0ff9a9e4b35b3342880f48e896ebf69f2967921fe8646bf5b7125956a" + +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.30.3" @@ -856,6 +1579,72 @@ version = "1.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3" +[[package]] +name = "onig" +version = "6.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f" +dependencies = [ + "bitflags", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7" +dependencies = [ + "cc", + "pkg-config", +] + +[[package]] +name = "openssl" +version = "0.10.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01b8574602df80f7b85fdfc5392fa884a4e3b3f4f35402c070ab34c3d3f78d56" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e17f59264b2809d77ae94f0e1ebabc434773f370d6ca667bd223ea10e06cc7e" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + [[package]] name = "owo-colors" version = "3.5.0" @@ -868,18 +1657,59 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7924d1d0ad836f665c9065e26d016c673ece3993f30d340068b16f282afc1156" +[[package]] +name = "password-hash" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f746c4065a8fa3fe23974dd82f15431cc8d40779821001404d10d2e79ca7d79" +[[package]] +name = "pbkdf2" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" +dependencies = [ + "digest", + "hmac", + "password-hash", + "sha2", +] + [[package]] name = "peeking_take_while" version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" +[[package]] +name = "percent-encoding" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "478c572c3d73181ff3c2539045f6eb99e5491218eae919370993b890cdbdd98e" + +[[package]] +name = "pin-project-lite" +version = "0.2.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0a7ae3ac2f1173085d398531c705756c94a4c56843785df85a60c1a0afac116" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + [[package]] name = "pkg-config" version = "0.3.27" @@ -976,6 +1806,17 @@ dependencies = [ "rayon-core", ] +[[package]] +name = "rayon-cond" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd1259362c9065e5ea39a789ef40b1e3fd934c94beb7b5ab3ac6629d3b5e7cb7" +dependencies = [ + "either", + "itertools 0.8.2", + "rayon", +] + [[package]] name = "rayon-core" version = "1.11.0" @@ -997,6 +1838,15 @@ dependencies = [ "bitflags", ] +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags", +] + [[package]] name = "redox_users" version = "0.4.3" @@ -1004,7 +1854,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" dependencies = [ "getrandom", - "redox_syscall", + "redox_syscall 0.2.16", "thiserror", ] @@ -1014,17 +1864,60 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af83e617f331cc6ae2da5443c602dfa5af81e517212d9d611a5b3ba1777b5370" dependencies = [ - "aho-corasick", + "aho-corasick 1.0.1", "memchr", - "regex-syntax", + "regex-syntax 0.7.1", ] +[[package]] +name = "regex-syntax" +version = "0.6.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" + [[package]] name = "regex-syntax" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5996294f19bd3aae0453a862ad728f60e6600695733dd5df01da90c54363a3c" +[[package]] +name = "reqwest" +version = "0.11.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +dependencies = [ + "base64 0.21.1", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2", + "http", + "http-body", + "hyper", + "hyper-tls", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "serde", + "serde_json", + "serde_urlencoded", + "tokio", + "tokio-native-tls", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", + "winreg", +] + [[package]] name = "rust_tokenizers" version = "3.1.6" @@ -1032,8 +1925,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77c4313059ea8764ff2743ffaaa42fba0e4d5f8ff12febe4f3c74d598f629f62" dependencies = [ "csv", - "hashbrown", - "itertools", + "hashbrown 0.7.2", + "itertools 0.8.2", "lazy_static", "protobuf", "rayon", @@ -1067,7 +1960,7 @@ dependencies = [ "io-lifetimes", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.48.0", ] [[package]] @@ -1112,11 +2005,43 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" [[package]] -name = "scopeguard" -version = "1.1.0" +name = "schannel" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" +dependencies = [ + "windows-sys 0.42.0", +] + +[[package]] +name = "scopeguard" +version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "security-framework" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "serde" version = "1.0.162" @@ -1157,18 +2082,71 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_urlencoded" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3491c14715ca2294c4d6a88f15e84739788c1d030eed8c110436aafdaa2f3fd" +dependencies = [ + "form_urlencoded", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha1" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + +[[package]] +name = "sha2" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82e6b795fe2e3b1e845bafcb27aa35405c4d47cdfc92af5fc8d3002f76cebdc0" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shlex" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" +[[package]] +name = "slab" +version = "0.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6528351c9bc8ab22353f9d776db39a20288e8d6c37ef8cfe3317cf875eecfc2d" +dependencies = [ + "autocfg", +] + [[package]] name = "smallvec" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +[[package]] +name = "socket2" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64a4a911eed85daf18834cfaa86a79b7d266ff93ff5ba14005426219480ed662" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "spinoff" version = "0.7.0" @@ -1180,6 +2158,18 @@ dependencies = [ "paste", ] +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "static_assertions" version = "1.1.0" @@ -1198,6 +2188,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + [[package]] name = "syn" version = "1.0.109" @@ -1220,6 +2216,30 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tar" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "tempfile" +version = "3.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9fbec84f381d5795b08656e4912bec604d162bff9291d6189a78f4c8ab87998" +dependencies = [ + "cfg-if", + "fastrand", + "redox_syscall 0.3.5", + "rustix", + "windows-sys 0.45.0", +] + [[package]] name = "termcolor" version = "1.2.0" @@ -1249,6 +2269,22 @@ dependencies = [ "syn 2.0.15", ] +[[package]] +name = "time" +version = "0.3.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f3403384eaacbca9923fa06940178ac13e4edb725486d70e8e15881d0c836cc" +dependencies = [ + "serde", + "time-core", +] + +[[package]] +name = "time-core" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb" + [[package]] name = "tinyvec" version = "1.6.0" @@ -1264,6 +2300,126 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cf49017523bf0bc01c9966f172c5f120bbb7b96cccd1708772dd42e767fb9f5" +dependencies = [ + "aho-corasick 0.7.20", + "cached-path", + "clap", + "derive_builder", + "dirs", + "esaxx-rs", + "getrandom", + "indicatif 0.15.0", + "itertools 0.9.0", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax 0.6.29", + "reqwest", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + +[[package]] +name = "tokio" +version = "1.28.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0aa32867d44e6f2ce3385e89dceb990188b8bb0fb25b0cf576647a6f98ac5105" +dependencies = [ + "autocfg", + "bytes", + "libc", + "mio", + "num_cpus", + "pin-project-lite", + "socket2", + "windows-sys 0.48.0", +] + +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "806fe8c2c87eccc8b3267cbae29ed3ab2d0bd37fca70ab622e46aaa9375ddb7d" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", + "tracing", +] + +[[package]] +name = "tower-service" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" + +[[package]] +name = "tracing" +version = "0.1.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8" +dependencies = [ + "cfg-if", + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" +dependencies = [ + "once_cell", +] + +[[package]] +name = "try-lock" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" + +[[package]] +name = "typenum" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" + +[[package]] +name = "unicode-bidi" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" + [[package]] name = "unicode-ident" version = "1.0.8" @@ -1300,18 +2456,133 @@ version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0edd1e5b14653f783770bce4a4dabb4a5108a5370a5f5d8cfe8710c361f6c8b" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + +[[package]] +name = "url" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d68c799ae75762b8c3fe375feb6600ef5602c883c5d21eb51c09f22b83c4643" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "utf8parse" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "want" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ce8a968cb1cd110d136ff8b819a556d6fb6d919363c61534f6860c7eb172ba0" +dependencies = [ + "log", + "try-lock", +] + [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasm-bindgen" +version = "0.2.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5bba0e8cb82ba49ff4e229459ff22a191bbe9a1cb3a341610c9c33efc27ddf73" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b04bc93f9d6bdee709f6bd2118f57dd6679cf1176a1af464fca3ab0d66d8fb" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.15", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-futures" +version = "0.4.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d1985d03709c53167ce907ff394f5316aa22cb4e12761295c5dc57dacb6297e" +dependencies = [ + "cfg-if", + "js-sys", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14d6b024f1a526bb0234f52840389927257beb670610081360e5a03c5df9c258" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e128beba882dd1eb6200e1dc92ae6c5dbaa4311aa7bb211ca035779e5efc39f8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.15", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed9d5b4305409d1fc9482fee2d7f9bcbf24b3972bf59817ef757e23982242a93" + +[[package]] +name = "web-sys" +version = "0.3.63" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bdd9ef4e984da1187bf8110c5cf5b845fbc87a23602cdf912386a76fcd3a7c2" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "which" version = "4.4.0" @@ -1354,13 +2625,52 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.45.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0" +dependencies = [ + "windows-targets 0.42.2", +] + [[package]] name = "windows-sys" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ - "windows-targets", + "windows-targets 0.48.0", +] + +[[package]] +name = "windows-targets" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e5180c00cd44c9b1c88adb3693291f1cd93605ded80c250a75d472756b4d071" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", ] [[package]] @@ -1369,64 +2679,163 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + [[package]] name = "windows_aarch64_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + [[package]] name = "windows_i686_gnu" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + [[package]] name = "windows_i686_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + [[package]] name = "windows_x86_64_gnu" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + [[package]] name = "windows_x86_64_msvc" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +[[package]] +name = "winreg" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +dependencies = [ + "winapi", +] + +[[package]] +name = "xattr" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d1526bbe5aaeb5eb06885f4d987bcdfa5e23187055de9b83fe00156a821fabc" +dependencies = [ + "libc", +] + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "aes", + "byteorder", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "flate2", + "hmac", + "pbkdf2", + "sha1", + "time", + "zstd 0.11.2+zstd.1.5.2", +] + +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe 5.0.2+zstd.1.5.2", +] + [[package]] name = "zstd" version = "0.12.3+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76eea132fb024e0e13fd9c2f5d5d595d8a967aa72382ac2f9d39fcc95afd0806" dependencies = [ - "zstd-safe", + "zstd-safe 6.0.5+zstd.1.5.4", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", ] [[package]] diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 20d3baae..e4dd5740 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -332,6 +332,10 @@ pub struct ModelLoad { #[arg(long, short = 'm')] pub model_path: PathBuf, + /// Where to save the model from + #[arg(long, short = 'v')] + pub vocab_path: Option, + /// Sets the size of the context (in tokens). Allows feeding longer prompts. /// Note that this affects memory. /// @@ -376,6 +380,7 @@ impl ModelLoad { let model = llm::load::( &self.model_path, + self.vocab_path.as_deref(), params, overrides, |progress| match progress { diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 720cf969..deb4bf73 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -149,7 +149,7 @@ fn perplexity( fn info(args: &cli_args::Info) -> Result<()> { let file = File::open(&args.model_path)?; let mut reader = BufReader::new(&file); - let mut loader: llm::Loader = llm::Loader::new(|_| { + let mut loader: llm::Loader = llm::Loader::new(None, |_| { // We purposely do not print progress here, as we are only interested in the metadata }); diff --git a/crates/llm-base/Cargo.toml b/crates/llm-base/Cargo.toml index 94d91169..3d155ec8 100644 --- a/crates/llm-base/Cargo.toml +++ b/crates/llm-base/Cargo.toml @@ -22,3 +22,4 @@ partial_sort = "0.2.0" serde_bytes = "0.11" memmap2 = "0.5.10" half = "2.2.1" +tokenizers = "0.13.3" diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index d072939a..0018f75f 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -94,7 +94,7 @@ impl InferenceSession { if should_call_callback { // NOTE: No string ever tokenizes to the end of sentence. So we // can just return the id here. - match callback(vocab.token(tk as usize)) { + match callback(&vocab.token(tk as usize)) { Err(e) => return Err(InferenceError::UserCallback(Some(Box::new(e)))), Ok(f) => match f { InferenceFeedback::Continue => (), @@ -118,7 +118,7 @@ impl InferenceSession { params: &InferenceParameters, output_request: &mut OutputRequest, rng: &mut impl rand::Rng, - ) -> Result<&'v [u8], InferenceError> { + ) -> Result, InferenceError> { if self.n_past + 1 >= model.context_size() { return Err(InferenceError::ContextFull); } @@ -163,7 +163,7 @@ impl InferenceSession { for token_id in &self.tokens { // Buffer the token until it's valid UTF-8, then call the callback. if let Some(tokens) = - token_utf8_buf.push(model.vocabulary().token(*token_id as usize)) + token_utf8_buf.push(&model.vocabulary().token(*token_id as usize)) { if let Err(e) = callback(InferenceResponse::SnapshotToken(tokens)) { return Err(InferenceError::UserCallback(Some(Box::new(e)))); @@ -204,7 +204,7 @@ impl InferenceSession { }; // Buffer the token until it's valid UTF-8, then call the callback. - if let Some(tokens) = token_utf8_buf.push(token) { + if let Some(tokens) = token_utf8_buf.push(&token) { match callback(InferenceResponse::InferredToken(tokens)) { Err(e) => return Err(InferenceError::UserCallback(Some(Box::new(e)))), Ok(f) => match f { diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index d4da0861..9931435f 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -18,6 +18,8 @@ use ggml::{ use memmap2::Mmap; use thiserror::Error; +use tokenizers::Tokenizer; + #[derive(Debug, PartialEq, Clone, Copy, Eq, Default)] /// Information about the file. pub struct FileType { @@ -280,6 +282,15 @@ pub enum LoadError { /// The paths that were found. paths: Vec, }, + + /// The vocab file for the tokenizer could not be loaded. + /// + /// + #[error("could not load vocab file {path:?}")] + VocabLoadError { + /// The path that failed. + path: PathBuf, + }, } impl From for LoadError { fn from(value: util::FindAllModelFilesError) -> Self { @@ -343,6 +354,7 @@ pub trait TensorLoader { /// store any information about the architecture. pub fn load( path: &Path, + vocab_path: Option<&Path>, params: ModelParameters, overrides: Option, load_progress_callback: impl FnMut(LoadProgress), @@ -364,7 +376,29 @@ pub fn load( })?; let mut reader = BufReader::new(&file); - let mut loader = Loader::new(load_progress_callback); + let tokenizer = if let Some(path) = vocab_path { + let tok = if !path.exists() && path.to_str().unwrap().matches("/").count() == 1 { + Tokenizer::from_pretrained(path.to_str().unwrap(), None) + } else if path.exists() && path.is_file() { + Tokenizer::from_file(path) + } else { + return Err(LoadError::VocabLoadError { + path: path.to_owned(), + }); + }; + + if tok.is_err() { + return Err(LoadError::VocabLoadError { + path: path.to_owned(), + }); + } + + Some(tok.unwrap()) + } else { + None + }; + + let mut loader = Loader::new(tokenizer, load_progress_callback); ggml::format::load(&mut reader, &mut loader) .map_err(|err| LoadError::from_format_error(err, path.to_owned()))?; @@ -422,7 +456,7 @@ pub fn load( let mut lora_reader = BufReader::new(&lora_file); // TODO: Consider updating the progress callback to report the progress of the LoRA file. // Most LoRAs are small enough that this is not necessary, but it would be nice to have. - let mut lora_loader: Loader = Loader::new(|_| {}); + let mut lora_loader: Loader = Loader::new(None, |_| {}); ggml::format::load(&mut lora_reader, &mut lora_loader) .map_err(|err| LoadError::from_format_error(err, lora_path.to_owned()))?; @@ -498,13 +532,13 @@ pub struct Loader { } impl Loader { /// Creates a new loader. - pub fn new(load_progress_callback: F) -> Self { + pub fn new(tokenizer: Option, load_progress_callback: F) -> Self { Self { load_progress_callback, container_type: ContainerType::Ggml, hyperparameters: Hp::default(), - vocabulary: Vocabulary::default(), + vocabulary: Vocabulary::new(tokenizer), tensors: HashMap::default(), } } diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index 03ec965f..b2423aee 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -114,6 +114,7 @@ pub trait KnownModel: Send + Sync { /// is a helper function on top of [llm_base::load](crate::load). fn load( path: &Path, + vocab_path: Option<&Path>, params: ModelParameters, overrides: Option, load_progress_callback: impl FnMut(LoadProgress), @@ -121,7 +122,7 @@ pub trait KnownModel: Send + Sync { where Self: Sized, { - crate::load(path, params, overrides, load_progress_callback) + crate::load(path, vocab_path, params, overrides, load_progress_callback) } /// Creates a new model from the provided [ModelParameters] hyperparameters. @@ -151,7 +152,7 @@ pub trait KnownModel: Send + Sync { output_request: &mut OutputRequest, ); - /// Get the vocabulary (loaded from the GGML file) for this model. + /// Get the vocabulary for this model. fn vocabulary(&self) -> &Vocabulary; /// Get the context size (configured with [ModelParameters::context_size]) used by @@ -188,7 +189,7 @@ pub trait Model: Send + Sync { output_request: &mut OutputRequest, ); - /// Get the vocabulary (loaded from the GGML file) for this model. + /// Get the vocabulary for this model. fn vocabulary(&self) -> &Vocabulary; /// Get the context size (configured with [ModelParameters::context_size]) used by diff --git a/crates/llm-base/src/quantize.rs b/crates/llm-base/src/quantize.rs index 9ea8ef4a..c1e63fdb 100644 --- a/crates/llm-base/src/quantize.rs +++ b/crates/llm-base/src/quantize.rs @@ -151,7 +151,7 @@ pub fn quantize( // Load the model let progress_callback = Arc::new(progress_callback); - let mut loader = Loader::::new({ + let mut loader = Loader::::new(None, { let progress_callback = progress_callback.clone(); move |p| { if let LoadProgress::HyperparametersLoaded = p { diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index 736df47b..d277703f 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -1,9 +1,10 @@ use std::{collections::HashMap, error::Error, fmt::Display, str::FromStr}; use thiserror::Error; +use tokenizers::Tokenizer; /// The identifier of a token in a vocabulary. -pub type TokenId = i32; +pub type TokenId = u32; pub(crate) type Token = Vec; pub(crate) type TokenScore = f32; @@ -34,9 +35,19 @@ pub struct Vocabulary { /// The longest token in this vocabulary. pub max_token_length: usize, + + /// The tokenizer + pub tokenizer: Option, } impl Vocabulary { + /// Intialize a new vocabulary. + pub fn new(tokenizer: Option) -> Vocabulary { + let mut vocab = Vocabulary::default(); + vocab.tokenizer = tokenizer; + + vocab + } /// Add a token to the vocabulary. /// /// The token added must have `id` directly after the last token in the vocabulary. @@ -45,6 +56,10 @@ impl Vocabulary { /// - This function can panic if `id` does not correspond to the next token in the vocabulary. /// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`. pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { + if self.tokenizer.is_some() { + return; + } + // These are loader invariants. If this is broken, then the loader is broken and this is a bug, // not an issue with the model itself. assert_eq!(self.id_to_token.len(), self.id_to_token_score.len()); @@ -60,17 +75,33 @@ impl Vocabulary { } /// Converts a token index to the token it represents in this vocabulary. - pub fn token(&self, idx: usize) -> &[u8] { - &self.id_to_token[idx] + pub fn token(&self, idx: usize) -> Vec { + if let Some(tokenizer) = &self.tokenizer { + return tokenizer + .decode(vec![idx as u32], true) + .unwrap() + .as_bytes() + .to_vec(); + } + + (&self.id_to_token[idx]).clone() } /// Returns the number of tokens in the vocabulary. pub fn len(&self) -> usize { + if let Some(tokenizer) = &self.tokenizer { + return tokenizer.get_vocab_size(false) as usize; + } + self.id_to_token.len() } /// Returns whether the vocabulary is empty. pub fn is_empty(&self) -> bool { + if let Some(tokenizer) = &self.tokenizer { + return tokenizer.get_vocab_size(false) == 0; + } + self.id_to_token.is_empty() } @@ -82,53 +113,68 @@ impl Vocabulary { &'a self, text: &str, bos: bool, - ) -> Result, TokenizationError> { - let len = text.len(); - - let mut score = vec![0usize; len + 1]; - let mut prev = vec![TokenId::default(); len + 1]; - - for i in 0..len { - let max_len = (len - i).min(self.max_token_length); - for sub_len in 1..=max_len { - let sub = &text.as_bytes()[i..i + sub_len]; - let token = self.token_to_id.get(sub); - - if let Some(token) = token { - let token_score = sub.len() * sub.len(); - let local_score = score[i] + token_score; - let next = i + sub_len; - - if score[next] < local_score { - score[next] = local_score; - prev[next] = *token; + ) -> Result, TokenId)>, TokenizationError> { + if let Some(tokenizer) = &self.tokenizer { + let res = tokenizer.encode(text, bos); + if res.is_err() { + return Err(TokenizationError::TokenizationFailed); + } else { + Ok(tokenizer + .encode(text, bos) + .unwrap() + .get_ids() + .iter() + .map(|id| (self.token(*id as usize), *id)) + .collect::, TokenId)>>()) + } + } else { + let len = text.len(); + + let mut score = vec![0usize; len + 1]; + let mut prev = vec![TokenId::default(); len + 1]; + + for i in 0..len { + let max_len = (len - i).min(self.max_token_length); + for sub_len in 1..=max_len { + let sub = &text.as_bytes()[i..i + sub_len]; + let token = self.token_to_id.get(sub); + + if let Some(token) = token { + let token_score = sub.len() * sub.len(); + let local_score = score[i] + token_score; + let next = i + sub_len; + + if score[next] < local_score { + score[next] = local_score; + prev[next] = *token; + } } } } - } - // Backward pass - let mut res = vec![]; - let mut i = len; - while i > 0 { - let token_id = prev[i]; - if token_id == 0 { - return Err(TokenizationError::TokenizationFailed); + // Backward pass + let mut res = vec![]; + let mut i = len; + while i > 0 { + let token_id = prev[i]; + if token_id == 0 { + return Err(TokenizationError::TokenizationFailed); + } + let token = self.id_to_token[token_id as usize].as_slice(); + res.push((token.to_vec(), token_id)); + i -= token.len(); } - let token = self.id_to_token[token_id as usize].as_slice(); - res.push((token, token_id)); - i -= token.len(); - } - if bos { - // TODO: replace with vocab.bos - res.push((&[], 1)); - } + if bos { + // TODO: replace with vocab.bos + res.push((vec![], 1)); + } - // Pieces are in reverse order so correct that - res.reverse(); + // Pieces are in reverse order so correct that + res.reverse(); - Ok(res) + Ok(res) + } } } diff --git a/crates/llm/examples/inference.rs b/crates/llm/examples/inference.rs index 29388f67..83e5f198 100644 --- a/crates/llm/examples/inference.rs +++ b/crates/llm/examples/inference.rs @@ -24,6 +24,7 @@ fn main() { let model = llm::load_dynamic( model_architecture, model_path, + None, Default::default(), overrides, load_callback, diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index 6c15df1a..5065158f 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -24,6 +24,7 @@ fn main() { let model = llm::load_dynamic( model_architecture, model_path, + None, Default::default(), overrides, load_progress_callback(sp, now, prev_load_time), diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 050193ee..95250f81 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -222,6 +222,7 @@ impl Display for ModelArchitecture { pub fn load_dynamic( architecture: ModelArchitecture, path: &Path, + vocab_path: Option<&Path>, params: ModelParameters, overrides: Option, load_progress_callback: impl FnMut(LoadProgress), @@ -230,12 +231,14 @@ pub fn load_dynamic( fn load_model( path: &Path, + vocab_path: Option<&Path>, params: ModelParameters, overrides: Option, load_progress_callback: impl FnMut(LoadProgress), ) -> Result, LoadError> { Ok(Box::new(load::( path, + vocab_path, params, overrides.map(|o| o.into()), load_progress_callback, @@ -244,17 +247,41 @@ pub fn load_dynamic( let model: Box = match architecture { #[cfg(feature = "bloom")] - Bloom => load_model::(path, params, overrides, load_progress_callback)?, + Bloom => load_model::( + path, + vocab_path, + params, + overrides, + load_progress_callback, + )?, #[cfg(feature = "gpt2")] - Gpt2 => load_model::(path, params, overrides, load_progress_callback)?, + Gpt2 => { + load_model::(path, vocab_path, params, overrides, load_progress_callback)? + } #[cfg(feature = "gptj")] - GptJ => load_model::(path, params, overrides, load_progress_callback)?, + GptJ => { + load_model::(path, vocab_path, params, overrides, load_progress_callback)? + } #[cfg(feature = "gptneox")] - GptNeoX => load_model::(path, params, overrides, load_progress_callback)?, + GptNeoX => load_model::( + path, + vocab_path, + params, + overrides, + load_progress_callback, + )?, #[cfg(feature = "llama")] - Llama => load_model::(path, params, overrides, load_progress_callback)?, + Llama => load_model::( + path, + vocab_path, + params, + overrides, + load_progress_callback, + )?, #[cfg(feature = "mpt")] - Mpt => load_model::(path, params, overrides, load_progress_callback)?, + Mpt => { + load_model::(path, vocab_path, params, overrides, load_progress_callback)? + } }; Ok(model) diff --git a/crates/models/llama/src/convert.rs b/crates/models/llama/src/convert.rs index 763d5fbd..8a116830 100644 --- a/crates/models/llama/src/convert.rs +++ b/crates/models/llama/src/convert.rs @@ -55,7 +55,7 @@ fn load_vocabulary(path: &Path) -> Vocabulary { let word = piece.get_piece().as_bytes(); max_token_length = max_token_length.max(word.len()); id_to_token.push(word.to_owned()); - token_to_id.insert(word.to_owned(), idx as i32); + token_to_id.insert(word.to_owned(), idx as u32); id_to_token_score.push(piece.get_score()); } Vocabulary { @@ -63,6 +63,7 @@ fn load_vocabulary(path: &Path) -> Vocabulary { id_to_token_score, token_to_id, max_token_length, + tokenizer: None, } } From aba3bb884772ed7e0443a14a3a64518f525299e3 Mon Sep 17 00:00:00 2001 From: RedBoxing Date: Tue, 23 May 2023 21:01:38 +0200 Subject: [PATCH 02/20] Change Vocabulary into enum --- binaries/llm-cli/src/cli_args.rs | 4 +- binaries/llm-cli/src/main.rs | 15 +- crates/llm-base/src/lib.rs | 5 +- crates/llm-base/src/loader.rs | 49 +++-- crates/llm-base/src/quantize.rs | 18 +- crates/llm-base/src/vocabulary.rs | 302 ++++++++++++++++++++--------- crates/models/bloom/src/lib.rs | 8 +- crates/models/gpt2/src/lib.rs | 4 +- crates/models/gptj/src/lib.rs | 4 +- crates/models/gptneox/src/lib.rs | 4 +- crates/models/llama/src/convert.rs | 18 +- crates/models/mpt/src/lib.rs | 9 +- 12 files changed, 284 insertions(+), 156 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index e4dd5740..f94e542f 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -334,7 +334,7 @@ pub struct ModelLoad { /// Where to save the model from #[arg(long, short = 'v')] - pub vocab_path: Option, + pub vocabulary_path: Option, /// Sets the size of the context (in tokens). Allows feeding longer prompts. /// Note that this affects memory. @@ -380,7 +380,7 @@ impl ModelLoad { let model = llm::load::( &self.model_path, - self.vocab_path.as_deref(), + self.vocabulary_path.as_deref(), params, overrides, |progress| match progress { diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index deb4bf73..8f76fc90 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -7,7 +7,7 @@ use std::{ use clap::Parser; use cli_args::{Args, BaseArgs}; use color_eyre::eyre::{Context, Result}; -use llm::{InferenceError, InferenceFeedback, InferenceResponse}; +use llm::{InferenceError, InferenceFeedback, InferenceResponse, Vocabulary}; use rustyline::error::ReadlineError; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{history::DefaultHistory, Cmd, Event, EventHandler, KeyCode, KeyEvent, Modifiers}; @@ -149,9 +149,10 @@ fn perplexity( fn info(args: &cli_args::Info) -> Result<()> { let file = File::open(&args.model_path)?; let mut reader = BufReader::new(&file); - let mut loader: llm::Loader = llm::Loader::new(None, |_| { - // We purposely do not print progress here, as we are only interested in the metadata - }); + let mut loader: llm::Loader = + llm::Loader::new(Vocabulary::new_ggml(), |_| { + // We purposely do not print progress here, as we are only interested in the metadata + }); llm::ggml_format::load(&mut reader, &mut loader)?; @@ -165,12 +166,12 @@ fn info(args: &cli_args::Info) -> Result<()> { .map(|(name, tensor)| format!("{} ({:?})", name, tensor.element_type)) .collect::>() ); - log::info!("Vocabulary size: {}", loader.vocabulary.id_to_token.len()); + log::info!("Vocabulary size: {}", loader.vocabulary.len()); if args.dump_vocabulary { log::info!("Dumping vocabulary:"); - for (tid, token) in loader.vocabulary.id_to_token.iter().enumerate() { - log::info!("{}: {}", tid, utf8_or_array(token)); + for i in 0..loader.vocabulary.len() { + log::info!("{}: {}", i, utf8_or_array(&loader.vocabulary.token(i))); } } diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index fb1a2238..98f64f6f 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -36,7 +36,10 @@ pub use model::{ }; pub use quantize::{quantize, QuantizeError, QuantizeProgress}; pub use util::TokenUtf8Buffer; -pub use vocabulary::{InvalidTokenBias, Prompt, TokenBias, TokenId, TokenizationError, Vocabulary}; +pub use vocabulary::{ + GgmlVocabulary, InvalidTokenBias, Prompt, TokenBias, TokenId, TokenizationError, + TokenizerVocabulary, Vocabulary, +}; #[derive(Clone, Debug, PartialEq)] /// The parameters for text generation. diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 9931435f..7c1d1b87 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -1,5 +1,6 @@ use std::{ collections::HashMap, + error::Error, fmt::{Display, Formatter}, fs::File, io::{BufRead, BufReader, Read, Seek, SeekFrom}, @@ -284,12 +285,13 @@ pub enum LoadError { }, /// The vocab file for the tokenizer could not be loaded. - /// - /// #[error("could not load vocab file {path:?}")] - VocabLoadError { + VocabularyLoadError { /// The path that failed. path: PathBuf, + + /// The error that occurred. + error: Box, }, } impl From for LoadError { @@ -354,7 +356,7 @@ pub trait TensorLoader { /// store any information about the architecture. pub fn load( path: &Path, - vocab_path: Option<&Path>, + vocabulary_path: Option<&Path>, params: ModelParameters, overrides: Option, load_progress_callback: impl FnMut(LoadProgress), @@ -376,29 +378,34 @@ pub fn load( })?; let mut reader = BufReader::new(&file); - let tokenizer = if let Some(path) = vocab_path { - let tok = if !path.exists() && path.to_str().unwrap().matches("/").count() == 1 { + let vocabulary = if let Some(path) = vocabulary_path { + let tok = if !path.exists() && path.to_string_lossy().matches("/").count() == 1 { Tokenizer::from_pretrained(path.to_str().unwrap(), None) } else if path.exists() && path.is_file() { Tokenizer::from_file(path) } else { - return Err(LoadError::VocabLoadError { + return Err(LoadError::VocabularyLoadError { path: path.to_owned(), + error: Box::new(std::io::Error::new( + std::io::ErrorKind::NotFound, + "Vocabulary file not found", + )), }); }; if tok.is_err() { - return Err(LoadError::VocabLoadError { + return Err(LoadError::VocabularyLoadError { path: path.to_owned(), + error: tok.unwrap_err(), }); } - Some(tok.unwrap()) + Vocabulary::new_tokenizer(tok.unwrap()) } else { - None + Vocabulary::new_ggml() }; - let mut loader = Loader::new(tokenizer, load_progress_callback); + let mut loader = Loader::new(vocabulary, load_progress_callback); ggml::format::load(&mut reader, &mut loader) .map_err(|err| LoadError::from_format_error(err, path.to_owned()))?; @@ -456,7 +463,8 @@ pub fn load( let mut lora_reader = BufReader::new(&lora_file); // TODO: Consider updating the progress callback to report the progress of the LoRA file. // Most LoRAs are small enough that this is not necessary, but it would be nice to have. - let mut lora_loader: Loader = Loader::new(None, |_| {}); + let mut lora_loader: Loader = + Loader::new(Vocabulary::new_ggml(), |_| {}); ggml::format::load(&mut lora_reader, &mut lora_loader) .map_err(|err| LoadError::from_format_error(err, lora_path.to_owned()))?; @@ -532,13 +540,13 @@ pub struct Loader { } impl Loader { /// Creates a new loader. - pub fn new(tokenizer: Option, load_progress_callback: F) -> Self { + pub fn new(vocabulary: Vocabulary, load_progress_callback: F) -> Self { Self { load_progress_callback, container_type: ContainerType::Ggml, hyperparameters: Hp::default(), - vocabulary: Vocabulary::new(tokenizer), + vocabulary: vocabulary, tensors: HashMap::default(), } } @@ -552,11 +560,14 @@ impl ggml::format::LoadHandler, score: f32) -> Result<(), LoadError> { - let id = match TokenId::try_from(i) { - Ok(id) => id, - Err(err) => return Err(LoadError::InvalidIntegerConversion(err)), - }; - self.vocabulary.push_token(id, token, score); + if let Vocabulary::Ggml(_) = &self.vocabulary { + let id = match TokenId::try_from(i) { + Ok(id) => id, + Err(err) => return Err(LoadError::InvalidIntegerConversion(err)), + }; + + self.vocabulary.push_token(id, token, score); + } Ok(()) } diff --git a/crates/llm-base/src/quantize.rs b/crates/llm-base/src/quantize.rs index c1e63fdb..b11873b3 100644 --- a/crates/llm-base/src/quantize.rs +++ b/crates/llm-base/src/quantize.rs @@ -2,6 +2,7 @@ use crate::{ model::HyperparametersWriteError, Hyperparameters, KnownModel, LoadError, LoadProgress, Loader, + Vocabulary, }; use ggml::format::{SaveError, SaveHandler, TensorLoadInfo, TensorSaveInfo}; use half::f16; @@ -151,7 +152,7 @@ pub fn quantize( // Load the model let progress_callback = Arc::new(progress_callback); - let mut loader = Loader::::new(None, { + let mut loader = Loader::::new(Vocabulary::new_ggml(), { let progress_callback = progress_callback.clone(); move |p| { if let LoadProgress::HyperparametersLoaded = p { @@ -177,12 +178,15 @@ pub fn quantize( .expect("format has no corresponding ftype"); } - let vocabulary = vocabulary - .id_to_token - .iter() - .cloned() - .zip(vocabulary.id_to_token_score) - .collect::>(); + let vocabulary = match vocabulary { + Vocabulary::Ggml(v) => v + .id_to_token + .iter() + .cloned() + .zip(v.id_to_token_score) + .collect::>(), + Vocabulary::Tokenizer(_) => vec![], + }; let mut saver = QuantizeSaver::new(desired_type, &hyperparameters, &tensors, reader, |p| { progress_callback(p) diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index d277703f..4fcf3145 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -19,9 +19,103 @@ pub enum TokenizationError { InvalidTokenId(TokenId), } -/// The vocabulary used by a model. +pub trait VocabularyTrait { + fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore); + fn token_to_id(&self, token: &[u8]) -> Option; + fn token(&self, idx: usize) -> Vec; + fn len(&self) -> usize; + fn is_empty(&self) -> bool; + fn tokenize<'a>( + &'a self, + text: &str, + bos: bool, + ) -> Result, TokenId)>, TokenizationError>; +} + +/// Vocabulary enum +pub enum Vocabulary { + /// The vocabulary built-in to the model. + Ggml(GgmlVocabulary), + + /// A custom vocabulary provided by the user. + Tokenizer(TokenizerVocabulary), +} + +impl Vocabulary { + /// Create a new vocabulary with the default GGML vocabulary. + pub fn new_ggml() -> Self { + Vocabulary::Ggml(GgmlVocabulary::default()) + } + + /// Create a new vocabulary with a custom tokenizer. + pub fn new_tokenizer(tokenizer: Tokenizer) -> Self { + Vocabulary::Tokenizer(TokenizerVocabulary::new(tokenizer)) + } + + /// Add a token to the vocabulary. + /// + /// The token added must have `id` directly after the last token in the vocabulary. + /// + /// # Panics + /// - This function can panic if `id` does not correspond to the next token in the vocabulary. + /// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`. + pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { + match self { + Vocabulary::Ggml(v) => v.push_token(id, content, score), + Vocabulary::Tokenizer(v) => v.push_token(id, content, score), + } + } + + /// Converts a token to the token ID it represents in this vocabulary. + pub fn token_to_id(&self, token: &[u8]) -> Option { + match self { + Vocabulary::Ggml(v) => v.token_to_id(token), + Vocabulary::Tokenizer(v) => v.token_to_id(token), + } + } + + /// Converts a token index to the token it represents in this vocabulary. + pub fn token(&self, idx: usize) -> Vec { + match self { + Vocabulary::Ggml(v) => v.token(idx), + Vocabulary::Tokenizer(v) => v.token(idx), + } + } + + /// Returns the number of tokens in the vocabulary. + pub fn len(&self) -> usize { + match self { + Vocabulary::Ggml(v) => v.len(), + Vocabulary::Tokenizer(v) => v.len(), + } + } + + /// Returns whether the vocabulary is empty. + pub fn is_empty(&self) -> bool { + match self { + Vocabulary::Ggml(v) => v.is_empty(), + Vocabulary::Tokenizer(v) => v.is_empty(), + } + } + + /// Tokenize a `text` with this vocabulary. + /// + /// `bos` controls whether a beginning-of-string token should be inserted. + pub fn tokenize<'a>( + &'a self, + text: &str, + bos: bool, + ) -> Result, TokenId)>, TokenizationError> { + match self { + Vocabulary::Ggml(v) => v.tokenize(text, bos), + Vocabulary::Tokenizer(v) => v.tokenize(text, bos), + } + } +} + +/// The built-in GGML vocabulary. #[derive(Debug, Clone, Default)] -pub struct Vocabulary { +pub struct GgmlVocabulary { // TODO: make these private /// Maps every integer (index) token ID to its corresponding token. pub id_to_token: Vec, @@ -35,19 +129,9 @@ pub struct Vocabulary { /// The longest token in this vocabulary. pub max_token_length: usize, - - /// The tokenizer - pub tokenizer: Option, } -impl Vocabulary { - /// Intialize a new vocabulary. - pub fn new(tokenizer: Option) -> Vocabulary { - let mut vocab = Vocabulary::default(); - vocab.tokenizer = tokenizer; - - vocab - } +impl VocabularyTrait for GgmlVocabulary { /// Add a token to the vocabulary. /// /// The token added must have `id` directly after the last token in the vocabulary. @@ -55,11 +139,7 @@ impl Vocabulary { /// # Panics /// - This function can panic if `id` does not correspond to the next token in the vocabulary. /// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`. - pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { - if self.tokenizer.is_some() { - return; - } - + fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { // These are loader invariants. If this is broken, then the loader is broken and this is a bug, // not an issue with the model itself. assert_eq!(self.id_to_token.len(), self.id_to_token_score.len()); @@ -74,34 +154,22 @@ impl Vocabulary { self.token_to_id.insert(content, id); } - /// Converts a token index to the token it represents in this vocabulary. - pub fn token(&self, idx: usize) -> Vec { - if let Some(tokenizer) = &self.tokenizer { - return tokenizer - .decode(vec![idx as u32], true) - .unwrap() - .as_bytes() - .to_vec(); - } + fn token_to_id(&self, token: &[u8]) -> Option { + self.token_to_id.get(token).copied() + } + /// Converts a token index to the token it represents in this vocabulary. + fn token(&self, idx: usize) -> Vec { (&self.id_to_token[idx]).clone() } /// Returns the number of tokens in the vocabulary. - pub fn len(&self) -> usize { - if let Some(tokenizer) = &self.tokenizer { - return tokenizer.get_vocab_size(false) as usize; - } - + fn len(&self) -> usize { self.id_to_token.len() } /// Returns whether the vocabulary is empty. - pub fn is_empty(&self) -> bool { - if let Some(tokenizer) = &self.tokenizer { - return tokenizer.get_vocab_size(false) == 0; - } - + fn is_empty(&self) -> bool { self.id_to_token.is_empty() } @@ -109,71 +177,123 @@ impl Vocabulary { /// Tokenize a `text` with this vocabulary. /// /// `bos` controls whether a beginning-of-string token should be inserted. - pub fn tokenize<'a>( + fn tokenize<'a>( &'a self, text: &str, bos: bool, ) -> Result, TokenId)>, TokenizationError> { - if let Some(tokenizer) = &self.tokenizer { - let res = tokenizer.encode(text, bos); - if res.is_err() { - return Err(TokenizationError::TokenizationFailed); - } else { - Ok(tokenizer - .encode(text, bos) - .unwrap() - .get_ids() - .iter() - .map(|id| (self.token(*id as usize), *id)) - .collect::, TokenId)>>()) - } - } else { - let len = text.len(); - - let mut score = vec![0usize; len + 1]; - let mut prev = vec![TokenId::default(); len + 1]; - - for i in 0..len { - let max_len = (len - i).min(self.max_token_length); - for sub_len in 1..=max_len { - let sub = &text.as_bytes()[i..i + sub_len]; - let token = self.token_to_id.get(sub); - - if let Some(token) = token { - let token_score = sub.len() * sub.len(); - let local_score = score[i] + token_score; - let next = i + sub_len; - - if score[next] < local_score { - score[next] = local_score; - prev[next] = *token; - } + let len = text.len(); + + let mut score = vec![0usize; len + 1]; + let mut prev = vec![TokenId::default(); len + 1]; + + for i in 0..len { + let max_len = (len - i).min(self.max_token_length); + for sub_len in 1..=max_len { + let sub = &text.as_bytes()[i..i + sub_len]; + let token = self.token_to_id.get(sub); + + if let Some(token) = token { + let token_score = sub.len() * sub.len(); + let local_score = score[i] + token_score; + let next = i + sub_len; + + if score[next] < local_score { + score[next] = local_score; + prev[next] = *token; } } } + } - // Backward pass - let mut res = vec![]; - let mut i = len; - while i > 0 { - let token_id = prev[i]; - if token_id == 0 { - return Err(TokenizationError::TokenizationFailed); - } - let token = self.id_to_token[token_id as usize].as_slice(); - res.push((token.to_vec(), token_id)); - i -= token.len(); + // Backward pass + let mut res = vec![]; + let mut i = len; + while i > 0 { + let token_id = prev[i]; + if token_id == 0 { + return Err(TokenizationError::TokenizationFailed); } + let token = self.id_to_token[token_id as usize].as_slice(); + res.push((token.to_vec(), token_id)); + i -= token.len(); + } - if bos { - // TODO: replace with vocab.bos - res.push((vec![], 1)); - } + if bos { + // TODO: replace with vocab.bos + res.push((vec![], 1)); + } + + // Pieces are in reverse order so correct that + res.reverse(); + + Ok(res) + } +} + +/// A vocabulary provided by the user. +#[derive(Debug, Clone)] +pub struct TokenizerVocabulary { + tokenizer: Tokenizer, +} + +impl TokenizerVocabulary { + /// Create a new `TokenizerVocabulary`. + pub fn new(tokenizer: Tokenizer) -> Self { + Self { tokenizer } + } +} + +impl VocabularyTrait for TokenizerVocabulary { + fn push_token(&mut self, _id: TokenId, _content: Token, _score: TokenScore) { + panic!("Cannot push token to tokenizer vocabulary."); + } + + fn token_to_id(&self, token: &[u8]) -> Option { + self.tokenizer + .token_to_id(std::str::from_utf8(token).unwrap()) + } + + /// Converts a token index to the token it represents in this vocabulary. + fn token(&self, idx: usize) -> Vec { + self.tokenizer + .decode(vec![idx as u32], true) + .unwrap() + .as_bytes() + .to_vec() + } + + /// Returns the number of tokens in the vocabulary. + fn len(&self) -> usize { + self.tokenizer.get_vocab_size(false) as usize + } - // Pieces are in reverse order so correct that - res.reverse(); + /// Returns whether the vocabulary is empty. + fn is_empty(&self) -> bool { + self.tokenizer.get_vocab_size(false) == 0 + } - Ok(res) + // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece + /// Tokenize a `text` with this vocabulary. + /// + /// `bos` controls whether a beginning-of-string token should be inserted. + fn tokenize<'a>( + &'a self, + text: &str, + bos: bool, + ) -> Result, TokenId)>, TokenizationError> { + let res = self.tokenizer.encode(text, bos); + if res.is_err() { + return Err(TokenizationError::TokenizationFailed); + } else { + Ok(self + .tokenizer + .encode(text, bos) + .unwrap() + .get_ids() + .iter() + .map(|id| (self.token(*id as usize), *id)) + .collect::, TokenId)>>()) } } } @@ -214,7 +334,7 @@ impl Prompt<'_> { if let Some(t) = tokens .iter() .copied() - .find(|t| vocab.id_to_token.get(*t as usize).is_none()) + .find(|t| vocab.token(*t as usize).is_empty()) { return Err(TokenizationError::InvalidTokenId(t)); } diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index c0073e33..e36b8f1c 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -383,15 +383,11 @@ impl KnownModel for Bloom { } fn bot_token_id(&self) -> Option { - self.vocabulary.token_to_id.get("".as_bytes()).copied() + self.vocabulary.token_to_id("".as_bytes()) } fn eot_token_id(&self) -> TokenId { - self.vocabulary - .token_to_id - .get("".as_bytes()) - .copied() - .unwrap() + self.vocabulary.token_to_id("".as_bytes()).unwrap() } fn inference_parameters(&self) -> &InferenceParameters { diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index c57bad35..3c853cb0 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -344,9 +344,7 @@ impl KnownModel for Gpt2 { fn eot_token_id(&self) -> TokenId { self.vocabulary - .token_to_id - .get("<|endoftext|>".as_bytes()) - .copied() + .token_to_id("<|endoftext|>".as_bytes()) .unwrap() } diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index ee1b0d95..0782460d 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -314,9 +314,7 @@ impl KnownModel for GptJ { fn eot_token_id(&self) -> TokenId { self.vocabulary - .token_to_id - .get("<|endoftext|>".as_bytes()) - .copied() + .token_to_id("<|endoftext|>".as_bytes()) .unwrap() } diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 2e2e9e9d..bcb81df8 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -395,9 +395,7 @@ impl KnownModel for GptNeoX { fn eot_token_id(&self) -> TokenId { self.vocabulary - .token_to_id - .get("<|endoftext|>".as_bytes()) - .copied() + .token_to_id("<|endoftext|>".as_bytes()) .unwrap() } diff --git a/crates/models/llama/src/convert.rs b/crates/models/llama/src/convert.rs index 8a116830..e4ee76f2 100644 --- a/crates/models/llama/src/convert.rs +++ b/crates/models/llama/src/convert.rs @@ -3,7 +3,7 @@ //! This is *incomplete* and does not convert the weights. It only converts the //! vocabulary and hyperparameters. It is included as a preliminary step to //! full conversion. -use llm_base::FileType; +use llm_base::{FileType, GgmlVocabulary}; /// /// For reference, see [the PR](https://github.com/rustformers/llm/pull/83). use rust_tokenizers::preprocessing::vocab::sentencepiece_proto::sentencepiece_model::ModelProto; @@ -17,7 +17,7 @@ use std::{ vec, }; -use crate::{Hyperparameters, Vocabulary}; +use crate::Hyperparameters; /// Converts a `pth` file to a `ggml` file. pub fn convert_pth_to_ggml(model_directory: &Path, file_type: FileType) { @@ -39,7 +39,7 @@ pub fn convert_pth_to_ggml(model_directory: &Path, file_type: FileType) { } } -fn load_vocabulary(path: &Path) -> Vocabulary { +fn load_vocabulary(path: &Path) -> GgmlVocabulary { let mut f = File::open(path).unwrap(); let mut contents = Vec::new(); f.read_to_end(&mut contents).unwrap(); @@ -58,16 +58,20 @@ fn load_vocabulary(path: &Path) -> Vocabulary { token_to_id.insert(word.to_owned(), idx as u32); id_to_token_score.push(piece.get_score()); } - Vocabulary { + + GgmlVocabulary { id_to_token, id_to_token_score, token_to_id, max_token_length, - tokenizer: None, } } -fn load_hyperparameters(path: &Path, file_type: FileType, vocab: &Vocabulary) -> Hyperparameters { +fn load_hyperparameters( + path: &Path, + file_type: FileType, + vocab: &GgmlVocabulary, +) -> Hyperparameters { #[derive(Deserialize)] struct HyperparametersJson { dim: usize, @@ -117,7 +121,7 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String Ok(()) } -fn write_tokens(file: &mut File, vocab: &Vocabulary) -> Result<(), String> { +fn write_tokens(file: &mut File, vocab: &GgmlVocabulary) -> Result<(), String> { let mut values: Vec = vec![]; for (i, token) in vocab.id_to_token.iter().enumerate() { let text = if let Ok(token) = std::str::from_utf8(token) { diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 960d0903..c1c08038 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -285,17 +285,12 @@ impl KnownModel for Mpt { } fn bot_token_id(&self) -> Option { - self.vocabulary - .token_to_id - .get("<|padding|>".as_bytes()) - .copied() + self.vocabulary.token_to_id("<|padding|>".as_bytes()) } fn eot_token_id(&self) -> TokenId { self.vocabulary - .token_to_id - .get("<|endoftext|>".as_bytes()) - .copied() + .token_to_id("<|endoftext|>".as_bytes()) .unwrap() } From 3b96aa05910eee9fa0a72cdf6122ff4568fc41de Mon Sep 17 00:00:00 2001 From: RedBoxing Date: Tue, 23 May 2023 21:23:49 +0200 Subject: [PATCH 03/20] VocabularySource argument to load tokenizer --- binaries/llm-cli/src/cli_args.rs | 17 ++++++-- crates/llm-base/src/lib.rs | 2 +- crates/llm-base/src/loader.rs | 63 ++++++++++++++++++------------ crates/llm-base/src/model/mod.rs | 12 ++++-- crates/llm-base/src/vocabulary.rs | 14 ++++++- crates/llm/examples/inference.rs | 4 +- crates/llm/examples/vicuna-chat.rs | 4 +- crates/llm/src/lib.rs | 43 +++++++++++++------- 8 files changed, 107 insertions(+), 52 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index f94e542f..5ea29806 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -4,7 +4,7 @@ use clap::{Parser, Subcommand, ValueEnum}; use color_eyre::eyre::{Result, WrapErr}; use llm::{ ggml_format, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, - LoadProgress, Model, ModelKVMemoryType, ModelParameters, TokenBias, + LoadProgress, Model, ModelKVMemoryType, ModelParameters, TokenBias, VocabularySource, }; use rand::SeedableRng; @@ -332,10 +332,14 @@ pub struct ModelLoad { #[arg(long, short = 'm')] pub model_path: PathBuf, - /// Where to save the model from + /// Where to load the vocabulary from #[arg(long, short = 'v')] pub vocabulary_path: Option, + /// Where to load the vocabulary from + #[arg(long, short = 'r')] + pub vocabulary_repo: Option, + /// Sets the size of the context (in tokens). Allows feeding longer prompts. /// Note that this affects memory. /// @@ -378,9 +382,16 @@ impl ModelLoad { let now = std::time::Instant::now(); let mut prev_load_time = now; + let mut vocabulary_source = VocabularySource::ModelEmbedded; + if let Some(path) = &self.vocabulary_path { + vocabulary_source = VocabularySource::TokenizerFile(path.clone()); + } else if let Some(repo) = &self.vocabulary_repo { + vocabulary_source = VocabularySource::TokenizerHfPretrained(repo.clone()); + } + let model = llm::load::( &self.model_path, - self.vocabulary_path.as_deref(), + vocabulary_source, params, overrides, |progress| match progress { diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index 98f64f6f..17510b87 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -38,7 +38,7 @@ pub use quantize::{quantize, QuantizeError, QuantizeProgress}; pub use util::TokenUtf8Buffer; pub use vocabulary::{ GgmlVocabulary, InvalidTokenBias, Prompt, TokenBias, TokenId, TokenizationError, - TokenizerVocabulary, Vocabulary, + TokenizerVocabulary, Vocabulary, VocabularySource, }; #[derive(Clone, Debug, PartialEq)] diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 7c1d1b87..6def2f96 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -9,7 +9,7 @@ use std::{ use crate::{ util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelParameters, TokenId, - Vocabulary, + Vocabulary, VocabularySource, }; pub use ggml::ContainerType; use ggml::{ @@ -288,7 +288,7 @@ pub enum LoadError { #[error("could not load vocab file {path:?}")] VocabularyLoadError { /// The path that failed. - path: PathBuf, + path: String, /// The error that occurred. error: Box, @@ -356,7 +356,7 @@ pub trait TensorLoader { /// store any information about the architecture. pub fn load( path: &Path, - vocabulary_path: Option<&Path>, + vocabulary_source: VocabularySource, params: ModelParameters, overrides: Option, load_progress_callback: impl FnMut(LoadProgress), @@ -378,31 +378,44 @@ pub fn load( })?; let mut reader = BufReader::new(&file); - let vocabulary = if let Some(path) = vocabulary_path { - let tok = if !path.exists() && path.to_string_lossy().matches("/").count() == 1 { - Tokenizer::from_pretrained(path.to_str().unwrap(), None) - } else if path.exists() && path.is_file() { - Tokenizer::from_file(path) - } else { - return Err(LoadError::VocabularyLoadError { - path: path.to_owned(), - error: Box::new(std::io::Error::new( - std::io::ErrorKind::NotFound, - "Vocabulary file not found", - )), - }); - }; + let vocabulary = match vocabulary_source { + VocabularySource::TokenizerHfPretrained(identifier) => { + let tokenizer = Tokenizer::from_pretrained(&identifier, None); - if tok.is_err() { - return Err(LoadError::VocabularyLoadError { - path: path.to_owned(), - error: tok.unwrap_err(), - }); + if tokenizer.is_err() { + return Err(LoadError::VocabularyLoadError { + path: identifier, + error: tokenizer.unwrap_err(), + }); + } + + Vocabulary::new_tokenizer(tokenizer.unwrap()) } - Vocabulary::new_tokenizer(tok.unwrap()) - } else { - Vocabulary::new_ggml() + VocabularySource::TokenizerFile(path) => { + if path.exists() && path.is_file() { + let tokenizer = Tokenizer::from_file(&path); + + if tokenizer.is_err() { + return Err(LoadError::VocabularyLoadError { + path: path.to_string_lossy().to_string(), + error: tokenizer.unwrap_err(), + }); + } + + Vocabulary::new_tokenizer(tokenizer.unwrap()) + } else { + return Err(LoadError::VocabularyLoadError { + path: path.to_string_lossy().to_string(), + error: Box::new(std::io::Error::new( + std::io::ErrorKind::NotFound, + "Vocabulary file not found", + )), + }); + } + } + + VocabularySource::ModelEmbedded => Vocabulary::new_ggml(), }; let mut loader = Loader::new(vocabulary, load_progress_callback); diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index b2423aee..19a669e3 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -13,7 +13,7 @@ use thiserror::Error; use crate::{ loader::TensorLoader, vocabulary::TokenId, FileType, InferenceParameters, InferenceSession, - InferenceSessionConfig, LoadError, LoadProgress, Vocabulary, + InferenceSessionConfig, LoadError, LoadProgress, Vocabulary, VocabularySource, }; /// Common functions for model evaluation @@ -114,7 +114,7 @@ pub trait KnownModel: Send + Sync { /// is a helper function on top of [llm_base::load](crate::load). fn load( path: &Path, - vocab_path: Option<&Path>, + vocabulary_source: VocabularySource, params: ModelParameters, overrides: Option, load_progress_callback: impl FnMut(LoadProgress), @@ -122,7 +122,13 @@ pub trait KnownModel: Send + Sync { where Self: Sized, { - crate::load(path, vocab_path, params, overrides, load_progress_callback) + crate::load( + path, + vocabulary_source, + params, + overrides, + load_progress_callback, + ) } /// Creates a new model from the provided [ModelParameters] hyperparameters. diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index 4fcf3145..42cbb584 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, error::Error, fmt::Display, str::FromStr}; +use std::{collections::HashMap, error::Error, fmt::Display, path::PathBuf, str::FromStr}; use thiserror::Error; use tokenizers::Tokenizer; @@ -19,6 +19,18 @@ pub enum TokenizationError { InvalidTokenId(TokenId), } +/// The source of a vocabulary. +pub enum VocabularySource { + /// The vocabulary is built-in to the model if available. + ModelEmbedded, + + /// The vocabulary is loaded from a file. + TokenizerFile(PathBuf), + + /// The vocabulary is loaded from a huggingface repository. + TokenizerHfPretrained(String), +} + pub trait VocabularyTrait { fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore); fn token_to_id(&self, token: &[u8]) -> Option; diff --git a/crates/llm/examples/inference.rs b/crates/llm/examples/inference.rs index 83e5f198..298c0174 100644 --- a/crates/llm/examples/inference.rs +++ b/crates/llm/examples/inference.rs @@ -1,6 +1,6 @@ use llm::{ load_progress_callback_stdout as load_callback, InferenceFeedback, InferenceRequest, - InferenceResponse, ModelArchitecture, + InferenceResponse, ModelArchitecture, VocabularySource, }; use std::{convert::Infallible, io::Write, path::Path}; @@ -24,7 +24,7 @@ fn main() { let model = llm::load_dynamic( model_architecture, model_path, - None, + VocabularySource::ModelEmbedded, Default::default(), overrides, load_callback, diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index 5065158f..a83baa99 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -1,6 +1,6 @@ use llm::{ InferenceFeedback, InferenceRequest, InferenceResponse, InferenceStats, LoadProgress, - ModelArchitecture, + ModelArchitecture, VocabularySource, }; use rustyline::error::ReadlineError; use spinoff::{spinners::Dots2, Spinner}; @@ -24,7 +24,7 @@ fn main() { let model = llm::load_dynamic( model_architecture, model_path, - None, + VocabularySource::ModelEmbedded, Default::default(), overrides, load_progress_callback(sp, now, prev_load_time), diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 95250f81..183e4f24 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -81,6 +81,7 @@ pub use llm_base::{ LoadError, LoadProgress, Loader, Model, ModelDynamicOverrideValue, ModelDynamicOverrides, ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Vocabulary, + VocabularySource, }; use serde::Serialize; @@ -222,7 +223,7 @@ impl Display for ModelArchitecture { pub fn load_dynamic( architecture: ModelArchitecture, path: &Path, - vocab_path: Option<&Path>, + vocabulary_source: VocabularySource, params: ModelParameters, overrides: Option, load_progress_callback: impl FnMut(LoadProgress), @@ -231,14 +232,14 @@ pub fn load_dynamic( fn load_model( path: &Path, - vocab_path: Option<&Path>, + vocabulary_source: VocabularySource, params: ModelParameters, overrides: Option, load_progress_callback: impl FnMut(LoadProgress), ) -> Result, LoadError> { Ok(Box::new(load::( path, - vocab_path, + vocabulary_source, params, overrides.map(|o| o.into()), load_progress_callback, @@ -249,23 +250,31 @@ pub fn load_dynamic( #[cfg(feature = "bloom")] Bloom => load_model::( path, - vocab_path, + vocabulary_source, params, overrides, load_progress_callback, )?, #[cfg(feature = "gpt2")] - Gpt2 => { - load_model::(path, vocab_path, params, overrides, load_progress_callback)? - } + Gpt2 => load_model::( + path, + vocabulary_source, + params, + overrides, + load_progress_callback, + )?, #[cfg(feature = "gptj")] - GptJ => { - load_model::(path, vocab_path, params, overrides, load_progress_callback)? - } + GptJ => load_model::( + path, + vocabulary_source, + params, + overrides, + load_progress_callback, + )?, #[cfg(feature = "gptneox")] GptNeoX => load_model::( path, - vocab_path, + vocabulary_source, params, overrides, load_progress_callback, @@ -273,15 +282,19 @@ pub fn load_dynamic( #[cfg(feature = "llama")] Llama => load_model::( path, - vocab_path, + vocabulary_source, params, overrides, load_progress_callback, )?, #[cfg(feature = "mpt")] - Mpt => { - load_model::(path, vocab_path, params, overrides, load_progress_callback)? - } + Mpt => load_model::( + path, + vocabulary_source, + params, + overrides, + load_progress_callback, + )?, }; Ok(model) From 81dc8aafde7f9b1e5675ccea11e05650cb7c54ff Mon Sep 17 00:00:00 2001 From: RedBoxing Date: Tue, 23 May 2023 21:56:10 +0200 Subject: [PATCH 04/20] Error handling in TokenizerVocabulary --- crates/llm-base/src/vocabulary.rs | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index 42cbb584..1e67557f 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -268,11 +268,13 @@ impl VocabularyTrait for TokenizerVocabulary { /// Converts a token index to the token it represents in this vocabulary. fn token(&self, idx: usize) -> Vec { - self.tokenizer - .decode(vec![idx as u32], true) - .unwrap() - .as_bytes() - .to_vec() + let res = self.tokenizer.decode(vec![idx as u32], true); + + if res.is_err() { + panic!("Cannot decode token from tokenizer vocabulary."); + } else { + res.unwrap().as_bytes().to_vec() + } } /// Returns the number of tokens in the vocabulary. @@ -298,9 +300,13 @@ impl VocabularyTrait for TokenizerVocabulary { if res.is_err() { return Err(TokenizationError::TokenizationFailed); } else { - Ok(self - .tokenizer - .encode(text, bos) + let res = self.tokenizer.encode(text, bos); + + if res.is_err() { + return Err(TokenizationError::TokenizationFailed); + } + + Ok(res .unwrap() .get_ids() .iter() From 12e72c346c0e0cf467c4ebd5ba010dba70ab1919 Mon Sep 17 00:00:00 2001 From: RedBoxing Date: Wed, 24 May 2023 17:22:19 +0200 Subject: [PATCH 05/20] mainly documentation edits --- binaries/llm-cli/src/cli_args.rs | 17 ++++++++++++----- binaries/llm-cli/src/main.rs | 1 + crates/llm-base/src/loader.rs | 8 ++++---- crates/llm-base/src/vocabulary.rs | 24 ++++++++++++------------ crates/llm/examples/inference.rs | 2 +- crates/llm/examples/vicuna-chat.rs | 2 +- crates/models/bloom/src/lib.rs | 4 ++-- crates/models/gpt2/src/lib.rs | 4 +--- crates/models/gptj/src/lib.rs | 4 +--- crates/models/gptneox/src/lib.rs | 4 +--- crates/models/mpt/src/lib.rs | 6 ++---- 11 files changed, 38 insertions(+), 38 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 5ea29806..09a73829 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -332,11 +332,11 @@ pub struct ModelLoad { #[arg(long, short = 'm')] pub model_path: PathBuf, - /// Where to load the vocabulary from + /// Local path to vocabulary #[arg(long, short = 'v')] pub vocabulary_path: Option, - /// Where to load the vocabulary from + /// Remote HuggingFace repository containing vocabulary #[arg(long, short = 'r')] pub vocabulary_repo: Option, @@ -382,11 +382,18 @@ impl ModelLoad { let now = std::time::Instant::now(); let mut prev_load_time = now; - let mut vocabulary_source = VocabularySource::ModelEmbedded; - if let Some(path) = &self.vocabulary_path { + let mut vocabulary_source = VocabularySource::ModelFile; + + if self.vocabulary_path.is_some() && self.vocabulary_repo.is_some() { + if let Some(sp) = sp.take() { + sp.fail("Invalid arguments"); + }; + + panic!("Cannot specify both --vocabulary-path and --vocabulary-repo"); + } else if let Some(path) = &self.vocabulary_path { vocabulary_source = VocabularySource::TokenizerFile(path.clone()); } else if let Some(repo) = &self.vocabulary_repo { - vocabulary_source = VocabularySource::TokenizerHfPretrained(repo.clone()); + vocabulary_source = VocabularySource::HuggingFaceRemote(repo.clone()); } let model = llm::load::( diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 8f76fc90..d9705cd1 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -56,6 +56,7 @@ fn infer( let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref()); let inference_session_config = args.generate.inference_session_config(); let model = args.model_load.load::(overrides)?; + let (mut session, session_loaded) = snapshot::read_or_create_session( model.as_ref(), args.persist_session.as_deref(), diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 6def2f96..d853b537 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -287,7 +287,7 @@ pub enum LoadError { /// The vocab file for the tokenizer could not be loaded. #[error("could not load vocab file {path:?}")] VocabularyLoadError { - /// The path that failed. + /// The invalid vocabulary path path: String, /// The error that occurred. @@ -379,7 +379,7 @@ pub fn load( let mut reader = BufReader::new(&file); let vocabulary = match vocabulary_source { - VocabularySource::TokenizerHfPretrained(identifier) => { + VocabularySource::HuggingFaceRemote(identifier) => { let tokenizer = Tokenizer::from_pretrained(&identifier, None); if tokenizer.is_err() { @@ -415,7 +415,7 @@ pub fn load( } } - VocabularySource::ModelEmbedded => Vocabulary::new_ggml(), + VocabularySource::ModelFile => Vocabulary::new_ggml(), }; let mut loader = Loader::new(vocabulary, load_progress_callback); @@ -559,7 +559,7 @@ impl Loader { container_type: ContainerType::Ggml, hyperparameters: Hp::default(), - vocabulary: vocabulary, + vocabulary, tensors: HashMap::default(), } } diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index 1e67557f..cebebe93 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -21,19 +21,19 @@ pub enum TokenizationError { /// The source of a vocabulary. pub enum VocabularySource { - /// The vocabulary is built-in to the model if available. - ModelEmbedded, + /// Fetch vocabulary from model file + ModelFile, - /// The vocabulary is loaded from a file. + /// Fetch vocabulary from a vocabulary file TokenizerFile(PathBuf), - /// The vocabulary is loaded from a huggingface repository. - TokenizerHfPretrained(String), + /// Fetch vocabulary from remote HuggingFace repository + HuggingFaceRemote(String), } pub trait VocabularyTrait { fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore); - fn token_to_id(&self, token: &[u8]) -> Option; + fn id(&self, token: &[u8]) -> Option; fn token(&self, idx: usize) -> Vec; fn len(&self) -> usize; fn is_empty(&self) -> bool; @@ -79,10 +79,10 @@ impl Vocabulary { } /// Converts a token to the token ID it represents in this vocabulary. - pub fn token_to_id(&self, token: &[u8]) -> Option { + pub fn id(&self, token: &[u8]) -> Option { match self { - Vocabulary::Ggml(v) => v.token_to_id(token), - Vocabulary::Tokenizer(v) => v.token_to_id(token), + Vocabulary::Ggml(v) => v.id(token), + Vocabulary::Tokenizer(v) => v.id(token), } } @@ -166,7 +166,7 @@ impl VocabularyTrait for GgmlVocabulary { self.token_to_id.insert(content, id); } - fn token_to_id(&self, token: &[u8]) -> Option { + fn id(&self, token: &[u8]) -> Option { self.token_to_id.get(token).copied() } @@ -243,7 +243,7 @@ impl VocabularyTrait for GgmlVocabulary { } } -/// A vocabulary provided by the user. +/// A vocabulary that does not originate from the model file. #[derive(Debug, Clone)] pub struct TokenizerVocabulary { tokenizer: Tokenizer, @@ -261,7 +261,7 @@ impl VocabularyTrait for TokenizerVocabulary { panic!("Cannot push token to tokenizer vocabulary."); } - fn token_to_id(&self, token: &[u8]) -> Option { + fn id(&self, token: &[u8]) -> Option { self.tokenizer .token_to_id(std::str::from_utf8(token).unwrap()) } diff --git a/crates/llm/examples/inference.rs b/crates/llm/examples/inference.rs index 298c0174..9b349938 100644 --- a/crates/llm/examples/inference.rs +++ b/crates/llm/examples/inference.rs @@ -24,7 +24,7 @@ fn main() { let model = llm::load_dynamic( model_architecture, model_path, - VocabularySource::ModelEmbedded, + VocabularySource::ModelFile, Default::default(), overrides, load_callback, diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index a83baa99..53a1f914 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -24,7 +24,7 @@ fn main() { let model = llm::load_dynamic( model_architecture, model_path, - VocabularySource::ModelEmbedded, + VocabularySource::ModelFile, Default::default(), overrides, load_progress_callback(sp, now, prev_load_time), diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index e36b8f1c..749d43c7 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -383,11 +383,11 @@ impl KnownModel for Bloom { } fn bot_token_id(&self) -> Option { - self.vocabulary.token_to_id("".as_bytes()) + self.vocabulary.id("".as_bytes()) } fn eot_token_id(&self) -> TokenId { - self.vocabulary.token_to_id("".as_bytes()).unwrap() + self.vocabulary.id("".as_bytes()).unwrap() } fn inference_parameters(&self) -> &InferenceParameters { diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index 3c853cb0..505e391a 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -343,9 +343,7 @@ impl KnownModel for Gpt2 { } fn eot_token_id(&self) -> TokenId { - self.vocabulary - .token_to_id("<|endoftext|>".as_bytes()) - .unwrap() + self.vocabulary.id("<|endoftext|>".as_bytes()).unwrap() } fn inference_parameters(&self) -> &InferenceParameters { diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index 0782460d..77884be4 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -313,9 +313,7 @@ impl KnownModel for GptJ { } fn eot_token_id(&self) -> TokenId { - self.vocabulary - .token_to_id("<|endoftext|>".as_bytes()) - .unwrap() + self.vocabulary.id("<|endoftext|>".as_bytes()).unwrap() } fn inference_parameters(&self) -> &InferenceParameters { diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index bcb81df8..199bd0fd 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -394,9 +394,7 @@ impl KnownModel for GptNeoX { } fn eot_token_id(&self) -> TokenId { - self.vocabulary - .token_to_id("<|endoftext|>".as_bytes()) - .unwrap() + self.vocabulary.id("<|endoftext|>".as_bytes()).unwrap() } fn inference_parameters(&self) -> &InferenceParameters { diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index c1c08038..73bc7216 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -285,13 +285,11 @@ impl KnownModel for Mpt { } fn bot_token_id(&self) -> Option { - self.vocabulary.token_to_id("<|padding|>".as_bytes()) + self.vocabulary.id("<|padding|>".as_bytes()) } fn eot_token_id(&self) -> TokenId { - self.vocabulary - .token_to_id("<|endoftext|>".as_bytes()) - .unwrap() + self.vocabulary.id("<|endoftext|>".as_bytes()).unwrap() } fn inference_parameters(&self) -> &InferenceParameters { From d4acc9b7b6635510544cf0815502f1ed63a3439b Mon Sep 17 00:00:00 2001 From: RedBoxing Date: Wed, 24 May 2023 17:36:56 +0200 Subject: [PATCH 06/20] rename vocabulary enum --- binaries/llm-cli/src/main.rs | 2 +- crates/llm-base/src/lib.rs | 4 +-- crates/llm-base/src/loader.rs | 10 +++---- crates/llm-base/src/quantize.rs | 6 ++-- crates/llm-base/src/vocabulary.rs | 46 +++++++++++++++--------------- crates/models/llama/src/convert.rs | 10 +++---- 6 files changed, 39 insertions(+), 39 deletions(-) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index d9705cd1..c05000c0 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -151,7 +151,7 @@ fn info(args: &cli_args::Info) -> Result<()> { let file = File::open(&args.model_path)?; let mut reader = BufReader::new(&file); let mut loader: llm::Loader = - llm::Loader::new(Vocabulary::new_ggml(), |_| { + llm::Loader::new(Vocabulary::new_model(), |_| { // We purposely do not print progress here, as we are only interested in the metadata }); diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index 17510b87..9d623295 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -37,8 +37,8 @@ pub use model::{ pub use quantize::{quantize, QuantizeError, QuantizeProgress}; pub use util::TokenUtf8Buffer; pub use vocabulary::{ - GgmlVocabulary, InvalidTokenBias, Prompt, TokenBias, TokenId, TokenizationError, - TokenizerVocabulary, Vocabulary, VocabularySource, + ExternalVocabulary, InvalidTokenBias, ModelVocabulary, Prompt, TokenBias, TokenId, + TokenizationError, Vocabulary, VocabularySource, }; #[derive(Clone, Debug, PartialEq)] diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index d853b537..fe4be51c 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -389,7 +389,7 @@ pub fn load( }); } - Vocabulary::new_tokenizer(tokenizer.unwrap()) + Vocabulary::new_external(tokenizer.unwrap()) } VocabularySource::TokenizerFile(path) => { @@ -403,7 +403,7 @@ pub fn load( }); } - Vocabulary::new_tokenizer(tokenizer.unwrap()) + Vocabulary::new_external(tokenizer.unwrap()) } else { return Err(LoadError::VocabularyLoadError { path: path.to_string_lossy().to_string(), @@ -415,7 +415,7 @@ pub fn load( } } - VocabularySource::ModelFile => Vocabulary::new_ggml(), + VocabularySource::ModelFile => Vocabulary::new_model(), }; let mut loader = Loader::new(vocabulary, load_progress_callback); @@ -477,7 +477,7 @@ pub fn load( // TODO: Consider updating the progress callback to report the progress of the LoRA file. // Most LoRAs are small enough that this is not necessary, but it would be nice to have. let mut lora_loader: Loader = - Loader::new(Vocabulary::new_ggml(), |_| {}); + Loader::new(Vocabulary::new_model(), |_| {}); ggml::format::load(&mut lora_reader, &mut lora_loader) .map_err(|err| LoadError::from_format_error(err, lora_path.to_owned()))?; @@ -573,7 +573,7 @@ impl ggml::format::LoadHandler, score: f32) -> Result<(), LoadError> { - if let Vocabulary::Ggml(_) = &self.vocabulary { + if let Vocabulary::Model(_) = &self.vocabulary { let id = match TokenId::try_from(i) { Ok(id) => id, Err(err) => return Err(LoadError::InvalidIntegerConversion(err)), diff --git a/crates/llm-base/src/quantize.rs b/crates/llm-base/src/quantize.rs index b11873b3..d9ea914f 100644 --- a/crates/llm-base/src/quantize.rs +++ b/crates/llm-base/src/quantize.rs @@ -152,7 +152,7 @@ pub fn quantize( // Load the model let progress_callback = Arc::new(progress_callback); - let mut loader = Loader::::new(Vocabulary::new_ggml(), { + let mut loader = Loader::::new(Vocabulary::new_model(), { let progress_callback = progress_callback.clone(); move |p| { if let LoadProgress::HyperparametersLoaded = p { @@ -179,13 +179,13 @@ pub fn quantize( } let vocabulary = match vocabulary { - Vocabulary::Ggml(v) => v + Vocabulary::Model(v) => v .id_to_token .iter() .cloned() .zip(v.id_to_token_score) .collect::>(), - Vocabulary::Tokenizer(_) => vec![], + Vocabulary::External(_) => vec![], }; let mut saver = QuantizeSaver::new(desired_type, &hyperparameters, &tensors, reader, |p| { diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index cebebe93..117e6e1d 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -47,21 +47,21 @@ pub trait VocabularyTrait { /// Vocabulary enum pub enum Vocabulary { /// The vocabulary built-in to the model. - Ggml(GgmlVocabulary), + Model(ModelVocabulary), /// A custom vocabulary provided by the user. - Tokenizer(TokenizerVocabulary), + External(ExternalVocabulary), } impl Vocabulary { /// Create a new vocabulary with the default GGML vocabulary. - pub fn new_ggml() -> Self { - Vocabulary::Ggml(GgmlVocabulary::default()) + pub fn new_model() -> Self { + Vocabulary::Model(ModelVocabulary::default()) } /// Create a new vocabulary with a custom tokenizer. - pub fn new_tokenizer(tokenizer: Tokenizer) -> Self { - Vocabulary::Tokenizer(TokenizerVocabulary::new(tokenizer)) + pub fn new_external(tokenizer: Tokenizer) -> Self { + Vocabulary::External(ExternalVocabulary::new(tokenizer)) } /// Add a token to the vocabulary. @@ -73,40 +73,40 @@ impl Vocabulary { /// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`. pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { match self { - Vocabulary::Ggml(v) => v.push_token(id, content, score), - Vocabulary::Tokenizer(v) => v.push_token(id, content, score), + Vocabulary::Model(v) => v.push_token(id, content, score), + Vocabulary::External(v) => v.push_token(id, content, score), } } /// Converts a token to the token ID it represents in this vocabulary. pub fn id(&self, token: &[u8]) -> Option { match self { - Vocabulary::Ggml(v) => v.id(token), - Vocabulary::Tokenizer(v) => v.id(token), + Vocabulary::Model(v) => v.id(token), + Vocabulary::External(v) => v.id(token), } } /// Converts a token index to the token it represents in this vocabulary. pub fn token(&self, idx: usize) -> Vec { match self { - Vocabulary::Ggml(v) => v.token(idx), - Vocabulary::Tokenizer(v) => v.token(idx), + Vocabulary::Model(v) => v.token(idx), + Vocabulary::External(v) => v.token(idx), } } /// Returns the number of tokens in the vocabulary. pub fn len(&self) -> usize { match self { - Vocabulary::Ggml(v) => v.len(), - Vocabulary::Tokenizer(v) => v.len(), + Vocabulary::Model(v) => v.len(), + Vocabulary::External(v) => v.len(), } } /// Returns whether the vocabulary is empty. pub fn is_empty(&self) -> bool { match self { - Vocabulary::Ggml(v) => v.is_empty(), - Vocabulary::Tokenizer(v) => v.is_empty(), + Vocabulary::Model(v) => v.is_empty(), + Vocabulary::External(v) => v.is_empty(), } } @@ -119,15 +119,15 @@ impl Vocabulary { bos: bool, ) -> Result, TokenId)>, TokenizationError> { match self { - Vocabulary::Ggml(v) => v.tokenize(text, bos), - Vocabulary::Tokenizer(v) => v.tokenize(text, bos), + Vocabulary::Model(v) => v.tokenize(text, bos), + Vocabulary::External(v) => v.tokenize(text, bos), } } } /// The built-in GGML vocabulary. #[derive(Debug, Clone, Default)] -pub struct GgmlVocabulary { +pub struct ModelVocabulary { // TODO: make these private /// Maps every integer (index) token ID to its corresponding token. pub id_to_token: Vec, @@ -143,7 +143,7 @@ pub struct GgmlVocabulary { pub max_token_length: usize, } -impl VocabularyTrait for GgmlVocabulary { +impl VocabularyTrait for ModelVocabulary { /// Add a token to the vocabulary. /// /// The token added must have `id` directly after the last token in the vocabulary. @@ -245,18 +245,18 @@ impl VocabularyTrait for GgmlVocabulary { /// A vocabulary that does not originate from the model file. #[derive(Debug, Clone)] -pub struct TokenizerVocabulary { +pub struct ExternalVocabulary { tokenizer: Tokenizer, } -impl TokenizerVocabulary { +impl ExternalVocabulary { /// Create a new `TokenizerVocabulary`. pub fn new(tokenizer: Tokenizer) -> Self { Self { tokenizer } } } -impl VocabularyTrait for TokenizerVocabulary { +impl VocabularyTrait for ExternalVocabulary { fn push_token(&mut self, _id: TokenId, _content: Token, _score: TokenScore) { panic!("Cannot push token to tokenizer vocabulary."); } diff --git a/crates/models/llama/src/convert.rs b/crates/models/llama/src/convert.rs index e4ee76f2..f556afb9 100644 --- a/crates/models/llama/src/convert.rs +++ b/crates/models/llama/src/convert.rs @@ -3,7 +3,7 @@ //! This is *incomplete* and does not convert the weights. It only converts the //! vocabulary and hyperparameters. It is included as a preliminary step to //! full conversion. -use llm_base::{FileType, GgmlVocabulary}; +use llm_base::{FileType, ModelVocabulary}; /// /// For reference, see [the PR](https://github.com/rustformers/llm/pull/83). use rust_tokenizers::preprocessing::vocab::sentencepiece_proto::sentencepiece_model::ModelProto; @@ -39,7 +39,7 @@ pub fn convert_pth_to_ggml(model_directory: &Path, file_type: FileType) { } } -fn load_vocabulary(path: &Path) -> GgmlVocabulary { +fn load_vocabulary(path: &Path) -> ModelVocabulary { let mut f = File::open(path).unwrap(); let mut contents = Vec::new(); f.read_to_end(&mut contents).unwrap(); @@ -59,7 +59,7 @@ fn load_vocabulary(path: &Path) -> GgmlVocabulary { id_to_token_score.push(piece.get_score()); } - GgmlVocabulary { + ModelVocabulary { id_to_token, id_to_token_score, token_to_id, @@ -70,7 +70,7 @@ fn load_vocabulary(path: &Path) -> GgmlVocabulary { fn load_hyperparameters( path: &Path, file_type: FileType, - vocab: &GgmlVocabulary, + vocab: &ModelVocabulary, ) -> Hyperparameters { #[derive(Deserialize)] struct HyperparametersJson { @@ -121,7 +121,7 @@ fn write_header(fout: &mut File, hparams: &Hyperparameters) -> Result<(), String Ok(()) } -fn write_tokens(file: &mut File, vocab: &GgmlVocabulary) -> Result<(), String> { +fn write_tokens(file: &mut File, vocab: &ModelVocabulary) -> Result<(), String> { let mut values: Vec = vec![]; for (i, token) in vocab.id_to_token.iter().enumerate() { let text = if let Ok(token) = std::str::from_utf8(token) { From 7d18d5c912ed13d94bee8ce88909d21b082671a6 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 25 May 2023 00:30:57 +0200 Subject: [PATCH 07/20] chore: clippy warnings --- crates/llm-base/src/inference_session.rs | 4 +- crates/llm-base/src/loader.rs | 30 ++++----- crates/llm-base/src/vocabulary.rs | 81 ++++++++++++------------ crates/llm/examples/embeddings.rs | 1 + 4 files changed, 56 insertions(+), 60 deletions(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 64536b18..1022f322 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -112,9 +112,9 @@ impl InferenceSession { } /// Infer the next token for this session. - pub fn infer_next_token<'v>( + pub fn infer_next_token( &mut self, - model: &'v dyn Model, + model: &dyn Model, params: &InferenceParameters, output_request: &mut OutputRequest, rng: &mut impl rand::Rng, diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index fe4be51c..df42e239 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -380,30 +380,22 @@ pub fn load( let vocabulary = match vocabulary_source { VocabularySource::HuggingFaceRemote(identifier) => { - let tokenizer = Tokenizer::from_pretrained(&identifier, None); - - if tokenizer.is_err() { - return Err(LoadError::VocabularyLoadError { - path: identifier, - error: tokenizer.unwrap_err(), - }); - } - - Vocabulary::new_external(tokenizer.unwrap()) + Vocabulary::new_external(Tokenizer::from_pretrained(&identifier, None).map_err( + |error| LoadError::VocabularyLoadError { + path: path.to_string_lossy().to_string(), + error, + }, + )?) } VocabularySource::TokenizerFile(path) => { if path.exists() && path.is_file() { - let tokenizer = Tokenizer::from_file(&path); - - if tokenizer.is_err() { - return Err(LoadError::VocabularyLoadError { + Vocabulary::new_external(Tokenizer::from_file(&path).map_err(|error| { + LoadError::VocabularyLoadError { path: path.to_string_lossy().to_string(), - error: tokenizer.unwrap_err(), - }); - } - - Vocabulary::new_external(tokenizer.unwrap()) + error, + } + })?) } else { return Err(LoadError::VocabularyLoadError { path: path.to_string_lossy().to_string(), diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index b90ee498..a802d4a3 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -13,7 +13,11 @@ pub(crate) type TokenScore = f32; pub enum TokenizationError { #[error("an invalid token was encountered during tokenization")] /// During tokenization, one of the produced tokens was invalid / zero. - TokenizationFailed, + TokenizationFailed { + #[source] + /// The error that occurred during tokenization. + error: Box, + }, #[error("the token ID {0} was invalid for this model")] /// One of the tokens provided by the user was invalid, and did not belong to this model's vocabulary. InvalidTokenId(TokenId), @@ -37,11 +41,8 @@ pub trait VocabularyTrait { fn token(&self, idx: usize) -> Vec; fn len(&self) -> usize; fn is_empty(&self) -> bool; - fn tokenize<'a>( - &'a self, - text: &str, - bos: bool, - ) -> Result, TokenId)>, TokenizationError>; + fn tokenize(&self, text: &str, bos: bool) + -> Result, TokenId)>, TokenizationError>; } /// Vocabulary enum @@ -113,8 +114,8 @@ impl Vocabulary { /// Tokenize a `text` with this vocabulary. /// /// `bos` controls whether a beginning-of-string token should be inserted. - pub fn tokenize<'a>( - &'a self, + pub fn tokenize( + &self, text: &str, bos: bool, ) -> Result, TokenId)>, TokenizationError> { @@ -125,6 +126,14 @@ impl Vocabulary { } } +#[derive(Debug, Error)] +/// Errors that can occur when using a model vocabulary. +pub enum ModelVocabularyError { + /// Arbitrary error that occurred during use of the model vocabulary. + #[error("Arbitrary error: {0:?}")] + Arbitrary(String), +} + /// The built-in GGML vocabulary. #[derive(Debug, Clone, Default)] pub struct ModelVocabulary { @@ -172,7 +181,7 @@ impl VocabularyTrait for ModelVocabulary { /// Converts a token index to the token it represents in this vocabulary. fn token(&self, idx: usize) -> Vec { - (&self.id_to_token[idx]).clone() + self.id_to_token[idx].clone() } /// Returns the number of tokens in the vocabulary. @@ -189,8 +198,8 @@ impl VocabularyTrait for ModelVocabulary { /// Tokenize a `text` with this vocabulary. /// /// `bos` controls whether a beginning-of-string token should be inserted. - fn tokenize<'a>( - &'a self, + fn tokenize( + &self, text: &str, bos: bool, ) -> Result, TokenId)>, TokenizationError> { @@ -224,7 +233,12 @@ impl VocabularyTrait for ModelVocabulary { while i > 0 { let token_id = prev[i]; if token_id == 0 { - return Err(TokenizationError::TokenizationFailed); + return Err(TokenizationError::TokenizationFailed { + error: Box::new(ModelVocabularyError::Arbitrary( + "the backward pass for the tokenizer encountered a non-set token" + .to_string(), + )), + }); } let token = self.id_to_token[token_id as usize].as_slice(); res.push((token.to_vec(), token_id)); @@ -268,18 +282,16 @@ impl VocabularyTrait for ExternalVocabulary { /// Converts a token index to the token it represents in this vocabulary. fn token(&self, idx: usize) -> Vec { - let res = self.tokenizer.decode(vec![idx as u32], true); - - if res.is_err() { - panic!("Cannot decode token from tokenizer vocabulary."); - } else { - res.unwrap().as_bytes().to_vec() - } + self.tokenizer + .decode(vec![idx as u32], true) + .expect("Cannot decode token from tokenizer vocabulary.") + .as_bytes() + .to_vec() } /// Returns the number of tokens in the vocabulary. fn len(&self) -> usize { - self.tokenizer.get_vocab_size(false) as usize + self.tokenizer.get_vocab_size(false) } /// Returns whether the vocabulary is empty. @@ -291,28 +303,19 @@ impl VocabularyTrait for ExternalVocabulary { /// Tokenize a `text` with this vocabulary. /// /// `bos` controls whether a beginning-of-string token should be inserted. - fn tokenize<'a>( - &'a self, + fn tokenize( + &self, text: &str, bos: bool, ) -> Result, TokenId)>, TokenizationError> { - let res = self.tokenizer.encode(text, bos); - if res.is_err() { - return Err(TokenizationError::TokenizationFailed); - } else { - let res = self.tokenizer.encode(text, bos); - - if res.is_err() { - return Err(TokenizationError::TokenizationFailed); - } - - Ok(res - .unwrap() - .get_ids() - .iter() - .map(|id| (self.token(*id as usize), *id)) - .collect::, TokenId)>>()) - } + Ok(self + .tokenizer + .encode(text, bos) + .map_err(|e| TokenizationError::TokenizationFailed { error: e })? + .get_ids() + .iter() + .map(|id| (self.token(*id as usize), *id)) + .collect::, TokenId)>>()) } } diff --git a/crates/llm/examples/embeddings.rs b/crates/llm/examples/embeddings.rs index f0c18d0c..97baa69f 100644 --- a/crates/llm/examples/embeddings.rs +++ b/crates/llm/examples/embeddings.rs @@ -35,6 +35,7 @@ fn main() { let model = llm::load_dynamic( model_architecture, model_path, + llm::VocabularySource::ModelFile, model_params, overrides, llm::load_progress_callback_stdout, From 7ddbdb26324e105a585e6a9c0f7e5309c499859e Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 25 May 2023 01:11:53 +0200 Subject: [PATCH 08/20] refactor: clean up vocabulary stuff --- binaries/llm-cli/src/cli_args.rs | 73 ++++++++++++----- binaries/llm-cli/src/main.rs | 22 ++++-- crates/llm-base/src/lib.rs | 5 +- crates/llm-base/src/loader.rs | 64 +++++---------- crates/llm-base/src/quantize.rs | 3 +- crates/llm-base/src/vocabulary.rs | 125 +++++++++++++++++++----------- 6 files changed, 175 insertions(+), 117 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index d2630f6b..848f82da 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -1,7 +1,7 @@ use std::{fmt, ops::Deref, path::PathBuf}; use clap::{Parser, Subcommand, ValueEnum}; -use color_eyre::eyre::{Result, WrapErr}; +use color_eyre::eyre::{eyre, Result, WrapErr}; use llm::{ ggml_format, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, LoadProgress, Model, ModelKVMemoryType, ModelParameters, TokenBias, VocabularySource, @@ -141,9 +141,8 @@ pub struct Perplexity { #[derive(Parser, Debug)] pub struct Info { - /// The model to inspect. - #[arg(long, short = 'm')] - pub model_path: PathBuf, + #[command(flatten)] + pub model_and_vocabulary: ModelAndVocabulary, /// Show all of the tensors in the model, including their names, formats and shapes. #[arg(long, short = 't')] @@ -331,18 +330,61 @@ fn parse_bias(s: &str) -> Result { } #[derive(Parser, Debug)] -pub struct ModelLoad { +pub struct ModelVocabulary { + /// Local path to vocabulary + #[arg(long, short = 'v')] + pub vocabulary_path: Option, + + /// Remote HuggingFace repository containing vocabulary + #[arg(long, short = 'r')] + pub vocabulary_repository: Option, +} +impl ModelVocabulary { + pub fn to_source(&self, sp: &mut Option) -> Result { + Ok(match (&self.vocabulary_path, &self.vocabulary_repository) { + (Some(_), Some(_)) => { + if let Some(sp) = sp.take() { + sp.fail("Invalid arguments"); + }; + + return Err(eyre!( + "Cannot specify both --vocabulary-path and --vocabulary-repo" + )); + } + (Some(path), None) => VocabularySource::HuggingFaceTokenizerFile(path.to_owned()), + (None, Some(repo)) => VocabularySource::HuggingFaceRemote(repo.to_owned()), + (None, None) => VocabularySource::ModelFile, + }) + } +} + +#[derive(Parser, Debug)] +pub struct ModelAndVocabulary { /// Where to load the model from #[arg(long, short = 'm')] pub model_path: PathBuf, + #[command(flatten)] + pub vocabulary: ModelVocabulary, + /// Local path to vocabulary #[arg(long, short = 'v')] pub vocabulary_path: Option, /// Remote HuggingFace repository containing vocabulary #[arg(long, short = 'r')] - pub vocabulary_repo: Option, + pub vocabulary_repository: Option, +} +impl ModelAndVocabulary { + pub fn to_source(&self, sp: &mut Option) -> Result { + self.vocabulary.to_source(sp) + } +} + +#[derive(Parser, Debug)] +pub struct ModelLoad { + #[command(flatten)] + pub model_and_vocabulary: ModelAndVocabulary, /// Sets the size of the context (in tokens). Allows feeding longer prompts. /// Note that this affects memory. @@ -385,22 +427,10 @@ impl ModelLoad { let now = std::time::Instant::now(); let mut prev_load_time = now; - let mut vocabulary_source = VocabularySource::ModelFile; - - if self.vocabulary_path.is_some() && self.vocabulary_repo.is_some() { - if let Some(sp) = sp.take() { - sp.fail("Invalid arguments"); - }; - - panic!("Cannot specify both --vocabulary-path and --vocabulary-repo"); - } else if let Some(path) = &self.vocabulary_path { - vocabulary_source = VocabularySource::TokenizerFile(path.clone()); - } else if let Some(repo) = &self.vocabulary_repo { - vocabulary_source = VocabularySource::HuggingFaceRemote(repo.clone()); - } + let vocabulary_source = self.model_and_vocabulary.to_source(&mut sp)?; let model = llm::load::( - &self.model_path, + &self.model_and_vocabulary.model_path, vocabulary_source, params, overrides, @@ -547,6 +577,9 @@ pub struct Quantize { #[arg()] pub destination: PathBuf, + #[command(flatten)] + pub vocabulary: ModelVocabulary, + /// The GGML container type to target. /// /// Note that using GGML requires the original model to have diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 5968fa55..2110212d 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -7,7 +7,7 @@ use std::{ use clap::Parser; use cli_args::{Args, BaseArgs}; use color_eyre::eyre::{Context, Result}; -use llm::{InferenceError, InferenceFeedback, InferenceResponse, Vocabulary}; +use llm::{InferenceError, InferenceFeedback, InferenceResponse}; use rustyline::error::ReadlineError; use rustyline::validate::{ValidationContext, ValidationResult, Validator}; use rustyline::{history::DefaultHistory, Cmd, Event, EventHandler, KeyCode, KeyEvent, Modifiers}; @@ -148,12 +148,17 @@ fn perplexity( } fn info(args: &cli_args::Info) -> Result<()> { - let file = File::open(&args.model_path)?; + let model_path = &args.model_and_vocabulary.model_path; + let vocabulary = args + .model_and_vocabulary + .to_source(&mut None)? + .retrieve(model_path)?; + + let file = File::open(model_path)?; let mut reader = BufReader::new(&file); - let mut loader: llm::Loader = - llm::Loader::new(Vocabulary::new_model(), |_| { - // We purposely do not print progress here, as we are only interested in the metadata - }); + let mut loader: llm::Loader = llm::Loader::new(vocabulary, |_| { + // We purposely do not print progress here, as we are only interested in the metadata + }); llm::ggml_format::load(&mut reader, &mut loader)?; @@ -316,10 +321,15 @@ fn quantize(args: &cli_args::Quantize) -> Result<( let mut source = BufReader::new(std::fs::File::open(&args.source)?); let mut destination = BufWriter::new(std::fs::File::create(&args.destination)?); + let vocabulary = args + .vocabulary + .to_source(&mut None)? + .retrieve(&args.source)?; llm::quantize::( &mut source, &mut destination, + vocabulary, args.container_type.into(), args.target.into(), |progress| match progress { diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index 8156bbd2..923b8af8 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -36,9 +36,10 @@ pub use model::{ }; pub use quantize::{quantize, QuantizeError, QuantizeProgress}; pub use util::TokenUtf8Buffer; +pub(crate) use vocabulary::ModelVocabulary; pub use vocabulary::{ - ExternalVocabulary, InvalidTokenBias, ModelVocabulary, Prompt, TokenBias, TokenId, - TokenizationError, Vocabulary, VocabularySource, + InvalidTokenBias, Prompt, TokenBias, TokenId, TokenizationError, Vocabulary, + VocabularyLoadError, VocabularySource, }; #[derive(Clone, Debug, PartialEq)] diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index df42e239..b05ffbc8 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -8,8 +8,8 @@ use std::{ }; use crate::{ - util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelParameters, TokenId, - Vocabulary, VocabularySource, + util, Hyperparameters, KnownModel, LoraAdapter, LoraParameters, ModelParameters, + ModelVocabulary, TokenId, Vocabulary, VocabularyLoadError, VocabularySource, }; pub use ggml::ContainerType; use ggml::{ @@ -19,8 +19,6 @@ use ggml::{ use memmap2::Mmap; use thiserror::Error; -use tokenizers::Tokenizer; - #[derive(Debug, PartialEq, Clone, Copy, Eq, Default)] /// Information about the file. pub struct FileType { @@ -283,12 +281,11 @@ pub enum LoadError { /// The paths that were found. paths: Vec, }, - /// The vocab file for the tokenizer could not be loaded. - #[error("could not load vocab file {path:?}")] + #[error("could not load vocabulary file {path:?}")] VocabularyLoadError { /// The invalid vocabulary path - path: String, + path: PathBuf, /// The error that occurred. error: Box, @@ -302,6 +299,14 @@ impl From for LoadError { } } } +impl From for LoadError { + fn from(value: VocabularyLoadError) -> Self { + LoadError::VocabularyLoadError { + path: value.path, + error: value.error, + } + } +} impl LoadError { #[doc(hidden)] @@ -378,38 +383,7 @@ pub fn load( })?; let mut reader = BufReader::new(&file); - let vocabulary = match vocabulary_source { - VocabularySource::HuggingFaceRemote(identifier) => { - Vocabulary::new_external(Tokenizer::from_pretrained(&identifier, None).map_err( - |error| LoadError::VocabularyLoadError { - path: path.to_string_lossy().to_string(), - error, - }, - )?) - } - - VocabularySource::TokenizerFile(path) => { - if path.exists() && path.is_file() { - Vocabulary::new_external(Tokenizer::from_file(&path).map_err(|error| { - LoadError::VocabularyLoadError { - path: path.to_string_lossy().to_string(), - error, - } - })?) - } else { - return Err(LoadError::VocabularyLoadError { - path: path.to_string_lossy().to_string(), - error: Box::new(std::io::Error::new( - std::io::ErrorKind::NotFound, - "Vocabulary file not found", - )), - }); - } - } - - VocabularySource::ModelFile => Vocabulary::new_model(), - }; - + let vocabulary = vocabulary_source.retrieve(path)?; let mut loader = Loader::new(vocabulary, load_progress_callback); ggml::format::load(&mut reader, &mut loader) @@ -469,7 +443,7 @@ pub fn load( // TODO: Consider updating the progress callback to report the progress of the LoRA file. // Most LoRAs are small enough that this is not necessary, but it would be nice to have. let mut lora_loader: Loader = - Loader::new(Vocabulary::new_model(), |_| {}); + Loader::new(ModelVocabulary::default().into(), |_| {}); ggml::format::load(&mut lora_reader, &mut lora_loader) .map_err(|err| LoadError::from_format_error(err, lora_path.to_owned()))?; @@ -533,13 +507,15 @@ pub struct Loader { // Input load_progress_callback: F, + // Input/Output + /// The vocabulary of the model. + pub vocabulary: Vocabulary, + // Output /// The container type of the model. pub container_type: ContainerType, /// The hyperparameters of the model. pub hyperparameters: Hp, - /// The vocabulary of the model. - pub vocabulary: Vocabulary, /// The tensors of the model. pub tensors: HashMap, } @@ -565,13 +541,13 @@ impl ggml::format::LoadHandler, score: f32) -> Result<(), LoadError> { - if let Vocabulary::Model(_) = &self.vocabulary { + if let Vocabulary::Model(mv) = &mut self.vocabulary { let id = match TokenId::try_from(i) { Ok(id) => id, Err(err) => return Err(LoadError::InvalidIntegerConversion(err)), }; - self.vocabulary.push_token(id, token, score); + mv.push_token(id, token, score); } Ok(()) diff --git a/crates/llm-base/src/quantize.rs b/crates/llm-base/src/quantize.rs index d9ea914f..affaedcd 100644 --- a/crates/llm-base/src/quantize.rs +++ b/crates/llm-base/src/quantize.rs @@ -138,6 +138,7 @@ impl QuantizeError { pub fn quantize( reader: &mut R, writer: &mut W, + vocabulary: Vocabulary, save_container_type: ggml::format::SaveContainerType, desired_type: ggml::Type, progress_callback: impl Fn(QuantizeProgress), @@ -152,7 +153,7 @@ pub fn quantize( // Load the model let progress_callback = Arc::new(progress_callback); - let mut loader = Loader::::new(Vocabulary::new_model(), { + let mut loader = Loader::::new(vocabulary, { let progress_callback = progress_callback.clone(); move |p| { if let LoadProgress::HyperparametersLoaded = p { diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index a802d4a3..5bdb4703 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -1,4 +1,10 @@ -use std::{collections::HashMap, error::Error, fmt::Display, path::PathBuf, str::FromStr}; +use std::{ + collections::HashMap, + error::Error, + fmt::Display, + path::{Path, PathBuf}, + str::FromStr, +}; use thiserror::Error; use tokenizers::Tokenizer; @@ -23,26 +29,76 @@ pub enum TokenizationError { InvalidTokenId(TokenId), } +#[derive(Error, Debug)] +/// Errors related to loading the vocabulary. +#[error("error loading vocabulary from {path}: {error}")] +pub struct VocabularyLoadError { + /// The path to the vocabulary. + pub path: PathBuf, + /// The error that occurred during loading. + pub error: Box, +} + +impl VocabularyLoadError { + fn new(path: impl Into, error: impl Into>) -> Self { + Self { + path: path.into(), + error: error.into(), + } + } +} + +#[derive(Clone, Debug, PartialEq)] /// The source of a vocabulary. pub enum VocabularySource { - /// Fetch vocabulary from model file + /// Read the vocabulary from the model if available, and use a simplistic tokenizer. + /// + /// This is easy to use, but may not be the best choice for your use case. ModelFile, - /// Fetch vocabulary from a vocabulary file - TokenizerFile(PathBuf), + /// Read the vocabulary from a local HuggingFace-format tokenizer file, and use the + /// HuggingFace tokenizer. + HuggingFaceTokenizerFile(PathBuf), - /// Fetch vocabulary from remote HuggingFace repository + /// Fetch the vocabulary from a remote HuggingFace repository. This will make a blocking + /// HTTP request to HuggingFace to retrieve the vocabulary and may store files locally, + /// so it is not recommended for production use. This will use the HuggingFace tokenizer. HuggingFaceRemote(String), } +impl VocabularySource { + /// Retrieve the vocabulary from the source. + /// + /// Note that this may make a blocking HTTP request to HuggingFace to retrieve the vocabulary + /// if `self` is [`Self::HuggingFaceRemote`]. + pub fn retrieve(self, model_path: &Path) -> Result { + Ok(match self { + Self::HuggingFaceRemote(identifier) => ExternalVocabulary::new( + Tokenizer::from_pretrained(&identifier, None) + .map_err(|error| VocabularyLoadError::new(model_path, error))?, + ) + .into(), + + Self::HuggingFaceTokenizerFile(path) => { + if !path.is_file() { + return Err(VocabularyLoadError::new( + path, + std::io::Error::new( + std::io::ErrorKind::NotFound, + "Vocabulary file not found", + ), + )); + } -pub trait VocabularyTrait { - fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore); - fn id(&self, token: &[u8]) -> Option; - fn token(&self, idx: usize) -> Vec; - fn len(&self) -> usize; - fn is_empty(&self) -> bool; - fn tokenize(&self, text: &str, bos: bool) - -> Result, TokenId)>, TokenizationError>; + ExternalVocabulary::new( + Tokenizer::from_file(&path) + .map_err(|error| VocabularyLoadError::new(path, error))?, + ) + .into() + } + + Self::ModelFile => ModelVocabulary::default().into(), + }) + } } /// Vocabulary enum @@ -53,32 +109,17 @@ pub enum Vocabulary { /// A custom vocabulary provided by the user. External(ExternalVocabulary), } - -impl Vocabulary { - /// Create a new vocabulary with the default GGML vocabulary. - pub fn new_model() -> Self { - Vocabulary::Model(ModelVocabulary::default()) - } - - /// Create a new vocabulary with a custom tokenizer. - pub fn new_external(tokenizer: Tokenizer) -> Self { - Vocabulary::External(ExternalVocabulary::new(tokenizer)) +impl From for Vocabulary { + fn from(v: ModelVocabulary) -> Self { + Self::Model(v) } - - /// Add a token to the vocabulary. - /// - /// The token added must have `id` directly after the last token in the vocabulary. - /// - /// # Panics - /// - This function can panic if `id` does not correspond to the next token in the vocabulary. - /// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`. - pub fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { - match self { - Vocabulary::Model(v) => v.push_token(id, content, score), - Vocabulary::External(v) => v.push_token(id, content, score), - } +} +impl From for Vocabulary { + fn from(v: ExternalVocabulary) -> Self { + Self::External(v) } - +} +impl Vocabulary { /// Converts a token to the token ID it represents in this vocabulary. pub fn id(&self, token: &[u8]) -> Option { match self { @@ -152,7 +193,7 @@ pub struct ModelVocabulary { pub max_token_length: usize, } -impl VocabularyTrait for ModelVocabulary { +impl ModelVocabulary { /// Add a token to the vocabulary. /// /// The token added must have `id` directly after the last token in the vocabulary. @@ -160,7 +201,7 @@ impl VocabularyTrait for ModelVocabulary { /// # Panics /// - This function can panic if `id` does not correspond to the next token in the vocabulary. /// That is, if there are already `n` tokens in the vocabulary, then `id` must be `n`. - fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { + pub(crate) fn push_token(&mut self, id: TokenId, content: Token, score: TokenScore) { // These are loader invariants. If this is broken, then the loader is broken and this is a bug, // not an issue with the model itself. assert_eq!(self.id_to_token.len(), self.id_to_token_score.len()); @@ -270,11 +311,7 @@ impl ExternalVocabulary { } } -impl VocabularyTrait for ExternalVocabulary { - fn push_token(&mut self, _id: TokenId, _content: Token, _score: TokenScore) { - panic!("Cannot push token to tokenizer vocabulary."); - } - +impl ExternalVocabulary { fn id(&self, token: &[u8]) -> Option { self.tokenizer .token_to_id(std::str::from_utf8(token).unwrap()) From b1945d78dc2c328ea8a90d77bb4d92f92b8c22b0 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 25 May 2023 01:38:42 +0200 Subject: [PATCH 09/20] chore: minor fixes --- binaries/llm-cli/src/cli_args.rs | 2 +- crates/llm-base/src/vocabulary.rs | 30 +++++++++++++++++++----------- crates/llm/examples/embeddings.rs | 2 +- crates/llm/examples/inference.rs | 2 +- crates/llm/examples/vicuna-chat.rs | 2 +- crates/llm/src/lib.rs | 2 ++ 6 files changed, 25 insertions(+), 15 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 848f82da..40756aa1 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -353,7 +353,7 @@ impl ModelVocabulary { } (Some(path), None) => VocabularySource::HuggingFaceTokenizerFile(path.to_owned()), (None, Some(repo)) => VocabularySource::HuggingFaceRemote(repo.to_owned()), - (None, None) => VocabularySource::ModelFile, + (None, None) => VocabularySource::Model, }) } } diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index 5bdb4703..f5d09a6e 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -53,8 +53,9 @@ impl VocabularyLoadError { pub enum VocabularySource { /// Read the vocabulary from the model if available, and use a simplistic tokenizer. /// - /// This is easy to use, but may not be the best choice for your use case. - ModelFile, + /// This is easy to use, but may not be the best choice for your use case, and is not + /// guaranteed to be available for all models. + Model, /// Read the vocabulary from a local HuggingFace-format tokenizer file, and use the /// HuggingFace tokenizer. @@ -96,7 +97,7 @@ impl VocabularySource { .into() } - Self::ModelFile => ModelVocabulary::default().into(), + Self::Model => ModelVocabulary::default().into(), }) } } @@ -305,7 +306,7 @@ pub struct ExternalVocabulary { } impl ExternalVocabulary { - /// Create a new `TokenizerVocabulary`. + /// Create a new `ExternalVocabulary`. pub fn new(tokenizer: Tokenizer) -> Self { Self { tokenizer } } @@ -336,7 +337,6 @@ impl ExternalVocabulary { self.tokenizer.get_vocab_size(false) == 0 } - // SentencePiece implementation after https://guillaume-be.github.io/2020-05-30/sentence_piece /// Tokenize a `text` with this vocabulary. /// /// `bos` controls whether a beginning-of-string token should be inserted. @@ -345,14 +345,22 @@ impl ExternalVocabulary { text: &str, bos: bool, ) -> Result, TokenId)>, TokenizationError> { - Ok(self + let encoding = self + .tokenizer + .encode(text, false) + .map_err(|e| TokenizationError::TokenizationFailed { error: e })?; + + let encoding = self .tokenizer - .encode(text, bos) - .map_err(|e| TokenizationError::TokenizationFailed { error: e })? - .get_ids() + .post_process(encoding, None, bos) + .map_err(|e| TokenizationError::TokenizationFailed { error: e })?; + + Ok(encoding + .get_tokens() .iter() - .map(|id| (self.token(*id as usize), *id)) - .collect::, TokenId)>>()) + .map(|t| t.as_bytes().to_vec()) + .zip(encoding.get_ids().iter().copied()) + .collect()) } } diff --git a/crates/llm/examples/embeddings.rs b/crates/llm/examples/embeddings.rs index 97baa69f..2663b3e8 100644 --- a/crates/llm/examples/embeddings.rs +++ b/crates/llm/examples/embeddings.rs @@ -35,7 +35,7 @@ fn main() { let model = llm::load_dynamic( model_architecture, model_path, - llm::VocabularySource::ModelFile, + llm::VocabularySource::Model, model_params, overrides, llm::load_progress_callback_stdout, diff --git a/crates/llm/examples/inference.rs b/crates/llm/examples/inference.rs index 6894f1bc..df0ce7ae 100644 --- a/crates/llm/examples/inference.rs +++ b/crates/llm/examples/inference.rs @@ -20,7 +20,7 @@ fn main() { let model = llm::load_dynamic( model_architecture, model_path, - llm::VocabularySource::ModelFile, + llm::VocabularySource::Model, Default::default(), overrides, llm::load_progress_callback_stdout, diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index d1b50738..07aa439b 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -15,7 +15,7 @@ fn main() { let model = llm::load_dynamic( model_architecture, model_path, - llm::VocabularySource::ModelFile, + llm::VocabularySource::Model, Default::default(), overrides, llm::load_progress_callback_stdout, diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 596f6e6a..a83967f9 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -21,6 +21,8 @@ //! let llama = llm::load::( //! // path to GGML file //! std::path::Path::new("/path/to/model"), +//! // llm::VocabularySource +//! llm::VocabularySource::Model, //! // llm::ModelParameters //! Default::default(), //! // llm::KnownModel::Overrides From d939fcd996d8b20cc2ec4e7bb117f6aaadb8bce4 Mon Sep 17 00:00:00 2001 From: RedBoxing Date: Sun, 28 May 2023 11:23:26 +0200 Subject: [PATCH 10/20] Fix issues with llama tokenizer --- crates/llm-base/src/inference_session.rs | 36 ++++++++++++++++++++++-- crates/llm-base/src/vocabulary.rs | 31 ++++++++++++++++++++ 2 files changed, 65 insertions(+), 2 deletions(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 1022f322..0e326be2 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -55,6 +55,9 @@ pub struct InferenceSession { /// All tokens generated by this inference session pub(crate) tokens: Vec, + // All decoded tokens generated by this inference session + pub(crate) decoded_tokens: Vec, + /// The logits that were last predicted by the network. Zeroed out otherwise. #[doc(hidden)] pub last_logits: Vec, @@ -91,10 +94,23 @@ impl InferenceSession { for &tk in batch { let should_call_callback = Some(tk) != model.bot_token_id(); + let mut token = match model.vocabulary() { + crate::Vocabulary::Model(_) => model.vocabulary().token(tk as usize).to_vec(), + crate::Vocabulary::External(_) => { + let mut previous_tokens = self.tokens.clone(); + previous_tokens.push(tk); + + let all_tokens = model.vocabulary().decode(previous_tokens, true); + let splitted = all_tokens.split_at(self.decoded_tokens.len()); + + splitted.1.to_vec() + } + }; + if should_call_callback { // NOTE: No string ever tokenizes to the end of sentence. So we // can just return the id here. - match callback(&vocab.token(tk as usize)) { + match callback(&token) { Err(e) => return Err(InferenceError::UserCallback(Some(Box::new(e)))), Ok(f) => match f { InferenceFeedback::Continue => (), @@ -105,6 +121,7 @@ impl InferenceSession { // Update the tokens for this session self.tokens.push(tk); + self.decoded_tokens.append(&mut token); } } @@ -136,7 +153,20 @@ impl InferenceSession { if next_token as TokenId == model.eot_token_id() { Err(InferenceError::EndOfText) } else { - Ok(model.vocabulary().token(next_token as usize)) + let res = match model.vocabulary() { + crate::Vocabulary::Model(_) => { + model.vocabulary().token(next_token as usize).to_vec() + } + crate::Vocabulary::External(_) => { + let all_tokens = model.vocabulary().decode(self.tokens.clone(), true); + let splitted = all_tokens.split_at(self.decoded_tokens.len()); + + splitted.1.to_vec() + } + }; + + self.decoded_tokens.append(&mut res.clone()); + Ok(res) } } @@ -493,6 +523,7 @@ impl InferenceSession { n_past: 0, mem_per_token: 0, tokens: vec![], + decoded_tokens: vec![], last_logits: vec![0.0; n_vocab], scratch: scratch_buffers(), } @@ -513,6 +544,7 @@ impl Clone for InferenceSession { n_past: self.n_past, mem_per_token: self.mem_per_token, tokens: self.tokens.clone(), + decoded_tokens: self.decoded_tokens.clone(), last_logits: self.last_logits.clone(), scratch: scratch_buffers(), } diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index f5d09a6e..ade4cd42 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -166,6 +166,14 @@ impl Vocabulary { Vocabulary::External(v) => v.tokenize(text, bos), } } + + /// decode a list `tokens` with this vocabulary. + pub fn decode(&self, tokens: Vec, bos: bool) -> Vec { + match self { + Vocabulary::Model(v) => v.decode(tokens, bos), + Vocabulary::External(v) => v.decode(tokens, bos), + } + } } #[derive(Debug, Error)] @@ -297,6 +305,20 @@ impl ModelVocabulary { Ok(res) } + + /// decode a list `tokens` with this vocabulary. + fn decode(&self, tokens: Vec, skip_special_tokens: bool) -> Vec { + for token in tokens { + if skip_special_tokens && token == 1 { + continue; + } + let token = self.id_to_token[token as usize].as_slice(); + + return token.to_vec(); + } + + vec![] + } } /// A vocabulary that does not originate from the model file. @@ -362,6 +384,15 @@ impl ExternalVocabulary { .zip(encoding.get_ids().iter().copied()) .collect()) } + + /// decode a list `tokens` with this vocabulary. + fn decode(&self, tokens: Vec, skip_special_tokens: bool) -> Vec { + self.tokenizer + .decode(tokens, skip_special_tokens) + .expect("Cannot decode token from tokenizer vocabulary.") + .as_bytes() + .to_vec() + } } #[derive(Debug, PartialEq, Clone, Copy)] From bae1682f1ea9a66df66b24e4796611e6872b1b3c Mon Sep 17 00:00:00 2001 From: RedBoxing Date: Sun, 28 May 2023 11:27:53 +0200 Subject: [PATCH 11/20] oups --- crates/llm-base/src/vocabulary.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index ade4cd42..c0f27ac7 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -309,9 +309,6 @@ impl ModelVocabulary { /// decode a list `tokens` with this vocabulary. fn decode(&self, tokens: Vec, skip_special_tokens: bool) -> Vec { for token in tokens { - if skip_special_tokens && token == 1 { - continue; - } let token = self.id_to_token[token as usize].as_slice(); return token.to_vec(); From 7b06bb93d15273272af1fad348607cf32a552f0e Mon Sep 17 00:00:00 2001 From: RedBoxing Date: Sun, 28 May 2023 11:32:52 +0200 Subject: [PATCH 12/20] Fix decode function on internal vocab --- crates/llm-base/src/vocabulary.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index c0f27ac7..0014b056 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -308,13 +308,13 @@ impl ModelVocabulary { /// decode a list `tokens` with this vocabulary. fn decode(&self, tokens: Vec, skip_special_tokens: bool) -> Vec { - for token in tokens { - let token = self.id_to_token[token as usize].as_slice(); + let mut vec = vec![]; - return token.to_vec(); + for token in tokens { + vec.append(&mut self.id_to_token[token as usize].to_vec()); } - vec![] + vec } } From 45e407a21faa16bc1030341f7d710ca7b1bdb118 Mon Sep 17 00:00:00 2001 From: RedBoxing Date: Sun, 28 May 2023 11:40:11 +0200 Subject: [PATCH 13/20] Revert `oups` (I thought copilot hallucinated smt) --- crates/llm-base/src/vocabulary.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crates/llm-base/src/vocabulary.rs b/crates/llm-base/src/vocabulary.rs index 0014b056..43540c83 100644 --- a/crates/llm-base/src/vocabulary.rs +++ b/crates/llm-base/src/vocabulary.rs @@ -311,6 +311,10 @@ impl ModelVocabulary { let mut vec = vec![]; for token in tokens { + if skip_special_tokens && token == 1 { + continue; + } + vec.append(&mut self.id_to_token[token as usize].to_vec()); } From 7407264364c2ea701a973f8bbab194e8333d4318 Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 29 May 2023 21:48:58 +0200 Subject: [PATCH 14/20] minor fix --- binaries/llm-cli/src/cli_args.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 6a5e2c39..9ce12736 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -348,7 +348,7 @@ impl ModelVocabulary { }; return Err(eyre!( - "Cannot specify both --vocabulary-path and --vocabulary-repo" + "Cannot specify both --vocabulary-path and --vocabulary-repository" )); } (Some(path), None) => VocabularySource::HuggingFaceTokenizerFile(path.to_owned()), From faecd36585239c9d571e66e99d435b557bc3b4d5 Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 29 May 2023 22:49:14 +0200 Subject: [PATCH 15/20] fix(cli): remove sp pass-down to ModelVocabulary --- binaries/llm-cli/src/cli_args.rs | 26 ++++++++++++++------------ binaries/llm-cli/src/main.rs | 7 ++----- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 9ce12736..3d3d549a 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -1,7 +1,7 @@ use std::{fmt, ops::Deref, path::PathBuf}; use clap::{Parser, Subcommand, ValueEnum}; -use color_eyre::eyre::{eyre, Result, WrapErr}; +use color_eyre::eyre::{bail, Result, WrapErr}; use llm::{ ggml_format, ElementType, InferenceParameters, InferenceSessionConfig, InvalidTokenBias, LoadProgress, Model, ModelKVMemoryType, ModelParameters, TokenBias, VocabularySource, @@ -340,16 +340,10 @@ pub struct ModelVocabulary { pub vocabulary_repository: Option, } impl ModelVocabulary { - pub fn to_source(&self, sp: &mut Option) -> Result { + pub fn to_source(&self) -> Result { Ok(match (&self.vocabulary_path, &self.vocabulary_repository) { (Some(_), Some(_)) => { - if let Some(sp) = sp.take() { - sp.fail("Invalid arguments"); - }; - - return Err(eyre!( - "Cannot specify both --vocabulary-path and --vocabulary-repository" - )); + bail!("Cannot specify both --vocabulary-path and --vocabulary-repository"); } (Some(path), None) => VocabularySource::HuggingFaceTokenizerFile(path.to_owned()), (None, Some(repo)) => VocabularySource::HuggingFaceRemote(repo.to_owned()), @@ -376,8 +370,8 @@ pub struct ModelAndVocabulary { pub vocabulary_repository: Option, } impl ModelAndVocabulary { - pub fn to_source(&self, sp: &mut Option) -> Result { - self.vocabulary.to_source(sp) + pub fn to_source(&self) -> Result { + self.vocabulary.to_source() } } @@ -427,7 +421,15 @@ impl ModelLoad { let now = std::time::Instant::now(); let mut prev_load_time = now; - let vocabulary_source = self.model_and_vocabulary.to_source(&mut sp)?; + let vocabulary_source = match self.model_and_vocabulary.to_source() { + Ok(vs) => vs, + Err(err) => { + if let Some(sp) = sp.take() { + sp.fail(&format!("Failed to load vocabulary: {}", err)); + } + return Err(err); + } + }; let model = llm::load::( &self.model_and_vocabulary.model_path, diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 2110212d..7088ca3a 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -151,7 +151,7 @@ fn info(args: &cli_args::Info) -> Result<()> { let model_path = &args.model_and_vocabulary.model_path; let vocabulary = args .model_and_vocabulary - .to_source(&mut None)? + .to_source()? .retrieve(model_path)?; let file = File::open(model_path)?; @@ -321,10 +321,7 @@ fn quantize(args: &cli_args::Quantize) -> Result<( let mut source = BufReader::new(std::fs::File::open(&args.source)?); let mut destination = BufWriter::new(std::fs::File::create(&args.destination)?); - let vocabulary = args - .vocabulary - .to_source(&mut None)? - .retrieve(&args.source)?; + let vocabulary = args.vocabulary.to_source()?.retrieve(&args.source)?; llm::quantize::( &mut source, From d795bc83bb25a74ffb5bf263fa1d65b0e2dcc5db Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 29 May 2023 23:21:00 +0200 Subject: [PATCH 16/20] refactor(examples): add vocabulary via clap --- Cargo.lock | 1 + Cargo.toml | 6 ++- binaries/llm-cli/Cargo.toml | 2 +- crates/llm/Cargo.toml | 1 + crates/llm/examples/embeddings.rs | 84 ++++++++++++++++++------------ crates/llm/examples/inference.rs | 57 +++++++++++++------- crates/llm/examples/vicuna-chat.rs | 50 ++++++++++++------ 7 files changed, 130 insertions(+), 71 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e03cb0d6..493df0d3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1240,6 +1240,7 @@ name = "llm" version = "0.2.0-dev" dependencies = [ "bytesize", + "clap", "llm-base", "llm-bloom", "llm-gpt2", diff --git a/Cargo.toml b/Cargo.toml index f3d4d795..0309bcb3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,12 +20,14 @@ bytemuck = "1.13.1" bytesize = "1.1" log = "0.4" rand = "0.8.5" +thiserror = "1.0" +anyhow = "1.0" + rustyline = { version = "11.0.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] } serde_json = { version = "1.0" } spinoff = { version = "0.7.0", default-features = false, features = ["dots2"] } -thiserror = "1.0" -anyhow = "1.0" +clap = { version = "4.1.8", features = ["derive"] } # Config for 'cargo dist' [workspace.metadata.dist] diff --git a/binaries/llm-cli/Cargo.toml b/binaries/llm-cli/Cargo.toml index 6f8561f3..ffb30920 100644 --- a/binaries/llm-cli/Cargo.toml +++ b/binaries/llm-cli/Cargo.toml @@ -19,11 +19,11 @@ log = { workspace = true } rand = { workspace = true } rustyline = { workspace = true } spinoff = { workspace = true } +clap = { workspace = true } bincode = "1.3.3" env_logger = "0.10.0" num_cpus = "1.15.0" -clap = { version = "4.1.8", features = ["derive"] } color-eyre = { version = "0.6.2", default-features = false } zstd = { version = "0.12", default-features = false } diff --git a/crates/llm/Cargo.toml b/crates/llm/Cargo.toml index 13fca508..05a2bdae 100644 --- a/crates/llm/Cargo.toml +++ b/crates/llm/Cargo.toml @@ -25,6 +25,7 @@ rand = { workspace = true } rustyline = { workspace = true } spinoff = { workspace = true } serde_json = { workspace = true } +clap = { workspace = true } [features] default = ["llama", "gpt2", "gptj", "bloom", "gptneox", "mpt"] diff --git a/crates/llm/examples/embeddings.rs b/crates/llm/examples/embeddings.rs index 2663b3e8..afe538a4 100644 --- a/crates/llm/examples/embeddings.rs +++ b/crates/llm/examples/embeddings.rs @@ -1,30 +1,52 @@ -use std::path::Path; +use std::path::PathBuf; -fn main() { - // Get arguments from command line - let raw_args: Vec = std::env::args().skip(1).collect(); - if raw_args.len() < 2 { - println!("Usage: cargo run --release --example embeddings [query] [comma-separated comparands] [overrides, json]"); - std::process::exit(1); +use clap::Parser; + +#[derive(Parser)] +struct Args { + architecture: String, + path: PathBuf, + #[arg(long, short = 'v')] + pub vocabulary_path: Option, + #[arg(long, short = 'r')] + pub vocabulary_repository: Option, + #[arg(long, short = 'q')] + pub query: Option, + #[arg(long, short = 'c')] + pub comparands: Vec, +} +impl Args { + pub fn to_vocabulary_source(&self) -> llm::VocabularySource { + match (&self.vocabulary_path, &self.vocabulary_repository) { + (Some(_), Some(_)) => { + panic!("Cannot specify both --vocabulary-path and --vocabulary-repository"); + } + (Some(path), None) => llm::VocabularySource::HuggingFaceTokenizerFile(path.to_owned()), + (None, Some(repo)) => llm::VocabularySource::HuggingFaceRemote(repo.to_owned()), + (None, None) => llm::VocabularySource::Model, + } } +} - let model_architecture: llm::ModelArchitecture = raw_args[0].parse().unwrap(); - let model_path = Path::new(&raw_args[1]); - let query = raw_args - .get(2) - .map(|s| s.as_str()) +fn main() { + let args = Args::parse(); + + let vocabulary_source = args.to_vocabulary_source(); + let architecture = args.architecture.parse().unwrap(); + let path = args.path; + let query = args + .query + .as_deref() .unwrap_or("My favourite animal is the dog"); - let comparands = raw_args - .get(3) - .map(|s| s.split(',').map(|s| s.trim()).collect::>()) - .unwrap_or_else(|| { - vec![ - "My favourite animal is the dog", - "I have just adopted a cute dog", - "My favourite animal is the cat", - ] - }); - let overrides = raw_args.get(4).map(|s| serde_json::from_str(s).unwrap()); + let comparands = if !args.comparands.is_empty() { + args.comparands + } else { + vec![ + "My favourite animal is the dog".to_string(), + "I have just adopted a cute dog".to_string(), + "My favourite animal is the cat".to_string(), + ] + }; // Load model let model_params = llm::ModelParameters { @@ -33,25 +55,23 @@ fn main() { lora_adapters: None, }; let model = llm::load_dynamic( - model_architecture, - model_path, - llm::VocabularySource::Model, + architecture, + &path, + vocabulary_source, model_params, - overrides, + None, llm::load_progress_callback_stdout, ) - .unwrap_or_else(|err| { - panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") - }); + .unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}")); let inference_parameters = llm::InferenceParameters::default(); // Generate embeddings for query and comparands let query_embeddings = get_embeddings(model.as_ref(), &inference_parameters, query); let comparand_embeddings: Vec<(String, Vec)> = comparands .iter() - .map(|&text| { + .map(|text| { ( - text.to_owned(), + text.clone(), get_embeddings(model.as_ref(), &inference_parameters, text), ) }) diff --git a/crates/llm/examples/inference.rs b/crates/llm/examples/inference.rs index df0ce7ae..0fb3463a 100644 --- a/crates/llm/examples/inference.rs +++ b/crates/llm/examples/inference.rs @@ -1,33 +1,52 @@ -use std::{convert::Infallible, io::Write, path::Path}; +use clap::Parser; +use std::{convert::Infallible, io::Write, path::PathBuf}; -fn main() { - let raw_args: Vec = std::env::args().skip(1).collect(); - if raw_args.len() < 2 { - println!("Usage: cargo run --release --example inference [prompt] [overrides, json]"); - std::process::exit(1); +#[derive(Parser)] +struct Args { + architecture: String, + path: PathBuf, + #[arg(long, short = 'p')] + prompt: Option, + #[arg(long, short = 'v')] + vocabulary_path: Option, + #[arg(long, short = 'r')] + vocabulary_repository: Option, +} +impl Args { + pub fn to_vocabulary_source(&self) -> llm::VocabularySource { + match (&self.vocabulary_path, &self.vocabulary_repository) { + (Some(_), Some(_)) => { + panic!("Cannot specify both --vocabulary-path and --vocabulary-repository"); + } + (Some(path), None) => llm::VocabularySource::HuggingFaceTokenizerFile(path.to_owned()), + (None, Some(repo)) => llm::VocabularySource::HuggingFaceRemote(repo.to_owned()), + (None, None) => llm::VocabularySource::Model, + } } +} + +fn main() { + let args = Args::parse(); - let model_architecture: llm::ModelArchitecture = raw_args[0].parse().unwrap(); - let model_path = Path::new(&raw_args[1]); - let prompt = raw_args - .get(2) - .map(|s| s.as_str()) + let vocabulary_source = args.to_vocabulary_source(); + let architecture = args.architecture.parse().unwrap(); + let path = args.path; + let prompt = args + .prompt + .as_deref() .unwrap_or("Rust is a cool programming language because"); - let overrides = raw_args.get(3).map(|s| serde_json::from_str(s).unwrap()); let now = std::time::Instant::now(); let model = llm::load_dynamic( - model_architecture, - model_path, - llm::VocabularySource::Model, + architecture, + &path, + vocabulary_source, Default::default(), - overrides, + None, llm::load_progress_callback_stdout, ) - .unwrap_or_else(|err| { - panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") - }); + .unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}")); println!( "Model fully loaded! Elapsed: {}ms", diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index 07aa439b..e23f42a4 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -1,28 +1,44 @@ +use clap::Parser; use rustyline::error::ReadlineError; -use std::{convert::Infallible, io::Write, path::Path}; - -fn main() { - let raw_args: Vec = std::env::args().skip(1).collect(); - if raw_args.len() < 2 { - println!("Usage: cargo run --release --example vicuna-chat [overrides, json]"); - std::process::exit(1); +use std::{convert::Infallible, io::Write, path::PathBuf}; + +#[derive(Parser)] +struct Args { + architecture: String, + path: PathBuf, + #[arg(long, short = 'v')] + pub vocabulary_path: Option, + #[arg(long, short = 'r')] + pub vocabulary_repository: Option, +} +impl Args { + pub fn to_vocabulary_source(&self) -> llm::VocabularySource { + match (&self.vocabulary_path, &self.vocabulary_repository) { + (Some(_), Some(_)) => { + panic!("Cannot specify both --vocabulary-path and --vocabulary-repository"); + } + (Some(path), None) => llm::VocabularySource::HuggingFaceTokenizerFile(path.to_owned()), + (None, Some(repo)) => llm::VocabularySource::HuggingFaceRemote(repo.to_owned()), + (None, None) => llm::VocabularySource::Model, + } } +} - let model_architecture: llm::ModelArchitecture = raw_args[0].parse().unwrap(); - let model_path = Path::new(&raw_args[1]); - let overrides = raw_args.get(2).map(|s| serde_json::from_str(s).unwrap()); +fn main() { + let args = Args::parse(); + let vocabulary_source = args.to_vocabulary_source(); + let architecture = args.architecture.parse().unwrap(); + let path = args.path; let model = llm::load_dynamic( - model_architecture, - model_path, - llm::VocabularySource::Model, + architecture, + &path, + vocabulary_source, Default::default(), - overrides, + None, llm::load_progress_callback_stdout, ) - .unwrap_or_else(|err| { - panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") - }); + .unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}")); let mut session = model.start_session(Default::default()); From 8f62b39ade8d4282077995a2f6d0b6f9057d301d Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 29 May 2023 23:27:58 +0200 Subject: [PATCH 17/20] refactor: remove model overrides --- binaries/llm-cli/src/cli_args.rs | 6 +- binaries/llm-cli/src/main.rs | 44 ++++++-------- crates/llm-base/src/lib.rs | 5 +- crates/llm-base/src/loader.rs | 3 +- crates/llm-base/src/model/mod.rs | 96 +----------------------------- crates/llm/examples/embeddings.rs | 1 - crates/llm/examples/inference.rs | 1 - crates/llm/examples/vicuna-chat.rs | 1 - crates/llm/src/lib.rs | 68 ++++++--------------- crates/models/bloom/src/lib.rs | 2 - crates/models/gpt2/src/lib.rs | 2 - crates/models/gptj/src/lib.rs | 2 - crates/models/gptneox/src/lib.rs | 2 - crates/models/llama/src/lib.rs | 2 - crates/models/mpt/src/lib.rs | 2 - 15 files changed, 40 insertions(+), 197 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 3d3d549a..06403a43 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -403,10 +403,7 @@ pub struct ModelLoad { pub lora_paths: Option>, } impl ModelLoad { - pub fn load( - &self, - overrides: Option, - ) -> Result> { + pub fn load(&self) -> Result> { let params = ModelParameters { prefer_mmap: !self.no_mmap, context_size: self.num_ctx_tokens, @@ -435,7 +432,6 @@ impl ModelLoad { &self.model_and_vocabulary.model_path, vocabulary_source, params, - overrides, |progress| match progress { LoadProgress::HyperparametersLoaded => { if let Some(sp) = sp.as_mut() { diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 7088ca3a..d61d62e5 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -25,37 +25,31 @@ fn main() -> Result<()> { let cli_args = Args::parse(); match &cli_args { - Args::Llama { args } => handle_args::(args, None), - Args::Bloom { args } => handle_args::(args, None), - Args::Gpt2 { args } => handle_args::(args, None), - Args::GptJ { args } => handle_args::(args, None), - Args::GptNeoX { args } => handle_args::(args, None), - Args::Mpt { args } => handle_args::(args, None), + Args::Llama { args } => handle_args::(args), + Args::Bloom { args } => handle_args::(args), + Args::Gpt2 { args } => handle_args::(args), + Args::GptJ { args } => handle_args::(args), + Args::GptNeoX { args } => handle_args::(args), + Args::Mpt { args } => handle_args::(args), } } -fn handle_args( - args: &cli_args::BaseArgs, - overrides: Option, -) -> Result<()> { +fn handle_args(args: &cli_args::BaseArgs) -> Result<()> { match args { - BaseArgs::Infer(args) => infer::(args, overrides), - BaseArgs::Perplexity(args) => perplexity::(args, overrides), + BaseArgs::Infer(args) => infer::(args), + BaseArgs::Perplexity(args) => perplexity::(args), BaseArgs::Info(args) => info::(args), BaseArgs::PromptTokens(args) => prompt_tokens::(args), - BaseArgs::Repl(args) => interactive::(args, overrides, false), - BaseArgs::Chat(args) => interactive::(args, overrides, true), + BaseArgs::Repl(args) => interactive::(args, false), + BaseArgs::Chat(args) => interactive::(args, true), BaseArgs::Quantize(args) => quantize::(args), } } -fn infer( - args: &cli_args::Infer, - overrides: Option, -) -> Result<()> { +fn infer(args: &cli_args::Infer) -> Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref()); let inference_session_config = args.generate.inference_session_config(); - let model = args.model_load.load::(overrides)?; + let model = args.model_load.load::()?; let (mut session, session_loaded) = snapshot::read_or_create_session( model.as_ref(), @@ -120,13 +114,10 @@ fn infer( Ok(()) } -fn perplexity( - args: &cli_args::Perplexity, - overrides: Option, -) -> Result<()> { +fn perplexity(args: &cli_args::Perplexity) -> Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref()); let inference_session_config = args.generate.inference_session_config(); - let model = args.model_load.load::(overrides)?; + let model = args.model_load.load::()?; let (mut session, _) = snapshot::read_or_create_session( model.as_ref(), None, @@ -191,7 +182,7 @@ fn info(args: &cli_args::Info) -> Result<()> { fn prompt_tokens(args: &cli_args::PromptTokens) -> Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref()); - let model = args.model_load.load::(None)?; + let model = args.model_load.load::()?; let toks = match model.vocabulary().tokenize(&prompt, false) { Ok(toks) => toks, Err(e) => { @@ -220,14 +211,13 @@ fn prompt_tokens(args: &cli_args::PromptTokens) -> fn interactive( args: &cli_args::Repl, - overrides: Option, // If set to false, the session will be cloned after each inference // to ensure that previous state is not carried over. chat_mode: bool, ) -> Result<()> { let prompt_file = args.prompt_file.contents(); let inference_session_config = args.generate.inference_session_config(); - let model = args.model_load.load::(overrides)?; + let model = args.model_load.load::()?; let (mut session, session_loaded) = snapshot::read_or_create_session( model.as_ref(), None, diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index cabe7d25..33b5197d 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -30,10 +30,7 @@ pub use loader::{ }; pub use lora::{LoraAdapter, LoraParameters}; pub use memmap2::Mmap; -pub use model::{ - Hyperparameters, KnownModel, Model, ModelDynamicOverrideValue, ModelDynamicOverrides, - ModelParameters, OutputRequest, -}; +pub use model::{Hyperparameters, KnownModel, Model, ModelParameters, OutputRequest}; pub use quantize::{quantize, QuantizeError, QuantizeProgress}; pub use regex::Regex; pub use util::TokenUtf8Buffer; diff --git a/crates/llm-base/src/loader.rs b/crates/llm-base/src/loader.rs index 665afe5a..28140882 100644 --- a/crates/llm-base/src/loader.rs +++ b/crates/llm-base/src/loader.rs @@ -363,7 +363,6 @@ pub fn load( path: &Path, vocabulary_source: VocabularySource, params: ModelParameters, - overrides: Option, load_progress_callback: impl FnMut(LoadProgress), ) -> Result { if !path.exists() { @@ -492,7 +491,7 @@ pub fn load( loaded_tensors: Default::default(), }; - let model = KnownModel::new(hyperparameters, params, overrides, vocabulary, tl)?; + let model = KnownModel::new(hyperparameters, params, vocabulary, tl)?; (load_progress_callback)(LoadProgress::Loaded { file_size, diff --git a/crates/llm-base/src/model/mod.rs b/crates/llm-base/src/model/mod.rs index dbb8376b..6c5b13e9 100644 --- a/crates/llm-base/src/model/mod.rs +++ b/crates/llm-base/src/model/mod.rs @@ -1,7 +1,6 @@ //! Large language model traits and types use std::{ - collections::HashMap, error::Error, fmt::Debug, io::{BufRead, Write}, @@ -9,7 +8,6 @@ use std::{ }; use regex::Regex; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; use thiserror::Error; use crate::{ @@ -20,96 +18,12 @@ use crate::{ /// Common functions for model evaluation pub mod common; -macro_rules! define_model_dynamic_override_value { - ($(($name:ident, $type:ty, $doc:literal)),*) => { - #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)] - #[serde(untagged)] - /// Valid value types for dynamic model overrides. - pub enum ModelDynamicOverrideValue { - $(#[doc=$doc] $name($type),)* - } - - $( - impl TryFrom for $type { - type Error = (); - - fn try_from(value: ModelDynamicOverrideValue) -> Result { - match value { - ModelDynamicOverrideValue::$name(value) => Ok(value), - _ => Err(()), - } - } - } - - impl From<$type> for ModelDynamicOverrideValue { - fn from(value: $type) -> Self { - Self::$name(value) - } - } - )* - }; -} - -define_model_dynamic_override_value!( - (Bool, bool, "A boolean value"), - (String, String, "A string value"), - (Int, i64, "An integer value"), - (Float, f64, "A float value") -); - -/// Model options that can be overridden by the user at runtime. -/// -/// Each model has its own set of options that can be overridden. -/// However, the calling code may not know the type of the model -/// at compile time. This type is used to store the overrides -/// for a model in a generic way. -#[derive(Debug, PartialEq, Serialize, Deserialize, Default, Clone)] -#[serde(transparent)] -pub struct ModelDynamicOverrides(pub HashMap); -impl ModelDynamicOverrides { - /// Get the value of the override with the given `key`. - pub fn get>(&self, key: &str) -> Option { - self.0 - .get(key) - .cloned() - .and_then(|value| T::try_from(value).ok()) - } - - /// Merge the overrides from `other` into this one. - pub fn merge(&mut self, other: impl Into) -> &mut Self { - self.0.extend(other.into().0.into_iter()); - self - } - - /// Insert a new override with the given `key` and `value`. - pub fn insert(&mut self, key: impl Into, value: impl Into) { - self.0.insert(key.into(), value.into()); - } -} -impl From for () { - fn from(_: ModelDynamicOverrides) -> Self {} -} -impl From<()> for ModelDynamicOverrides { - fn from(_: ()) -> Self { - Self::default() - } -} - /// Interfaces for creating and interacting with a large language model with a known type /// of [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning)). pub trait KnownModel: Send + Sync { /// Hyperparameters for the model. type Hyperparameters: Hyperparameters; - /// Model options that can be overridden by the user. - /// - /// If there are no options to override, use `()`. - type Overrides: Serialize - + DeserializeOwned - + Default - + From - + Into; - /// Load this model from the `path` and configure it per the `params`. The status /// of the loading process will be reported through `load_progress_callback`. This /// is a helper function on top of [llm_base::load](crate::load). @@ -117,19 +31,12 @@ pub trait KnownModel: Send + Sync { path: &Path, vocabulary_source: VocabularySource, params: ModelParameters, - overrides: Option, load_progress_callback: impl FnMut(LoadProgress), ) -> Result where Self: Sized, { - crate::load( - path, - vocabulary_source, - params, - overrides, - load_progress_callback, - ) + crate::load(path, vocabulary_source, params, load_progress_callback) } /// Creates a new model from the provided [ModelParameters] hyperparameters. @@ -137,7 +44,6 @@ pub trait KnownModel: Send + Sync { fn new( hyperparameters: Self::Hyperparameters, params: ModelParameters, - overrides: Option, vocabulary: Vocabulary, tensor_loader: impl TensorLoader, ) -> Result diff --git a/crates/llm/examples/embeddings.rs b/crates/llm/examples/embeddings.rs index afe538a4..6a4fcbdc 100644 --- a/crates/llm/examples/embeddings.rs +++ b/crates/llm/examples/embeddings.rs @@ -59,7 +59,6 @@ fn main() { &path, vocabulary_source, model_params, - None, llm::load_progress_callback_stdout, ) .unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}")); diff --git a/crates/llm/examples/inference.rs b/crates/llm/examples/inference.rs index 0fb3463a..93708d9e 100644 --- a/crates/llm/examples/inference.rs +++ b/crates/llm/examples/inference.rs @@ -43,7 +43,6 @@ fn main() { &path, vocabulary_source, Default::default(), - None, llm::load_progress_callback_stdout, ) .unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}")); diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index e23f42a4..e2f98966 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -35,7 +35,6 @@ fn main() { &path, vocabulary_source, Default::default(), - None, llm::load_progress_callback_stdout, ) .unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}")); diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index a83967f9..73939130 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -82,10 +82,9 @@ pub use llm_base::{ quantize, ElementType, FileType, FileTypeFormat, InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceStats, InvalidTokenBias, KnownModel, - LoadError, LoadProgress, Loader, Model, ModelDynamicOverrideValue, ModelDynamicOverrides, - ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, - SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Vocabulary, - VocabularySource, + LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, + Prompt, QuantizeError, QuantizeProgress, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, + TokenizationError, Vocabulary, VocabularySource, }; use serde::Serialize; @@ -229,7 +228,6 @@ pub fn load_dynamic( path: &Path, vocabulary_source: VocabularySource, params: ModelParameters, - overrides: Option, load_progress_callback: impl FnMut(LoadProgress), ) -> Result, LoadError> { use ModelArchitecture::*; @@ -238,67 +236,39 @@ pub fn load_dynamic( path: &Path, vocabulary_source: VocabularySource, params: ModelParameters, - overrides: Option, load_progress_callback: impl FnMut(LoadProgress), ) -> Result, LoadError> { Ok(Box::new(load::( path, vocabulary_source, params, - overrides.map(|o| o.into()), load_progress_callback, )?)) } let model: Box = match architecture { #[cfg(feature = "bloom")] - Bloom => load_model::( - path, - vocabulary_source, - params, - overrides, - load_progress_callback, - )?, + Bloom => { + load_model::(path, vocabulary_source, params, load_progress_callback)? + } #[cfg(feature = "gpt2")] - Gpt2 => load_model::( - path, - vocabulary_source, - params, - overrides, - load_progress_callback, - )?, + Gpt2 => { + load_model::(path, vocabulary_source, params, load_progress_callback)? + } #[cfg(feature = "gptj")] - GptJ => load_model::( - path, - vocabulary_source, - params, - overrides, - load_progress_callback, - )?, + GptJ => { + load_model::(path, vocabulary_source, params, load_progress_callback)? + } #[cfg(feature = "gptneox")] - GptNeoX => load_model::( - path, - vocabulary_source, - params, - overrides, - load_progress_callback, - )?, + GptNeoX => { + load_model::(path, vocabulary_source, params, load_progress_callback)? + } #[cfg(feature = "llama")] - Llama => load_model::( - path, - vocabulary_source, - params, - overrides, - load_progress_callback, - )?, + Llama => { + load_model::(path, vocabulary_source, params, load_progress_callback)? + } #[cfg(feature = "mpt")] - Mpt => load_model::( - path, - vocabulary_source, - params, - overrides, - load_progress_callback, - )?, + Mpt => load_model::(path, vocabulary_source, params, load_progress_callback)?, }; Ok(model) diff --git a/crates/models/bloom/src/lib.rs b/crates/models/bloom/src/lib.rs index db95d8f0..dafd8783 100644 --- a/crates/models/bloom/src/lib.rs +++ b/crates/models/bloom/src/lib.rs @@ -45,12 +45,10 @@ unsafe impl Sync for Bloom {} impl KnownModel for Bloom { type Hyperparameters = Hyperparameters; - type Overrides = (); fn new( hyperparameters: Self::Hyperparameters, params: ModelParameters, - _overrides: Option, vocabulary: Vocabulary, tensor_loader: impl llm_base::TensorLoader, ) -> Result { diff --git a/crates/models/gpt2/src/lib.rs b/crates/models/gpt2/src/lib.rs index 9d995b80..a21231da 100644 --- a/crates/models/gpt2/src/lib.rs +++ b/crates/models/gpt2/src/lib.rs @@ -44,12 +44,10 @@ unsafe impl Sync for Gpt2 {} impl KnownModel for Gpt2 { type Hyperparameters = Hyperparameters; - type Overrides = (); fn new( hyperparameters: Self::Hyperparameters, params: ModelParameters, - _overrides: Option, vocabulary: Vocabulary, tensor_loader: impl llm_base::TensorLoader, ) -> Result { diff --git a/crates/models/gptj/src/lib.rs b/crates/models/gptj/src/lib.rs index 4f7c46fb..fa4a4c2b 100644 --- a/crates/models/gptj/src/lib.rs +++ b/crates/models/gptj/src/lib.rs @@ -45,12 +45,10 @@ unsafe impl Sync for GptJ {} impl KnownModel for GptJ { type Hyperparameters = Hyperparameters; - type Overrides = (); fn new( hyperparameters: Self::Hyperparameters, params: ModelParameters, - _overrides: Option, vocabulary: Vocabulary, tensor_loader: impl TensorLoader, ) -> Result diff --git a/crates/models/gptneox/src/lib.rs b/crates/models/gptneox/src/lib.rs index 78d4efd5..02f4649b 100644 --- a/crates/models/gptneox/src/lib.rs +++ b/crates/models/gptneox/src/lib.rs @@ -45,12 +45,10 @@ unsafe impl Sync for GptNeoX {} impl KnownModel for GptNeoX { type Hyperparameters = Hyperparameters; - type Overrides = (); fn new( hyperparameters: Hyperparameters, params: ModelParameters, - _overrides: Option, vocabulary: Vocabulary, tensor_loader: impl TensorLoader, ) -> Result diff --git a/crates/models/llama/src/lib.rs b/crates/models/llama/src/lib.rs index 59a3b315..21d78fa6 100644 --- a/crates/models/llama/src/lib.rs +++ b/crates/models/llama/src/lib.rs @@ -42,12 +42,10 @@ unsafe impl Sync for Llama {} impl KnownModel for Llama { type Hyperparameters = Hyperparameters; - type Overrides = (); fn new( hyperparameters: Self::Hyperparameters, params: ModelParameters, - _overrides: Option, vocabulary: Vocabulary, tensor_loader: impl TensorLoader, ) -> Result { diff --git a/crates/models/mpt/src/lib.rs b/crates/models/mpt/src/lib.rs index 0ae6dc2e..888818b0 100644 --- a/crates/models/mpt/src/lib.rs +++ b/crates/models/mpt/src/lib.rs @@ -39,12 +39,10 @@ unsafe impl Sync for Mpt {} impl KnownModel for Mpt { type Hyperparameters = Hyperparameters; - type Overrides = (); fn new( hyperparameters: Self::Hyperparameters, params: ModelParameters, - _overrides: Option, vocabulary: Vocabulary, tensor_loader: impl llm_base::TensorLoader, ) -> Result { From 79a214f9c56e10cea017cd97b49d0226339dcf0e Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 29 May 2023 23:40:40 +0200 Subject: [PATCH 18/20] doctest. --- crates/llm/src/lib.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 73939130..53a2639f 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -25,8 +25,6 @@ //! llm::VocabularySource::Model, //! // llm::ModelParameters //! Default::default(), -//! // llm::KnownModel::Overrides -//! None, //! // load progress callback //! llm::load_progress_callback_stdout //! ) From c5ce2d135bd171324bfd6aeb371756b6d4b3de5c Mon Sep 17 00:00:00 2001 From: Philpax Date: Mon, 29 May 2023 23:44:51 +0200 Subject: [PATCH 19/20] doc. --- crates/llm/src/lib.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 53a2639f..63f8e5fa 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -216,10 +216,6 @@ impl Display for ModelArchitecture { /// A helper function that loads the specified model from disk using an architecture /// specified at runtime. /// -/// The `overrides` will attempt to deserialize to the [KnownModel::Overrides] type -/// for that model. If the model does not support overrides, this will be an empty -/// struct. If the overrides are invalid, this will return an error. -/// /// A wrapper around [load] that dispatches to the correct model. pub fn load_dynamic( architecture: ModelArchitecture, From bf30f841edd802ea934db7ba820538164b9fa5f3 Mon Sep 17 00:00:00 2001 From: Philpax Date: Tue, 30 May 2023 00:24:15 +0200 Subject: [PATCH 20/20] revise example arguments --- crates/llm/examples/embeddings.rs | 16 +++++++++------- crates/llm/examples/inference.rs | 16 +++++++++------- crates/llm/examples/vicuna-chat.rs | 16 +++++++++------- 3 files changed, 27 insertions(+), 21 deletions(-) diff --git a/crates/llm/examples/embeddings.rs b/crates/llm/examples/embeddings.rs index 6a4fcbdc..4e092bc3 100644 --- a/crates/llm/examples/embeddings.rs +++ b/crates/llm/examples/embeddings.rs @@ -4,8 +4,8 @@ use clap::Parser; #[derive(Parser)] struct Args { - architecture: String, - path: PathBuf, + model_architecture: llm::ModelArchitecture, + model_path: PathBuf, #[arg(long, short = 'v')] pub vocabulary_path: Option, #[arg(long, short = 'r')] @@ -32,8 +32,8 @@ fn main() { let args = Args::parse(); let vocabulary_source = args.to_vocabulary_source(); - let architecture = args.architecture.parse().unwrap(); - let path = args.path; + let model_architecture = args.model_architecture; + let model_path = args.model_path; let query = args .query .as_deref() @@ -55,13 +55,15 @@ fn main() { lora_adapters: None, }; let model = llm::load_dynamic( - architecture, - &path, + model_architecture, + &model_path, vocabulary_source, model_params, llm::load_progress_callback_stdout, ) - .unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}")); + .unwrap_or_else(|err| { + panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") + }); let inference_parameters = llm::InferenceParameters::default(); // Generate embeddings for query and comparands diff --git a/crates/llm/examples/inference.rs b/crates/llm/examples/inference.rs index 93708d9e..d2385b8c 100644 --- a/crates/llm/examples/inference.rs +++ b/crates/llm/examples/inference.rs @@ -3,8 +3,8 @@ use std::{convert::Infallible, io::Write, path::PathBuf}; #[derive(Parser)] struct Args { - architecture: String, - path: PathBuf, + model_architecture: llm::ModelArchitecture, + model_path: PathBuf, #[arg(long, short = 'p')] prompt: Option, #[arg(long, short = 'v')] @@ -29,8 +29,8 @@ fn main() { let args = Args::parse(); let vocabulary_source = args.to_vocabulary_source(); - let architecture = args.architecture.parse().unwrap(); - let path = args.path; + let model_architecture = args.model_architecture; + let model_path = args.model_path; let prompt = args .prompt .as_deref() @@ -39,13 +39,15 @@ fn main() { let now = std::time::Instant::now(); let model = llm::load_dynamic( - architecture, - &path, + model_architecture, + &model_path, vocabulary_source, Default::default(), llm::load_progress_callback_stdout, ) - .unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}")); + .unwrap_or_else(|err| { + panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") + }); println!( "Model fully loaded! Elapsed: {}ms", diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index e2f98966..98d94606 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -4,8 +4,8 @@ use std::{convert::Infallible, io::Write, path::PathBuf}; #[derive(Parser)] struct Args { - architecture: String, - path: PathBuf, + model_architecture: llm::ModelArchitecture, + model_path: PathBuf, #[arg(long, short = 'v')] pub vocabulary_path: Option, #[arg(long, short = 'r')] @@ -28,16 +28,18 @@ fn main() { let args = Args::parse(); let vocabulary_source = args.to_vocabulary_source(); - let architecture = args.architecture.parse().unwrap(); - let path = args.path; + let model_architecture = args.model_architecture; + let model_path = args.model_path; let model = llm::load_dynamic( - architecture, - &path, + model_architecture, + &model_path, vocabulary_source, Default::default(), llm::load_progress_callback_stdout, ) - .unwrap_or_else(|err| panic!("Failed to load {architecture} model from {path:?}: {err}")); + .unwrap_or_else(|err| { + panic!("Failed to load {model_architecture} model from {model_path:?}: {err}") + }); let mut session = model.start_session(Default::default());