diff --git a/rust/Cargo.lock b/rust/Cargo.lock new file mode 100644 index 000000000..2e58e9322 --- /dev/null +++ b/rust/Cargo.lock @@ -0,0 +1,1814 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "ahash" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f" +dependencies = [ + "cfg-if", + "once_cell", + "version_check", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstream" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is-terminal", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" + +[[package]] +name = "anstyle-parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ca11d4be1bab0c8bc8734a9aa7bf4ee8316d462a08c6ac5052f888fef5b494b" +dependencies = [ + "windows-sys 0.48.0", +] + +[[package]] +name = "anstyle-wincon" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" +dependencies = [ + "anstyle", + "windows-sys 0.48.0", +] + +[[package]] +name = "anyhow" +version = "1.0.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c7d0618f0e0b7e8ff11427422b64564d5fb0be1940354bfe2e0529b18a9d9b8" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "base64" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "build_and_insert_delete_memory_index" +version = "0.1.0" +dependencies = [ + "diskann", + "logger", + "vector", +] + +[[package]] +name = "build_and_insert_memory_index" +version = "0.1.0" +dependencies = [ + "diskann", + "logger", + "vector", +] + +[[package]] +name = "build_disk_index" +version = "0.1.0" +dependencies = [ + "diskann", + "logger", + "openblas-src", + "vector", +] + +[[package]] +name = "build_memory_index" +version = "0.1.0" +dependencies = [ + "clap", + "diskann", + "logger", + "vector", +] + +[[package]] +name = "bumpalo" +version = "3.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3e2c3daef883ecc1b5d58c15adae93470a91d425f3532ba1695849656af3fc1" + +[[package]] +name = "bytemuck" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" + +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + +[[package]] +name = "bytes" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cblas" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3de46dff748ed7e891bc46faae117f48d2a7911041c6630aed3c61a3fe12326f" +dependencies = [ + "cblas-sys", + "libc", + "num-complex", +] + +[[package]] +name = "cblas-sys" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b6feecd82cce51b0204cf063f0041d69f24ce83f680d87514b004248e7b0fa65" +dependencies = [ + "libc", +] + +[[package]] +name = "cc" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "ciborium" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" + +[[package]] +name = "ciborium-ll" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" +dependencies = [ + "ciborium-io", + "half 1.8.2", +] + +[[package]] +name = "clap" +version = "4.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9394150f5b4273a1763355bd1c2ec54cc5a2593f790587bcd6b2c947cfa9211" +dependencies = [ + "clap_builder", + "clap_derive", + "once_cell", +] + +[[package]] +name = "clap_builder" +version = "4.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a78fbdd3cc2914ddf37ba444114bc7765bbdcb55ec9cbe6fa054f0137400717" +dependencies = [ + "anstream", + "anstyle", + "bitflags", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8cd2b2a819ad6eec39e8f1d6b53001af1e5469f8c177579cdaeb313115b825f" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn 2.0.18", +] + +[[package]] +name = "clap_lex" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" + +[[package]] +name = "colorchoice" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" + +[[package]] +name = "convert_f32_to_bf16" +version = "0.1.0" +dependencies = [ + "half 2.2.1", +] + +[[package]] +name = "core-foundation" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "core-foundation-sys" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" + +[[package]] +name = "crc32fast" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2801af0d36612ae591caa9568261fddce32ce6e08a7275ea334a06a4ad021a2c" +dependencies = [ + "cfg-if", + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", + "memoffset", + "scopeguard", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cfb3ea8a53f37c40dea2c7bedcbd88bdfae54f5e2175d6ecaff1c988353add" +dependencies = [ + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "dirs" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30baa043103c9d0c2a57cf537cc2f35623889dc0d405e6c3cccfadbc81c71309" +dependencies = [ + "dirs-sys", +] + +[[package]] +name = "dirs-sys" +version = "0.3.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b1d1d91c932ef41c0f2663aa8b0ca0342d444d842c06914aa0a7e352d0bada6" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + +[[package]] +name = "diskann" +version = "0.1.0" +dependencies = [ + "approx", + "bincode", + "bit-vec", + "byteorder", + "cblas", + "cc", + "criterion", + "crossbeam", + "half 2.2.1", + "hashbrown 0.13.2", + "logger", + "num-traits", + "once_cell", + "openblas-src", + "platform", + "rand", + "rayon", + "serde", + "thiserror", + "vector", + "winapi", +] + +[[package]] +name = "either" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" + +[[package]] +name = "errno" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +dependencies = [ + "errno-dragonfly", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "errno-dragonfly" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "fastrand" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be" +dependencies = [ + "instant", +] + +[[package]] +name = "filetime" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.2.16", + "windows-sys 0.48.0", +] + +[[package]] +name = "fixedbitset" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" + +[[package]] +name = "flate2" +version = "1.0.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b9429470923de8e8cbd4d2dc513535400b4b3fef0319fb5c4e1f520a7bef743" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + +[[package]] +name = "form_urlencoded" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a62bc1cf6f830c2ec14a513a9fb124d0a213a629668a4186f329db21fe045652" +dependencies = [ + "percent-encoding", +] + +[[package]] +name = "getrandom" +version = "0.2.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + +[[package]] +name = "half" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0" +dependencies = [ + "crunchy", +] + +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + +[[package]] +name = "hashbrown" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43a3c133739dddd0d2990f9a4bdf8eb4b21ef50e4851ca85ab661199821d510e" +dependencies = [ + "ahash", +] + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7" +dependencies = [ + "libc", +] + +[[package]] +name = "hermit-abi" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fed44880c466736ef9a5c5b5facefb5ed0785676d0c02d612db14e54f0d84286" + +[[package]] +name = "idna" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d20d6b07bfbc108882d88ed8e37d39636dcc260e15e30c45e6ba089610b917c" +dependencies = [ + "unicode-bidi", + "unicode-normalization", +] + +[[package]] +name = "indexmap" +version = "1.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +dependencies = [ + "autocfg", + "hashbrown 0.12.3", +] + +[[package]] +name = "instant" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "io-lifetimes" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" +dependencies = [ + "hermit-abi 0.3.1", + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "is-terminal" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adcf93614601c8129ddf72e2d5633df827ba6551541c6d8c59520a371475be1f" +dependencies = [ + "hermit-abi 0.3.1", + "io-lifetimes", + "rustix", + "windows-sys 0.48.0", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "453ad9f582a441959e5f0d088b02ce04cfe8d51a8eaf077f12ac6d3e94164ca6" + +[[package]] +name = "js-sys" +version = "0.3.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5f195fe497f702db0f318b07fdd68edb16955aed830df8363d837542f8f935a" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.146" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f92be4933c13fd498862a9e02a3055f8a8d9c039ce33db97306fd5a6caa7f29b" + +[[package]] +name = "linux-raw-sys" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" + +[[package]] +name = "load_and_insert_memory_index" +version = "0.1.0" +dependencies = [ + "diskann", + "logger", + "vector", +] + +[[package]] +name = "log" +version = "0.4.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" + +[[package]] +name = "logger" +version = "0.1.0" +dependencies = [ + "lazy_static", + "log", + "once_cell", + "prost", + "prost-build", + "prost-types", + "thiserror", + "vcpkg", + "win_etw_macros", + "win_etw_provider", +] + +[[package]] +name = "memoffset" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "miniz_oxide" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +dependencies = [ + "adler", +] + +[[package]] +name = "multimap" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a" + +[[package]] +name = "native-tls" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" +dependencies = [ + "lazy_static", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + +[[package]] +name = "num-complex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_cpus" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b" +dependencies = [ + "hermit-abi 0.2.6", + "libc", +] + +[[package]] +name = "once_cell" +version = "1.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" + +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + +[[package]] +name = "openblas-build" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eba42c395477605f400a8d79ee0b756cfb82abe3eb5618e35fa70d3a36010a7f" +dependencies = [ + "anyhow", + "flate2", + "native-tls", + "tar", + "thiserror", + "ureq", + "walkdir", +] + +[[package]] +name = "openblas-src" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38e5d8af0b707ac2fe1574daa88b4157da73b0de3dc7c39fe3e2c0bb64070501" +dependencies = [ + "dirs", + "openblas-build", + "vcpkg", +] + +[[package]] +name = "openssl" +version = "0.10.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.90" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374533b0e45f3a7ced10fcaeccca020e66656bc03dac384f852e4e5a7a8104a6" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "percent-encoding" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b2a4787296e9989611394c33f193f676704af1686e70b8f8033ab5ba9a35a94" + +[[package]] +name = "petgraph" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" +dependencies = [ + "fixedbitset", + "indexmap", +] + +[[package]] +name = "pkg-config" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" + +[[package]] +name = "platform" +version = "0.1.0" +dependencies = [ + "log", + "winapi", +] + +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "prettyplease" +version = "0.1.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86" +dependencies = [ + "proc-macro2", + "syn 1.0.109", +] + +[[package]] +name = "proc-macro2" +version = "1.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec2b086b7a862cf4de201096214fa870344cf922b2b30c167badb3af3195406" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "prost" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b82eaa1d779e9a4bc1c3217db8ffbeabaae1dca241bf70183242128d48681cd" +dependencies = [ + "bytes", + "prost-derive", +] + +[[package]] +name = "prost-build" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "119533552c9a7ffacc21e099c24a0ac8bb19c2a2a3f363de84cd9b844feab270" +dependencies = [ + "bytes", + "heck", + "itertools", + "lazy_static", + "log", + "multimap", + "petgraph", + "prettyplease", + "prost", + "prost-types", + "regex", + "syn 1.0.109", + "tempfile", + "which", +] + +[[package]] +name = "prost-derive" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5d2d8d10f3c6ded6da8b05b5fb3b8a5082514344d56c9f871412d29b4e075b4" +dependencies = [ + "anyhow", + "itertools", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "prost-types" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213622a1460818959ac1181aaeb2dc9c7f63df720db7d788b3e24eacd1983e13" +dependencies = [ + "prost", +] + +[[package]] +name = "quote" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9ab9c7eadfd8df19006f1cf1a4aed13540ed5cbc047010ece5826e10825488" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rayon" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-utils", + "num_cpus", +] + +[[package]] +name = "redox_syscall" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_syscall" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "567664f262709473930a4bf9e51bf2ebf3348f2e748ccc50dea20646858f8f29" +dependencies = [ + "bitflags", +] + +[[package]] +name = "redox_users" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b033d837a7cf162d7993aded9304e30a83213c648b6e389db233191f891e5c2b" +dependencies = [ + "getrandom", + "redox_syscall 0.2.16", + "thiserror", +] + +[[package]] +name = "regex" +version = "1.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0ab3ca65655bb1e41f2a8c8cd662eb4fb035e67c3f78da1d61dffe89d07300f" +dependencies = [ + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78" + +[[package]] +name = "rustix" +version = "0.37.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b96e891d04aa506a6d1f318d2771bcb1c7dfda84e126660ace067c9b474bb2c0" +dependencies = [ + "bitflags", + "errno", + "io-lifetimes", + "libc", + "linux-raw-sys", + "windows-sys 0.48.0", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" +dependencies = [ + "base64", +] + +[[package]] +name = "ryu" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f91339c0467de62360649f8d3e185ca8de4224ff281f66000de5eb2a77a79041" + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "schannel" +version = "0.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "713cfb06c7059f3588fb8044c0fad1d09e3c01d225e25b9220dbfdcf16dbb1b3" +dependencies = [ + "windows-sys 0.42.0", +] + +[[package]] +name = "scopeguard" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" + +[[package]] +name = "search_memory_index" +version = "0.1.0" +dependencies = [ + "bytemuck", + "diskann", + "num_cpus", + "rayon", + "vector", +] + +[[package]] +name = "security-framework" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fc758eb7bffce5b308734e9b0c1468893cae9ff70ebf13e7090be8dcbcc83a8" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f51d0c0d83bec45f16480d0ce0058397a69e48fcdc52d1dc8855fb68acbd31a7" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "serde" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e8c8cf938e98f769bc164923b06dce91cea1751522f46f8466461af04c9027d" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.164" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9735b638ccc51c28bf6914d90a2e9725b377144fc612c49a611fddd1b631d68" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] + +[[package]] +name = "serde_json" +version = "1.0.97" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bdf3bf93142acad5821c99197022e170842cdbc1c30482b98750c688c640842a" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + +[[package]] +name = "strsim" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "syn" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32d41677bcbe24c20c52e7c70b0d8db04134c5d1066bf98662e2871ad200ea3e" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tar" +version = "0.4.38" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b55807c0344e1e6c04d7c965f5289c39a8d94ae23ed5c0b57aabac549f871c6" +dependencies = [ + "filetime", + "libc", + "xattr", +] + +[[package]] +name = "tempfile" +version = "3.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6" +dependencies = [ + "autocfg", + "cfg-if", + "fastrand", + "redox_syscall 0.3.5", + "rustix", + "windows-sys 0.48.0", +] + +[[package]] +name = "thiserror" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978c9a314bd8dc99be594bc3c175faaa9794be04a5a5e153caba6915336cebac" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9456a42c5b0d803c8cd86e73dd7cc9edd429499f37a3550d286d5e86720569f" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "unicode-bidi" +version = "0.3.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92888ba5573ff080736b3648696b70cafad7d250551175acbaa4e0385b3e1460" + +[[package]] +name = "unicode-ident" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" + +[[package]] +name = "unicode-normalization" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +dependencies = [ + "tinyvec", +] + +[[package]] +name = "ureq" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" +dependencies = [ + "base64", + "flate2", + "log", + "native-tls", + "once_cell", + "rustls-native-certs", + "url", +] + +[[package]] +name = "url" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50bff7831e19200a85b17131d085c25d7811bc4e186efdaf54bbd132994a88cb" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + +[[package]] +name = "uuid" +version = "1.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa2982af2eec27de306107c027578ff7f423d65f7250e40ce0fea8f45248b81" +dependencies = [ + "sha1_smol", +] + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "vector" +version = "0.1.0" +dependencies = [ + "approx", + "base64", + "bincode", + "bytemuck", + "cc", + "half 2.2.1", + "rand", + "serde", + "thiserror", +] + +[[package]] +name = "vector_base64" +version = "0.1.0" +dependencies = [ + "base64", + "bincode", + "half 2.2.1", + "serde", +] + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "w32-error" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7c61a6bd91e168c12fc170985725340f6b458eb6f971d1cf6c34f74ffafb43" +dependencies = [ + "winapi", +] + +[[package]] +name = "walkdir" +version = "2.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36df944cda56c7d8d8b7496af378e6b16de9284591917d307c9b4d313c44e698" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ef2b6d3c510e9625e5fe6f509ab07d66a760f0885d858736483c32ed7809abd" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.18", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dee495e55982a3bd48105a7b947fd2a9b4a8ae3010041b9e0faab3f9cd028f1d" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54681b18a46765f095758388f2d0cf16eb8d4169b639ab575a8f5693af210c7b" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.18", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.87" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" + +[[package]] +name = "web-sys" +version = "0.3.64" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "which" +version = "4.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2441c784c52b289a054b7201fc93253e288f094e2f4be9058343127c4226a269" +dependencies = [ + "either", + "libc", + "once_cell", +] + +[[package]] +name = "widestring" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "653f141f39ec16bba3c5abe400a0c60da7468261cc2cbf36805022876bc721a8" + +[[package]] +name = "win_etw_macros" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1bc4c591edb4858e3445f7a60c7e0a50915aedadfa044f28f17c98c145ef54d" +dependencies = [ + "proc-macro2", + "quote", + "sha1_smol", + "syn 1.0.109", + "uuid", + "win_etw_metadata", +] + +[[package]] +name = "win_etw_metadata" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e50d0fa665033a19ecefd281b4fb5481eba2972dedbb5ec129c9392a206d652f" +dependencies = [ + "bitflags", +] + +[[package]] +name = "win_etw_provider" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dffcc196e0e180e73a275a91f6914f173227fd627cabac3efdd8d6adec113892" +dependencies = [ + "w32-error", + "widestring", + "win_etw_metadata", + "winapi", + "zerocopy", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.42.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7" +dependencies = [ + "windows_aarch64_gnullvm 0.42.2", + "windows_aarch64_msvc 0.42.2", + "windows_i686_gnu 0.42.2", + "windows_i686_msvc 0.42.2", + "windows_x86_64_gnu 0.42.2", + "windows_x86_64_gnullvm 0.42.2", + "windows_x86_64_msvc 0.42.2", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b1eb6f0cd7c80c79759c929114ef071b87354ce476d9d94271031c0497adfd5" +dependencies = [ + "windows_aarch64_gnullvm 0.48.0", + "windows_aarch64_msvc 0.48.0", + "windows_i686_gnu 0.48.0", + "windows_i686_msvc 0.48.0", + "windows_x86_64_gnu 0.48.0", + "windows_x86_64_gnullvm 0.48.0", + "windows_x86_64_msvc 0.48.0", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "597a5118570b68bc08d8d59125332c54f1ba9d9adeedeef5b99b02ba2b0698f8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e08e8864a60f06ef0d0ff4ba04124db8b0fb3be5776a5cd47641e942e58c4d43" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" + +[[package]] +name = "windows_i686_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c61d927d8da41da96a81f029489353e68739737d3beca43145c8afec9a31a84f" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" + +[[package]] +name = "windows_i686_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44d840b6ec649f480a41c8d80f9c65108b92d89345dd94027bfe06ac444d1060" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8de912b8b8feb55c064867cf047dda097f92d51efad5b491dfb98f6bbb70cb36" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26d41b46a36d453748aedef1486d5c7a85db22e56aff34643984ea85514e94a3" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.42.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9aec5da331524158c6d1a4ac0ab1541149c0b9505fde06423b02f5ef0106b9f0" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" + +[[package]] +name = "xattr" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d1526bbe5aaeb5eb06885f4d987bcdfa5e23187055de9b83fe00156a821fabc" +dependencies = [ + "libc", +] + +[[package]] +name = "zerocopy" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "332f188cc1bcf1fe1064b8c58d150f497e697f49774aa846f2dc949d9a25f236" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6505e6815af7de1746a08f69c69606bb45695a17149517680f3b2149713b19a3" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] diff --git a/rust/Cargo.toml b/rust/Cargo.toml new file mode 100644 index 000000000..5236f96a0 --- /dev/null +++ b/rust/Cargo.toml @@ -0,0 +1,23 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. + +[workspace] +members = [ + "cmd_drivers/build_memory_index", + "cmd_drivers/build_and_insert_memory_index", + "cmd_drivers/load_and_insert_memory_index", + "cmd_drivers/convert_f32_to_bf16", + "cmd_drivers/search_memory_index", + "cmd_drivers/build_disk_index", + "cmd_drivers/build_and_insert_delete_memory_index", + "vector", + "diskann", + "platform", + "logger", + "vector_base64" +] +resolver = "2" + +[profile.release] +opt-level = 3 +codegen-units=1 diff --git a/rust/cmd_drivers/build_and_insert_delete_memory_index/Cargo.toml b/rust/cmd_drivers/build_and_insert_delete_memory_index/Cargo.toml new file mode 100644 index 000000000..42aa1851a --- /dev/null +++ b/rust/cmd_drivers/build_and_insert_delete_memory_index/Cargo.toml @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "build_and_insert_delete_memory_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +diskann = { path = "../../diskann" } +logger = { path = "../../logger" } +vector = { path = "../../vector" } + diff --git a/rust/cmd_drivers/build_and_insert_delete_memory_index/src/main.rs b/rust/cmd_drivers/build_and_insert_delete_memory_index/src/main.rs new file mode 100644 index 000000000..4593a9ed5 --- /dev/null +++ b/rust/cmd_drivers/build_and_insert_delete_memory_index/src/main.rs @@ -0,0 +1,420 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::env; + +use diskann::{ + common::{ANNError, ANNResult}, + index::create_inmem_index, + model::{ + configuration::index_write_parameters::IndexWriteParametersBuilder, + vertex::{DIM_104, DIM_128, DIM_256}, + IndexConfiguration, + }, + utils::round_up, + utils::{file_exists, load_ids_to_delete_from_file, load_metadata_from_file, Timer}, +}; + +use vector::{FullPrecisionDistance, Half, Metric}; + +// The main function to build an in-memory index +#[allow(clippy::too_many_arguments)] +fn build_and_insert_delete_in_memory_index( + metric: Metric, + data_path: &str, + delta_path: &str, + r: u32, + l: u32, + alpha: f32, + save_path: &str, + num_threads: u32, + _use_pq_build: bool, + _num_pq_bytes: usize, + use_opq: bool, + delete_path: &str, +) -> ANNResult<()> +where + T: Default + Copy + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + let index_write_parameters = IndexWriteParametersBuilder::new(l, r) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + let (data_num, data_dim) = load_metadata_from_file(data_path)?; + + let config = IndexConfiguration::new( + metric, + data_dim, + round_up(data_dim as u64, 8_u64) as usize, + data_num, + false, + 0, + use_opq, + 0, + 2.0f32, + index_write_parameters, + ); + let mut index = create_inmem_index::(config)?; + + let timer = Timer::new(); + + index.build(data_path, data_num)?; + + let diff = timer.elapsed(); + + println!("Initial indexing time: {}", diff.as_secs_f64()); + + let (delta_data_num, _) = load_metadata_from_file(delta_path)?; + + index.insert(delta_path, delta_data_num)?; + + if !delete_path.is_empty() { + if !file_exists(delete_path) { + return Err(ANNError::log_index_error(format!( + "ERROR: Data file for delete {} does not exist.", + delete_path + ))); + } + + let (num_points_to_delete, vertex_ids_to_delete) = + load_ids_to_delete_from_file(delete_path)?; + index.soft_delete(vertex_ids_to_delete, num_points_to_delete)?; + } + + index.save(save_path)?; + + Ok(()) +} + +fn main() -> ANNResult<()> { + let mut data_type = String::new(); + let mut dist_fn = String::new(); + let mut data_path = String::new(); + let mut insert_path = String::new(); + let mut index_path_prefix = String::new(); + let mut delete_path = String::new(); + + let mut num_threads = 0u32; + let mut r = 64u32; + let mut l = 100u32; + + let mut alpha = 1.2f32; + let mut build_pq_bytes = 0u32; + let mut _use_pq_build = false; + let mut use_opq = false; + + let args: Vec = env::args().collect(); + let mut iter = args.iter().skip(1).peekable(); + + while let Some(arg) = iter.next() { + match arg.as_str() { + "--help" | "-h" => { + print_help(); + return Ok(()); + } + "--data_type" => { + data_type = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_type".to_string(), + "Missing data type".to_string(), + ) + })? + .to_owned(); + } + "--dist_fn" => { + dist_fn = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "dist_fn".to_string(), + "Missing distance function".to_string(), + ) + })? + .to_owned(); + } + "--data_path" => { + data_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_path".to_string(), + "Missing data path".to_string(), + ) + })? + .to_owned(); + } + "--insert_path" => { + insert_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "insert_path".to_string(), + "Missing insert path".to_string(), + ) + })? + .to_owned(); + } + "--index_path_prefix" => { + index_path_prefix = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "index_path_prefix".to_string(), + "Missing index path prefix".to_string(), + ) + })? + .to_owned(); + } + "--max_degree" | "-R" => { + r = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "max_degree".to_string(), + "Missing max degree".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "max_degree".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--Lbuild" | "-L" => { + l = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + "Missing build complexity".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--alpha" => { + alpha = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "alpha".to_string(), + "Missing alpha".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "alpha".to_string(), + format!("ParseFloatError: {}", err), + ) + })?; + } + "--num_threads" | "-T" => { + num_threads = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "num_threads".to_string(), + "Missing number of threads".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "num_threads".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--build_PQ_bytes" => { + build_pq_bytes = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + "Missing PQ bytes".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--use_opq" => { + use_opq = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "use_opq".to_string(), + "Missing use_opq flag".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "use_opq".to_string(), + format!("ParseBoolError: {}", err), + ) + })?; + } + "--delete_path" => { + delete_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "delete_path".to_string(), + "Missing delete_path".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "delete_set_path".to_string(), + format!("ParseStringError: {}", err), + ) + })?; + } + _ => { + return Err(ANNError::log_index_config_error( + String::from(""), + format!("Unknown argument: {}", arg), + )); + } + } + } + + if data_type.is_empty() + || dist_fn.is_empty() + || data_path.is_empty() + || index_path_prefix.is_empty() + { + return Err(ANNError::log_index_config_error( + String::from(""), + "Missing required arguments".to_string(), + )); + } + + _use_pq_build = build_pq_bytes > 0; + + let metric = dist_fn + .parse::() + .map_err(|err| ANNError::log_index_config_error("dist_fn".to_string(), err.to_string()))?; + + println!( + "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}", + r, l, alpha, num_threads + ); + + match data_type.as_str() { + "int8" => { + build_and_insert_delete_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + &delete_path, + )?; + } + "uint8" => { + build_and_insert_delete_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + &delete_path, + )?; + } + "float" => { + build_and_insert_delete_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + &delete_path, + )?; + } + "f16" => { + build_and_insert_delete_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + &delete_path, + )?; + } + _ => { + println!("Unsupported type. Use one of int8, uint8 or float."); + return Err(ANNError::log_index_config_error( + "data_type".to_string(), + "Invalid data type".to_string(), + )); + } + } + + Ok(()) +} + +fn print_help() { + println!("Arguments"); + println!("--help, -h Print information on arguments"); + println!("--data_type data type (required)"); + println!("--dist_fn distance function (required)"); + println!( + "--data_path Input data file in bin format for initial build (required)" + ); + println!("--insert_path Input data file in bin format for insert (required)"); + println!("--index_path_prefix Path prefix for saving index file components (required)"); + println!("--max_degree, -R Maximum graph degree (default: 64)"); + println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)"); + println!("--alpha alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter (default: 1.2)"); + println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)"); + println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)"); + println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)"); +} + diff --git a/rust/cmd_drivers/build_and_insert_memory_index/Cargo.toml b/rust/cmd_drivers/build_and_insert_memory_index/Cargo.toml new file mode 100644 index 000000000..d9811fc22 --- /dev/null +++ b/rust/cmd_drivers/build_and_insert_memory_index/Cargo.toml @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "build_and_insert_memory_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +diskann = { path = "../../diskann" } +logger = { path = "../../logger" } +vector = { path = "../../vector" } + diff --git a/rust/cmd_drivers/build_and_insert_memory_index/src/main.rs b/rust/cmd_drivers/build_and_insert_memory_index/src/main.rs new file mode 100644 index 000000000..46e4ba4a4 --- /dev/null +++ b/rust/cmd_drivers/build_and_insert_memory_index/src/main.rs @@ -0,0 +1,382 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::env; + +use diskann::{ + common::{ANNResult, ANNError}, + index::create_inmem_index, + utils::round_up, + model::{ + IndexWriteParametersBuilder, + IndexConfiguration, + vertex::{DIM_128, DIM_256, DIM_104} + }, + utils::{load_metadata_from_file, Timer}, +}; + +use vector::{Metric, FullPrecisionDistance, Half}; + +// The main function to build an in-memory index +#[allow(clippy::too_many_arguments)] +fn build_and_insert_in_memory_index ( + metric: Metric, + data_path: &str, + delta_path: &str, + r: u32, + l: u32, + alpha: f32, + save_path: &str, + num_threads: u32, + _use_pq_build: bool, + _num_pq_bytes: usize, + use_opq: bool +) -> ANNResult<()> +where + T: Default + Copy + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance +{ + let index_write_parameters = IndexWriteParametersBuilder::new(l, r) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + let (data_num, data_dim) = load_metadata_from_file(data_path)?; + + let config = IndexConfiguration::new( + metric, + data_dim, + round_up(data_dim as u64, 8_u64) as usize, + data_num, + false, + 0, + use_opq, + 0, + 2.0f32, + index_write_parameters, + ); + let mut index = create_inmem_index::(config)?; + + let timer = Timer::new(); + + index.build(data_path, data_num)?; + + let diff = timer.elapsed(); + + println!("Initial indexing time: {}", diff.as_secs_f64()); + + let (delta_data_num, _) = load_metadata_from_file(delta_path)?; + + index.insert(delta_path, delta_data_num)?; + + index.save(save_path)?; + + Ok(()) +} + +fn main() -> ANNResult<()> { + let mut data_type = String::new(); + let mut dist_fn = String::new(); + let mut data_path = String::new(); + let mut insert_path = String::new(); + let mut index_path_prefix = String::new(); + + let mut num_threads = 0u32; + let mut r = 64u32; + let mut l = 100u32; + + let mut alpha = 1.2f32; + let mut build_pq_bytes = 0u32; + let mut _use_pq_build = false; + let mut use_opq = false; + + let args: Vec = env::args().collect(); + let mut iter = args.iter().skip(1).peekable(); + + while let Some(arg) = iter.next() { + match arg.as_str() { + "--help" | "-h" => { + print_help(); + return Ok(()); + } + "--data_type" => { + data_type = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_type".to_string(), + "Missing data type".to_string(), + ) + })? + .to_owned(); + } + "--dist_fn" => { + dist_fn = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "dist_fn".to_string(), + "Missing distance function".to_string(), + ) + })? + .to_owned(); + } + "--data_path" => { + data_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_path".to_string(), + "Missing data path".to_string(), + ) + })? + .to_owned(); + } + "--insert_path" => { + insert_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "insert_path".to_string(), + "Missing insert path".to_string(), + ) + })? + .to_owned(); + } + "--index_path_prefix" => { + index_path_prefix = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "index_path_prefix".to_string(), + "Missing index path prefix".to_string(), + ) + })? + .to_owned(); + } + "--max_degree" | "-R" => { + r = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "max_degree".to_string(), + "Missing max degree".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "max_degree".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--Lbuild" | "-L" => { + l = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + "Missing build complexity".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--alpha" => { + alpha = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "alpha".to_string(), + "Missing alpha".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "alpha".to_string(), + format!("ParseFloatError: {}", err), + ) + })?; + } + "--num_threads" | "-T" => { + num_threads = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "num_threads".to_string(), + "Missing number of threads".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "num_threads".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--build_PQ_bytes" => { + build_pq_bytes = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + "Missing PQ bytes".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--use_opq" => { + use_opq = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "use_opq".to_string(), + "Missing use_opq flag".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "use_opq".to_string(), + format!("ParseBoolError: {}", err), + ) + })?; + } + _ => { + return Err(ANNError::log_index_config_error( + String::from(""), + format!("Unknown argument: {}", arg), + )); + } + } + } + + if data_type.is_empty() + || dist_fn.is_empty() + || data_path.is_empty() + || index_path_prefix.is_empty() + { + return Err(ANNError::log_index_config_error( + String::from(""), + "Missing required arguments".to_string(), + )); + } + + _use_pq_build = build_pq_bytes > 0; + + let metric = dist_fn + .parse::() + .map_err(|err| ANNError::log_index_config_error( + "dist_fn".to_string(), + err.to_string(), + ))?; + + println!( + "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}", + r, l, alpha, num_threads + ); + + match data_type.as_str() { + "int8" => { + build_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "uint8" => { + build_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "float" => { + build_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "f16" => { + build_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + _ => { + println!("Unsupported type. Use one of int8, uint8 or float."); + return Err(ANNError::log_index_config_error("data_type".to_string(), "Invalid data type".to_string())); + } + } + + Ok(()) +} + +fn print_help() { + println!("Arguments"); + println!("--help, -h Print information on arguments"); + println!("--data_type data type (required)"); + println!("--dist_fn distance function (required)"); + println!("--data_path Input data file in bin format for initial build (required)"); + println!("--insert_path Input data file in bin format for insert (required)"); + println!("--index_path_prefix Path prefix for saving index file components (required)"); + println!("--max_degree, -R Maximum graph degree (default: 64)"); + println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)"); + println!("--alpha alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter (default: 1.2)"); + println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)"); + println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)"); + println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)"); +} + diff --git a/rust/cmd_drivers/build_disk_index/Cargo.toml b/rust/cmd_drivers/build_disk_index/Cargo.toml new file mode 100644 index 000000000..afe5e5b33 --- /dev/null +++ b/rust/cmd_drivers/build_disk_index/Cargo.toml @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "build_disk_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +diskann = { path = "../../diskann" } +logger = { path = "../../logger" } +vector = { path = "../../vector" } +openblas-src = { version = "0.10.8", features = ["system", "static"] } diff --git a/rust/cmd_drivers/build_disk_index/src/main.rs b/rust/cmd_drivers/build_disk_index/src/main.rs new file mode 100644 index 000000000..e0b6dbe24 --- /dev/null +++ b/rust/cmd_drivers/build_disk_index/src/main.rs @@ -0,0 +1,377 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::env; + +use diskann::{ + common::{ANNError, ANNResult}, + index::ann_disk_index::create_disk_index, + model::{ + default_param_vals::ALPHA, + vertex::{DIM_104, DIM_128, DIM_256}, + DiskIndexBuildParameters, IndexConfiguration, IndexWriteParametersBuilder, + }, + storage::DiskIndexStorage, + utils::round_up, + utils::{load_metadata_from_file, Timer}, +}; + +use vector::{FullPrecisionDistance, Half, Metric}; + +/// The main function to build a disk index +#[allow(clippy::too_many_arguments)] +fn build_disk_index( + metric: Metric, + data_path: &str, + r: u32, + l: u32, + index_path_prefix: &str, + num_threads: u32, + search_ram_limit_gb: f64, + index_build_ram_limit_gb: f64, + num_pq_chunks: usize, + use_opq: bool, +) -> ANNResult<()> +where + T: Default + Copy + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + let disk_index_build_parameters = + DiskIndexBuildParameters::new(search_ram_limit_gb, index_build_ram_limit_gb)?; + + let index_write_parameters = IndexWriteParametersBuilder::new(l, r) + .with_saturate_graph(true) + .with_num_threads(num_threads) + .build(); + + let (data_num, data_dim) = load_metadata_from_file(data_path)?; + + let config = IndexConfiguration::new( + metric, + data_dim, + round_up(data_dim as u64, 8_u64) as usize, + data_num, + num_pq_chunks > 0, + num_pq_chunks, + use_opq, + 0, + 1f32, + index_write_parameters, + ); + let storage = DiskIndexStorage::new(data_path.to_string(), index_path_prefix.to_string())?; + let mut index = create_disk_index::(Some(disk_index_build_parameters), config, storage)?; + + let timer = Timer::new(); + + index.build("")?; + + let diff = timer.elapsed(); + println!("Indexing time: {}", diff.as_secs_f64()); + + Ok(()) +} + +fn main() -> ANNResult<()> { + let mut data_type = String::new(); + let mut dist_fn = String::new(); + let mut data_path = String::new(); + let mut index_path_prefix = String::new(); + + let mut num_threads = 0u32; + let mut r = 64u32; + let mut l = 100u32; + let mut search_ram_limit_gb = 0f64; + let mut index_build_ram_limit_gb = 0f64; + + let mut build_pq_bytes = 0u32; + let mut use_opq = false; + + let args: Vec = env::args().collect(); + let mut iter = args.iter().skip(1).peekable(); + + while let Some(arg) = iter.next() { + match arg.as_str() { + "--help" | "-h" => { + print_help(); + return Ok(()); + } + "--data_type" => { + data_type = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_type".to_string(), + "Missing data type".to_string(), + ) + })? + .to_owned(); + } + "--dist_fn" => { + dist_fn = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "dist_fn".to_string(), + "Missing distance function".to_string(), + ) + })? + .to_owned(); + } + "--data_path" => { + data_path = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "data_path".to_string(), + "Missing data path".to_string(), + ) + })? + .to_owned(); + } + "--index_path_prefix" => { + index_path_prefix = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "index_path_prefix".to_string(), + "Missing index path prefix".to_string(), + ) + })? + .to_owned(); + } + "--max_degree" | "-R" => { + r = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "max_degree".to_string(), + "Missing max degree".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "max_degree".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--Lbuild" | "-L" => { + l = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + "Missing build complexity".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "Lbuild".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--num_threads" | "-T" => { + num_threads = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "num_threads".to_string(), + "Missing number of threads".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "num_threads".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--build_PQ_bytes" => { + build_pq_bytes = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + "Missing PQ bytes".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + format!("ParseIntError: {}", err), + ) + })?; + } + "--use_opq" => { + use_opq = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "use_opq".to_string(), + "Missing use_opq flag".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "use_opq".to_string(), + format!("ParseBoolError: {}", err), + ) + })?; + } + "--search_DRAM_budget" | "-B" => { + search_ram_limit_gb = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "search_DRAM_budget".to_string(), + "Missing search_DRAM_budget flag".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "search_DRAM_budget".to_string(), + format!("ParseBoolError: {}", err), + ) + })?; + } + "--build_DRAM_budget" | "-M" => { + index_build_ram_limit_gb = iter + .next() + .ok_or_else(|| { + ANNError::log_index_config_error( + "build_DRAM_budget".to_string(), + "Missing build_DRAM_budget flag".to_string(), + ) + })? + .parse() + .map_err(|err| { + ANNError::log_index_config_error( + "build_DRAM_budget".to_string(), + format!("ParseBoolError: {}", err), + ) + })?; + } + _ => { + return Err(ANNError::log_index_config_error( + String::from(""), + format!("Unknown argument: {}", arg), + )); + } + } + } + + if data_type.is_empty() + || dist_fn.is_empty() + || data_path.is_empty() + || index_path_prefix.is_empty() + { + return Err(ANNError::log_index_config_error( + String::from(""), + "Missing required arguments".to_string(), + )); + } + + let metric = dist_fn + .parse::() + .map_err(|err| ANNError::log_index_config_error("dist_fn".to_string(), err.to_string()))?; + + println!( + "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {} search_DRAM_budget: {} build_DRAM_budget: {}", + r, l, ALPHA, num_threads, search_ram_limit_gb, index_build_ram_limit_gb + ); + + let err = match data_type.as_str() { + "int8" => build_disk_index::( + metric, + &data_path, + r, + l, + &index_path_prefix, + num_threads, + search_ram_limit_gb, + index_build_ram_limit_gb, + build_pq_bytes as usize, + use_opq, + ), + "uint8" => build_disk_index::( + metric, + &data_path, + r, + l, + &index_path_prefix, + num_threads, + search_ram_limit_gb, + index_build_ram_limit_gb, + build_pq_bytes as usize, + use_opq, + ), + "float" => build_disk_index::( + metric, + &data_path, + r, + l, + &index_path_prefix, + num_threads, + search_ram_limit_gb, + index_build_ram_limit_gb, + build_pq_bytes as usize, + use_opq, + ), + "f16" => build_disk_index::( + metric, + &data_path, + r, + l, + &index_path_prefix, + num_threads, + search_ram_limit_gb, + index_build_ram_limit_gb, + build_pq_bytes as usize, + use_opq, + ), + _ => { + println!("Unsupported type. Use one of int8, uint8, float or f16."); + return Err(ANNError::log_index_config_error( + "data_type".to_string(), + "Invalid data type".to_string(), + )); + } + }; + + match err { + Ok(_) => { + println!("Index build completed successfully"); + Ok(()) + } + Err(err) => { + eprintln!("Error: {:?}", err); + Err(err) + } + } +} + +fn print_help() { + println!("Arguments"); + println!("--help, -h Print information on arguments"); + println!("--data_type data type (required)"); + println!("--dist_fn distance function (required)"); + println!("--data_path Input data file in bin format (required)"); + println!("--index_path_prefix Path prefix for saving index file components (required)"); + println!("--max_degree, -R Maximum graph degree (default: 64)"); + println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)"); + println!("--search_DRAM_budget Bound on the memory footprint of the index at search time in GB. Once built, the index will use up only the specified RAM limit, the rest will reside on disk"); + println!("--build_DRAM_budget Limit on the memory allowed for building the index in GB"); + println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)"); + println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)"); + println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)"); +} diff --git a/rust/cmd_drivers/build_memory_index/Cargo.toml b/rust/cmd_drivers/build_memory_index/Cargo.toml new file mode 100644 index 000000000..eb4708d84 --- /dev/null +++ b/rust/cmd_drivers/build_memory_index/Cargo.toml @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "build_memory_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +clap = { version = "4.3.8", features = ["derive"] } +diskann = { path = "../../diskann" } +logger = { path = "../../logger" } +vector = { path = "../../vector" } + diff --git a/rust/cmd_drivers/build_memory_index/src/args.rs b/rust/cmd_drivers/build_memory_index/src/args.rs new file mode 100644 index 000000000..ede31f2db --- /dev/null +++ b/rust/cmd_drivers/build_memory_index/src/args.rs @@ -0,0 +1,62 @@ +use clap::{Args, Parser}; + +#[derive(Debug, Args)] +enum DataType { + /// Float data type. + Float, + + /// Half data type. + FP16, +} + +#[derive(Debug, Args)] +enum DistanceFunction { + /// Euclidean distance. + L2, + + /// Cosine distance. + Cosine, +} + +#[derive(Debug, Parser)] +struct BuildMemoryIndexArgs { + /// Data type of the vectors. + #[clap(long, default_value = "float")] + pub data_type: DataType, + + /// Distance function to use. + #[clap(long, default_value = "l2")] + pub dist_fn: Metric, + + /// Path to the data file. The file should be in the format specified by the `data_type` argument. + #[clap(long, short, required = true)] + pub data_path: String, + + /// Path to the index file. The index will be saved to this prefixed name. + #[clap(long, short, required = true)] + pub index_path_prefix: String, + + /// Number of max out degree from a vertex. + #[clap(long, default_value = "32")] + pub max_degree: usize, + + /// Number of candidates to consider when building out edges + #[clap(long, short default_value = "50")] + pub l_build: usize, + + /// Alpha to use to build diverse edges + #[clap(long, short default_value = "1.0")] + pub alpha: f32, + + /// Number of threads to use. + #[clap(long, short, default_value = "1")] + pub num_threads: u8, + + /// Number of PQ bytes to use. + #[clap(long, short, default_value = "8")] + pub build_pq_bytes: usize, + + /// Use opq? + #[clap(long, short, default_value = "false")] + pub use_opq: bool, +} diff --git a/rust/cmd_drivers/build_memory_index/src/main.rs b/rust/cmd_drivers/build_memory_index/src/main.rs new file mode 100644 index 000000000..cdccc0061 --- /dev/null +++ b/rust/cmd_drivers/build_memory_index/src/main.rs @@ -0,0 +1,174 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use clap::{Parser, ValueEnum}; +use std::path::PathBuf; + +use diskann::{ + common::ANNResult, + index::create_inmem_index, + model::{ + vertex::{DIM_104, DIM_128, DIM_256}, + IndexConfiguration, IndexWriteParametersBuilder, + }, + utils::round_up, + utils::{load_metadata_from_file, Timer}, +}; + +use vector::{FullPrecisionDistance, Half, Metric}; + +/// The main function to build an in-memory index +#[allow(clippy::too_many_arguments)] +fn build_in_memory_index( + metric: Metric, + data_path: &str, + r: u32, + l: u32, + alpha: f32, + save_path: &str, + num_threads: u32, + _use_pq_build: bool, + _num_pq_bytes: usize, + use_opq: bool, +) -> ANNResult<()> +where + T: Default + Copy + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + let index_write_parameters = IndexWriteParametersBuilder::new(l, r) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + let (data_num, data_dim) = load_metadata_from_file(data_path)?; + + let config = IndexConfiguration::new( + metric, + data_dim, + round_up(data_dim as u64, 8_u64) as usize, + data_num, + false, + 0, + use_opq, + 0, + 1f32, + index_write_parameters, + ); + let mut index = create_inmem_index::(config)?; + + let timer = Timer::new(); + + index.build(data_path, data_num)?; + + let diff = timer.elapsed(); + + println!("Indexing time: {}", diff.as_secs_f64()); + index.save(save_path)?; + + Ok(()) +} + +fn main() -> ANNResult<()> { + let args = BuildMemoryIndexArgs::parse(); + + let _use_pq_build = args.build_pq_bytes > 0; + + println!( + "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}", + args.max_degree, args.l_build, args.alpha, args.num_threads + ); + + let err = match args.data_type { + DataType::Float => build_in_memory_index::( + args.dist_fn, + &args.data_path.to_string_lossy(), + args.max_degree, + args.l_build, + args.alpha, + &args.index_path_prefix, + args.num_threads, + _use_pq_build, + args.build_pq_bytes, + args.use_opq, + ), + DataType::FP16 => build_in_memory_index::( + args.dist_fn, + &args.data_path.to_string_lossy(), + args.max_degree, + args.l_build, + args.alpha, + &args.index_path_prefix, + args.num_threads, + _use_pq_build, + args.build_pq_bytes, + args.use_opq, + ), + }; + + match err { + Ok(_) => { + println!("Index build completed successfully"); + Ok(()) + } + Err(err) => { + eprintln!("Error: {:?}", err); + Err(err) + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)] +enum DataType { + /// Float data type. + Float, + + /// Half data type. + FP16, +} + +#[derive(Debug, Parser)] +struct BuildMemoryIndexArgs { + /// data type (required) + #[arg(long = "data_type", default_value = "float")] + pub data_type: DataType, + + /// Distance function to use. + #[arg(long = "dist_fn", default_value = "l2")] + pub dist_fn: Metric, + + /// Path to the data file. The file should be in the format specified by the `data_type` argument. + #[arg(long = "data_path", short, required = true)] + pub data_path: PathBuf, + + /// Path to the index file. The index will be saved to this prefixed name. + #[arg(long = "index_path_prefix", short, required = true)] + pub index_path_prefix: String, + + /// Number of max out degree from a vertex. + #[arg(long = "max_degree", short = 'R', default_value = "64")] + pub max_degree: u32, + + /// Number of candidates to consider when building out edges + #[arg(long = "l_build", short = 'L', default_value = "100")] + pub l_build: u32, + + /// alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter + #[arg(long, short, default_value = "1.2")] + pub alpha: f32, + + /// Number of threads to use. + #[arg(long = "num_threads", short = 'T', default_value = "1")] + pub num_threads: u32, + + /// Number of PQ bytes to build the index; 0 for full precision build + #[arg(long = "build_pq_bytes", short, default_value = "0")] + pub build_pq_bytes: usize, + + /// Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression + #[arg(long = "use_opq", short, default_value = "false")] + pub use_opq: bool, +} diff --git a/rust/cmd_drivers/convert_f32_to_bf16/Cargo.toml b/rust/cmd_drivers/convert_f32_to_bf16/Cargo.toml new file mode 100644 index 000000000..1993aab9d --- /dev/null +++ b/rust/cmd_drivers/convert_f32_to_bf16/Cargo.toml @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "convert_f32_to_bf16" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +half = "2.2.1" diff --git a/rust/cmd_drivers/convert_f32_to_bf16/src/main.rs b/rust/cmd_drivers/convert_f32_to_bf16/src/main.rs new file mode 100644 index 000000000..87b4fbaf3 --- /dev/null +++ b/rust/cmd_drivers/convert_f32_to_bf16/src/main.rs @@ -0,0 +1,154 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use half::{bf16, f16}; +use std::env; +use std::fs::{File, OpenOptions}; +use std::io::{self, Read, Write, BufReader, BufWriter}; + +enum F16OrBF16 { + F16(f16), + BF16(bf16), +} + +fn main() -> io::Result<()> { + // Retrieve command-line arguments + let args: Vec = env::args().collect(); + + match args.len() { + 3|4|5|6=> {}, + _ => { + print_usage(); + std::process::exit(1); + } + } + + // Retrieve the input and output file paths from the arguments + let input_file_path = &args[1]; + let output_file_path = &args[2]; + let use_f16 = args.len() >= 4 && args[3] == "f16"; + let save_as_float = args.len() >= 5 && args[4] == "save_as_float"; + let batch_size = if args.len() >= 6 { args[5].parse::().unwrap() } else { 100000 }; + println!("use_f16: {}", use_f16); + println!("save_as_float: {}", save_as_float); + println!("batch_size: {}", batch_size); + + // Open the input file for reading + let mut input_file = BufReader::new(File::open(input_file_path)?); + + // Open the output file for writing + let mut output_file = BufWriter::new(OpenOptions::new().write(true).create(true).open(output_file_path)?); + + // Read the first 8 bytes as metadata + let mut metadata = [0; 8]; + input_file.read_exact(&mut metadata)?; + + // Write the metadata to the output file + output_file.write_all(&metadata)?; + + // Extract the number of points and dimension from the metadata + let num_points = i32::from_le_bytes(metadata[..4].try_into().unwrap()); + let dimension = i32::from_le_bytes(metadata[4..].try_into().unwrap()); + let num_batches = num_points / batch_size; + // Calculate the size of one data point in bytes + let data_point_size = (dimension * 4 * batch_size) as usize; + let mut batches_processed = 0; + let numbers_to_print = 2; + let mut numbers_printed = 0; + let mut num_fb16_wins = 0; + let mut num_f16_wins = 0; + let mut bf16_overflow = 0; + let mut f16_overflow = 0; + + // Process each data point + for _ in 0..num_batches { + // Read one data point from the input file + let mut buffer = vec![0; data_point_size]; + match input_file.read_exact(&mut buffer){ + Ok(()) => { + // Convert the float32 data to bf16 + let half_data: Vec = buffer + .chunks_exact(4) + .map(|chunk| { + let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); + let converted_bf16 = bf16::from_f32(value); + let converted_f16 = f16::from_f32(value); + let distance_f16 = (converted_f16.to_f32() - value).abs(); + let distance_bf16 = (converted_bf16.to_f32() - value).abs(); + + if distance_f16 < distance_bf16 { + num_f16_wins += 1; + } else { + num_fb16_wins += 1; + } + + if (converted_bf16 == bf16::INFINITY) || (converted_bf16 == bf16::NEG_INFINITY) { + bf16_overflow += 1; + } + + if (converted_f16 == f16::INFINITY) || (converted_f16 == f16::NEG_INFINITY) { + f16_overflow += 1; + } + + if numbers_printed < numbers_to_print { + numbers_printed += 1; + println!("f32 value: {} f16 value: {} | distance {}, bf16 value: {} | distance {},", + value, converted_f16, converted_f16.to_f32() - value, converted_bf16, converted_bf16.to_f32() - value); + } + + if use_f16 { + F16OrBF16::F16(converted_f16) + } else { + F16OrBF16::BF16(converted_bf16) + } + }) + .collect(); + + batches_processed += 1; + + match save_as_float { + true => { + for float_val in half_data { + match float_val { + F16OrBF16::F16(f16_val) => output_file.write_all(&f16_val.to_f32().to_le_bytes())?, + F16OrBF16::BF16(bf16_val) => output_file.write_all(&bf16_val.to_f32().to_le_bytes())?, + } + } + } + false => { + for float_val in half_data { + match float_val { + F16OrBF16::F16(f16_val) => output_file.write_all(&f16_val.to_le_bytes())?, + F16OrBF16::BF16(bf16_val) => output_file.write_all(&bf16_val.to_le_bytes())?, + } + } + } + } + + // Print the number of points processed + println!("Processed {} points out of {}", batches_processed * batch_size, num_points); + } + Err(ref e) if e.kind() == io::ErrorKind::UnexpectedEof => { + println!("Conversion completed! {} of times f16 wins | overflow count {}, {} of times bf16 wins | overflow count{}", + num_f16_wins, f16_overflow, num_fb16_wins, bf16_overflow); + break; + } + Err(err) => { + println!("Error: {}", err); + break; + } + }; + } + + Ok(()) +} + +/// Prints the usage information +fn print_usage() { + println!("Usage: program_name input_file output_file [f16] [save_as_float] [batch_size]]"); + println!("specify f16 to downscale to f16. otherwise, downscale to bf16."); + println!("specify save_as_float to downcast to f16 or bf16, and upcast to float before saving the output data. otherwise, the data will be saved as half type."); + println!("specify the batch_size as a int, the default value is 100000."); +} + diff --git a/rust/cmd_drivers/load_and_insert_memory_index/Cargo.toml b/rust/cmd_drivers/load_and_insert_memory_index/Cargo.toml new file mode 100644 index 000000000..cbb4e1e3c --- /dev/null +++ b/rust/cmd_drivers/load_and_insert_memory_index/Cargo.toml @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "load_and_insert_memory_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +diskann = { path = "../../diskann" } +logger = { path = "../../logger" } +vector = { path = "../../vector" } + diff --git a/rust/cmd_drivers/load_and_insert_memory_index/src/main.rs b/rust/cmd_drivers/load_and_insert_memory_index/src/main.rs new file mode 100644 index 000000000..41680460a --- /dev/null +++ b/rust/cmd_drivers/load_and_insert_memory_index/src/main.rs @@ -0,0 +1,313 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::env; + +use diskann::{ + common::{ANNResult, ANNError}, + index::create_inmem_index, + utils::round_up, + model::{ + IndexWriteParametersBuilder, + IndexConfiguration, + vertex::{DIM_128, DIM_256, DIM_104} + }, + utils::{Timer, load_metadata_from_file}, +}; + +use vector::{Metric, FullPrecisionDistance, Half}; + +// The main function to build an in-memory index +#[allow(clippy::too_many_arguments)] +fn load_and_insert_in_memory_index ( + metric: Metric, + data_path: &str, + delta_path: &str, + r: u32, + l: u32, + alpha: f32, + save_path: &str, + num_threads: u32, + _use_pq_build: bool, + _num_pq_bytes: usize, + use_opq: bool +) -> ANNResult<()> +where + T: Default + Copy + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance +{ + let index_write_parameters = IndexWriteParametersBuilder::new(l, r) + .with_alpha(alpha) + .with_saturate_graph(false) + .with_num_threads(num_threads) + .build(); + + let (data_num, data_dim) = load_metadata_from_file(&format!("{}.data", data_path))?; + + let config = IndexConfiguration::new( + metric, + data_dim, + round_up(data_dim as u64, 8_u64) as usize, + data_num, + false, + 0, + use_opq, + 0, + 2.0f32, + index_write_parameters, + ); + let mut index = create_inmem_index::(config)?; + + let timer = Timer::new(); + + index.load(data_path, data_num)?; + + let diff = timer.elapsed(); + + println!("Initial indexing time: {}", diff.as_secs_f64()); + + let (delta_data_num, _) = load_metadata_from_file(delta_path)?; + + index.insert(delta_path, delta_data_num)?; + + index.save(save_path)?; + + Ok(()) +} + +fn main() -> ANNResult<()> { + let mut data_type = String::new(); + let mut dist_fn = String::new(); + let mut data_path = String::new(); + let mut insert_path = String::new(); + let mut index_path_prefix = String::new(); + + let mut num_threads = 0u32; + let mut r = 64u32; + let mut l = 100u32; + + let mut alpha = 1.2f32; + let mut build_pq_bytes = 0u32; + let mut _use_pq_build = false; + let mut use_opq = false; + + let args: Vec = env::args().collect(); + let mut iter = args.iter().skip(1).peekable(); + + while let Some(arg) = iter.next() { + match arg.as_str() { + "--help" | "-h" => { + print_help(); + return Ok(()); + } + "--data_type" => { + data_type = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "data_type".to_string(), + "Missing data type".to_string()) + )? + .to_owned(); + } + "--dist_fn" => { + dist_fn = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "dist_fn".to_string(), + "Missing distance function".to_string()) + )? + .to_owned(); + } + "--data_path" => { + data_path = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "data_path".to_string(), + "Missing data path".to_string()) + )? + .to_owned(); + } + "--insert_path" => { + insert_path = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "insert_path".to_string(), + "Missing insert path".to_string()) + )? + .to_owned(); + } + "--index_path_prefix" => { + index_path_prefix = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "index_path_prefix".to_string(), + "Missing index path prefix".to_string()))? + .to_owned(); + } + "--max_degree" | "-R" => { + r = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "max_degree".to_string(), + "Missing max degree".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "max_degree".to_string(), + format!("ParseIntError: {}", err)) + )?; + } + "--Lbuild" | "-L" => { + l = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "Lbuild".to_string(), + "Missing build complexity".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "Lbuild".to_string(), + format!("ParseIntError: {}", err)) + )?; + } + "--alpha" => { + alpha = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "alpha".to_string(), + "Missing alpha".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "alpha".to_string(), + format!("ParseFloatError: {}", err)) + )?; + } + "--num_threads" | "-T" => { + num_threads = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "num_threads".to_string(), + "Missing number of threads".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "num_threads".to_string(), + format!("ParseIntError: {}", err)) + )?; + } + "--build_PQ_bytes" => { + build_pq_bytes = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + "Missing PQ bytes".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "build_PQ_bytes".to_string(), + format!("ParseIntError: {}", err)) + )?; + } + "--use_opq" => { + use_opq = iter.next().ok_or_else(|| ANNError::log_index_config_error( + "use_opq".to_string(), + "Missing use_opq flag".to_string()))? + .parse() + .map_err(|err| ANNError::log_index_config_error( + "use_opq".to_string(), + format!("ParseBoolError: {}", err)) + )?; + } + _ => { + return Err(ANNError::log_index_config_error(String::from(""), format!("Unknown argument: {}", arg))); + } + } + } + + if data_type.is_empty() + || dist_fn.is_empty() + || data_path.is_empty() + || index_path_prefix.is_empty() + { + return Err(ANNError::log_index_config_error(String::from(""), "Missing required arguments".to_string())); + } + + _use_pq_build = build_pq_bytes > 0; + + let metric = dist_fn + .parse::() + .map_err(|err| ANNError::log_index_config_error( + "dist_fn".to_string(), + err.to_string(), + ))?; + + println!( + "Starting index build with R: {} Lbuild: {} alpha: {} #threads: {}", + r, l, alpha, num_threads + ); + + match data_type.as_str() { + "int8" => { + load_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "uint8" => { + load_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "float" => { + load_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )?; + } + "f16" => { + load_and_insert_in_memory_index::( + metric, + &data_path, + &insert_path, + r, + l, + alpha, + &index_path_prefix, + num_threads, + _use_pq_build, + build_pq_bytes as usize, + use_opq, + )? + } + _ => { + println!("Unsupported type. Use one of int8, uint8 or float."); + return Err(ANNError::log_index_config_error("data_type".to_string(), "Invalid data type".to_string())); + } + } + + Ok(()) +} + +fn print_help() { + println!("Arguments"); + println!("--help, -h Print information on arguments"); + println!("--data_type data type (required)"); + println!("--dist_fn distance function (required)"); + println!("--data_path Input data file in bin format for initial build (required)"); + println!("--insert_path Input data file in bin format for insert (required)"); + println!("--index_path_prefix Path prefix for saving index file components (required)"); + println!("--max_degree, -R Maximum graph degree (default: 64)"); + println!("--Lbuild, -L Build complexity, higher value results in better graphs (default: 100)"); + println!("--alpha alpha controls density and diameter of graph, set 1 for sparse graph, 1.2 or 1.4 for denser graphs with lower diameter (default: 1.2)"); + println!("--num_threads, -T Number of threads used for building index (defaults to num of CPU logic cores)"); + println!("--build_PQ_bytes Number of PQ bytes to build the index; 0 for full precision build (default: 0)"); + println!("--use_opq Set true for OPQ compression while using PQ distance comparisons for building the index, and false for PQ compression (default: false)"); +} + diff --git a/rust/cmd_drivers/search_memory_index/Cargo.toml b/rust/cmd_drivers/search_memory_index/Cargo.toml new file mode 100644 index 000000000..cba3709aa --- /dev/null +++ b/rust/cmd_drivers/search_memory_index/Cargo.toml @@ -0,0 +1,16 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "search_memory_index" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytemuck = "1.13.1" +diskann = { path = "../../diskann" } +num_cpus = "1.15.0" +rayon = "1.7.0" +vector = { path = "../../vector" } + diff --git a/rust/cmd_drivers/search_memory_index/src/main.rs b/rust/cmd_drivers/search_memory_index/src/main.rs new file mode 100644 index 000000000..ca4d4cd1d --- /dev/null +++ b/rust/cmd_drivers/search_memory_index/src/main.rs @@ -0,0 +1,430 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod search_index_utils; +use bytemuck::Pod; +use diskann::{ + common::{ANNError, ANNResult}, + index, + model::{ + configuration::index_write_parameters::{default_param_vals, IndexWriteParametersBuilder}, + vertex::{DIM_104, DIM_128, DIM_256}, + IndexConfiguration, + }, + utils::{load_metadata_from_file, save_bin_u32}, +}; +use std::{env, path::Path, process::exit, time::Instant}; +use vector::{FullPrecisionDistance, Half, Metric}; + +use rayon::prelude::*; + +#[allow(clippy::too_many_arguments)] +fn search_memory_index( + metric: Metric, + index_path: &str, + result_path_prefix: &str, + query_file: &str, + truthset_file: &str, + num_threads: u32, + recall_at: u32, + print_all_recalls: bool, + l_vec: &Vec, + show_qps_per_thread: bool, + fail_if_recall_below: f32, +) -> ANNResult +where + T: Default + Copy + Sized + Pod + Sync + Send + Into, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + // Load the query file + let (query, query_num, query_dim, query_aligned_dim) = + search_index_utils::load_aligned_bin::(query_file)?; + let mut gt_dim: usize = 0; + let mut gt_ids: Option> = None; + let mut gt_dists: Option> = None; + + // Check for ground truth + let mut calc_recall_flag = false; + if !truthset_file.is_empty() && Path::new(truthset_file).exists() { + let ret = search_index_utils::load_truthset(truthset_file)?; + gt_ids = Some(ret.0); + gt_dists = ret.1; + let gt_num = ret.2; + gt_dim = ret.3; + + if gt_num != query_num { + println!("Error. Mismatch in number of queries and ground truth data"); + } + + calc_recall_flag = true; + } else { + println!( + "Truthset file {} not found. Not computing recall", + truthset_file + ); + } + + let num_frozen_pts = search_index_utils::get_graph_num_frozen_points(index_path)?; + + // C++ uses the max given L value, so we do the same here. Max degree is never specified in C++ so use the rust default + let index_write_params = IndexWriteParametersBuilder::new( + *l_vec.iter().max().unwrap(), + default_param_vals::MAX_DEGREE, + ) + .with_num_threads(num_threads) + .build(); + + let (index_num_points, _) = load_metadata_from_file(&format!("{}.data", index_path))?; + + let index_config = IndexConfiguration::new( + metric, + query_dim, + query_aligned_dim, + index_num_points, + false, + 0, + false, + num_frozen_pts, + 1f32, + index_write_params, + ); + let mut index = index::create_inmem_index::(index_config)?; + + index.load(index_path, index_num_points)?; + + println!("Using {} threads to search", num_threads); + let qps_title = if show_qps_per_thread { + "QPS/thread" + } else { + "QPS" + }; + let mut table_width = 4 + 12 + 18 + 20 + 15; + let mut table_header_str = format!( + "{:>4}{:>12}{:>18}{:>20}{:>15}", + "Ls", qps_title, "Avg dist cmps", "Mean Latency (mus)", "99.9 Latency" + ); + + let first_recall: u32 = if print_all_recalls { 1 } else { recall_at }; + let mut recalls_to_print: usize = 0; + if calc_recall_flag { + for curr_recall in first_recall..=recall_at { + let recall_str = format!("Recall@{}", curr_recall); + table_header_str.push_str(&format!("{:>12}", recall_str)); + recalls_to_print = (recall_at + 1 - first_recall) as usize; + table_width += recalls_to_print * 12; + } + } + + println!("{}", table_header_str); + println!("{}", "=".repeat(table_width)); + + let mut query_result_ids: Vec> = + vec![vec![0; query_num * recall_at as usize]; l_vec.len()]; + let mut latency_stats: Vec = vec![0.0; query_num]; + let mut cmp_stats: Vec = vec![0; query_num]; + let mut best_recall = 0.0; + + std::env::set_var("RAYON_NUM_THREADS", num_threads.to_string()); + + for test_id in 0..l_vec.len() { + let l_value = l_vec[test_id]; + + if l_value < recall_at { + println!( + "Ignoring search with L:{} since it's smaller than K:{}", + l_value, recall_at + ); + continue; + } + + let zipped = cmp_stats + .par_iter_mut() + .zip(latency_stats.par_iter_mut()) + .zip(query_result_ids[test_id].par_chunks_mut(recall_at as usize)) + .zip(query.par_chunks(query_aligned_dim)); + + let start = Instant::now(); + zipped.for_each(|(((cmp, latency), query_result), query_chunk)| { + let query_start = Instant::now(); + *cmp = index + .search(query_chunk, recall_at as usize, l_value, query_result) + .unwrap(); + + let query_end = Instant::now(); + let diff = query_end.duration_since(query_start); + *latency = diff.as_micros() as f32; + }); + let diff = Instant::now().duration_since(start); + + let mut displayed_qps: f32 = query_num as f32 / diff.as_secs_f32(); + if show_qps_per_thread { + displayed_qps /= num_threads as f32; + } + + let mut recalls: Vec = Vec::new(); + if calc_recall_flag { + recalls.reserve(recalls_to_print); + for curr_recall in first_recall..=recall_at { + recalls.push(search_index_utils::calculate_recall( + query_num, + gt_ids.as_ref().unwrap(), + >_dists, + gt_dim, + &query_result_ids[test_id], + recall_at, + curr_recall, + )? as f32); + } + } + + latency_stats.sort_by(|a, b| a.partial_cmp(b).unwrap()); + let mean_latency = latency_stats.iter().sum::() / query_num as f32; + let avg_cmps = cmp_stats.iter().sum::() as f32 / query_num as f32; + + let mut stat_str = format!( + "{: >4}{: >12.2}{: >18.2}{: >20.2}{: >15.2}", + l_value, + displayed_qps, + avg_cmps, + mean_latency, + latency_stats[(0.999 * query_num as f32).round() as usize] + ); + + for recall in recalls.iter() { + stat_str.push_str(&format!("{: >12.2}", recall)); + best_recall = f32::max(best_recall, *recall); + } + + println!("{}", stat_str); + } + + println!("Done searching. Now saving results"); + for (test_id, l_value) in l_vec.iter().enumerate() { + if *l_value < recall_at { + println!( + "Ignoring all search with L: {} since it's smaller than K: {}", + l_value, recall_at + ); + } + + let cur_result_path = format!("{}_{}_idx_uint32.bin", result_path_prefix, l_value); + save_bin_u32( + &cur_result_path, + query_result_ids[test_id].as_slice(), + query_num, + recall_at as usize, + 0, + )?; + } + + if best_recall >= fail_if_recall_below { + Ok(0) + } else { + Ok(-1) + } +} + +fn main() -> ANNResult<()> { + let return_val: i32; + { + let mut data_type: String = String::new(); + let mut metric: Option = None; + let mut index_path: String = String::new(); + let mut result_path_prefix: String = String::new(); + let mut query_file: String = String::new(); + let mut truthset_file: String = String::new(); + let mut num_cpus: u32 = num_cpus::get() as u32; + let mut recall_at: Option = None; + let mut print_all_recalls: bool = false; + let mut l_vec: Vec = Vec::new(); + let mut show_qps_per_thread: bool = false; + let mut fail_if_recall_below: f32 = 0.0; + + let args: Vec = env::args().collect(); + let mut iter = args.iter().skip(1).peekable(); + while let Some(arg) = iter.next() { + let ann_error = + || ANNError::log_index_config_error(String::from(arg), format!("Missing {}", arg)); + match arg.as_str() { + "--help" | "-h" => { + print_help(); + return Ok(()); + } + "--data_type" => { + data_type = iter.next().ok_or_else(ann_error)?.to_owned(); + } + "--dist_fn" => { + metric = Some(iter.next().ok_or_else(ann_error)?.parse().map_err(|err| { + ANNError::log_index_config_error( + String::from(arg), + format!("ParseError: {}", err), + ) + })?); + } + "--index_path_prefix" => { + index_path = iter.next().ok_or_else(ann_error)?.to_owned(); + } + "--result_path" => { + result_path_prefix = iter.next().ok_or_else(ann_error)?.to_owned(); + } + "--query_file" => { + query_file = iter.next().ok_or_else(ann_error)?.to_owned(); + } + "--gt_file" => { + truthset_file = iter.next().ok_or_else(ann_error)?.to_owned(); + } + "--recall_at" | "-K" => { + recall_at = + Some(iter.next().ok_or_else(ann_error)?.parse().map_err(|err| { + ANNError::log_index_config_error( + String::from(arg), + format!("ParseError: {}", err), + ) + })?); + } + "--print_all_recalls" => { + print_all_recalls = true; + } + "--search_list" | "-L" => { + while iter.peek().is_some() && !iter.peek().unwrap().starts_with('-') { + l_vec.push(iter.next().ok_or_else(ann_error)?.parse().map_err(|err| { + ANNError::log_index_config_error( + String::from(arg), + format!("ParseError: {}", err), + ) + })?); + } + } + "--num_threads" => { + num_cpus = iter.next().ok_or_else(ann_error)?.parse().map_err(|err| { + ANNError::log_index_config_error( + String::from(arg), + format!("ParseError: {}", err), + ) + })?; + } + "--qps_per_thread" => { + show_qps_per_thread = true; + } + "--fail_if_recall_below" => { + fail_if_recall_below = + iter.next().ok_or_else(ann_error)?.parse().map_err(|err| { + ANNError::log_index_config_error( + String::from(arg), + format!("ParseError: {}", err), + ) + })?; + } + _ => { + return Err(ANNError::log_index_error(format!( + "Unknown argument: {}", + arg + ))); + } + } + } + + if metric.is_none() { + return Err(ANNError::log_index_error(String::from("No metric given!"))); + } else if recall_at.is_none() { + return Err(ANNError::log_index_error(String::from( + "No recall_at given!", + ))); + } + + // Seems like float is the only supported data type for FullPrecisionDistance right now, + // but keep the structure in place here for future data types + match data_type.as_str() { + "float" => { + return_val = search_memory_index::( + metric.unwrap(), + &index_path, + &result_path_prefix, + &query_file, + &truthset_file, + num_cpus, + recall_at.unwrap(), + print_all_recalls, + &l_vec, + show_qps_per_thread, + fail_if_recall_below, + )?; + } + "int8" => { + return_val = search_memory_index::( + metric.unwrap(), + &index_path, + &result_path_prefix, + &query_file, + &truthset_file, + num_cpus, + recall_at.unwrap(), + print_all_recalls, + &l_vec, + show_qps_per_thread, + fail_if_recall_below, + )?; + } + "uint8" => { + return_val = search_memory_index::( + metric.unwrap(), + &index_path, + &result_path_prefix, + &query_file, + &truthset_file, + num_cpus, + recall_at.unwrap(), + print_all_recalls, + &l_vec, + show_qps_per_thread, + fail_if_recall_below, + )?; + } + "f16" => { + return_val = search_memory_index::( + metric.unwrap(), + &index_path, + &result_path_prefix, + &query_file, + &truthset_file, + num_cpus, + recall_at.unwrap(), + print_all_recalls, + &l_vec, + show_qps_per_thread, + fail_if_recall_below, + )?; + } + _ => { + return Err(ANNError::log_index_error(format!( + "Unknown data type: {}!", + data_type + ))); + } + } + } + + // Rust only allows returning values with this method, but this will immediately terminate the program without running destructors on the + // stack. To get around this enclose main function logic in a block so that by the time we return here all destructors have been called. + exit(return_val); +} + +fn print_help() { + println!("Arguments"); + println!("--help, -h Print information on arguments"); + println!("--data_type data type (required)"); + println!("--dist_fn distance function (required)"); + println!("--index_path_prefix Path prefix to the index (required)"); + println!("--result_path Path prefix for saving results of the queries (required)"); + println!("--query_file Query file in binary format"); + println!("--gt_file Ground truth file for the queryset"); + println!("--recall_at, -K Number of neighbors to be returned"); + println!("--print_all_recalls Print recalls at all positions, from 1 up to specified recall_at value"); + println!("--search_list List of L values of search"); + println!("----num_threads, -T Number of threads used for building index (defaults to num_cpus::get())"); + println!("--qps_per_thread Print overall QPS divided by the number of threads in the output table"); + println!("--fail_if_recall_below If set to a value >0 and <100%, program returns -1 if best recall found is below this threshold"); +} diff --git a/rust/cmd_drivers/search_memory_index/src/search_index_utils.rs b/rust/cmd_drivers/search_memory_index/src/search_index_utils.rs new file mode 100644 index 000000000..c7b04a47f --- /dev/null +++ b/rust/cmd_drivers/search_memory_index/src/search_index_utils.rs @@ -0,0 +1,186 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use bytemuck::{cast_slice, Pod}; +use diskann::{ + common::{ANNError, ANNResult, AlignedBoxWithSlice}, + model::data_store::DatasetDto, + utils::{copy_aligned_data_from_file, is_aligned, round_up}, +}; +use std::collections::HashSet; +use std::fs::File; +use std::io::Read; +use std::mem::size_of; + +pub(crate) fn calculate_recall( + num_queries: usize, + gold_std: &[u32], + gs_dist: &Option>, + dim_gs: usize, + our_results: &[u32], + dim_or: u32, + recall_at: u32, +) -> ANNResult { + let mut total_recall: f64 = 0.0; + let (mut gt, mut res): (HashSet, HashSet) = (HashSet::new(), HashSet::new()); + + for i in 0..num_queries { + gt.clear(); + res.clear(); + + let gt_slice = &gold_std[dim_gs * i..]; + let res_slice = &our_results[dim_or as usize * i..]; + let mut tie_breaker = recall_at as usize; + + if gs_dist.is_some() { + tie_breaker = (recall_at - 1) as usize; + let gt_dist_vec = &gs_dist.as_ref().unwrap()[dim_gs * i..]; + while tie_breaker < dim_gs + && gt_dist_vec[tie_breaker] == gt_dist_vec[(recall_at - 1) as usize] + { + tie_breaker += 1; + } + } + + (0..tie_breaker).for_each(|idx| { + gt.insert(gt_slice[idx]); + }); + + (0..tie_breaker).for_each(|idx| { + res.insert(res_slice[idx]); + }); + + let mut cur_recall: u32 = 0; + for v in gt.iter() { + if res.contains(v) { + cur_recall += 1; + } + } + + total_recall += cur_recall as f64; + } + + Ok(total_recall / num_queries as f64 * (100.0 / recall_at as f64)) +} + +pub(crate) fn get_graph_num_frozen_points(graph_file: &str) -> ANNResult { + let mut file = File::open(graph_file)?; + let mut usize_buffer = [0; size_of::()]; + let mut u32_buffer = [0; size_of::()]; + + file.read_exact(&mut usize_buffer)?; + file.read_exact(&mut u32_buffer)?; + file.read_exact(&mut u32_buffer)?; + file.read_exact(&mut usize_buffer)?; + let file_frozen_pts = usize::from_le_bytes(usize_buffer); + + Ok(file_frozen_pts) +} + +#[inline] +pub(crate) fn load_truthset( + bin_file: &str, +) -> ANNResult<(Vec, Option>, usize, usize)> { + let mut file = File::open(bin_file)?; + let actual_file_size = file.metadata()?.len() as usize; + + let mut buffer = [0; size_of::()]; + file.read_exact(&mut buffer)?; + let npts = i32::from_le_bytes(buffer) as usize; + + file.read_exact(&mut buffer)?; + let dim = i32::from_le_bytes(buffer) as usize; + + println!("Metadata: #pts = {npts}, #dims = {dim}... "); + + let expected_file_size_with_dists: usize = + 2 * npts * dim * size_of::() + 2 * size_of::(); + let expected_file_size_just_ids: usize = npts * dim * size_of::() + 2 * size_of::(); + + let truthset_type : i32 = match actual_file_size + { + // This is in the C++ code, but nothing is done in this case. Keeping it here for future reference just in case. + // expected_file_size_just_ids => 2, + x if x == expected_file_size_with_dists => 1, + _ => return Err(ANNError::log_index_error(format!("Error. File size mismatch. File should have bin format, with npts followed by ngt + followed by npts*ngt ids and optionally followed by npts*ngt distance values; actual size: {}, expected: {} or {}", + actual_file_size, + expected_file_size_with_dists, + expected_file_size_just_ids))) + }; + + let mut ids: Vec = vec![0; npts * dim]; + let mut buffer = vec![0; npts * dim * size_of::()]; + file.read_exact(&mut buffer)?; + ids.clone_from_slice(cast_slice::(&buffer)); + + if truthset_type == 1 { + let mut dists: Vec = vec![0.0; npts * dim]; + let mut buffer = vec![0; npts * dim * size_of::()]; + file.read_exact(&mut buffer)?; + dists.clone_from_slice(cast_slice::(&buffer)); + + return Ok((ids, Some(dists), npts, dim)); + } + + Ok((ids, None, npts, dim)) +} + +#[inline] +pub(crate) fn load_aligned_bin( + bin_file: &str, +) -> ANNResult<(AlignedBoxWithSlice, usize, usize, usize)> { + let t_size = size_of::(); + let (npts, dim, file_size): (usize, usize, usize); + { + println!("Reading (with alignment) bin file: {bin_file}"); + let mut file = File::open(bin_file)?; + file_size = file.metadata()?.len() as usize; + + let mut buffer = [0; size_of::()]; + file.read_exact(&mut buffer)?; + npts = i32::from_le_bytes(buffer) as usize; + + file.read_exact(&mut buffer)?; + dim = i32::from_le_bytes(buffer) as usize; + } + + let rounded_dim = round_up(dim, 8); + let expected_actual_file_size = npts * dim * size_of::() + 2 * size_of::(); + + if file_size != expected_actual_file_size { + return Err(ANNError::log_index_error(format!( + "ERROR: File size mismatch. Actual size is {} while expected size is {} + npts = {}, #dims = {}, aligned_dim = {}", + file_size, expected_actual_file_size, npts, dim, rounded_dim + ))); + } + + println!("Metadata: #pts = {npts}, #dims = {dim}, aligned_dim = {rounded_dim}..."); + + let alloc_size = npts * rounded_dim; + let alignment = 8 * t_size; + println!( + "allocating aligned memory of {} bytes... ", + alloc_size * t_size + ); + if !is_aligned(alloc_size * t_size, alignment) { + return Err(ANNError::log_index_error(format!( + "Requested memory size is not a multiple of {}. Can not be allocated.", + alignment + ))); + } + + let mut data = AlignedBoxWithSlice::::new(alloc_size, alignment)?; + let dto = DatasetDto { + data: &mut data, + rounded_dim, + }; + + println!("done. Copying data to mem_aligned buffer..."); + + let (_, _) = copy_aligned_data_from_file(bin_file, dto, 0)?; + + Ok((data, npts, dim, rounded_dim)) +} diff --git a/rust/diskann/Cargo.toml b/rust/diskann/Cargo.toml new file mode 100644 index 000000000..a5be54750 --- /dev/null +++ b/rust/diskann/Cargo.toml @@ -0,0 +1,45 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "diskann" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bincode = "1.3.3" +bit-vec = "0.6.3" +byteorder = "1.4.3" +cblas = "0.4.0" +crossbeam = "0.8.2" +half = "2.2.1" +hashbrown = "0.13.2" +num-traits = "0.2.15" +once_cell = "1.17.1" +openblas-src = { version = "0.10.8", features = ["system"] } +rand = { version = "0.8.5", features = [ "small_rng" ] } +rayon = "1.7.0" +serde = { version = "1.0.130", features = ["derive"] } +thiserror = "1.0.40" +winapi = { version = "0.3.9", features = ["errhandlingapi", "fileapi", "ioapiset", "handleapi", "winnt", "minwindef", "basetsd", "winerror", "winbase"] } + +logger = { path = "../logger" } +platform = { path = "../platform" } +vector = { path = "../vector" } + +[build-dependencies] +cc = "1.0.79" + +[dev-dependencies] +approx = "0.5.1" +criterion = "0.5.1" + + +[[bench]] +name = "distance_bench" +harness = false + +[[bench]] +name = "neighbor_bench" +harness = false diff --git a/rust/diskann/benches/distance_bench.rs b/rust/diskann/benches/distance_bench.rs new file mode 100644 index 000000000..885c95bac --- /dev/null +++ b/rust/diskann/benches/distance_bench.rs @@ -0,0 +1,47 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +use rand::{thread_rng, Rng}; +use vector::{FullPrecisionDistance, Metric}; + +// make sure the vector is 256-bit (32 bytes) aligned required by _mm256_load_ps +#[repr(C, align(32))] +struct Vector32ByteAligned { + v: [f32; 256], +} + +fn benchmark_l2_distance_float_rust(c: &mut Criterion) { + let (a, b) = prepare_random_aligned_vectors(); + let mut group = c.benchmark_group("avx-computation"); + group.sample_size(5000); + + group.bench_function("AVX Rust run", |f| { + f.iter(|| { + black_box(<[f32; 256]>::distance_compare( + black_box(&a.v), + black_box(&b.v), + Metric::L2, + )) + }) + }); +} + +// make sure the vector is 256-bit (32 bytes) aligned required by _mm256_load_ps +fn prepare_random_aligned_vectors() -> (Box, Box) { + let a = Box::new(Vector32ByteAligned { + v: [(); 256].map(|_| thread_rng().gen_range(0.0..100.0)), + }); + + let b = Box::new(Vector32ByteAligned { + v: [(); 256].map(|_| thread_rng().gen_range(0.0..100.0)), + }); + + (a, b) +} + +criterion_group!(benches, benchmark_l2_distance_float_rust,); +criterion_main!(benches); + diff --git a/rust/diskann/benches/kmeans_bench.rs b/rust/diskann/benches/kmeans_bench.rs new file mode 100644 index 000000000..c69c16a8c --- /dev/null +++ b/rust/diskann/benches/kmeans_bench.rs @@ -0,0 +1,70 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use criterion::{criterion_group, criterion_main, Criterion}; +use diskann::utils::k_means_clustering; +use rand::Rng; + +const NUM_POINTS: usize = 10000; +const DIM: usize = 100; +const NUM_CENTERS: usize = 256; +const MAX_KMEANS_REPS: usize = 12; + +fn benchmark_kmeans_rust(c: &mut Criterion) { + let mut rng = rand::thread_rng(); + let data: Vec = (0..NUM_POINTS * DIM) + .map(|_| rng.gen_range(-1.0..1.0)) + .collect(); + let centers: Vec = vec![0.0; NUM_CENTERS * DIM]; + + let mut group = c.benchmark_group("kmeans-computation"); + group.sample_size(500); + + group.bench_function("K-Means Rust run", |f| { + f.iter(|| { + // let mut centers_copy = centers.clone(); + let data_copy = data.clone(); + let mut centers_copy = centers.clone(); + k_means_clustering( + &data_copy, + NUM_POINTS, + DIM, + &mut centers_copy, + NUM_CENTERS, + MAX_KMEANS_REPS, + ) + }) + }); +} + +fn benchmark_kmeans_c(c: &mut Criterion) { + let mut rng = rand::thread_rng(); + let data: Vec = (0..NUM_POINTS * DIM) + .map(|_| rng.gen_range(-1.0..1.0)) + .collect(); + let centers: Vec = vec![0.0; NUM_CENTERS * DIM]; + + let mut group = c.benchmark_group("kmeans-computation"); + group.sample_size(500); + + group.bench_function("K-Means C++ Run", |f| { + f.iter(|| { + let data_copy = data.clone(); + let mut centers_copy = centers.clone(); + let _ = k_means_clustering( + data_copy.as_slice(), + NUM_POINTS, + DIM, + centers_copy.as_mut_slice(), + NUM_CENTERS, + MAX_KMEANS_REPS, + ); + }) + }); +} + +criterion_group!(benches, benchmark_kmeans_rust, benchmark_kmeans_c); + +criterion_main!(benches); + diff --git a/rust/diskann/benches/neighbor_bench.rs b/rust/diskann/benches/neighbor_bench.rs new file mode 100644 index 000000000..958acdce2 --- /dev/null +++ b/rust/diskann/benches/neighbor_bench.rs @@ -0,0 +1,49 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::time::Duration; + +use criterion::{black_box, criterion_group, criterion_main, Criterion}; + +use diskann::model::{Neighbor, NeighborPriorityQueue}; +use rand::distributions::{Distribution, Uniform}; +use rand::rngs::StdRng; +use rand::SeedableRng; + +fn benchmark_priority_queue_insert(c: &mut Criterion) { + let vec = generate_random_floats(); + let mut group = c.benchmark_group("neighborqueue-insert"); + group.measurement_time(Duration::from_secs(3)).sample_size(500); + + let mut queue = NeighborPriorityQueue::with_capacity(64_usize); + group.bench_function("Neighbor Priority Queue Insert", |f| { + f.iter(|| { + queue.clear(); + for n in vec.iter() { + queue.insert(*n); + } + + black_box(&1) + }); + }); +} + +fn generate_random_floats() -> Vec { + let seed: [u8; 32] = [73; 32]; + let mut rng: StdRng = SeedableRng::from_seed(seed); + let range = Uniform::new(0.0, 1.0); + let mut random_floats = Vec::with_capacity(100); + + for i in 0..100 { + let random_float = range.sample(&mut rng) as f32; + let n = Neighbor::new(i, random_float); + random_floats.push(n); + } + + random_floats +} + +criterion_group!(benches, benchmark_priority_queue_insert); +criterion_main!(benches); + diff --git a/rust/diskann/src/algorithm/mod.rs b/rust/diskann/src/algorithm/mod.rs new file mode 100644 index 000000000..87e377c8b --- /dev/null +++ b/rust/diskann/src/algorithm/mod.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod search; + +pub mod prune; diff --git a/rust/diskann/src/algorithm/prune/mod.rs b/rust/diskann/src/algorithm/prune/mod.rs new file mode 100644 index 000000000..4627eeb10 --- /dev/null +++ b/rust/diskann/src/algorithm/prune/mod.rs @@ -0,0 +1,6 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +pub mod prune; diff --git a/rust/diskann/src/algorithm/prune/prune.rs b/rust/diskann/src/algorithm/prune/prune.rs new file mode 100644 index 000000000..40fec4a5d --- /dev/null +++ b/rust/diskann/src/algorithm/prune/prune.rs @@ -0,0 +1,288 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use hashbrown::HashSet; +use vector::{FullPrecisionDistance, Metric}; + +use crate::common::{ANNError, ANNResult}; +use crate::index::InmemIndex; +use crate::model::graph::AdjacencyList; +use crate::model::neighbor::SortedNeighborVector; +use crate::model::scratch::InMemQueryScratch; +use crate::model::Neighbor; + +impl InmemIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + /// A method that occludes a list of neighbors based on some criteria + #[allow(clippy::too_many_arguments)] + fn occlude_list( + &self, + location: u32, + pool: &mut SortedNeighborVector, + alpha: f32, + degree: u32, + max_candidate_size: usize, + result: &mut AdjacencyList, + scratch: &mut InMemQueryScratch, + delete_set_ptr: Option<&HashSet>, + ) -> ANNResult<()> { + if pool.is_empty() { + return Ok(()); + } + + if !result.is_empty() { + return Err(ANNError::log_index_error( + "result is not empty.".to_string(), + )); + } + + // Truncate pool at max_candidate_size and initialize scratch spaces + if pool.len() > max_candidate_size { + pool.truncate(max_candidate_size); + } + + let occlude_factor = &mut scratch.occlude_factor; + + // occlude_list can be called with the same scratch more than once by + // search_for_point_and_add_link through inter_insert. + occlude_factor.clear(); + + // Initialize occlude_factor to pool.len() many 0.0 values for correctness + occlude_factor.resize(pool.len(), 0.0); + + let mut cur_alpha = 1.0; + while cur_alpha <= alpha && result.len() < degree as usize { + for (i, neighbor) in pool.iter().enumerate() { + if result.len() >= degree as usize { + break; + } + if occlude_factor[i] > cur_alpha { + continue; + } + // Set the entry to f32::MAX so that is not considered again + occlude_factor[i] = f32::MAX; + + // Add the entry to the result if its not been deleted, and doesn't + // add a self loop + if delete_set_ptr.map_or(true, |delete_set| !delete_set.contains(&neighbor.id)) + && neighbor.id != location + { + result.push(neighbor.id); + } + + // Update occlude factor for points from i+1 to pool.len() + for (j, neighbor2) in pool.iter().enumerate().skip(i + 1) { + if occlude_factor[j] > alpha { + continue; + } + + // todo - self.filtered_index + let djk = self.get_distance(neighbor2.id, neighbor.id)?; + match self.configuration.dist_metric { + Metric::L2 | Metric::Cosine => { + occlude_factor[j] = if djk == 0.0 { + f32::MAX + } else { + occlude_factor[j].max(neighbor2.distance / djk) + }; + } + } + } + } + + cur_alpha *= 1.2; + } + + Ok(()) + } + + /// Prunes the neighbors of a given data point based on some criteria and returns a list of pruned ids. + /// + /// # Arguments + /// + /// * `location` - The id of the data point whose neighbors are to be pruned. + /// * `pool` - A vector of neighbors to be pruned, sorted by distance to the query point. + /// * `pruned_list` - A vector to store the ids of the pruned neighbors. + /// * `scratch` - A mutable reference to a scratch space for in-memory queries. + /// + /// # Panics + /// + /// Panics if `pruned_list` contains more than `range` elements after pruning. + pub fn prune_neighbors( + &self, + location: u32, + pool: &mut Vec, + pruned_list: &mut AdjacencyList, + scratch: &mut InMemQueryScratch, + ) -> ANNResult<()> { + self.robust_prune( + location, + pool, + self.configuration.index_write_parameter.max_degree, + self.configuration.index_write_parameter.max_occlusion_size, + self.configuration.index_write_parameter.alpha, + pruned_list, + scratch, + ) + } + + /// Prunes the neighbors of a given data point based on some criteria and returns a list of pruned ids. + /// + /// # Arguments + /// + /// * `location` - The id of the data point whose neighbors are to be pruned. + /// * `pool` - A vector of neighbors to be pruned, sorted by distance to the query point. + /// * `range` - The maximum number of neighbors to keep after pruning. + /// * `max_candidate_size` - The maximum number of candidates to consider for pruning. + /// * `alpha` - A parameter that controls the occlusion pruning strategy. + /// * `pruned_list` - A vector to store the ids of the pruned neighbors. + /// * `scratch` - A mutable reference to a scratch space for in-memory queries. + /// + /// # Error + /// + /// Return error if `pruned_list` contains more than `range` elements after pruning. + #[allow(clippy::too_many_arguments)] + fn robust_prune( + &self, + location: u32, + pool: &mut Vec, + range: u32, + max_candidate_size: u32, + alpha: f32, + pruned_list: &mut AdjacencyList, + scratch: &mut InMemQueryScratch, + ) -> ANNResult<()> { + if pool.is_empty() { + // if the pool is empty, behave like a noop + pruned_list.clear(); + return Ok(()); + } + + // If using _pq_build, over-write the PQ distances with actual distances + // todo : pq_dist + + // sort the pool based on distance to query and prune it with occlude_list + let mut pool = SortedNeighborVector::new(pool); + pruned_list.clear(); + + self.occlude_list( + location, + &mut pool, + alpha, + range, + max_candidate_size as usize, + pruned_list, + scratch, + Option::None, + )?; + + if pruned_list.len() > range as usize { + return Err(ANNError::log_index_error(format!( + "pruned_list's len {} is over range {}.", + pruned_list.len(), + range + ))); + } + + if self.configuration.index_write_parameter.saturate_graph && alpha > 1.0f32 { + for neighbor in pool.iter() { + if pruned_list.len() >= (range as usize) { + break; + } + if !pruned_list.contains(&neighbor.id) && neighbor.id != location { + pruned_list.push(neighbor.id); + } + } + } + + Ok(()) + } + + /// A method that inserts a point n into the graph of its neighbors and their neighbors, + /// pruning the graph if necessary to keep it within the specified range + /// * `n` - The index of the new point + /// * `pruned_list` is a vector of the neighbors of n that have been pruned by a previous step + /// * `range` is the target number of neighbors for each point + /// * `scratch` is a mutable reference to a scratch space that can be reused for intermediate computations + pub fn inter_insert( + &self, + n: u32, + pruned_list: &Vec, + range: u32, + scratch: &mut InMemQueryScratch, + ) -> ANNResult<()> { + // Borrow the pruned_list as a source pool of neighbors + let src_pool = pruned_list; + + if src_pool.is_empty() { + return Err(ANNError::log_index_error("src_pool is empty.".to_string())); + } + + for &vertex_id in src_pool { + // vertex is the index of a neighbor of n + // Assert that vertex is within the valid range of points + if (vertex_id as usize) + >= self.configuration.max_points + self.configuration.num_frozen_pts + { + return Err(ANNError::log_index_error(format!( + "vertex_id {} is out of valid range of points {}", + vertex_id, + self.configuration.max_points + self.configuration.num_frozen_pts, + ))); + } + + let neighbors = self.add_to_neighbors(vertex_id, n, range)?; + + if let Some(copy_of_neighbors) = neighbors { + // Pruning is needed, create a dummy set and a dummy vector to store the unique neighbors of vertex_id + let mut dummy_pool = self.get_unique_neighbors(©_of_neighbors, vertex_id)?; + + // Create a new vector to store the pruned neighbors of vertex_id + let mut new_out_neighbors = + AdjacencyList::for_range(self.configuration.write_range()); + // Prune the neighbors of vertex_id using a helper method + self.prune_neighbors(vertex_id, &mut dummy_pool, &mut new_out_neighbors, scratch)?; + + self.set_neighbors(vertex_id, new_out_neighbors)?; + } + } + + Ok(()) + } + + /// Adds a node to the list of neighbors for the given node. + /// + /// # Arguments + /// + /// * `vertex_id` - The ID of the node to add the neighbor to. + /// * `node_id` - The ID of the node to add. + /// * `range` - The range of the graph. + /// + /// # Return + /// + /// Returns `None` if the node is already in the list of neighbors, or a `Vec` containing the updated list of neighbors if the list of neighbors is full. + fn add_to_neighbors( + &self, + vertex_id: u32, + node_id: u32, + range: u32, + ) -> ANNResult>> { + // vertex contains a vector of the neighbors of vertex_id + let mut vertex_guard = self.final_graph.write_vertex_and_neighbors(vertex_id)?; + + Ok(vertex_guard.add_to_neighbors(node_id, range)) + } + + fn set_neighbors(&self, vertex_id: u32, new_out_neighbors: AdjacencyList) -> ANNResult<()> { + // vertex contains a vector of the neighbors of vertex_id + let mut vertex_guard = self.final_graph.write_vertex_and_neighbors(vertex_id)?; + + vertex_guard.set_neighbors(new_out_neighbors); + Ok(()) + } +} + diff --git a/rust/diskann/src/algorithm/search/mod.rs b/rust/diskann/src/algorithm/search/mod.rs new file mode 100644 index 000000000..9f007ab69 --- /dev/null +++ b/rust/diskann/src/algorithm/search/mod.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +pub mod search; + diff --git a/rust/diskann/src/algorithm/search/search.rs b/rust/diskann/src/algorithm/search/search.rs new file mode 100644 index 000000000..ab6d01696 --- /dev/null +++ b/rust/diskann/src/algorithm/search/search.rs @@ -0,0 +1,359 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Search algorithm for index construction and query + +use crate::common::{ANNError, ANNResult}; +use crate::index::InmemIndex; +use crate::model::{scratch::InMemQueryScratch, Neighbor, Vertex}; +use hashbrown::hash_set::Entry::*; +use vector::FullPrecisionDistance; + +impl InmemIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + /// Search for query using given L value, for benchmarking purposes + /// # Arguments + /// * `query` - query vertex + /// * `scratch` - in-memory query scratch + /// * `search_list_size` - search list size to use for the benchmark + pub fn search_with_l_override( + &self, + query: &Vertex, + scratch: &mut InMemQueryScratch, + search_list_size: usize, + ) -> ANNResult { + let init_ids = self.get_init_ids()?; + self.init_graph_for_point(query, init_ids, scratch)?; + // Scratch is created using largest L val from search_memory_index, so we artifically make it smaller here + // This allows us to use the same scratch for all L values without having to rebuild the query scratch + scratch.best_candidates.set_capacity(search_list_size); + let (_, cmp) = self.greedy_search(query, scratch)?; + + Ok(cmp) + } + + /// search for point + /// # Arguments + /// * `query` - query vertex + /// * `scratch` - in-memory query scratch + /// TODO: use_filter, filteredLindex + pub fn search_for_point( + &self, + query: &Vertex, + scratch: &mut InMemQueryScratch, + ) -> ANNResult> { + let init_ids = self.get_init_ids()?; + self.init_graph_for_point(query, init_ids, scratch)?; + let (mut visited_nodes, _) = self.greedy_search(query, scratch)?; + + visited_nodes.retain(|&element| element.id != query.vertex_id()); + Ok(visited_nodes) + } + + /// Returns the locations of start point and frozen points suitable for use with iterate_to_fixed_point. + fn get_init_ids(&self) -> ANNResult> { + let mut init_ids = Vec::with_capacity(1 + self.configuration.num_frozen_pts); + init_ids.push(self.start); + + for frozen in self.configuration.max_points + ..(self.configuration.max_points + self.configuration.num_frozen_pts) + { + let frozen_u32 = frozen.try_into()?; + if frozen_u32 != self.start { + init_ids.push(frozen_u32); + } + } + + Ok(init_ids) + } + + /// Initialize graph for point + /// # Arguments + /// * `query` - query vertex + /// * `init_ids` - initial nodes from which search starts + /// * `scratch` - in-memory query scratch + /// * `search_list_size_override` - override for search list size in index config + fn init_graph_for_point( + &self, + query: &Vertex, + init_ids: Vec, + scratch: &mut InMemQueryScratch, + ) -> ANNResult<()> { + scratch + .best_candidates + .reserve(self.configuration.index_write_parameter.search_list_size as usize); + scratch.query.memcpy(query.vector())?; + + if !scratch.id_scratch.is_empty() { + return Err(ANNError::log_index_error( + "id_scratch is not empty.".to_string(), + )); + } + + let query_vertex = Vertex::::try_from((&scratch.query[..], query.vertex_id())) + .map_err(|err| { + ANNError::log_index_error(format!( + "TryFromSliceError: failed to get Vertex for query, err={}", + err + )) + })?; + + for id in init_ids { + if (id as usize) >= self.configuration.max_points + self.configuration.num_frozen_pts { + return Err(ANNError::log_index_error(format!( + "vertex_id {} is out of valid range of points {}", + id, + self.configuration.max_points + self.configuration.num_frozen_pts + ))); + } + + if let Vacant(entry) = scratch.node_visited_robinset.entry(id) { + entry.insert(); + + let vertex = self.dataset.get_vertex(id)?; + + let distance = vertex.compare(&query_vertex, self.configuration.dist_metric); + let neighbor = Neighbor::new(id, distance); + scratch.best_candidates.insert(neighbor); + } + } + + Ok(()) + } + + /// GreedySearch against query node + /// Returns visited nodes + /// # Arguments + /// * `query` - query vertex + /// * `scratch` - in-memory query scratch + /// TODO: use_filter, filter_label, search_invocation + fn greedy_search( + &self, + query: &Vertex, + scratch: &mut InMemQueryScratch, + ) -> ANNResult<(Vec, u32)> { + let mut visited_nodes = + Vec::with_capacity((3 * scratch.candidate_size + scratch.max_degree) as usize); + + // TODO: uncomment hops? + // let mut hops: u32 = 0; + let mut cmps: u32 = 0; + + let query_vertex = Vertex::::try_from((&scratch.query[..], query.vertex_id())) + .map_err(|err| { + ANNError::log_index_error(format!( + "TryFromSliceError: failed to get Vertex for query, err={}", + err + )) + })?; + + while scratch.best_candidates.has_notvisited_node() { + let closest_node = scratch.best_candidates.closest_notvisited(); + + // Add node to visited nodes to create pool for prune later + // TODO: search_invocation and use_filter + visited_nodes.push(closest_node); + + // Find which of the nodes in des have not been visited before + scratch.id_scratch.clear(); + + let max_vertex_id = self.configuration.max_points + self.configuration.num_frozen_pts; + + for id in self + .final_graph + .read_vertex_and_neighbors(closest_node.id)? + .get_neighbors() + { + let current_vertex_id = *id; + debug_assert!( + (current_vertex_id as usize) < max_vertex_id, + "current_vertex_id {} is out of valid range of points {}", + current_vertex_id, + max_vertex_id + ); + if current_vertex_id as usize >= max_vertex_id { + continue; + } + + // quickly de-dup. Remember, we are in a read lock + // we want to exit out of it quickly + if scratch.node_visited_robinset.insert(current_vertex_id) { + scratch.id_scratch.push(current_vertex_id); + } + } + + let len = scratch.id_scratch.len(); + for (m, &id) in scratch.id_scratch.iter().enumerate() { + if m + 1 < len { + let next_node = unsafe { *scratch.id_scratch.get_unchecked(m + 1) }; + self.dataset.prefetch_vector(next_node); + } + + let vertex = self.dataset.get_vertex(id)?; + let distance = query_vertex.compare(&vertex, self.configuration.dist_metric); + + // Insert pairs into the pool of candidates + scratch.best_candidates.insert(Neighbor::new(id, distance)); + } + + cmps += len as u32; + } + + Ok((visited_nodes, cmps)) + } +} + +#[cfg(test)] +mod search_test { + use vector::Metric; + + use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder; + use crate::model::graph::AdjacencyList; + use crate::model::IndexConfiguration; + use crate::test_utils::inmem_index_initialization::create_index_with_test_data; + + use super::*; + + #[test] + fn get_init_ids_no_forzen_pts() { + let index_write_parameters = IndexWriteParametersBuilder::new(50, 4) + .with_alpha(1.2) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + 256, + 256, + 256, + false, + 0, + false, + 0, + 1f32, + index_write_parameters, + ); + + let index = InmemIndex::::new(config).unwrap(); + let init_ids = index.get_init_ids().unwrap(); + assert_eq!(init_ids.len(), 1); + assert_eq!(init_ids[0], 256); + } + + #[test] + fn get_init_ids_with_forzen_pts() { + let index_write_parameters = IndexWriteParametersBuilder::new(50, 4) + .with_alpha(1.2) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + 256, + 256, + 256, + false, + 0, + false, + 2, + 1f32, + index_write_parameters, + ); + + let index = InmemIndex::::new(config).unwrap(); + let init_ids = index.get_init_ids().unwrap(); + assert_eq!(init_ids.len(), 2); + assert_eq!(init_ids[0], 256); + assert_eq!(init_ids[1], 257); + } + + #[test] + fn search_for_point_initial_call() { + let index = create_index_with_test_data(); + let query = index.dataset.get_vertex(0).unwrap(); + + let mut scratch = InMemQueryScratch::new( + index.configuration.index_write_parameter.search_list_size, + &index.configuration.index_write_parameter, + false, + ) + .unwrap(); + let visited_nodes = index.search_for_point(&query, &mut scratch).unwrap(); + assert_eq!(visited_nodes.len(), 1); + assert_eq!(scratch.best_candidates.size(), 1); + assert_eq!(scratch.best_candidates[0].id, 72); + assert_eq!(scratch.best_candidates[0].distance, 125678.0_f32); + assert!(scratch.best_candidates[0].visited); + } + + fn set_neighbors(index: &InmemIndex, vertex_id: u32, neighbors: Vec) { + index + .final_graph + .write_vertex_and_neighbors(vertex_id) + .unwrap() + .set_neighbors(AdjacencyList::from(neighbors)); + } + #[test] + fn search_for_point_works_with_edges() { + let index = create_index_with_test_data(); + let query = index.dataset.get_vertex(14).unwrap(); + + set_neighbors(&index, 0, vec![12, 72, 5, 9]); + set_neighbors(&index, 1, vec![2, 12, 10, 4]); + set_neighbors(&index, 2, vec![1, 72, 9]); + set_neighbors(&index, 3, vec![13, 6, 5, 11]); + set_neighbors(&index, 4, vec![1, 3, 7, 9]); + set_neighbors(&index, 5, vec![3, 0, 8, 11, 13]); + set_neighbors(&index, 6, vec![3, 72, 7, 10, 13]); + set_neighbors(&index, 7, vec![72, 4, 6]); + set_neighbors(&index, 8, vec![72, 5, 9, 12]); + set_neighbors(&index, 9, vec![8, 4, 0, 2]); + set_neighbors(&index, 10, vec![72, 1, 9, 6]); + set_neighbors(&index, 11, vec![3, 0, 5]); + set_neighbors(&index, 12, vec![1, 0, 8, 9]); + set_neighbors(&index, 13, vec![3, 72, 5, 6]); + set_neighbors(&index, 72, vec![7, 2, 10, 8, 13]); + + let mut scratch = InMemQueryScratch::new( + index.configuration.index_write_parameter.search_list_size, + &index.configuration.index_write_parameter, + false, + ) + .unwrap(); + let visited_nodes = index.search_for_point(&query, &mut scratch).unwrap(); + assert_eq!(visited_nodes.len(), 15); + assert_eq!(scratch.best_candidates.size(), 15); + assert_eq!(scratch.best_candidates[0].id, 2); + assert_eq!(scratch.best_candidates[0].distance, 120899.0_f32); + assert_eq!(scratch.best_candidates[1].id, 8); + assert_eq!(scratch.best_candidates[1].distance, 145538.0_f32); + assert_eq!(scratch.best_candidates[2].id, 72); + assert_eq!(scratch.best_candidates[2].distance, 146046.0_f32); + assert_eq!(scratch.best_candidates[3].id, 4); + assert_eq!(scratch.best_candidates[3].distance, 148462.0_f32); + assert_eq!(scratch.best_candidates[4].id, 7); + assert_eq!(scratch.best_candidates[4].distance, 148912.0_f32); + assert_eq!(scratch.best_candidates[5].id, 10); + assert_eq!(scratch.best_candidates[5].distance, 154570.0_f32); + assert_eq!(scratch.best_candidates[6].id, 1); + assert_eq!(scratch.best_candidates[6].distance, 159448.0_f32); + assert_eq!(scratch.best_candidates[7].id, 12); + assert_eq!(scratch.best_candidates[7].distance, 170698.0_f32); + assert_eq!(scratch.best_candidates[8].id, 9); + assert_eq!(scratch.best_candidates[8].distance, 177205.0_f32); + assert_eq!(scratch.best_candidates[9].id, 0); + assert_eq!(scratch.best_candidates[9].distance, 259996.0_f32); + assert_eq!(scratch.best_candidates[10].id, 6); + assert_eq!(scratch.best_candidates[10].distance, 371819.0_f32); + assert_eq!(scratch.best_candidates[11].id, 5); + assert_eq!(scratch.best_candidates[11].distance, 385240.0_f32); + assert_eq!(scratch.best_candidates[12].id, 3); + assert_eq!(scratch.best_candidates[12].distance, 413899.0_f32); + assert_eq!(scratch.best_candidates[13].id, 13); + assert_eq!(scratch.best_candidates[13].distance, 416386.0_f32); + assert_eq!(scratch.best_candidates[14].id, 11); + assert_eq!(scratch.best_candidates[14].distance, 449266.0_f32); + } +} diff --git a/rust/diskann/src/common/aligned_allocator.rs b/rust/diskann/src/common/aligned_allocator.rs new file mode 100644 index 000000000..6164a1f40 --- /dev/null +++ b/rust/diskann/src/common/aligned_allocator.rs @@ -0,0 +1,281 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Aligned allocator + +use std::alloc::Layout; +use std::ops::{Deref, DerefMut, Range}; +use std::ptr::copy_nonoverlapping; + +use super::{ANNResult, ANNError}; + +#[derive(Debug)] +/// A box that holds a slice but is aligned to the specified layout. +/// +/// This type is useful for working with types that require a certain alignment, +/// such as SIMD vectors or FFI structs. It allocates memory using the global allocator +/// and frees it when dropped. It also implements Deref and DerefMut to allow access +/// to the underlying slice. +pub struct AlignedBoxWithSlice { + /// The layout of the allocated memory. + layout: Layout, + + /// The slice that points to the allocated memory. + val: Box<[T]>, +} + +impl AlignedBoxWithSlice { + /// Creates a new `AlignedBoxWithSlice` with the given capacity and alignment. + /// The allocated memory are set to 0. + /// + /// # Error + /// + /// Return IndexError if the alignment is not a power of two or if the layout is invalid. + /// + /// This function is unsafe because it allocates uninitialized memory and casts it to + /// a slice of `T`. The caller must ensure that the capacity and alignment are valid + /// for the type `T` and that the memory is initialized before accessing the elements + /// of the slice. + pub fn new(capacity: usize, alignment: usize) -> ANNResult { + let allocsize = capacity.checked_mul(std::mem::size_of::()) + .ok_or_else(|| ANNError::log_index_error("capacity overflow".to_string()))?; + let layout = Layout::from_size_align(allocsize, alignment) + .map_err(ANNError::log_mem_alloc_layout_error)?; + + let val = unsafe { + let mem = std::alloc::alloc_zeroed(layout); + let ptr = mem as *mut T; + let slice = std::slice::from_raw_parts_mut(ptr, capacity); + std::boxed::Box::from_raw(slice) + }; + + Ok(Self { layout, val }) + } + + /// Returns a reference to the slice. + pub fn as_slice(&self) -> &[T] { + &self.val + } + + /// Returns a mutable reference to the slice. + pub fn as_mut_slice(&mut self) -> &mut [T] { + &mut self.val + } + + /// Copies data from the source slice to the destination box. + pub fn memcpy(&mut self, src: &[T]) -> ANNResult<()> { + if src.len() > self.val.len() { + return Err(ANNError::log_index_error(format!("source slice is too large (src:{}, dst:{})", src.len(), self.val.len()))); + } + + // Check that they don't overlap + let src_ptr = src.as_ptr(); + let src_end = unsafe { src_ptr.add(src.len()) }; + let dst_ptr = self.val.as_mut_ptr(); + let dst_end = unsafe { dst_ptr.add(self.val.len()) }; + + if src_ptr < dst_end && src_end > dst_ptr { + return Err(ANNError::log_index_error("Source and destination overlap".to_string())); + } + + unsafe { + copy_nonoverlapping(src.as_ptr(), self.val.as_mut_ptr(), src.len()); + } + + Ok(()) + } + + /// Split the range of memory into nonoverlapping mutable slices. + /// The number of returned slices is (range length / slice_len) and each has a length of slice_len. + pub fn split_into_nonoverlapping_mut_slices(&mut self, range: Range, slice_len: usize) -> ANNResult> { + if range.len() % slice_len != 0 || range.end > self.len() { + return Err(ANNError::log_index_error(format!( + "Cannot split range ({:?}) of AlignedBoxWithSlice (len: {}) into nonoverlapping mutable slices with length {}", + range, + self.len(), + slice_len, + ))); + } + + let mut slices = Vec::with_capacity(range.len() / slice_len); + let mut remaining_slice = &mut self.val[range]; + + while remaining_slice.len() >= slice_len { + let (left, right) = remaining_slice.split_at_mut(slice_len); + slices.push(left); + remaining_slice = right; + } + + Ok(slices) + } +} + + +impl Drop for AlignedBoxWithSlice { + /// Frees the memory allocated for the slice using the global allocator. + fn drop(&mut self) { + let val = std::mem::take(&mut self.val); + let mut val2 = std::mem::ManuallyDrop::new(val); + let ptr = val2.as_mut_ptr(); + + unsafe { + // let nonNull = NonNull::new_unchecked(ptr as *mut u8); + std::alloc::dealloc(ptr as *mut u8, self.layout) + } + } +} + +impl Deref for AlignedBoxWithSlice { + type Target = [T]; + + fn deref(&self) -> &Self::Target { + &self.val + } +} + +impl DerefMut for AlignedBoxWithSlice { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.val + } +} + +#[cfg(test)] +mod tests { + use rand::Rng; + + use crate::utils::is_aligned; + + use super::*; + + #[test] + fn create_alignedvec_works_32() { + (0..100).for_each(|_| { + let size = 1_000_000; + println!("Attempting {}", size); + let data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + assert_eq!(data.len(), size, "Capacity should match"); + + let ptr = data.as_ptr() as usize; + assert_eq!(ptr % 32, 0, "Ptr should be aligned to 32"); + + // assert that the slice is initialized. + (0..size).for_each(|i| { + assert_eq!(data[i], f32::default()); + }); + + drop(data); + }); + } + + #[test] + fn create_alignedvec_works_256() { + let mut rng = rand::thread_rng(); + + (0..100).for_each(|_| { + let n = rng.gen::(); + let size = usize::from(n) + 1; + println!("Attempting {}", size); + let data = AlignedBoxWithSlice::::new(size, 256).unwrap(); + assert_eq!(data.len(), size, "Capacity should match"); + + let ptr = data.as_ptr() as usize; + assert_eq!(ptr % 256, 0, "Ptr should be aligned to 32"); + + // assert that the slice is initialized. + (0..size).for_each(|i| { + assert_eq!(data[i], u8::default()); + }); + + drop(data); + }); + } + + #[test] + fn as_slice_test() { + let size = 1_000_000; + let data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + // assert that the slice is initialized. + (0..size).for_each(|i| { + assert_eq!(data[i], f32::default()); + }); + + let slice = data.as_slice(); + (0..size).for_each(|i| { + assert_eq!(slice[i], f32::default()); + }); + } + + #[test] + fn as_mut_slice_test() { + let size = 1_000_000; + let mut data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + let mut_slice = data.as_mut_slice(); + (0..size).for_each(|i| { + assert_eq!(mut_slice[i], f32::default()); + }); + } + + #[test] + fn memcpy_test() { + let size = 1_000_000; + let mut data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + let mut destination = AlignedBoxWithSlice::::new(size-2, 32).unwrap(); + let mut_destination = destination.as_mut_slice(); + data.memcpy(mut_destination).unwrap(); + (0..size-2).for_each(|i| { + assert_eq!(data[i], mut_destination[i]); + }); + } + + #[test] + #[should_panic(expected = "source slice is too large (src:1000000, dst:999998)")] + fn memcpy_panic_test() { + let size = 1_000_000; + let mut data = AlignedBoxWithSlice::::new(size-2, 32).unwrap(); + let mut destination = AlignedBoxWithSlice::::new(size, 32).unwrap(); + let mut_destination = destination.as_mut_slice(); + data.memcpy(mut_destination).unwrap(); + } + + #[test] + fn is_aligned_test() { + assert!(is_aligned(256,256)); + assert!(!is_aligned(255,256)); + } + + #[test] + fn split_into_nonoverlapping_mut_slices_test() { + let size = 10; + let slice_len = 2; + let mut data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + let slices = data.split_into_nonoverlapping_mut_slices(2..8, slice_len).unwrap(); + assert_eq!(slices.len(), 3); + for (i, slice) in slices.into_iter().enumerate() { + assert_eq!(slice.len(), slice_len); + slice[0] = i as f32 + 1.0; + slice[1] = i as f32 + 1.0; + } + let expected_arr = [0.0f32, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 0.0, 0.0]; + assert_eq!(data.as_ref(), &expected_arr); + } + + #[test] + fn split_into_nonoverlapping_mut_slices_error_when_indivisible() { + let size = 10; + let slice_len = 2; + let range = 2..7; + let mut data = AlignedBoxWithSlice::::new(size, 32).unwrap(); + let result = data.split_into_nonoverlapping_mut_slices(range.clone(), slice_len); + let expected_err_str = format!( + "IndexError: Cannot split range ({:?}) of AlignedBoxWithSlice (len: {}) into nonoverlapping mutable slices with length {}", + range, + size, + slice_len, + ); + assert!(result.is_err_and(|e| e.to_string() == expected_err_str)); + } +} + diff --git a/rust/diskann/src/common/ann_result.rs b/rust/diskann/src/common/ann_result.rs new file mode 100644 index 000000000..69fcf03f6 --- /dev/null +++ b/rust/diskann/src/common/ann_result.rs @@ -0,0 +1,179 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::alloc::LayoutError; +use std::array::TryFromSliceError; +use std::io; +use std::num::TryFromIntError; + +use logger::error_logger::log_error; +use logger::log_error::LogError; + +/// Result +pub type ANNResult = Result; + +/// DiskANN Error +/// ANNError is `Send` (i.e., safe to send across threads) +#[derive(thiserror::Error, Debug)] +pub enum ANNError { + /// Index construction and search error + #[error("IndexError: {err}")] + IndexError { err: String }, + + /// Index configuration error + #[error("IndexConfigError: {parameter} is invalid, err={err}")] + IndexConfigError { parameter: String, err: String }, + + /// Integer conversion error + #[error("TryFromIntError: {err}")] + TryFromIntError { + #[from] + err: TryFromIntError, + }, + + /// IO error + #[error("IOError: {err}")] + IOError { + #[from] + err: io::Error, + }, + + /// Layout error in memory allocation + #[error("MemoryAllocLayoutError: {err}")] + MemoryAllocLayoutError { + #[from] + err: LayoutError, + }, + + /// PoisonError which can be returned whenever a lock is acquired + /// Both Mutexes and RwLocks are poisoned whenever a thread fails while the lock is held + #[error("LockPoisonError: {err}")] + LockPoisonError { err: String }, + + /// DiskIOAlignmentError which can be returned when calling windows API CreateFileA for the disk index file fails. + #[error("DiskIOAlignmentError: {err}")] + DiskIOAlignmentError { err: String }, + + /// Logging error + #[error("LogError: {err}")] + LogError { + #[from] + err: LogError, + }, + + // PQ construction error + // Error happened when we construct PQ pivot or PQ compressed table + #[error("PQError: {err}")] + PQError { err: String }, + + /// Array conversion error + #[error("Error try creating array from slice: {err}")] + TryFromSliceError { + #[from] + err: TryFromSliceError, + }, +} + +impl ANNError { + /// Create, log and return IndexError + #[inline] + pub fn log_index_error(err: String) -> Self { + let ann_err = ANNError::IndexError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return IndexConfigError + #[inline] + pub fn log_index_config_error(parameter: String, err: String) -> Self { + let ann_err = ANNError::IndexConfigError { parameter, err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return TryFromIntError + #[inline] + pub fn log_try_from_int_error(err: TryFromIntError) -> Self { + let ann_err = ANNError::TryFromIntError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return IOError + #[inline] + pub fn log_io_error(err: io::Error) -> Self { + let ann_err = ANNError::IOError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return DiskIOAlignmentError + /// #[inline] + pub fn log_disk_io_request_alignment_error(err: String) -> Self { + let ann_err: ANNError = ANNError::DiskIOAlignmentError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return IOError + #[inline] + pub fn log_mem_alloc_layout_error(err: LayoutError) -> Self { + let ann_err = ANNError::MemoryAllocLayoutError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return LockPoisonError + #[inline] + pub fn log_lock_poison_error(err: String) -> Self { + let ann_err = ANNError::LockPoisonError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return PQError + #[inline] + pub fn log_pq_error(err: String) -> Self { + let ann_err = ANNError::PQError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } + + /// Create, log and return TryFromSliceError + #[inline] + pub fn log_try_from_slice_error(err: TryFromSliceError) -> Self { + let ann_err = ANNError::TryFromSliceError { err }; + match log_error(ann_err.to_string()) { + Ok(()) => ann_err, + Err(log_err) => ANNError::LogError { err: log_err }, + } + } +} + +#[cfg(test)] +mod ann_result_test { + use super::*; + + #[test] + fn ann_err_is_send() { + fn assert_send() {} + assert_send::(); + } +} diff --git a/rust/diskann/src/common/mod.rs b/rust/diskann/src/common/mod.rs new file mode 100644 index 000000000..d9da72bbc --- /dev/null +++ b/rust/diskann/src/common/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod aligned_allocator; +pub use aligned_allocator::AlignedBoxWithSlice; + +mod ann_result; +pub use ann_result::*; diff --git a/rust/diskann/src/index/disk_index/ann_disk_index.rs b/rust/diskann/src/index/disk_index/ann_disk_index.rs new file mode 100644 index 000000000..a6e053e17 --- /dev/null +++ b/rust/diskann/src/index/disk_index/ann_disk_index.rs @@ -0,0 +1,54 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_docs)] + +//! ANN disk index abstraction + +use vector::FullPrecisionDistance; + +use crate::model::{IndexConfiguration, DiskIndexBuildParameters}; +use crate::storage::DiskIndexStorage; +use crate::model::vertex::{DIM_128, DIM_256, DIM_104}; + +use crate::common::{ANNResult, ANNError}; + +use super::DiskIndex; + +/// ANN disk index abstraction for custom +pub trait ANNDiskIndex : Sync + Send +where T : Default + Copy + Sync + Send + Into + { + /// Build index + fn build(&mut self, codebook_prefix: &str) -> ANNResult<()>; +} + +/// Create Index based on configuration +pub fn create_disk_index<'a, T>( + disk_build_param: Option, + config: IndexConfiguration, + storage: DiskIndexStorage, +) -> ANNResult + 'a>> +where + T: Default + Copy + Sync + Send + Into + 'a, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + match config.aligned_dim { + DIM_104 => { + let index = Box::new(DiskIndex::::new(disk_build_param, config, storage)); + Ok(index as Box>) + }, + DIM_128 => { + let index = Box::new(DiskIndex::::new(disk_build_param, config, storage)); + Ok(index as Box>) + }, + DIM_256 => { + let index = Box::new(DiskIndex::::new(disk_build_param, config, storage)); + Ok(index as Box>) + }, + _ => Err(ANNError::log_index_error(format!("Invalid dimension: {}", config.aligned_dim))), + } +} diff --git a/rust/diskann/src/index/disk_index/disk_index.rs b/rust/diskann/src/index/disk_index/disk_index.rs new file mode 100644 index 000000000..16f0d5969 --- /dev/null +++ b/rust/diskann/src/index/disk_index/disk_index.rs @@ -0,0 +1,161 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::mem; + +use logger::logger::indexlog::DiskIndexConstructionCheckpoint; +use vector::FullPrecisionDistance; + +use crate::common::{ANNResult, ANNError}; +use crate::index::{InmemIndex, ANNInmemIndex}; +use crate::instrumentation::DiskIndexBuildLogger; +use crate::model::configuration::DiskIndexBuildParameters; +use crate::model::{IndexConfiguration, MAX_PQ_TRAINING_SET_SIZE, MAX_PQ_CHUNKS, generate_quantized_data, GRAPH_SLACK_FACTOR}; +use crate::storage::DiskIndexStorage; +use crate::utils::set_rayon_num_threads; + +use super::ann_disk_index::ANNDiskIndex; + +pub const OVERHEAD_FACTOR: f64 = 1.1f64; + +pub const MAX_SAMPLE_POINTS_FOR_WARMUP: usize = 100_000; + +pub struct DiskIndex +where + [T; N]: FullPrecisionDistance, +{ + /// Parameters for index construction + /// None for query path + disk_build_param: Option, + + configuration: IndexConfiguration, + + pub storage: DiskIndexStorage, +} + +impl DiskIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + pub fn new( + disk_build_param: Option, + configuration: IndexConfiguration, + storage: DiskIndexStorage, + ) -> Self { + Self { + disk_build_param, + configuration, + storage, + } + } + + pub fn disk_build_param(&self) -> &Option { + &self.disk_build_param + } + + pub fn index_configuration(&self) -> &IndexConfiguration { + &self.configuration + } + + fn build_inmem_index(&self, num_points: usize, data_path: &str, inmem_index_path: &str) -> ANNResult<()> { + let estimated_index_ram = self.estimate_ram_usage(num_points); + if estimated_index_ram >= self.fetch_disk_build_param()?.index_build_ram_limit() * 1024_f64 * 1024_f64 * 1024_f64 { + return Err(ANNError::log_index_error(format!( + "Insufficient memory budget for index build, index_build_ram_limit={}GB estimated_index_ram={}GB", + self.fetch_disk_build_param()?.index_build_ram_limit(), + estimated_index_ram / (1024_f64 * 1024_f64 * 1024_f64), + ))); + } + + let mut index = InmemIndex::::new(self.configuration.clone())?; + index.build(data_path, num_points)?; + index.save(inmem_index_path)?; + + Ok(()) + } + + #[inline] + fn estimate_ram_usage(&self, size: usize) -> f64 { + let degree = self.configuration.index_write_parameter.max_degree as usize; + let datasize = mem::size_of::(); + + let dataset_size = (size * N * datasize) as f64; + let graph_size = (size * degree * mem::size_of::()) as f64 * GRAPH_SLACK_FACTOR; + + OVERHEAD_FACTOR * (dataset_size + graph_size) + } + + #[inline] + fn fetch_disk_build_param(&self) -> ANNResult<&DiskIndexBuildParameters> { + self.disk_build_param + .as_ref() + .ok_or_else(|| ANNError::log_index_config_error( + "disk_build_param".to_string(), + "disk_build_param is None".to_string())) + } +} + +impl ANNDiskIndex for DiskIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + fn build(&mut self, codebook_prefix: &str) -> ANNResult<()> { + if self.configuration.index_write_parameter.num_threads > 0 { + set_rayon_num_threads(self.configuration.index_write_parameter.num_threads); + } + + println!("Starting index build: R={} L={} Query RAM budget={} Indexing RAM budget={} T={}", + self.configuration.index_write_parameter.max_degree, + self.configuration.index_write_parameter.search_list_size, + self.fetch_disk_build_param()?.search_ram_limit(), + self.fetch_disk_build_param()?.index_build_ram_limit(), + self.configuration.index_write_parameter.num_threads + ); + + let mut logger = DiskIndexBuildLogger::new(DiskIndexConstructionCheckpoint::PqConstruction); + + // PQ memory consumption = PQ pivots + PQ compressed table + // PQ pivots: dim * num_centroids * sizeof::() + // PQ compressed table: num_pts * num_pq_chunks * (dim / num_pq_chunks) * sizeof::() + // * Because num_centroids is 256, centroid id can be represented by u8 + let num_points = self.configuration.max_points; + let dim = self.configuration.dim; + let p_val = MAX_PQ_TRAINING_SET_SIZE / (num_points as f64); + let mut num_pq_chunks = ((self.fetch_disk_build_param()?.search_ram_limit() / (num_points as f64)).floor()) as usize; + num_pq_chunks = if num_pq_chunks == 0 { 1 } else { num_pq_chunks }; + num_pq_chunks = if num_pq_chunks > dim { dim } else { num_pq_chunks }; + num_pq_chunks = if num_pq_chunks > MAX_PQ_CHUNKS { MAX_PQ_CHUNKS } else { num_pq_chunks }; + + println!("Compressing {}-dimensional data into {} bytes per vector.", dim, num_pq_chunks); + + // TODO: Decouple PQ from file access + generate_quantized_data::( + p_val, + num_pq_chunks, + codebook_prefix, + self.storage.get_pq_storage(), + )?; + logger.log_checkpoint(DiskIndexConstructionCheckpoint::InmemIndexBuild)?; + + // TODO: Decouple index from file access + let inmem_index_path = self.storage.index_path_prefix().clone() + "_mem.index"; + self.build_inmem_index(num_points, self.storage.dataset_file(), inmem_index_path.as_str())?; + logger.log_checkpoint(DiskIndexConstructionCheckpoint::DiskLayout)?; + + self.storage.create_disk_layout()?; + logger.log_checkpoint(DiskIndexConstructionCheckpoint::None)?; + + let ten_percent_points = ((num_points as f64) * 0.1_f64).ceil(); + let num_sample_points = if ten_percent_points > (MAX_SAMPLE_POINTS_FOR_WARMUP as f64) { MAX_SAMPLE_POINTS_FOR_WARMUP as f64 } else { ten_percent_points }; + let sample_sampling_rate = num_sample_points / (num_points as f64); + self.storage.gen_query_warmup_data(sample_sampling_rate)?; + + self.storage.index_build_cleanup()?; + + Ok(()) + } +} + diff --git a/rust/diskann/src/index/disk_index/mod.rs b/rust/diskann/src/index/disk_index/mod.rs new file mode 100644 index 000000000..4f07bd78d --- /dev/null +++ b/rust/diskann/src/index/disk_index/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod disk_index; +pub use disk_index::DiskIndex; + +pub mod ann_disk_index; diff --git a/rust/diskann/src/index/inmem_index/ann_inmem_index.rs b/rust/diskann/src/index/inmem_index/ann_inmem_index.rs new file mode 100644 index 000000000..dc8dfc876 --- /dev/null +++ b/rust/diskann/src/index/inmem_index/ann_inmem_index.rs @@ -0,0 +1,97 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_docs)] + +//! ANN in-memory index abstraction + +use vector::FullPrecisionDistance; + +use crate::model::{vertex::{DIM_128, DIM_256, DIM_104}, IndexConfiguration}; +use crate::common::{ANNResult, ANNError}; + +use super::InmemIndex; + +/// ANN inmem-index abstraction for custom +pub trait ANNInmemIndex : Sync + Send +where T : Default + Copy + Sync + Send + Into + { + /// Build index + fn build(&mut self, filename: &str, num_points_to_load: usize) -> ANNResult<()>; + + /// Save index + fn save(&mut self, filename: &str) -> ANNResult<()>; + + /// Load index + fn load(&mut self, filename: &str, expected_num_points: usize) -> ANNResult<()>; + + /// insert index + fn insert(&mut self, filename: &str, num_points_to_insert: usize) -> ANNResult<()>; + + /// Search the index for K nearest neighbors of query using given L value, for benchmarking purposes + fn search(&self, query : &[T], k_value : usize, l_value : u32, indices : &mut[u32]) -> ANNResult; + + /// Soft deletes the nodes with the ids in the given array. + fn soft_delete(&mut self, vertex_ids_to_delete: Vec, num_points_to_delete: usize) -> ANNResult<()>; +} + +/// Create Index based on configuration +pub fn create_inmem_index<'a, T>(config: IndexConfiguration) -> ANNResult + 'a>> +where + T: Default + Copy + Sync + Send + Into + 'a, + [T; DIM_104]: FullPrecisionDistance, + [T; DIM_128]: FullPrecisionDistance, + [T; DIM_256]: FullPrecisionDistance, +{ + match config.aligned_dim { + DIM_104 => { + let index = Box::new(InmemIndex::::new(config)?); + Ok(index as Box>) + }, + DIM_128 => { + let index = Box::new(InmemIndex::::new(config)?); + Ok(index as Box>) + }, + DIM_256 => { + let index = Box::new(InmemIndex::::new(config)?); + Ok(index as Box>) + }, + _ => Err(ANNError::log_index_error(format!("Invalid dimension: {}", config.aligned_dim))), + } +} + +#[cfg(test)] +mod dataset_test { + use vector::Metric; + + use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder; + + use super::*; + + #[test] + #[should_panic(expected = "ERROR: Data file fake_file does not exist.")] + fn create_index_test() { + let index_write_parameters = IndexWriteParametersBuilder::new(50, 4) + .with_alpha(1.2) + .with_saturate_graph(false) + .with_num_threads(1) + .build(); + + let config = IndexConfiguration::new( + Metric::L2, + 128, + 256, + 1_000_000, + false, + 0, + false, + 0, + 1f32, + index_write_parameters, + ); + let mut index = create_inmem_index::(config).unwrap(); + index.build("fake_file", 100).unwrap(); + } +} + diff --git a/rust/diskann/src/index/inmem_index/inmem_index.rs b/rust/diskann/src/index/inmem_index/inmem_index.rs new file mode 100644 index 000000000..871d21092 --- /dev/null +++ b/rust/diskann/src/index/inmem_index/inmem_index.rs @@ -0,0 +1,1033 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::cmp; +use std::sync::RwLock; +use std::time::Duration; + +use hashbrown::hash_set::Entry::*; +use hashbrown::HashSet; +use vector::FullPrecisionDistance; + +use crate::common::{ANNError, ANNResult}; +use crate::index::ANNInmemIndex; +use crate::instrumentation::IndexLogger; +use crate::model::graph::AdjacencyList; +use crate::model::{ + ArcConcurrentBoxedQueue, InMemQueryScratch, InMemoryGraph, IndexConfiguration, InmemDataset, + Neighbor, ScratchStoreManager, Vertex, +}; + +use crate::utils::file_util::{file_exists, load_metadata_from_file}; +use crate::utils::rayon_util::execute_with_rayon; +use crate::utils::{set_rayon_num_threads, Timer}; + +/// In-memory Index +pub struct InmemIndex +where + [T; N]: FullPrecisionDistance, +{ + /// Dataset + pub dataset: InmemDataset, + + /// Graph + pub final_graph: InMemoryGraph, + + /// Index configuration + pub configuration: IndexConfiguration, + + /// Start point of the search. When _num_frozen_pts is greater than zero, + /// this is the location of the first frozen point. Otherwise, this is a + /// location of one of the points in index. + pub start: u32, + + /// Max observed out degree + pub max_observed_degree: u32, + + /// Number of active points i.e. existing in the graph + pub num_active_pts: usize, + + /// query scratch queue. + query_scratch_queue: ArcConcurrentBoxedQueue>, + + pub delete_set: RwLock>, +} + +impl InmemIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + /// Create Index obj based on configuration + pub fn new(mut config: IndexConfiguration) -> ANNResult { + // Sanity check. While logically it is correct, max_points = 0 causes + // downstream problems. + if config.max_points == 0 { + config.max_points = 1; + } + + let total_internal_points = config.max_points + config.num_frozen_pts; + + if config.use_pq_dist { + // TODO: pq + todo!("PQ is not supported now"); + } + + let start = config.max_points.try_into()?; + + let query_scratch_queue = ArcConcurrentBoxedQueue::>::new(); + let delete_set = RwLock::new(HashSet::::new()); + + Ok(Self { + dataset: InmemDataset::::new(total_internal_points, config.growth_potential)?, + final_graph: InMemoryGraph::new( + total_internal_points, + config.index_write_parameter.max_degree, + ), + configuration: config, + start, + max_observed_degree: 0, + num_active_pts: 0, + query_scratch_queue, + delete_set, + }) + } + + /// Get distance between two vertices. + pub fn get_distance(&self, id1: u32, id2: u32) -> ANNResult { + self.dataset + .get_distance(id1, id2, self.configuration.dist_metric) + } + + fn build_with_data_populated(&mut self) -> ANNResult<()> { + println!( + "Starting index build with {} points...", + self.num_active_pts + ); + + if self.num_active_pts < 1 { + return Err(ANNError::log_index_error( + "Error: Trying to build an index with 0 points.".to_string(), + )); + } + + if self.query_scratch_queue.size()? == 0 { + self.initialize_query_scratch( + 5 + self.configuration.index_write_parameter.num_threads, + self.configuration.index_write_parameter.search_list_size, + )?; + } + + // TODO: generate_frozen_point() + + self.link()?; + + self.print_stats()?; + + Ok(()) + } + + fn link(&mut self) -> ANNResult<()> { + // visit_order is a vector that is initialized to the entire graph + let mut visit_order = + Vec::with_capacity(self.num_active_pts + self.configuration.num_frozen_pts); + for i in 0..self.num_active_pts { + visit_order.push(i as u32); + } + + // If there are any frozen points, add them all. + for frozen in self.configuration.max_points + ..(self.configuration.max_points + self.configuration.num_frozen_pts) + { + visit_order.push(frozen as u32); + } + + // if there are frozen points, the first such one is set to be the _start + if self.configuration.num_frozen_pts > 0 { + self.start = self.configuration.max_points as u32; + } else { + self.start = self.dataset.calculate_medoid_point_id()?; + } + + let timer = Timer::new(); + + let range = visit_order.len(); + let logger = IndexLogger::new(range); + + execute_with_rayon( + 0..range, + self.configuration.index_write_parameter.num_threads, + |idx| { + self.insert_vertex_id(visit_order[idx])?; + logger.vertex_processed()?; + + Ok(()) + }, + )?; + + self.cleanup_graph(&visit_order)?; + + if self.num_active_pts > 0 { + println!("{}", timer.elapsed_seconds_for_step("Link time: ")); + } + + Ok(()) + } + + fn insert_vertex_id(&self, vertex_id: u32) -> ANNResult<()> { + let mut scratch_manager = + ScratchStoreManager::new(self.query_scratch_queue.clone(), Duration::from_millis(10))?; + let scratch = scratch_manager.scratch_space().ok_or_else(|| { + ANNError::log_index_error( + "ScratchStoreManager doesn't have InMemQueryScratch instance available".to_string(), + ) + })?; + + let new_neighbors = self.search_for_point_and_prune(scratch, vertex_id)?; + self.update_vertex_with_neighbors(vertex_id, new_neighbors)?; + self.update_neighbors_of_vertex(vertex_id, scratch)?; + + Ok(()) + } + + fn update_neighbors_of_vertex( + &self, + vertex_id: u32, + scratch: &mut InMemQueryScratch, + ) -> Result<(), ANNError> { + let vertex = self.final_graph.read_vertex_and_neighbors(vertex_id)?; + assert!(vertex.size() <= self.configuration.index_write_parameter.max_degree as usize); + self.inter_insert( + vertex_id, + vertex.get_neighbors(), + self.configuration.index_write_parameter.max_degree, + scratch, + )?; + Ok(()) + } + + fn update_vertex_with_neighbors( + &self, + vertex_id: u32, + new_neighbors: AdjacencyList, + ) -> Result<(), ANNError> { + let vertex = &mut self.final_graph.write_vertex_and_neighbors(vertex_id)?; + vertex.set_neighbors(new_neighbors); + assert!(vertex.size() <= self.configuration.index_write_parameter.max_degree as usize); + Ok(()) + } + + fn search_for_point_and_prune( + &self, + scratch: &mut InMemQueryScratch, + vertex_id: u32, + ) -> ANNResult { + let mut pruned_list = + AdjacencyList::for_range(self.configuration.index_write_parameter.max_degree as usize); + let vertex = self.dataset.get_vertex(vertex_id)?; + let mut visited_nodes = self.search_for_point(&vertex, scratch)?; + + self.prune_neighbors(vertex_id, &mut visited_nodes, &mut pruned_list, scratch)?; + + if pruned_list.is_empty() { + return Err(ANNError::log_index_error( + "pruned_list is empty.".to_string(), + )); + } + + if self.final_graph.size() + != self.configuration.max_points + self.configuration.num_frozen_pts + { + return Err(ANNError::log_index_error(format!( + "final_graph has {} vertices instead of {}", + self.final_graph.size(), + self.configuration.max_points + self.configuration.num_frozen_pts, + ))); + } + + Ok(pruned_list) + } + + fn search( + &self, + query: &Vertex, + k_value: usize, + l_value: u32, + indices: &mut [u32], + ) -> ANNResult { + if k_value > l_value as usize { + return Err(ANNError::log_index_error(format!( + "Set L: {} to a value of at least K: {}", + l_value, k_value + ))); + } + + let mut scratch_manager = + ScratchStoreManager::new(self.query_scratch_queue.clone(), Duration::from_millis(10))?; + + let scratch = scratch_manager.scratch_space().ok_or_else(|| { + ANNError::log_index_error( + "ScratchStoreManager doesn't have InMemQueryScratch instance available".to_string(), + ) + })?; + + if l_value > scratch.candidate_size { + println!("Attempting to expand query scratch_space. Was created with Lsize: {} but search L is: {}", scratch.candidate_size, l_value); + scratch.resize_for_new_candidate_size(l_value); + println!( + "Resize completed. New scratch size is: {}", + scratch.candidate_size + ); + } + + let cmp = self.search_with_l_override(query, scratch, l_value as usize)?; + let mut pos = 0; + + for i in 0..scratch.best_candidates.size() { + if scratch.best_candidates[i].id < self.configuration.max_points as u32 { + // Filter out the deleted points. + if let Ok(delete_set_guard) = self.delete_set.read() { + if !delete_set_guard.contains(&scratch.best_candidates[i].id) { + indices[pos] = scratch.best_candidates[i].id; + pos += 1; + } + } else { + return Err(ANNError::log_lock_poison_error( + "failed to acquire the lock for delete_set.".to_string(), + )); + } + } + + if pos == k_value { + break; + } + } + + if pos < k_value { + eprintln!( + "Found fewer than K elements for query! Found: {} but K: {}", + pos, k_value + ); + } + + Ok(cmp) + } + + fn cleanup_graph(&mut self, visit_order: &Vec) -> ANNResult<()> { + if self.num_active_pts > 0 { + println!("Starting final cleanup.."); + } + + execute_with_rayon( + 0..visit_order.len(), + self.configuration.index_write_parameter.num_threads, + |idx| { + let vertex_id = visit_order[idx]; + let num_nbrs = self.get_neighbor_count(vertex_id)?; + + if num_nbrs <= self.configuration.index_write_parameter.max_degree as usize { + // Neighbor list is already small enough. + return Ok(()); + } + + let mut scratch_manager = ScratchStoreManager::new( + self.query_scratch_queue.clone(), + Duration::from_millis(10), + )?; + let scratch = scratch_manager.scratch_space().ok_or_else(|| { + ANNError::log_index_error( + "ScratchStoreManager doesn't have InMemQueryScratch instance available" + .to_string(), + ) + })?; + + let mut dummy_pool = self.get_neighbors_for_vertex(vertex_id)?; + + let mut new_out_neighbors = AdjacencyList::for_range( + self.configuration.index_write_parameter.max_degree as usize, + ); + self.prune_neighbors(vertex_id, &mut dummy_pool, &mut new_out_neighbors, scratch)?; + + self.final_graph + .write_vertex_and_neighbors(vertex_id)? + .set_neighbors(new_out_neighbors); + + Ok(()) + }, + ) + } + + /// Get the unique neighbors for a vertex. + /// + /// This code feels out of place here. This should have nothing to do with whether this + /// is in memory index? + /// # Errors + /// + /// This function will return an error if we are not able to get the read lock. + fn get_neighbors_for_vertex(&self, vertex_id: u32) -> ANNResult> { + let binding = self.final_graph.read_vertex_and_neighbors(vertex_id)?; + let neighbors = binding.get_neighbors(); + let dummy_pool = self.get_unique_neighbors(neighbors, vertex_id)?; + + Ok(dummy_pool) + } + + /// Returns a vector of unique neighbors for the given vertex, along with their distances. + /// + /// # Arguments + /// + /// * `neighbors` - A vector of neighbor id index for the given vertex. + /// * `vertex_id` - The given vertex id. + /// + /// # Errors + /// + /// Returns an `ANNError` if there is an error retrieving the vertex or one of its neighbors. + pub fn get_unique_neighbors( + &self, + neighbors: &Vec, + vertex_id: u32, + ) -> Result, ANNError> { + let vertex = self.dataset.get_vertex(vertex_id)?; + + let len = neighbors.len(); + if len == 0 { + return Ok(Vec::new()); + } + + self.dataset.prefetch_vector(neighbors[0]); + + let mut dummy_visited: HashSet = HashSet::with_capacity(len); + let mut dummy_pool: Vec = Vec::with_capacity(len); + + // let slice = ['w', 'i', 'n', 'd', 'o', 'w', 's']; + // for window in slice.windows(2) { + // &println!{"[{}, {}]", window[0], window[1]}; + // } + // prints: [w, i] -> [i, n] -> [n, d] -> [d, o] -> [o, w] -> [w, s] + for current in neighbors.windows(2) { + // Prefetch the next item. + self.dataset.prefetch_vector(current[1]); + let current = current[0]; + + self.insert_neighbor_if_unique( + &mut dummy_visited, + current, + vertex_id, + &vertex, + &mut dummy_pool, + )?; + } + + // Insert the last neighbor + #[allow(clippy::unwrap_used)] + self.insert_neighbor_if_unique( + &mut dummy_visited, + *neighbors.last().unwrap(), // we know len != 0, so this is safe. + vertex_id, + &vertex, + &mut dummy_pool, + )?; + + Ok(dummy_pool) + } + + fn insert_neighbor_if_unique( + &self, + dummy_visited: &mut HashSet, + current: u32, + vertex_id: u32, + vertex: &Vertex<'_, T, N>, + dummy_pool: &mut Vec, + ) -> Result<(), ANNError> { + if current != vertex_id { + if let Vacant(entry) = dummy_visited.entry(current) { + let cur_nbr_vertex = self.dataset.get_vertex(current)?; + let dist = vertex.compare(&cur_nbr_vertex, self.configuration.dist_metric); + dummy_pool.push(Neighbor::new(current, dist)); + entry.insert(); + } + } + + Ok(()) + } + + /// Get count of neighbors for a given vertex. + /// + /// # Errors + /// + /// This function will return an error if we can't get a lock. + fn get_neighbor_count(&self, vertex_id: u32) -> ANNResult { + let num_nbrs = self + .final_graph + .read_vertex_and_neighbors(vertex_id)? + .size(); + Ok(num_nbrs) + } + + fn soft_delete_vertex(&self, vertex_id_to_delete: u32) -> ANNResult<()> { + if vertex_id_to_delete as usize > self.num_active_pts { + return Err(ANNError::log_index_error(format!( + "vertex_id_to_delete: {} is greater than the number of active points in the graph: {}", + vertex_id_to_delete, self.num_active_pts + ))); + } + + let mut delete_set_guard = match self.delete_set.write() { + Ok(guard) => guard, + Err(_) => { + return Err(ANNError::log_index_error(format!( + "Failed to acquire delete_set lock, cannot delete vertex {}", + vertex_id_to_delete + ))); + } + }; + + delete_set_guard.insert(vertex_id_to_delete); + Ok(()) + } + + fn initialize_query_scratch( + &mut self, + num_threads: u32, + search_candidate_size: u32, + ) -> ANNResult<()> { + self.query_scratch_queue.reserve(num_threads as usize)?; + for _ in 0..num_threads { + let scratch = Box::new(InMemQueryScratch::::new( + search_candidate_size, + &self.configuration.index_write_parameter, + false, + )?); + + self.query_scratch_queue.push(scratch)?; + } + + Ok(()) + } + + fn print_stats(&mut self) -> ANNResult<()> { + let mut max = 0; + let mut min = usize::MAX; + let mut total = 0; + let mut cnt = 0; + + for i in 0..self.num_active_pts { + let vertex_id = i.try_into()?; + let pool_size = self + .final_graph + .read_vertex_and_neighbors(vertex_id)? + .size(); + max = cmp::max(max, pool_size); + min = cmp::min(min, pool_size); + total += pool_size; + if pool_size < 2 { + cnt += 1; + } + } + + println!( + "Index built with degree: max: {} avg: {} min: {} count(deg<2): {}", + max, + (total as f32) / ((self.num_active_pts + self.configuration.num_frozen_pts) as f32), + min, + cnt + ); + + match self.delete_set.read() { + Ok(guard) => { + println!( + "Number of soft deleted vertices {}, soft deleted percentage: {}", + guard.len(), + (guard.len() as f32) + / ((self.num_active_pts + self.configuration.num_frozen_pts) as f32), + ); + } + Err(_) => { + return Err(ANNError::log_lock_poison_error( + "Failed to acquire delete_set lock, cannot get the number of deleted vertices" + .to_string(), + )); + } + }; + + self.max_observed_degree = cmp::max(max as u32, self.max_observed_degree); + + Ok(()) + } +} + +impl ANNInmemIndex for InmemIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + fn build(&mut self, filename: &str, num_points_to_load: usize) -> ANNResult<()> { + // TODO: fresh-diskANN + // std::unique_lock ul(_update_lock); + + if !file_exists(filename) { + return Err(ANNError::log_index_error(format!( + "ERROR: Data file {} does not exist.", + filename + ))); + } + + let (file_num_points, file_dim) = load_metadata_from_file(filename)?; + if file_num_points > self.configuration.max_points { + return Err(ANNError::log_index_error(format!( + "ERROR: Driver requests loading {} points and file has {} points, + but index can support only {} points as specified in configuration.", + num_points_to_load, file_num_points, self.configuration.max_points + ))); + } + + if num_points_to_load > file_num_points { + return Err(ANNError::log_index_error(format!( + "ERROR: Driver requests loading {} points and file has only {} points.", + num_points_to_load, file_num_points + ))); + } + + if file_dim != self.configuration.dim { + return Err(ANNError::log_index_error(format!( + "ERROR: Driver requests loading {} dimension, but file has {} dimension.", + self.configuration.dim, file_dim + ))); + } + + if self.configuration.use_pq_dist { + // TODO: PQ + todo!("PQ is not supported now"); + } + + if self.configuration.index_write_parameter.num_threads > 0 { + set_rayon_num_threads(self.configuration.index_write_parameter.num_threads); + } + + self.dataset.build_from_file(filename, num_points_to_load)?; + + println!("Using only first {} from file.", num_points_to_load); + + // TODO: tag_lock + + self.num_active_pts = num_points_to_load; + self.build_with_data_populated()?; + + Ok(()) + } + + fn insert(&mut self, filename: &str, num_points_to_insert: usize) -> ANNResult<()> { + // fresh-diskANN + if !file_exists(filename) { + return Err(ANNError::log_index_error(format!( + "ERROR: Data file {} does not exist.", + filename + ))); + } + + let (file_num_points, file_dim) = load_metadata_from_file(filename)?; + + if num_points_to_insert > file_num_points { + return Err(ANNError::log_index_error(format!( + "ERROR: Driver requests loading {} points and file has only {} points.", + num_points_to_insert, file_num_points + ))); + } + + if file_dim != self.configuration.dim { + return Err(ANNError::log_index_error(format!( + "ERROR: Driver requests loading {} dimension, but file has {} dimension.", + self.configuration.dim, file_dim + ))); + } + + if self.configuration.use_pq_dist { + // TODO: PQ + todo!("PQ is not supported now"); + } + + if self.query_scratch_queue.size()? == 0 { + self.initialize_query_scratch( + 5 + self.configuration.index_write_parameter.num_threads, + self.configuration.index_write_parameter.search_list_size, + )?; + } + + if self.configuration.index_write_parameter.num_threads > 0 { + // set the thread count of Rayon, otherwise it will use threads as many as logical cores. + std::env::set_var( + "RAYON_NUM_THREADS", + self.configuration + .index_write_parameter + .num_threads + .to_string(), + ); + } + + self.dataset + .append_from_file(filename, num_points_to_insert)?; + self.final_graph.extend( + num_points_to_insert, + self.configuration.index_write_parameter.max_degree, + ); + + // TODO: this should not consider frozen points + let previous_last_pt = self.num_active_pts; + self.num_active_pts += num_points_to_insert; + self.configuration.max_points += num_points_to_insert; + + println!("Inserting {} vectors from file.", num_points_to_insert); + + // TODO: tag_lock + let logger = IndexLogger::new(num_points_to_insert); + let timer = Timer::new(); + execute_with_rayon( + previous_last_pt..self.num_active_pts, + self.configuration.index_write_parameter.num_threads, + |idx| { + self.insert_vertex_id(idx as u32)?; + logger.vertex_processed()?; + + Ok(()) + }, + )?; + + let mut visit_order = + Vec::with_capacity(self.num_active_pts + self.configuration.num_frozen_pts); + for i in 0..self.num_active_pts { + visit_order.push(i as u32); + } + + self.cleanup_graph(&visit_order)?; + println!("{}", timer.elapsed_seconds_for_step("Insert time: ")); + + self.print_stats()?; + + Ok(()) + } + + fn save(&mut self, filename: &str) -> ANNResult<()> { + let data_file = filename.to_string() + ".data"; + let delete_file = filename.to_string() + ".delete"; + + self.save_graph(filename)?; + self.save_data(data_file.as_str())?; + self.save_delete_list(delete_file.as_str())?; + + Ok(()) + } + + fn load(&mut self, filename: &str, expected_num_points: usize) -> ANNResult<()> { + self.num_active_pts = expected_num_points; + self.dataset + .build_from_file(&format!("{}.data", filename), expected_num_points)?; + + self.load_graph(filename, expected_num_points)?; + self.load_delete_list(&format!("{}.delete", filename))?; + + if self.query_scratch_queue.size()? == 0 { + self.initialize_query_scratch( + 5 + self.configuration.index_write_parameter.num_threads, + self.configuration.index_write_parameter.search_list_size, + )?; + } + + Ok(()) + } + + fn search( + &self, + query: &[T], + k_value: usize, + l_value: u32, + indices: &mut [u32], + ) -> ANNResult { + let query_vector = Vertex::new(<&[T; N]>::try_from(query)?, 0); + InmemIndex::search(self, &query_vector, k_value, l_value, indices) + } + + fn soft_delete( + &mut self, + vertex_ids_to_delete: Vec, + num_points_to_delete: usize, + ) -> ANNResult<()> { + println!("Deleting {} vectors from file.", num_points_to_delete); + + let logger = IndexLogger::new(num_points_to_delete); + let timer = Timer::new(); + + execute_with_rayon( + 0..num_points_to_delete, + self.configuration.index_write_parameter.num_threads, + |idx: usize| { + self.soft_delete_vertex(vertex_ids_to_delete[idx])?; + logger.vertex_processed()?; + + Ok(()) + }, + )?; + + println!("{}", timer.elapsed_seconds_for_step("Delete time: ")); + self.print_stats()?; + + Ok(()) + } +} + +#[cfg(test)] +mod index_test { + use vector::Metric; + + use super::*; + use crate::{ + model::{ + configuration::index_write_parameters::IndexWriteParametersBuilder, vertex::DIM_128, + }, + test_utils::get_test_file_path, + utils::file_util::load_ids_to_delete_from_file, + utils::round_up, + }; + + const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin"; + const TRUTH_GRAPH: &str = "tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2"; + const TEST_DELETE_FILE: &str = "tests/data/delete_set_50pts.bin"; + const TRUTH_GRAPH_WITH_SATURATED: &str = + "tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_mem.index"; + const R: u32 = 4; + const L: u32 = 50; + const ALPHA: f32 = 1.2; + + /// Build the index with TEST_DATA_FILE and compare the index graph with truth graph TRUTH_GRAPH + /// Change above constants if you want to test with different dataset + macro_rules! index_end_to_end_test_singlethread { + ($saturate_graph:expr, $truth_graph:expr) => {{ + let (data_num, dim) = + load_metadata_from_file(get_test_file_path(TEST_DATA_FILE).as_str()).unwrap(); + + let index_write_parameters = IndexWriteParametersBuilder::new(L, R) + .with_alpha(ALPHA) + .with_num_threads(1) + .with_saturate_graph($saturate_graph) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num, + false, + 0, + false, + 0, + 1.0f32, + index_write_parameters, + ); + let mut index: InmemIndex = InmemIndex::new(config.clone()).unwrap(); + + index + .build(get_test_file_path(TEST_DATA_FILE).as_str(), data_num) + .unwrap(); + + let mut truth_index: InmemIndex = InmemIndex::new(config).unwrap(); + truth_index + .load_graph(get_test_file_path($truth_graph).as_str(), data_num) + .unwrap(); + + compare_graphs(&index, &truth_index); + }}; + } + + #[test] + fn index_end_to_end_test_singlethread() { + index_end_to_end_test_singlethread!(false, TRUTH_GRAPH); + } + + #[test] + fn index_end_to_end_test_singlethread_with_saturate_graph() { + index_end_to_end_test_singlethread!(true, TRUTH_GRAPH_WITH_SATURATED); + } + + #[test] + fn index_end_to_end_test_multithread() { + let (data_num, dim) = + load_metadata_from_file(get_test_file_path(TEST_DATA_FILE).as_str()).unwrap(); + + let index_write_parameters = IndexWriteParametersBuilder::new(L, R) + .with_alpha(ALPHA) + .with_num_threads(8) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num, + false, + 0, + false, + 0, + 1f32, + index_write_parameters, + ); + let mut index: InmemIndex = InmemIndex::new(config).unwrap(); + + index + .build(get_test_file_path(TEST_DATA_FILE).as_str(), data_num) + .unwrap(); + + for i in 0..index.final_graph.size() { + assert_ne!( + index + .final_graph + .read_vertex_and_neighbors(i as u32) + .unwrap() + .size(), + 0 + ); + } + } + + const TEST_DATA_FILE_2: &str = "tests/data/siftsmall_learn_256pts_2.fbin"; + const INSERT_TRUTH_GRAPH: &str = + "tests/data/truth_index_siftsmall_learn_256pts_1+2_R4_L50_A1.2"; + const INSERT_TRUTH_GRAPH_WITH_SATURATED: &str = + "tests/data/truth_index_siftsmall_learn_256pts_1+2_saturated_R4_L50_A1.2"; + + /// Build the index with TEST_DATA_FILE, insert TEST_DATA_FILE_2 and compare the index graph with truth graph TRUTH_GRAPH + /// Change above constants if you want to test with different dataset + macro_rules! index_insert_end_to_end_test_singlethread { + ($saturate_graph:expr, $truth_graph:expr) => {{ + let (data_num, dim) = + load_metadata_from_file(get_test_file_path(TEST_DATA_FILE).as_str()).unwrap(); + + let index_write_parameters = IndexWriteParametersBuilder::new(L, R) + .with_alpha(ALPHA) + .with_num_threads(1) + .with_saturate_graph($saturate_graph) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num, + false, + 0, + false, + 0, + 2.0f32, + index_write_parameters, + ); + let mut index: InmemIndex = InmemIndex::new(config.clone()).unwrap(); + + index + .build(get_test_file_path(TEST_DATA_FILE).as_str(), data_num) + .unwrap(); + index + .insert(get_test_file_path(TEST_DATA_FILE_2).as_str(), data_num) + .unwrap(); + + let config2 = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num * 2, + false, + 0, + false, + 0, + 1.0f32, + index_write_parameters, + ); + let mut truth_index: InmemIndex = InmemIndex::new(config2).unwrap(); + truth_index + .load_graph(get_test_file_path($truth_graph).as_str(), data_num) + .unwrap(); + + compare_graphs(&index, &truth_index); + }}; + } + + /// Build the index with TEST_DATA_FILE, and delete the vertices with id defined in TEST_DELETE_SET + macro_rules! index_delete_end_to_end_test_singlethread { + () => {{ + let (data_num, dim) = + load_metadata_from_file(get_test_file_path(TEST_DATA_FILE).as_str()).unwrap(); + + let index_write_parameters = IndexWriteParametersBuilder::new(L, R) + .with_alpha(ALPHA) + .with_num_threads(1) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num, + false, + 0, + false, + 0, + 2.0f32, + index_write_parameters, + ); + let mut index: InmemIndex = InmemIndex::new(config.clone()).unwrap(); + + index + .build(get_test_file_path(TEST_DATA_FILE).as_str(), data_num) + .unwrap(); + + let (num_points_to_delete, vertex_ids_to_delete) = + load_ids_to_delete_from_file(TEST_DELETE_FILE).unwrap(); + index + .soft_delete(vertex_ids_to_delete, num_points_to_delete) + .unwrap(); + assert!(index.delete_set.read().unwrap().len() == num_points_to_delete); + }}; + } + + #[test] + fn index_insert_end_to_end_test_singlethread() { + index_insert_end_to_end_test_singlethread!(false, INSERT_TRUTH_GRAPH); + } + + #[test] + fn index_delete_end_to_end_test_singlethread() { + index_delete_end_to_end_test_singlethread!(); + } + + #[test] + fn index_insert_end_to_end_test_saturated_singlethread() { + index_insert_end_to_end_test_singlethread!(true, INSERT_TRUTH_GRAPH_WITH_SATURATED); + } + + fn compare_graphs(index: &InmemIndex, truth_index: &InmemIndex) { + assert_eq!(index.start, truth_index.start); + assert_eq!(index.max_observed_degree, truth_index.max_observed_degree); + assert_eq!(index.final_graph.size(), truth_index.final_graph.size()); + + for i in 0..index.final_graph.size() { + assert_eq!( + index + .final_graph + .read_vertex_and_neighbors(i as u32) + .unwrap() + .size(), + truth_index + .final_graph + .read_vertex_and_neighbors(i as u32) + .unwrap() + .size() + ); + assert_eq!( + index + .final_graph + .read_vertex_and_neighbors(i as u32) + .unwrap() + .get_neighbors(), + truth_index + .final_graph + .read_vertex_and_neighbors(i as u32) + .unwrap() + .get_neighbors() + ); + } + } +} diff --git a/rust/diskann/src/index/inmem_index/inmem_index_storage.rs b/rust/diskann/src/index/inmem_index/inmem_index_storage.rs new file mode 100644 index 000000000..fa14d70b2 --- /dev/null +++ b/rust/diskann/src/index/inmem_index/inmem_index_storage.rs @@ -0,0 +1,304 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::fs::File; +use std::io::{BufReader, BufWriter, Seek, SeekFrom, Write}; +use std::path::Path; + +use byteorder::{LittleEndian, ReadBytesExt}; +use vector::FullPrecisionDistance; + +use crate::common::{ANNError, ANNResult}; +use crate::model::graph::AdjacencyList; +use crate::model::InMemoryGraph; +use crate::utils::{file_exists, save_data_in_base_dimensions}; + +use super::InmemIndex; + +impl InmemIndex +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + pub fn load_graph(&mut self, filename: &str, expected_num_points: usize) -> ANNResult { + // let file_offset = 0; // will need this for single file format support + + let mut in_file = BufReader::new(File::open(Path::new(filename))?); + // in_file.seek(SeekFrom::Start(file_offset as u64))?; + + let expected_file_size: usize = in_file.read_u64::()? as usize; + self.max_observed_degree = in_file.read_u32::()?; + self.start = in_file.read_u32::()?; + let file_frozen_pts: usize = in_file.read_u64::()? as usize; + + let vamana_metadata_size = 24; + + println!("From graph header, expected_file_size: {}, max_observed_degree: {}, start: {}, file_frozen_pts: {}", + expected_file_size, self.max_observed_degree, self.start, file_frozen_pts); + + if file_frozen_pts != self.configuration.num_frozen_pts { + if file_frozen_pts == 1 { + return Err(ANNError::log_index_config_error( + "num_frozen_pts".to_string(), + "ERROR: When loading index, detected dynamic index, but constructor asks for static index. Exitting.".to_string()) + ); + } else { + return Err(ANNError::log_index_config_error( + "num_frozen_pts".to_string(), + "ERROR: When loading index, detected static index, but constructor asks for dynamic index. Exitting.".to_string()) + ); + } + } + + println!("Loading vamana graph {}...", filename); + + let expected_max_points = expected_num_points - file_frozen_pts; + + // If user provides more points than max_points + // resize the _final_graph to the larger size. + if self.configuration.max_points < expected_max_points { + println!("Number of points in data: {} is greater than max_points: {} Setting max points to: {}", expected_max_points, self.configuration.max_points, expected_max_points); + + self.configuration.max_points = expected_max_points; + self.final_graph = InMemoryGraph::new( + self.configuration.max_points + self.configuration.num_frozen_pts, + self.configuration.index_write_parameter.max_degree, + ); + } + + let mut bytes_read = vamana_metadata_size; + let mut num_edges = 0; + let mut nodes_read = 0; + let mut max_observed_degree = 0; + + while bytes_read != expected_file_size { + let num_nbrs = in_file.read_u32::()?; + max_observed_degree = if num_nbrs > max_observed_degree { + num_nbrs + } else { + max_observed_degree + }; + + if num_nbrs == 0 { + return Err(ANNError::log_index_error(format!( + "ERROR: Point found with no out-neighbors, point# {}", + nodes_read + ))); + } + + num_edges += num_nbrs; + nodes_read += 1; + let mut tmp: Vec = Vec::with_capacity(num_nbrs as usize); + for _ in 0..num_nbrs { + tmp.push(in_file.read_u32::()?); + } + + self.final_graph + .write_vertex_and_neighbors(nodes_read - 1)? + .set_neighbors(AdjacencyList::from(tmp)); + bytes_read += 4 * (num_nbrs as usize + 1); + } + + println!( + "Done. Index has {} nodes and {} out-edges, _start is set to {}", + nodes_read, num_edges, self.start + ); + + self.max_observed_degree = max_observed_degree; + Ok(nodes_read as usize) + } + + /// Save the graph index on a file as an adjacency list. + /// For each point, first store the number of neighbors, + /// and then the neighbor list (each as 4 byte u32) + pub fn save_graph(&mut self, graph_file: &str) -> ANNResult { + let file: File = File::create(graph_file)?; + let mut out = BufWriter::new(file); + + let file_offset: u64 = 0; + out.seek(SeekFrom::Start(file_offset))?; + let mut index_size: u64 = 24; + let mut max_degree: u32 = 0; + out.write_all(&index_size.to_le_bytes())?; + out.write_all(&self.max_observed_degree.to_le_bytes())?; + out.write_all(&self.start.to_le_bytes())?; + out.write_all(&(self.configuration.num_frozen_pts as u64).to_le_bytes())?; + + // At this point, either nd == max_points or any frozen points have + // been temporarily moved to nd, so nd + num_frozen_points is the valid + // location limit + for i in 0..self.num_active_pts + self.configuration.num_frozen_pts { + let idx = i as u32; + let gk: u32 = self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32; + out.write_all(&gk.to_le_bytes())?; + for neighbor in self + .final_graph + .read_vertex_and_neighbors(idx)? + .get_neighbors() + .iter() + { + out.write_all(&neighbor.to_le_bytes())?; + } + max_degree = + if self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32 > max_degree { + self.final_graph.read_vertex_and_neighbors(idx)?.size() as u32 + } else { + max_degree + }; + index_size += (std::mem::size_of::() * (gk as usize + 1)) as u64; + } + out.seek(SeekFrom::Start(file_offset))?; + out.write_all(&index_size.to_le_bytes())?; + out.write_all(&max_degree.to_le_bytes())?; + out.flush()?; + Ok(index_size) + } + + /// Save the data on a file. + pub fn save_data(&mut self, data_file: &str) -> ANNResult { + // Note: at this point, either _nd == _max_points or any frozen points have + // been temporarily moved to _nd, so _nd + _num_frozen_points is the valid + // location limit. + Ok(save_data_in_base_dimensions( + data_file, + &mut self.dataset.data, + self.num_active_pts + self.configuration.num_frozen_pts, + self.configuration.dim, + self.configuration.aligned_dim, + 0, + )?) + } + + /// Save the delete list to a file only if the delete list length is not zero. + pub fn save_delete_list(&mut self, delete_list_file: &str) -> ANNResult { + let mut delete_file_size = 0; + if let Ok(delete_set) = self.delete_set.read() { + let delete_set_len = delete_set.len() as u32; + + if delete_set_len != 0 { + let file: File = File::create(delete_list_file)?; + let mut writer = BufWriter::new(file); + + // Write the length of the set. + writer.write_all(&delete_set_len.to_le_bytes())?; + delete_file_size += std::mem::size_of::(); + + // Write the elements of the set. + for &item in delete_set.iter() { + writer.write_all(&item.to_be_bytes())?; + delete_file_size += std::mem::size_of::(); + } + + writer.flush()?; + } + } else { + return Err(ANNError::log_lock_poison_error( + "Poisoned lock on delete set. Can't save deleted list.".to_string(), + )); + } + + Ok(delete_file_size) + } + + // load the deleted list from the delete file if it exists. + pub fn load_delete_list(&mut self, delete_list_file: &str) -> ANNResult { + let mut len = 0; + + if file_exists(delete_list_file) { + let file = File::open(delete_list_file)?; + let mut reader = BufReader::new(file); + + len = reader.read_u32::()? as usize; + + if let Ok(mut delete_set) = self.delete_set.write() { + for _ in 0..len { + let item = reader.read_u32::()?; + delete_set.insert(item); + } + } else { + return Err(ANNError::log_lock_poison_error( + "Poisoned lock on delete set. Can't load deleted list.".to_string(), + )); + } + } + + Ok(len) + } +} + +#[cfg(test)] +mod index_test { + use std::fs; + + use vector::Metric; + + use super::*; + use crate::{ + index::ANNInmemIndex, + model::{ + configuration::index_write_parameters::IndexWriteParametersBuilder, vertex::DIM_128, + IndexConfiguration, + }, + utils::{load_metadata_from_file, round_up}, + }; + + const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin"; + const R: u32 = 4; + const L: u32 = 50; + const ALPHA: f32 = 1.2; + + #[cfg_attr(not(coverage), test)] + fn save_graph_test() { + let parameters = IndexWriteParametersBuilder::new(50, 4) + .with_alpha(1.2) + .build(); + let config = + IndexConfiguration::new(Metric::L2, 10, 16, 16, false, 0, false, 8, 1f32, parameters); + let mut index = InmemIndex::::new(config).unwrap(); + let final_graph = InMemoryGraph::new(10, 3); + let num_active_pts = 2_usize; + index.final_graph = final_graph; + index.num_active_pts = num_active_pts; + let graph_file = "test_save_graph_data.bin"; + let result = index.save_graph(graph_file); + assert!(result.is_ok()); + + fs::remove_file(graph_file).expect("Failed to delete file"); + } + + #[test] + fn save_data_test() { + let (data_num, dim) = load_metadata_from_file(TEST_DATA_FILE).unwrap(); + + let index_write_parameters = IndexWriteParametersBuilder::new(L, R) + .with_alpha(ALPHA) + .build(); + let config = IndexConfiguration::new( + Metric::L2, + dim, + round_up(dim as u64, 16_u64) as usize, + data_num, + false, + 0, + false, + 0, + 1f32, + index_write_parameters, + ); + let mut index: InmemIndex = InmemIndex::new(config).unwrap(); + + index.build(TEST_DATA_FILE, data_num).unwrap(); + + let data_file = "test.data"; + let result = index.save_data(data_file); + assert_eq!( + result.unwrap(), + 2 * std::mem::size_of::() + + (index.num_active_pts + index.configuration.num_frozen_pts) + * index.configuration.dim + * (std::mem::size_of::()) + ); + fs::remove_file(data_file).expect("Failed to delete file"); + } +} diff --git a/rust/diskann/src/index/inmem_index/mod.rs b/rust/diskann/src/index/inmem_index/mod.rs new file mode 100644 index 000000000..f2a091a09 --- /dev/null +++ b/rust/diskann/src/index/inmem_index/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod inmem_index; +pub use inmem_index::InmemIndex; + +mod inmem_index_storage; + +pub mod ann_inmem_index; + diff --git a/rust/diskann/src/index/mod.rs b/rust/diskann/src/index/mod.rs new file mode 100644 index 000000000..18c3bd5e9 --- /dev/null +++ b/rust/diskann/src/index/mod.rs @@ -0,0 +1,11 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod inmem_index; +pub use inmem_index::ann_inmem_index::*; +pub use inmem_index::InmemIndex; + +mod disk_index; +pub use disk_index::*; + diff --git a/rust/diskann/src/instrumentation/disk_index_build_logger.rs b/rust/diskann/src/instrumentation/disk_index_build_logger.rs new file mode 100644 index 000000000..d34935342 --- /dev/null +++ b/rust/diskann/src/instrumentation/disk_index_build_logger.rs @@ -0,0 +1,57 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use logger::logger::indexlog::DiskIndexConstructionCheckpoint; +use logger::logger::indexlog::DiskIndexConstructionLog; +use logger::logger::indexlog::Log; +use logger::logger::indexlog::LogLevel; +use logger::message_handler::send_log; + +use crate::{utils::Timer, common::ANNResult}; + +pub struct DiskIndexBuildLogger { + timer: Timer, + checkpoint: DiskIndexConstructionCheckpoint, +} + +impl DiskIndexBuildLogger { + pub fn new(checkpoint: DiskIndexConstructionCheckpoint) -> Self { + Self { + timer: Timer::new(), + checkpoint, + } + } + + pub fn log_checkpoint(&mut self, next_checkpoint: DiskIndexConstructionCheckpoint) -> ANNResult<()> { + if self.checkpoint == DiskIndexConstructionCheckpoint::None { + return Ok(()); + } + + let mut log = Log::default(); + let disk_index_construction_log = DiskIndexConstructionLog { + checkpoint: self.checkpoint as i32, + time_spent_in_seconds: self.timer.elapsed().as_secs_f32(), + g_cycles_spent: self.timer.elapsed_gcycles(), + log_level: LogLevel::Info as i32, + }; + log.disk_index_construction_log = Some(disk_index_construction_log); + + send_log(log)?; + self.checkpoint = next_checkpoint; + self.timer.reset(); + Ok(()) + } +} + +#[cfg(test)] +mod dataset_test { + use super::*; + + #[test] + fn test_log() { + let mut logger = DiskIndexBuildLogger::new(DiskIndexConstructionCheckpoint::PqConstruction); + logger.log_checkpoint(DiskIndexConstructionCheckpoint::InmemIndexBuild).unwrap();logger.log_checkpoint(logger::logger::indexlog::DiskIndexConstructionCheckpoint::DiskLayout).unwrap(); + } +} + diff --git a/rust/diskann/src/instrumentation/index_logger.rs b/rust/diskann/src/instrumentation/index_logger.rs new file mode 100644 index 000000000..dfc81ad15 --- /dev/null +++ b/rust/diskann/src/instrumentation/index_logger.rs @@ -0,0 +1,47 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::sync::atomic::{AtomicUsize, Ordering}; + +use logger::logger::indexlog::IndexConstructionLog; +use logger::logger::indexlog::Log; +use logger::logger::indexlog::LogLevel; +use logger::message_handler::send_log; + +use crate::common::ANNResult; +use crate::utils::Timer; + +pub struct IndexLogger { + items_processed: AtomicUsize, + timer: Timer, + range: usize, +} + +impl IndexLogger { + pub fn new(range: usize) -> Self { + Self { + items_processed: AtomicUsize::new(0), + timer: Timer::new(), + range, + } + } + + pub fn vertex_processed(&self) -> ANNResult<()> { + let count = self.items_processed.fetch_add(1, Ordering::Relaxed); + if count % 100_000 == 0 { + let mut log = Log::default(); + let index_construction_log = IndexConstructionLog { + percentage_complete: (100_f32 * count as f32) / (self.range as f32), + time_spent_in_seconds: self.timer.elapsed().as_secs_f32(), + g_cycles_spent: self.timer.elapsed_gcycles(), + log_level: LogLevel::Info as i32, + }; + log.index_construction_log = Some(index_construction_log); + + send_log(log)?; + } + + Ok(()) + } +} diff --git a/rust/diskann/src/instrumentation/mod.rs b/rust/diskann/src/instrumentation/mod.rs new file mode 100644 index 000000000..234e53ce9 --- /dev/null +++ b/rust/diskann/src/instrumentation/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod index_logger; +pub use index_logger::IndexLogger; + +mod disk_index_build_logger; +pub use disk_index_build_logger::DiskIndexBuildLogger; diff --git a/rust/diskann/src/lib.rs b/rust/diskann/src/lib.rs new file mode 100644 index 000000000..1f89e33fc --- /dev/null +++ b/rust/diskann/src/lib.rs @@ -0,0 +1,26 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![cfg_attr( + not(test), + warn(clippy::panic, clippy::unwrap_used, clippy::expect_used) +)] +#![cfg_attr(test, allow(clippy::unused_io_amount))] + +pub mod utils; + +pub mod algorithm; + +pub mod model; + +pub mod common; + +pub mod index; + +pub mod storage; + +pub mod instrumentation; + +#[cfg(test)] +pub mod test_utils; diff --git a/rust/diskann/src/model/configuration/disk_index_build_parameter.rs b/rust/diskann/src/model/configuration/disk_index_build_parameter.rs new file mode 100644 index 000000000..539192af0 --- /dev/null +++ b/rust/diskann/src/model/configuration/disk_index_build_parameter.rs @@ -0,0 +1,85 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Parameters for disk index construction. + +use crate::common::{ANNResult, ANNError}; + +/// Cached nodes size in GB +const SPACE_FOR_CACHED_NODES_IN_GB: f64 = 0.25; + +/// Threshold for caching in GB +const THRESHOLD_FOR_CACHING_IN_GB: f64 = 1.0; + +/// Parameters specific for disk index construction. +#[derive(Clone, Copy, PartialEq, Debug)] +pub struct DiskIndexBuildParameters { + /// Bound on the memory footprint of the index at search time in bytes. + /// Once built, the index will use up only the specified RAM limit, the rest will reside on disk. + /// This will dictate how aggressively we compress the data vectors to store in memory. + /// Larger will yield better performance at search time. + search_ram_limit: f64, + + /// Limit on the memory allowed for building the index in bytes. + index_build_ram_limit: f64, +} + +impl DiskIndexBuildParameters { + /// Create DiskIndexBuildParameters instance + pub fn new(search_ram_limit_gb: f64, index_build_ram_limit_gb: f64) -> ANNResult { + let param = Self { + search_ram_limit: Self::get_memory_budget(search_ram_limit_gb), + index_build_ram_limit: index_build_ram_limit_gb * 1024_f64 * 1024_f64 * 1024_f64, + }; + + if param.search_ram_limit <= 0f64 { + return Err(ANNError::log_index_config_error("search_ram_limit".to_string(), "RAM budget should be > 0".to_string())) + } + + if param.index_build_ram_limit <= 0f64 { + return Err(ANNError::log_index_config_error("index_build_ram_limit".to_string(), "RAM budget should be > 0".to_string())) + } + + Ok(param) + } + + /// Get search_ram_limit + pub fn search_ram_limit(&self) -> f64 { + self.search_ram_limit + } + + /// Get index_build_ram_limit + pub fn index_build_ram_limit(&self) -> f64 { + self.index_build_ram_limit + } + + fn get_memory_budget(mut index_ram_limit_gb: f64) -> f64 { + if index_ram_limit_gb - SPACE_FOR_CACHED_NODES_IN_GB > THRESHOLD_FOR_CACHING_IN_GB { + // slack for space used by cached nodes + index_ram_limit_gb -= SPACE_FOR_CACHED_NODES_IN_GB; + } + + index_ram_limit_gb * 1024_f64 * 1024_f64 * 1024_f64 + } +} + +#[cfg(test)] +mod dataset_test { + use super::*; + + #[test] + fn sufficient_ram_for_caching() { + let param = DiskIndexBuildParameters::new(1.26_f64, 1.0_f64).unwrap(); + assert_eq!(param.search_ram_limit, 1.01_f64 * 1024_f64 * 1024_f64 * 1024_f64); + } + + #[test] + fn insufficient_ram_for_caching() { + let param = DiskIndexBuildParameters::new(0.03_f64, 1.0_f64).unwrap(); + assert_eq!(param.search_ram_limit, 0.03_f64 * 1024_f64 * 1024_f64 * 1024_f64); + } +} + diff --git a/rust/diskann/src/model/configuration/index_configuration.rs b/rust/diskann/src/model/configuration/index_configuration.rs new file mode 100644 index 000000000..3e8c472ae --- /dev/null +++ b/rust/diskann/src/model/configuration/index_configuration.rs @@ -0,0 +1,92 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Index configuration. + +use vector::Metric; + +use super::index_write_parameters::IndexWriteParameters; + +/// The index configuration +#[derive(Debug, Clone)] +pub struct IndexConfiguration { + /// Index write parameter + pub index_write_parameter: IndexWriteParameters, + + /// Distance metric + pub dist_metric: Metric, + + /// Dimension of the raw data + pub dim: usize, + + /// Aligned dimension - round up dim to the nearest multiple of 8 + pub aligned_dim: usize, + + /// Total number of points in given data set + pub max_points: usize, + + /// Number of points which are used as initial candidates when iterating to + /// closest point(s). These are not visible externally and won't be returned + /// by search. DiskANN forces at least 1 frozen point for dynamic index. + /// The frozen points have consecutive locations. + pub num_frozen_pts: usize, + + /// Calculate distance by PQ or not + pub use_pq_dist: bool, + + /// Number of PQ chunks + pub num_pq_chunks: usize, + + /// Use optimized product quantization + /// Currently not supported + pub use_opq: bool, + + /// potential for growth. 1.2 means the index can grow by up to 20%. + pub growth_potential: f32, + + // TODO: below settings are not supported in current iteration + // pub concurrent_consolidate: bool, + // pub has_built: bool, + // pub save_as_one_file: bool, + // pub dynamic_index: bool, + // pub enable_tags: bool, + // pub normalize_vecs: bool, +} + +impl IndexConfiguration { + /// Create IndexConfiguration instance + #[allow(clippy::too_many_arguments)] + pub fn new( + dist_metric: Metric, + dim: usize, + aligned_dim: usize, + max_points: usize, + use_pq_dist: bool, + num_pq_chunks: usize, + use_opq: bool, + num_frozen_pts: usize, + growth_potential: f32, + index_write_parameter: IndexWriteParameters + ) -> Self { + Self { + index_write_parameter, + dist_metric, + dim, + aligned_dim, + max_points, + num_frozen_pts, + use_pq_dist, + num_pq_chunks, + use_opq, + growth_potential, + } + } + + /// Get the size of adjacency list that we build out. + pub fn write_range(&self) -> usize { + self.index_write_parameter.max_degree as usize + } +} diff --git a/rust/diskann/src/model/configuration/index_write_parameters.rs b/rust/diskann/src/model/configuration/index_write_parameters.rs new file mode 100644 index 000000000..cb71f4297 --- /dev/null +++ b/rust/diskann/src/model/configuration/index_write_parameters.rs @@ -0,0 +1,245 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Index write parameters. + +/// Default parameter values. +pub mod default_param_vals { + /// Default value of alpha. + pub const ALPHA: f32 = 1.2; + + /// Default value of number of threads. + pub const NUM_THREADS: u32 = 0; + + /// Default value of number of rounds. + pub const NUM_ROUNDS: u32 = 2; + + /// Default value of max occlusion size. + pub const MAX_OCCLUSION_SIZE: u32 = 750; + + /// Default value of filter list size. + pub const FILTER_LIST_SIZE: u32 = 0; + + /// Default value of number of frozen points. + pub const NUM_FROZEN_POINTS: u32 = 0; + + /// Default value of max degree. + pub const MAX_DEGREE: u32 = 64; + + /// Default value of build list size. + pub const BUILD_LIST_SIZE: u32 = 100; + + /// Default value of saturate graph. + pub const SATURATE_GRAPH: bool = false; + + /// Default value of search list size. + pub const SEARCH_LIST_SIZE: u32 = 100; +} + +/// Index write parameters. +#[derive(Clone, Copy, PartialEq, Debug)] +pub struct IndexWriteParameters { + /// Search list size - L. + pub search_list_size: u32, + + /// Max degree - R. + pub max_degree: u32, + + /// Saturate graph. + pub saturate_graph: bool, + + /// Max occlusion size - C. + pub max_occlusion_size: u32, + + /// Alpha. + pub alpha: f32, + + /// Number of rounds. + pub num_rounds: u32, + + /// Number of threads. + pub num_threads: u32, + + /// Number of frozen points. + pub num_frozen_points: u32, +} + +impl Default for IndexWriteParameters { + /// Create IndexWriteParameters with default values + fn default() -> Self { + Self { + search_list_size: default_param_vals::SEARCH_LIST_SIZE, + max_degree: default_param_vals::MAX_DEGREE, + saturate_graph: default_param_vals::SATURATE_GRAPH, + max_occlusion_size: default_param_vals::MAX_OCCLUSION_SIZE, + alpha: default_param_vals::ALPHA, + num_rounds: default_param_vals::NUM_ROUNDS, + num_threads: default_param_vals::NUM_THREADS, + num_frozen_points: default_param_vals::NUM_FROZEN_POINTS + } + } +} + +/// The builder for IndexWriteParameters. +#[derive(Debug)] +pub struct IndexWriteParametersBuilder { + search_list_size: u32, + max_degree: u32, + max_occlusion_size: Option, + saturate_graph: Option, + alpha: Option, + num_rounds: Option, + num_threads: Option, + // filter_list_size: Option, + num_frozen_points: Option, +} + +impl IndexWriteParametersBuilder { + /// Initialize IndexWriteParametersBuilder + pub fn new(search_list_size: u32, max_degree: u32) -> Self { + Self { + search_list_size, + max_degree, + max_occlusion_size: None, + saturate_graph: None, + alpha: None, + num_rounds: None, + num_threads: None, + // filter_list_size: None, + num_frozen_points: None, + } + } + + /// Set max occlusion size. + pub fn with_max_occlusion_size(mut self, max_occlusion_size: u32) -> Self { + self.max_occlusion_size = Some(max_occlusion_size); + self + } + + /// Set saturate graph. + pub fn with_saturate_graph(mut self, saturate_graph: bool) -> Self { + self.saturate_graph = Some(saturate_graph); + self + } + + /// Set alpha. + pub fn with_alpha(mut self, alpha: f32) -> Self { + self.alpha = Some(alpha); + self + } + + /// Set number of rounds. + pub fn with_num_rounds(mut self, num_rounds: u32) -> Self { + self.num_rounds = Some(num_rounds); + self + } + + /// Set number of threads. + pub fn with_num_threads(mut self, num_threads: u32) -> Self { + self.num_threads = Some(num_threads); + self + } + + /* + pub fn with_filter_list_size(mut self, filter_list_size: u32) -> Self { + self.filter_list_size = Some(filter_list_size); + self + } + */ + + /// Set number of frozen points. + pub fn with_num_frozen_points(mut self, num_frozen_points: u32) -> Self { + self.num_frozen_points = Some(num_frozen_points); + self + } + + /// Build IndexWriteParameters from IndexWriteParametersBuilder. + pub fn build(self) -> IndexWriteParameters { + IndexWriteParameters { + search_list_size: self.search_list_size, + max_degree: self.max_degree, + saturate_graph: self.saturate_graph.unwrap_or(default_param_vals::SATURATE_GRAPH), + max_occlusion_size: self.max_occlusion_size.unwrap_or(default_param_vals::MAX_OCCLUSION_SIZE), + alpha: self.alpha.unwrap_or(default_param_vals::ALPHA), + num_rounds: self.num_rounds.unwrap_or(default_param_vals::NUM_ROUNDS), + num_threads: self.num_threads.unwrap_or(default_param_vals::NUM_THREADS), + // filter_list_size: self.filter_list_size.unwrap_or(default_param_vals::FILTER_LIST_SIZE), + num_frozen_points: self.num_frozen_points.unwrap_or(default_param_vals::NUM_FROZEN_POINTS), + } + } +} + +/// Construct IndexWriteParametersBuilder from IndexWriteParameters. +impl From for IndexWriteParametersBuilder { + fn from(param: IndexWriteParameters) -> Self { + Self { + search_list_size: param.search_list_size, + max_degree: param.max_degree, + max_occlusion_size: Some(param.max_occlusion_size), + saturate_graph: Some(param.saturate_graph), + alpha: Some(param.alpha), + num_rounds: Some(param.num_rounds), + num_threads: Some(param.num_threads), + // filter_list_size: Some(param.filter_list_size), + num_frozen_points: Some(param.num_frozen_points), + } + } +} + +#[cfg(test)] +mod parameters_test { + use crate::model::configuration::index_write_parameters::*; + + #[test] + fn test_default_index_params() { + let wp1 = IndexWriteParameters::default(); + assert_eq!(wp1.search_list_size, default_param_vals::SEARCH_LIST_SIZE); + assert_eq!(wp1.max_degree, default_param_vals::MAX_DEGREE); + assert_eq!(wp1.saturate_graph, default_param_vals::SATURATE_GRAPH); + assert_eq!(wp1.max_occlusion_size, default_param_vals::MAX_OCCLUSION_SIZE); + assert_eq!(wp1.alpha, default_param_vals::ALPHA); + assert_eq!(wp1.num_rounds, default_param_vals::NUM_ROUNDS); + assert_eq!(wp1.num_threads, default_param_vals::NUM_THREADS); + assert_eq!(wp1.num_frozen_points, default_param_vals::NUM_FROZEN_POINTS); + } + + #[test] + fn test_index_write_parameters_builder() { + // default value + let wp1 = IndexWriteParametersBuilder::new(10, 20).build(); + assert_eq!(wp1.search_list_size, 10); + assert_eq!(wp1.max_degree, 20); + assert_eq!(wp1.saturate_graph, default_param_vals::SATURATE_GRAPH); + assert_eq!(wp1.max_occlusion_size, default_param_vals::MAX_OCCLUSION_SIZE); + assert_eq!(wp1.alpha, default_param_vals::ALPHA); + assert_eq!(wp1.num_rounds, default_param_vals::NUM_ROUNDS); + assert_eq!(wp1.num_threads, default_param_vals::NUM_THREADS); + assert_eq!(wp1.num_frozen_points, default_param_vals::NUM_FROZEN_POINTS); + + // build with custom values + let wp2 = IndexWriteParametersBuilder::new(10, 20) + .with_max_occlusion_size(30) + .with_saturate_graph(true) + .with_alpha(0.5) + .with_num_rounds(40) + .with_num_threads(50) + .with_num_frozen_points(60) + .build(); + assert_eq!(wp2.search_list_size, 10); + assert_eq!(wp2.max_degree, 20); + assert!(wp2.saturate_graph); + assert_eq!(wp2.max_occlusion_size, 30); + assert_eq!(wp2.alpha, 0.5); + assert_eq!(wp2.num_rounds, 40); + assert_eq!(wp2.num_threads, 50); + assert_eq!(wp2.num_frozen_points, 60); + + // test from + let wp3 = IndexWriteParametersBuilder::from(wp2).build(); + assert_eq!(wp3, wp2); + } +} + diff --git a/rust/diskann/src/model/configuration/mod.rs b/rust/diskann/src/model/configuration/mod.rs new file mode 100644 index 000000000..201f97e98 --- /dev/null +++ b/rust/diskann/src/model/configuration/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod index_configuration; +pub use index_configuration::IndexConfiguration; + +pub mod index_write_parameters; +pub use index_write_parameters::*; + +pub mod disk_index_build_parameter; +pub use disk_index_build_parameter::DiskIndexBuildParameters; diff --git a/rust/diskann/src/model/data_store/disk_scratch_dataset.rs b/rust/diskann/src/model/data_store/disk_scratch_dataset.rs new file mode 100644 index 000000000..0d9a007ab --- /dev/null +++ b/rust/diskann/src/model/data_store/disk_scratch_dataset.rs @@ -0,0 +1,76 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Disk scratch dataset + +use std::mem::{size_of, size_of_val}; +use std::ptr; + +use crate::common::{AlignedBoxWithSlice, ANNResult}; +use crate::model::MAX_N_CMPS; +use crate::utils::round_up; + +/// DiskScratchDataset alignment +pub const DISK_SCRATCH_DATASET_ALIGN: usize = 256; + +/// Disk scratch dataset storing fp vectors with aligned dim +#[derive(Debug)] +pub struct DiskScratchDataset +{ + /// fp vectors with aligned dim + pub data: AlignedBoxWithSlice, + + /// current index to store the next fp vector + pub cur_index: usize, +} + +impl DiskScratchDataset +{ + /// Create DiskScratchDataset instance + pub fn new() -> ANNResult { + Ok(Self { + // C++ code allocates round_up(MAX_N_CMPS * N, 256) bytes, shouldn't it be round_up(MAX_N_CMPS * N, 256) * size_of:: bytes? + data: AlignedBoxWithSlice::new( + round_up(MAX_N_CMPS * N, DISK_SCRATCH_DATASET_ALIGN), + DISK_SCRATCH_DATASET_ALIGN)?, + cur_index: 0, + }) + } + + /// memcpy from fp vector bytes (its len should be `dim * size_of::()`) to self.data + /// The dest slice is a fp vector with aligned dim + /// * fp_vector_buf's dim might not be aligned dim (N) + /// # Safety + /// Behavior is undefined if any of the following conditions are violated: + /// + /// * `fp_vector_buf`'s len must be `dim * size_of::()` bytes + /// + /// * `fp_vector_buf` must be smaller than or equal to `N * size_of::()` bytes. + /// + /// * `fp_vector_buf` and `self.data` must be nonoverlapping. + pub unsafe fn memcpy_from_fp_vector_buf(&mut self, fp_vector_buf: &[u8]) -> &[T] { + if self.cur_index == MAX_N_CMPS { + self.cur_index = 0; + } + + let aligned_dim_vector = &mut self.data[self.cur_index * N..(self.cur_index + 1) * N]; + + assert!(fp_vector_buf.len() % size_of::() == 0); + assert!(fp_vector_buf.len() <= size_of_val(aligned_dim_vector)); + + // memcpy from fp_vector_buf to aligned_dim_vector + unsafe { + ptr::copy_nonoverlapping( + fp_vector_buf.as_ptr(), + aligned_dim_vector.as_mut_ptr() as *mut u8, + fp_vector_buf.len(), + ); + } + + self.cur_index += 1; + aligned_dim_vector + } +} diff --git a/rust/diskann/src/model/data_store/inmem_dataset.rs b/rust/diskann/src/model/data_store/inmem_dataset.rs new file mode 100644 index 000000000..6d8b649a2 --- /dev/null +++ b/rust/diskann/src/model/data_store/inmem_dataset.rs @@ -0,0 +1,285 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! In-memory Dataset + +use rayon::prelude::*; +use std::mem; +use vector::{FullPrecisionDistance, Metric}; + +use crate::common::{ANNError, ANNResult, AlignedBoxWithSlice}; +use crate::model::Vertex; +use crate::utils::copy_aligned_data_from_file; + +/// Dataset of all in-memory FP points +#[derive(Debug)] +pub struct InmemDataset +where + [T; N]: FullPrecisionDistance, +{ + /// All in-memory points + pub data: AlignedBoxWithSlice, + + /// Number of points we anticipate to have + pub num_points: usize, + + /// Number of active points i.e. existing in the graph + pub num_active_pts: usize, + + /// Capacity of the dataset + pub capacity: usize, +} + +impl<'a, T, const N: usize> InmemDataset +where + T: Default + Copy + Sync + Send + Into, + [T; N]: FullPrecisionDistance, +{ + /// Create the dataset with size num_points and growth factor. + /// growth factor=1 means no growth (provision 100% space of num_points) + /// growth factor=1.2 means provision 120% space of num_points (20% extra space) + pub fn new(num_points: usize, index_growth_factor: f32) -> ANNResult { + let capacity = (((num_points * N) as f32) * index_growth_factor) as usize; + + Ok(Self { + data: AlignedBoxWithSlice::new(capacity, mem::size_of::() * 16)?, + num_points, + num_active_pts: num_points, + capacity, + }) + } + + /// get immutable data slice + pub fn get_data(&self) -> &[T] { + &self.data + } + + /// Build the dataset from file + pub fn build_from_file(&mut self, filename: &str, num_points_to_load: usize) -> ANNResult<()> { + println!( + "Loading {} vectors from file {} into dataset...", + num_points_to_load, filename + ); + self.num_active_pts = num_points_to_load; + + copy_aligned_data_from_file(filename, self.into_dto(), 0)?; + + println!("Dataset loaded."); + Ok(()) + } + + /// Append the dataset from file + pub fn append_from_file( + &mut self, + filename: &str, + num_points_to_append: usize, + ) -> ANNResult<()> { + println!( + "Appending {} vectors from file {} into dataset...", + num_points_to_append, filename + ); + if self.num_points + num_points_to_append > self.capacity { + return Err(ANNError::log_index_error(format!( + "Cannot append {} points to dataset of capacity {}", + num_points_to_append, self.capacity + ))); + } + + let pts_offset = self.num_active_pts; + copy_aligned_data_from_file(filename, self.into_dto(), pts_offset)?; + + self.num_active_pts += num_points_to_append; + self.num_points += num_points_to_append; + + println!("Dataset appended."); + Ok(()) + } + + /// Get vertex by id + pub fn get_vertex(&'a self, id: u32) -> ANNResult> { + let start = id as usize * N; + let end = start + N; + + if end <= self.data.len() { + let val = <&[T; N]>::try_from(&self.data[start..end]).map_err(|err| { + ANNError::log_index_error(format!("Failed to get vertex {}, err={}", id, err)) + })?; + Ok(Vertex::new(val, id)) + } else { + Err(ANNError::log_index_error(format!( + "Invalid vertex id {}.", + id + ))) + } + } + + /// Get full precision distance between two nodes + pub fn get_distance(&self, id1: u32, id2: u32, metric: Metric) -> ANNResult { + let vertex1 = self.get_vertex(id1)?; + let vertex2 = self.get_vertex(id2)?; + + Ok(vertex1.compare(&vertex2, metric)) + } + + /// find out the medoid, the vertex in the dataset that is closest to the centroid + pub fn calculate_medoid_point_id(&self) -> ANNResult { + Ok(self.find_nearest_point_id(self.calculate_centroid_point()?)) + } + + /// calculate centroid, average of all vertices in the dataset + fn calculate_centroid_point(&self) -> ANNResult<[f32; N]> { + // Allocate and initialize the centroid vector + let mut center: [f32; N] = [0.0; N]; + + // Sum the data points' components + for i in 0..self.num_active_pts { + let vertex = self.get_vertex(i as u32)?; + let vertex_slice = vertex.vector(); + for j in 0..N { + center[j] += vertex_slice[j].into(); + } + } + + // Divide by the number of points to calculate the centroid + let capacity = self.num_active_pts as f32; + for item in center.iter_mut().take(N) { + *item /= capacity; + } + + Ok(center) + } + + /// find out the vertex closest to the given point + fn find_nearest_point_id(&self, point: [f32; N]) -> u32 { + // compute all to one distance + let mut distances = vec![0f32; self.num_active_pts]; + let slice = &self.data[..]; + distances.par_iter_mut().enumerate().for_each(|(i, dist)| { + let start = i * N; + for j in 0..N { + let diff: f32 = (point.as_slice()[j] - slice[start + j].into()) + * (point.as_slice()[j] - slice[start + j].into()); + *dist += diff; + } + }); + + let mut min_idx = 0; + let mut min_dist = f32::MAX; + for (i, distance) in distances.iter().enumerate().take(self.num_active_pts) { + if *distance < min_dist { + min_idx = i; + min_dist = *distance; + } + } + min_idx as u32 + } + + /// Prefetch vertex data in the memory hierarchy + /// NOTE: good efficiency when total_vec_size is integral multiple of 64 + #[inline] + pub fn prefetch_vector(&self, id: u32) { + let start = id as usize * N; + let end = start + N; + + if end <= self.data.len() { + let vec = &self.data[start..end]; + vector::prefetch_vector(vec); + } + } + + /// Convert into dto object + pub fn into_dto(&mut self) -> DatasetDto { + DatasetDto { + data: &mut self.data, + rounded_dim: N, + } + } +} + +/// Dataset dto used for other layer, such as storage +/// N is the aligned dimension +#[derive(Debug)] +pub struct DatasetDto<'a, T> { + /// data slice borrow from dataset + pub data: &'a mut [T], + + /// rounded dimension + pub rounded_dim: usize, +} + +#[cfg(test)] +mod dataset_test { + use std::fs; + + use super::*; + use crate::model::vertex::DIM_128; + + #[test] + fn get_vertex_within_range() { + let num_points = 1_000_000; + let id = 999_999; + let dataset = InmemDataset::::new(num_points, 1f32).unwrap(); + + let vertex = dataset.get_vertex(999_999).unwrap(); + + assert_eq!(vertex.vertex_id(), id); + assert_eq!(vertex.vector().len(), DIM_128); + assert_eq!(vertex.vector().as_ptr(), unsafe { + dataset.data.as_ptr().add((id as usize) * DIM_128) + }); + } + + #[test] + fn get_vertex_out_of_range() { + let num_points = 1_000_000; + let invalid_id = 1_000_000; + let dataset = InmemDataset::::new(num_points, 1f32).unwrap(); + + if dataset.get_vertex(invalid_id).is_ok() { + panic!("id ({}) should be out of range", invalid_id) + }; + } + + #[test] + fn load_data_test() { + let file_name = "dataset_test_load_data_test.bin"; + //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8] + let data: [u8; 72] = [ + 2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, + 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, + 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, + 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, + 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41, + ]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let mut dataset = InmemDataset::::new(2, 1f32).unwrap(); + + match copy_aligned_data_from_file( + file_name, + dataset.into_dto(), + 0, + ) { + Ok((npts, dim)) => { + fs::remove_file(file_name).expect("Failed to delete file"); + assert!(npts == 2); + assert!(dim == 8); + assert!(dataset.data.len() == 16); + + let first_vertex = dataset.get_vertex(0).unwrap(); + let second_vertex = dataset.get_vertex(1).unwrap(); + + assert!(*first_vertex.vector() == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + assert!(*second_vertex.vector() == [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]); + } + Err(e) => { + fs::remove_file(file_name).expect("Failed to delete file"); + panic!("{}", e) + } + } + } +} + diff --git a/rust/diskann/src/model/data_store/mod.rs b/rust/diskann/src/model/data_store/mod.rs new file mode 100644 index 000000000..4e7e68393 --- /dev/null +++ b/rust/diskann/src/model/data_store/mod.rs @@ -0,0 +1,11 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod inmem_dataset; +pub use inmem_dataset::InmemDataset; +pub use inmem_dataset::DatasetDto; + +mod disk_scratch_dataset; +pub use disk_scratch_dataset::*; diff --git a/rust/diskann/src/model/graph/adjacency_list.rs b/rust/diskann/src/model/graph/adjacency_list.rs new file mode 100644 index 000000000..7ad2d7d5b --- /dev/null +++ b/rust/diskann/src/model/graph/adjacency_list.rs @@ -0,0 +1,64 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Adjacency List + +use std::ops::{Deref, DerefMut}; + +#[derive(Debug, Eq, PartialEq)] +/// Represents the out neighbors of a vertex +pub struct AdjacencyList { + edges: Vec, +} + +/// In-mem index related limits +const GRAPH_SLACK_FACTOR: f32 = 1.3_f32; + +impl AdjacencyList { + /// Create AdjacencyList with capacity slack for a range. + pub fn for_range(range: usize) -> Self { + let capacity = (range as f32 * GRAPH_SLACK_FACTOR).ceil() as usize; + Self { + edges: Vec::with_capacity(capacity), + } + } + + /// Push a node to the list of neighbors for the given node. + pub fn push(&mut self, node_id: u32) { + debug_assert!(self.edges.len() < self.edges.capacity()); + self.edges.push(node_id); + } +} + +impl From> for AdjacencyList { + fn from(edges: Vec) -> Self { + Self { edges } + } +} + +impl Deref for AdjacencyList { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.edges + } +} + +impl DerefMut for AdjacencyList { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.edges + } +} + +impl<'a> IntoIterator for &'a AdjacencyList { + type Item = &'a u32; + type IntoIter = std::slice::Iter<'a, u32>; + + fn into_iter(self) -> Self::IntoIter { + self.edges.iter() + } +} + diff --git a/rust/diskann/src/model/graph/disk_graph.rs b/rust/diskann/src/model/graph/disk_graph.rs new file mode 100644 index 000000000..49190b1cd --- /dev/null +++ b/rust/diskann/src/model/graph/disk_graph.rs @@ -0,0 +1,179 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_docs)] + +//! Disk graph + +use byteorder::{LittleEndian, ByteOrder}; +use vector::FullPrecisionDistance; + +use crate::common::{ANNResult, ANNError}; +use crate::model::data_store::DiskScratchDataset; +use crate::model::Vertex; +use crate::storage::DiskGraphStorage; + +use super::{VertexAndNeighbors, SectorGraph, AdjacencyList}; + +/// Disk graph +pub struct DiskGraph { + /// dim of fp vector in disk sector + dim: usize, + + /// number of nodes per sector + num_nodes_per_sector: u64, + + /// max node length in bytes + max_node_len: u64, + + /// the len of fp vector + fp_vector_len: u64, + + /// list of nodes (vertex_id) to fetch from disk + nodes_to_fetch: Vec, + + /// Sector graph + sector_graph: SectorGraph, +} + +impl<'a> DiskGraph { + /// Create DiskGraph instance + pub fn new( + dim: usize, + num_nodes_per_sector: u64, + max_node_len: u64, + fp_vector_len: u64, + beam_width: usize, + graph_storage: DiskGraphStorage, + ) -> ANNResult { + let graph = Self { + dim, + num_nodes_per_sector, + max_node_len, + fp_vector_len, + nodes_to_fetch: Vec::with_capacity(2 * beam_width), + sector_graph: SectorGraph::new(graph_storage)?, + }; + + Ok(graph) + } + + /// Add vertex_id into the list to fetch from disk + pub fn add_vertex(&mut self, id: u32) { + self.nodes_to_fetch.push(id); + } + + /// Fetch nodes from disk index + pub fn fetch_nodes(&mut self) -> ANNResult<()> { + let sectors_to_fetch: Vec = self.nodes_to_fetch.iter().map(|&id| self.node_sector_index(id)).collect(); + self.sector_graph.read_graph(§ors_to_fetch)?; + + Ok(()) + } + + /// Copy disk fp vector to DiskScratchDataset + /// Return the fp vector with aligned dim from DiskScratchDataset + pub fn copy_fp_vector_to_disk_scratch_dataset( + &self, + node_index: usize, + disk_scratch_dataset: &'a mut DiskScratchDataset + ) -> ANNResult> + where + [T; N]: FullPrecisionDistance, + { + if self.dim > N { + return Err(ANNError::log_index_error(format!( + "copy_sector_fp_to_aligned_dataset: dim {} is greater than aligned dim {}", + self.dim, N))); + } + + let fp_vector_buf = self.node_fp_vector_buf(node_index); + + // Safety condition is met here + let aligned_dim_vector = unsafe { disk_scratch_dataset.memcpy_from_fp_vector_buf(fp_vector_buf) }; + + Vertex::<'a, T, N>::try_from((aligned_dim_vector, self.nodes_to_fetch[node_index])) + .map_err(|err| ANNError::log_index_error(format!("TryFromSliceError: failed to get Vertex for disk index node, err={}", err))) + } + + /// Reset graph + pub fn reset(&mut self) { + self.nodes_to_fetch.clear(); + self.sector_graph.reset(); + } + + fn get_vertex_and_neighbors(&self, node_index: usize) -> VertexAndNeighbors { + let node_disk_buf = self.node_disk_buf(node_index); + let buf = &node_disk_buf[self.fp_vector_len as usize..]; + let num_neighbors = LittleEndian::read_u32(&buf[0..4]) as usize; + let neighbors_buf = &buf[4..4 + num_neighbors * 4]; + + let mut adjacency_list = AdjacencyList::for_range(num_neighbors); + for chunk in neighbors_buf.chunks(4) { + let neighbor_id = LittleEndian::read_u32(chunk); + adjacency_list.push(neighbor_id); + } + + VertexAndNeighbors::new(self.nodes_to_fetch[node_index], adjacency_list) + } + + #[inline] + fn node_sector_index(&self, vertex_id: u32) -> u64 { + vertex_id as u64 / self.num_nodes_per_sector + 1 + } + + #[inline] + fn node_disk_buf(&self, node_index: usize) -> &[u8] { + let vertex_id = self.nodes_to_fetch[node_index]; + + // get sector_buf where this node is located + let sector_buf = self.sector_graph.get_sector_buf(node_index); + let node_offset = (vertex_id as u64 % self.num_nodes_per_sector * self.max_node_len) as usize; + §or_buf[node_offset..node_offset + self.max_node_len as usize] + } + + #[inline] + fn node_fp_vector_buf(&self, node_index: usize) -> &[u8] { + let node_disk_buf = self.node_disk_buf(node_index); + &node_disk_buf[..self.fp_vector_len as usize] + } +} + +/// Iterator for DiskGraph +pub struct DiskGraphIntoIterator<'a> { + graph: &'a DiskGraph, + index: usize, +} + +impl<'a> IntoIterator for &'a DiskGraph +{ + type IntoIter = DiskGraphIntoIterator<'a>; + type Item = ANNResult<(usize, VertexAndNeighbors)>; + + #[inline] + fn into_iter(self) -> Self::IntoIter { + DiskGraphIntoIterator { + graph: self, + index: 0, + } + } +} + +impl<'a> Iterator for DiskGraphIntoIterator<'a> +{ + type Item = ANNResult<(usize, VertexAndNeighbors)>; + + fn next(&mut self) -> Option { + if self.index >= self.graph.nodes_to_fetch.len() { + return None; + } + + let node_index = self.index; + let vertex_and_neighbors = self.graph.get_vertex_and_neighbors(self.index); + + self.index += 1; + Some(Ok((node_index, vertex_and_neighbors))) + } +} + diff --git a/rust/diskann/src/model/graph/inmem_graph.rs b/rust/diskann/src/model/graph/inmem_graph.rs new file mode 100644 index 000000000..3d08db837 --- /dev/null +++ b/rust/diskann/src/model/graph/inmem_graph.rs @@ -0,0 +1,141 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! In-memory graph + +use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard}; + +use crate::common::ANNError; + +use super::VertexAndNeighbors; + +/// The entire graph of in-memory index +#[derive(Debug)] +pub struct InMemoryGraph { + /// The entire graph + pub final_graph: Vec>, +} + +impl InMemoryGraph { + /// Create InMemoryGraph instance + pub fn new(size: usize, max_degree: u32) -> Self { + let mut graph = Vec::with_capacity(size); + for id in 0..size { + graph.push(RwLock::new(VertexAndNeighbors::for_range( + id as u32, + max_degree as usize, + ))); + } + Self { final_graph: graph } + } + + /// Size of graph + pub fn size(&self) -> usize { + self.final_graph.len() + } + + /// Extend the graph by size vectors + pub fn extend(&mut self, size: usize, max_degree: u32) { + for id in 0..size { + self.final_graph + .push(RwLock::new(VertexAndNeighbors::for_range( + id as u32, + max_degree as usize, + ))); + } + } + + /// Get read guard of vertex_id + pub fn read_vertex_and_neighbors( + &self, + vertex_id: u32, + ) -> Result, ANNError> { + self.final_graph[vertex_id as usize].read().map_err(|err| { + ANNError::log_lock_poison_error(format!( + "PoisonError: Lock poisoned when reading final_graph for vertex_id {}, err={}", + vertex_id, err + )) + }) + } + + /// Get write guard of vertex_id + pub fn write_vertex_and_neighbors( + &self, + vertex_id: u32, + ) -> Result, ANNError> { + self.final_graph[vertex_id as usize].write().map_err(|err| { + ANNError::log_lock_poison_error(format!( + "PoisonError: Lock poisoned when writing final_graph for vertex_id {}, err={}", + vertex_id, err + )) + }) + } +} + +#[cfg(test)] +mod graph_tests { + use crate::model::{graph::AdjacencyList, GRAPH_SLACK_FACTOR}; + + use super::*; + + #[test] + fn test_new() { + let graph = InMemoryGraph::new(10, 10); + let capacity = (GRAPH_SLACK_FACTOR * 10_f64).ceil() as usize; + + assert_eq!(graph.final_graph.len(), 10); + for i in 0..10 { + let neighbor = graph.final_graph[i].read().unwrap(); + assert_eq!(neighbor.vertex_id, i as u32); + assert_eq!(neighbor.get_neighbors().capacity(), capacity); + } + } + + #[test] + fn test_size() { + let graph = InMemoryGraph::new(10, 10); + assert_eq!(graph.size(), 10); + } + + #[test] + fn test_extend() { + let mut graph = InMemoryGraph::new(10, 10); + graph.extend(10, 10); + + assert_eq!(graph.size(), 20); + + let capacity = (GRAPH_SLACK_FACTOR * 10_f64).ceil() as usize; + let mut id: u32 = 0; + + for i in 10..20 { + let neighbor = graph.final_graph[i].read().unwrap(); + assert_eq!(neighbor.vertex_id, id); + assert_eq!(neighbor.get_neighbors().capacity(), capacity); + id += 1; + } + } + + #[test] + fn test_read_vertex_and_neighbors() { + let graph = InMemoryGraph::new(10, 10); + let neighbor = graph.read_vertex_and_neighbors(0); + assert!(neighbor.is_ok()); + assert_eq!(neighbor.unwrap().vertex_id, 0); + } + + #[test] + fn test_write_vertex_and_neighbors() { + let graph = InMemoryGraph::new(10, 10); + { + let neighbor = graph.write_vertex_and_neighbors(0); + assert!(neighbor.is_ok()); + neighbor.unwrap().add_to_neighbors(10, 10); + } + + let neighbor = graph.read_vertex_and_neighbors(0).unwrap(); + assert_eq!(neighbor.get_neighbors(), &AdjacencyList::from(vec![10_u32])); + } +} diff --git a/rust/diskann/src/model/graph/mod.rs b/rust/diskann/src/model/graph/mod.rs new file mode 100644 index 000000000..d1457f1c2 --- /dev/null +++ b/rust/diskann/src/model/graph/mod.rs @@ -0,0 +1,20 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod inmem_graph; +pub use inmem_graph::InMemoryGraph; + +pub mod vertex_and_neighbors; +pub use vertex_and_neighbors::VertexAndNeighbors; + +mod adjacency_list; +pub use adjacency_list::AdjacencyList; + +mod sector_graph; +pub use sector_graph::*; + +mod disk_graph; +pub use disk_graph::*; + diff --git a/rust/diskann/src/model/graph/sector_graph.rs b/rust/diskann/src/model/graph/sector_graph.rs new file mode 100644 index 000000000..e51e0bf03 --- /dev/null +++ b/rust/diskann/src/model/graph/sector_graph.rs @@ -0,0 +1,87 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_docs)] + +//! Sector graph + +use std::ops::Deref; + +use crate::common::{AlignedBoxWithSlice, ANNResult, ANNError}; +use crate::model::{MAX_N_SECTOR_READS, SECTOR_LEN, AlignedRead}; +use crate::storage::DiskGraphStorage; + +/// Sector graph read from disk index +pub struct SectorGraph { + /// Sector bytes from disk + /// One sector has num_nodes_per_sector nodes + /// Each node's layout: {full precision vector:[T; DIM]}{num_nbrs: u32}{neighbors: [u32; num_nbrs]} + /// The fp vector is not aligned + sectors_data: AlignedBoxWithSlice, + + /// Graph storage to read sectors + graph_storage: DiskGraphStorage, + + /// Current sector index into which the next read reads data + cur_sector_idx: u64, +} + +impl SectorGraph { + /// Create SectorGraph instance + pub fn new(graph_storage: DiskGraphStorage) -> ANNResult { + Ok(Self { + sectors_data: AlignedBoxWithSlice::new(MAX_N_SECTOR_READS * SECTOR_LEN, SECTOR_LEN)?, + graph_storage, + cur_sector_idx: 0, + }) + } + + /// Reset SectorGraph + pub fn reset(&mut self) { + self.cur_sector_idx = 0; + } + + /// Read sectors into sectors_data + /// They are in the same order as sectors_to_fetch + pub fn read_graph(&mut self, sectors_to_fetch: &[u64]) -> ANNResult<()> { + let cur_sector_idx_usize: usize = self.cur_sector_idx.try_into()?; + if sectors_to_fetch.len() > MAX_N_SECTOR_READS - cur_sector_idx_usize { + return Err(ANNError::log_index_error(format!( + "Trying to read too many sectors. number of sectors to read: {}, max number of sectors can read: {}", + sectors_to_fetch.len(), + MAX_N_SECTOR_READS - cur_sector_idx_usize, + ))); + } + + let mut sector_slices = self.sectors_data.split_into_nonoverlapping_mut_slices( + cur_sector_idx_usize * SECTOR_LEN..(cur_sector_idx_usize + sectors_to_fetch.len()) * SECTOR_LEN, + SECTOR_LEN)?; + + let mut read_requests = Vec::with_capacity(sector_slices.len()); + for (local_sector_idx, slice) in sector_slices.iter_mut().enumerate() { + let sector_id = sectors_to_fetch[local_sector_idx]; + read_requests.push(AlignedRead::new(sector_id * SECTOR_LEN as u64, slice)?); + } + + self.graph_storage.read(&mut read_requests)?; + self.cur_sector_idx += sectors_to_fetch.len() as u64; + + Ok(()) + } + + /// Get sector data by local index + #[inline] + pub fn get_sector_buf(&self, local_sector_idx: usize) -> &[u8] { + &self.sectors_data[local_sector_idx * SECTOR_LEN..(local_sector_idx + 1) * SECTOR_LEN] + } +} + +impl Deref for SectorGraph { + type Target = [u8]; + + fn deref(&self) -> &Self::Target { + &self.sectors_data + } +} + diff --git a/rust/diskann/src/model/graph/vertex_and_neighbors.rs b/rust/diskann/src/model/graph/vertex_and_neighbors.rs new file mode 100644 index 000000000..a9fa38932 --- /dev/null +++ b/rust/diskann/src/model/graph/vertex_and_neighbors.rs @@ -0,0 +1,159 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Vertex and its Adjacency List + +use crate::model::GRAPH_SLACK_FACTOR; + +use super::AdjacencyList; + +/// The out neighbors of vertex_id +#[derive(Debug)] +pub struct VertexAndNeighbors { + /// The id of the vertex + pub vertex_id: u32, + + /// All out neighbors (id) of vertex_id + neighbors: AdjacencyList, +} + +impl VertexAndNeighbors { + /// Create VertexAndNeighbors with id and capacity + pub fn for_range(id: u32, range: usize) -> Self { + Self { + vertex_id: id, + neighbors: AdjacencyList::for_range(range), + } + } + + /// Create VertexAndNeighbors with id and neighbors + pub fn new(vertex_id: u32, neighbors: AdjacencyList) -> Self { + Self { + vertex_id, + neighbors, + } + } + + /// Get size of neighbors + #[inline(always)] + pub fn size(&self) -> usize { + self.neighbors.len() + } + + /// Update the neighbors vector (post a pruning exercise) + #[inline(always)] + pub fn set_neighbors(&mut self, new_neighbors: AdjacencyList) { + // Replace the graph entry with the pruned neighbors + self.neighbors = new_neighbors; + } + + /// Get the neighbors + #[inline(always)] + pub fn get_neighbors(&self) -> &AdjacencyList { + &self.neighbors + } + + /// Adds a node to the list of neighbors for the given node. + /// + /// # Arguments + /// + /// * `node_id` - The ID of the node to add. + /// * `range` - The range of the graph. + /// + /// # Return + /// + /// Returns `None` if the node is already in the list of neighbors, or a `Vec` containing the updated list of neighbors if the list of neighbors is full. + pub fn add_to_neighbors(&mut self, node_id: u32, range: u32) -> Option> { + // Check if n is already in the graph entry + if self.neighbors.contains(&node_id) { + return None; + } + + let neighbor_len = self.neighbors.len(); + + // If not, check if the graph entry has enough space + if neighbor_len < (GRAPH_SLACK_FACTOR * range as f64) as usize { + // If yes, add n to the graph entry + self.neighbors.push(node_id); + return None; + } + + let mut copy_of_neighbors = Vec::with_capacity(neighbor_len + 1); + unsafe { + let dst = copy_of_neighbors.as_mut_ptr(); + std::ptr::copy_nonoverlapping(self.neighbors.as_ptr(), dst, neighbor_len); + dst.add(neighbor_len).write(node_id); + copy_of_neighbors.set_len(neighbor_len + 1); + } + + Some(copy_of_neighbors) + } +} + +#[cfg(test)] +mod vertex_and_neighbors_tests { + use crate::model::GRAPH_SLACK_FACTOR; + + use super::*; + + #[test] + fn test_set_with_capacity() { + let neighbors = VertexAndNeighbors::for_range(20, 10); + assert_eq!(neighbors.vertex_id, 20); + assert_eq!( + neighbors.neighbors.capacity(), + (10_f32 * GRAPH_SLACK_FACTOR as f32).ceil() as usize + ); + } + + #[test] + fn test_size() { + let mut neighbors = VertexAndNeighbors::for_range(20, 10); + + for i in 0..5 { + neighbors.neighbors.push(i); + } + + assert_eq!(neighbors.size(), 5); + } + + #[test] + fn test_set_neighbors() { + let mut neighbors = VertexAndNeighbors::for_range(20, 10); + let new_vec = AdjacencyList::from(vec![1, 2, 3, 4, 5]); + neighbors.set_neighbors(AdjacencyList::from(new_vec.clone())); + + assert_eq!(neighbors.neighbors, new_vec); + } + + #[test] + fn test_get_neighbors() { + let mut neighbors = VertexAndNeighbors::for_range(20, 10); + neighbors.set_neighbors(AdjacencyList::from(vec![1, 2, 3, 4, 5])); + let neighbor_ref = neighbors.get_neighbors(); + + assert!(std::ptr::eq(&neighbors.neighbors, neighbor_ref)) + } + + #[test] + fn test_add_to_neighbors() { + let mut neighbors = VertexAndNeighbors::for_range(20, 10); + + assert_eq!(neighbors.add_to_neighbors(1, 1), None); + assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1])); + + assert_eq!(neighbors.add_to_neighbors(1, 1), None); + assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1])); + + let ret = neighbors.add_to_neighbors(2, 1); + assert!(ret.is_some()); + assert_eq!(ret.unwrap(), vec![1, 2]); + assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1])); + + assert_eq!(neighbors.add_to_neighbors(2, 2), None); + assert_eq!(neighbors.neighbors, AdjacencyList::from(vec![1, 2])); + } +} diff --git a/rust/diskann/src/model/mod.rs b/rust/diskann/src/model/mod.rs new file mode 100644 index 000000000..a4f15ee52 --- /dev/null +++ b/rust/diskann/src/model/mod.rs @@ -0,0 +1,29 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod neighbor; +pub use neighbor::Neighbor; +pub use neighbor::NeighborPriorityQueue; + +pub mod data_store; +pub use data_store::InmemDataset; + +pub mod graph; +pub use graph::InMemoryGraph; +pub use graph::VertexAndNeighbors; + +pub mod configuration; +pub use configuration::*; + +pub mod scratch; +pub use scratch::*; + +pub mod vertex; +pub use vertex::Vertex; + +pub mod pq; +pub use pq::*; + +pub mod windows_aligned_file_reader; +pub use windows_aligned_file_reader::*; diff --git a/rust/diskann/src/model/neighbor/mod.rs b/rust/diskann/src/model/neighbor/mod.rs new file mode 100644 index 000000000..cd0dbad2a --- /dev/null +++ b/rust/diskann/src/model/neighbor/mod.rs @@ -0,0 +1,13 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod neighbor; +pub use neighbor::*; + +mod neighbor_priority_queue; +pub use neighbor_priority_queue::*; + +mod sorted_neighbor_vector; +pub use sorted_neighbor_vector::SortedNeighborVector; diff --git a/rust/diskann/src/model/neighbor/neighbor.rs b/rust/diskann/src/model/neighbor/neighbor.rs new file mode 100644 index 000000000..8c712bcd3 --- /dev/null +++ b/rust/diskann/src/model/neighbor/neighbor.rs @@ -0,0 +1,104 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::cmp::Ordering; + +/// Neighbor node +#[derive(Debug, Clone, Copy)] +pub struct Neighbor { + /// The id of the node + pub id: u32, + + /// The distance from the query node to current node + pub distance: f32, + + /// Whether the current is visited or not + pub visited: bool, +} + +impl Neighbor { + /// Create the neighbor node and it has not been visited + pub fn new (id: u32, distance: f32) -> Self { + Self { + id, + distance, + visited: false + } + } +} + +impl Default for Neighbor { + fn default() -> Self { + Self { id: 0, distance: 0.0_f32, visited: false } + } +} + +impl PartialEq for Neighbor { + #[inline] + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl Eq for Neighbor {} + +impl Ord for Neighbor { + fn cmp(&self, other: &Self) -> Ordering { + let ord = self.distance.partial_cmp(&other.distance).unwrap_or(std::cmp::Ordering::Equal); + + if ord == Ordering::Equal { + return self.id.cmp(&other.id); + } + + ord + } +} + +impl PartialOrd for Neighbor { + #[inline] + fn lt(&self, other: &Self) -> bool { + self.distance < other.distance || (self.distance == other.distance && self.id < other.id) + } + + // Reason for allowing panic = "Does not support comparing Neighbor with partial_cmp" + #[allow(clippy::panic)] + fn partial_cmp(&self, _: &Self) -> Option { + panic!("Neighbor only allows eq and lt") + } +} + +#[cfg(test)] +mod neighbor_test { + use super::*; + + #[test] + fn eq_lt_works() { + let n1 = Neighbor::new(1, 1.1); + let n2 = Neighbor::new(2, 2.0); + let n3 = Neighbor::new(1, 1.1); + + assert!(n1 != n2); + assert!(n1 < n2); + assert!(n1 == n3); + } + + #[test] + #[should_panic] + fn gt_should_panic() { + let n1 = Neighbor::new(1, 1.1); + let n2 = Neighbor::new(2, 2.0); + + assert!(n2 > n1); + } + + #[test] + #[should_panic] + fn le_should_panic() { + let n1 = Neighbor::new(1, 1.1); + let n2 = Neighbor::new(2, 2.0); + + assert!(n1 <= n2); + } +} + diff --git a/rust/diskann/src/model/neighbor/neighbor_priority_queue.rs b/rust/diskann/src/model/neighbor/neighbor_priority_queue.rs new file mode 100644 index 000000000..81b161026 --- /dev/null +++ b/rust/diskann/src/model/neighbor/neighbor_priority_queue.rs @@ -0,0 +1,241 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::model::Neighbor; + +/// Neighbor priority Queue based on the distance to the query node +#[derive(Debug)] +pub struct NeighborPriorityQueue { + /// The size of the priority queue + size: usize, + + /// The capacity of the priority queue + capacity: usize, + + /// The current notvisited neighbor whose distance is smallest among all notvisited neighbor + cur: usize, + + /// The neighbor collection + data: Vec, +} + +impl Default for NeighborPriorityQueue { + fn default() -> Self { + Self::new() + } +} + +impl NeighborPriorityQueue { + /// Create NeighborPriorityQueue without capacity + pub fn new() -> Self { + Self { + size: 0, + capacity: 0, + cur: 0, + data: Vec::new(), + } + } + + /// Create NeighborPriorityQueue with capacity + pub fn with_capacity(capacity: usize) -> Self { + Self { + size: 0, + capacity, + cur: 0, + data: vec![Neighbor::default(); capacity + 1], + } + } + + /// Inserts item with order. + /// The item will be dropped if queue is full / already exist in queue / it has a greater distance than the last item. + /// The set cursor that is used to pop() the next item will be set to the lowest index of an uncheck item. + pub fn insert(&mut self, nbr: Neighbor) { + if self.size == self.capacity && self.get_at(self.size - 1) < &nbr { + return; + } + + let mut lo = 0; + let mut hi = self.size; + while lo < hi { + let mid = (lo + hi) >> 1; + if &nbr < self.get_at(mid) { + hi = mid; + } else if self.get_at(mid).id == nbr.id { + // Make sure the same neighbor isn't inserted into the set + return; + } else { + lo = mid + 1; + } + } + + if lo < self.capacity { + self.data.copy_within(lo..self.size, lo + 1); + } + self.data[lo] = Neighbor::new(nbr.id, nbr.distance); + if self.size < self.capacity { + self.size += 1; + } + if lo < self.cur { + self.cur = lo; + } + } + + /// Get the neighbor at index - SAFETY: index must be less than size + fn get_at(&self, index: usize) -> &Neighbor { + unsafe { self.data.get_unchecked(index) } + } + + /// Get the closest and notvisited neighbor + pub fn closest_notvisited(&mut self) -> Neighbor { + self.data[self.cur].visited = true; + let pre = self.cur; + while self.cur < self.size && self.get_at(self.cur).visited { + self.cur += 1; + } + self.data[pre] + } + + /// Whether there is notvisited node or not + pub fn has_notvisited_node(&self) -> bool { + self.cur < self.size + } + + /// Get the size of the NeighborPriorityQueue + pub fn size(&self) -> usize { + self.size + } + + /// Get the capacity of the NeighborPriorityQueue + pub fn capacity(&self) -> usize { + self.capacity + } + + /// Sets an artificial capacity of the NeighborPriorityQueue. For benchmarking purposes only. + pub fn set_capacity(&mut self, capacity: usize) { + if capacity < self.data.len() { + self.capacity = capacity; + } + } + + /// Reserve capacity + pub fn reserve(&mut self, capacity: usize) { + if capacity > self.capacity { + self.data.resize(capacity + 1, Neighbor::default()); + self.capacity = capacity; + } + } + + /// Set size and cur to 0 + pub fn clear(&mut self) { + self.size = 0; + self.cur = 0; + } +} + +impl std::ops::Index for NeighborPriorityQueue { + type Output = Neighbor; + + fn index(&self, i: usize) -> &Self::Output { + &self.data[i] + } +} + +#[cfg(test)] +mod neighbor_priority_queue_test { + use super::*; + + #[test] + fn test_reserve_capacity() { + let mut queue = NeighborPriorityQueue::with_capacity(10); + assert_eq!(queue.capacity(), 10); + queue.reserve(20); + assert_eq!(queue.capacity(), 20); + } + + #[test] + fn test_insert() { + let mut queue = NeighborPriorityQueue::with_capacity(3); + assert_eq!(queue.size(), 0); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + assert_eq!(queue.size(), 2); + queue.insert(Neighbor::new(2, 0.5)); // should be ignored as the same neighbor + assert_eq!(queue.size(), 2); + queue.insert(Neighbor::new(3, 0.9)); + assert_eq!(queue.size(), 3); + assert_eq!(queue[2].id, 1); + queue.insert(Neighbor::new(4, 2.0)); // should be dropped as queue is full and distance is greater than last item + assert_eq!(queue.size(), 3); + assert_eq!(queue[0].id, 2); // node id in queue should be [2,3,1] + assert_eq!(queue[1].id, 3); + assert_eq!(queue[2].id, 1); + println!("{:?}", queue); + } + + #[test] + fn test_index() { + let mut queue = NeighborPriorityQueue::with_capacity(3); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); + assert_eq!(queue[0].id, 2); + assert_eq!(queue[0].distance, 0.5); + } + + #[test] + fn test_visit() { + let mut queue = NeighborPriorityQueue::with_capacity(3); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + queue.insert(Neighbor::new(3, 1.5)); // node id in queue should be [2,1,3] + assert!(queue.has_notvisited_node()); + let nbr = queue.closest_notvisited(); + assert_eq!(nbr.id, 2); + assert_eq!(nbr.distance, 0.5); + assert!(nbr.visited); + assert!(queue.has_notvisited_node()); + let nbr = queue.closest_notvisited(); + assert_eq!(nbr.id, 1); + assert_eq!(nbr.distance, 1.0); + assert!(nbr.visited); + assert!(queue.has_notvisited_node()); + let nbr = queue.closest_notvisited(); + assert_eq!(nbr.id, 3); + assert_eq!(nbr.distance, 1.5); + assert!(nbr.visited); + assert!(!queue.has_notvisited_node()); + } + + #[test] + fn test_clear_queue() { + let mut queue = NeighborPriorityQueue::with_capacity(3); + queue.insert(Neighbor::new(1, 1.0)); + queue.insert(Neighbor::new(2, 0.5)); + assert_eq!(queue.size(), 2); + assert!(queue.has_notvisited_node()); + queue.clear(); + assert_eq!(queue.size(), 0); + assert!(!queue.has_notvisited_node()); + } + + #[test] + fn test_reserve() { + let mut queue = NeighborPriorityQueue::new(); + queue.reserve(10); + assert_eq!(queue.data.len(), 11); + assert_eq!(queue.capacity, 10); + } + + #[test] + fn test_set_capacity() { + let mut queue = NeighborPriorityQueue::with_capacity(10); + queue.set_capacity(5); + assert_eq!(queue.capacity, 5); + assert_eq!(queue.data.len(), 11); + + queue.set_capacity(11); + assert_eq!(queue.capacity, 5); + } +} + diff --git a/rust/diskann/src/model/neighbor/sorted_neighbor_vector.rs b/rust/diskann/src/model/neighbor/sorted_neighbor_vector.rs new file mode 100644 index 000000000..4c3eff00f --- /dev/null +++ b/rust/diskann/src/model/neighbor/sorted_neighbor_vector.rs @@ -0,0 +1,37 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Sorted Neighbor Vector + +use std::ops::{Deref, DerefMut}; + +use super::Neighbor; + +/// A newtype on top of vector of neighbors, is sorted by distance +#[derive(Debug)] +pub struct SortedNeighborVector<'a>(&'a mut Vec); + +impl<'a> SortedNeighborVector<'a> { + /// Create a new SortedNeighborVector + pub fn new(vec: &'a mut Vec) -> Self { + vec.sort_unstable(); + Self(vec) + } +} + +impl<'a> Deref for SortedNeighborVector<'a> { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + +impl<'a> DerefMut for SortedNeighborVector<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0 + } +} diff --git a/rust/diskann/src/model/pq/fixed_chunk_pq_table.rs b/rust/diskann/src/model/pq/fixed_chunk_pq_table.rs new file mode 100644 index 000000000..bfedcae6e --- /dev/null +++ b/rust/diskann/src/model/pq/fixed_chunk_pq_table.rs @@ -0,0 +1,483 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations)] + +use hashbrown::HashMap; +use rayon::prelude::{ + IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator, ParallelSliceMut, +}; +use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0}; + +use crate::{ + common::{ANNError, ANNResult}, + model::NUM_PQ_CENTROIDS, +}; + +/// PQ Pivot table loading and calculate distance +#[derive(Debug)] +pub struct FixedChunkPQTable { + /// pq_tables = float array of size [256 * ndims] + pq_table: Vec, + + /// ndims = true dimension of vectors + dim: usize, + + /// num_pq_chunks = the pq chunk number + num_pq_chunks: usize, + + /// chunk_offsets = the offset of each chunk, start from 0 + chunk_offsets: Vec, + + /// centroid of each dimension + centroids: Vec, + + /// Becasue we're using L2 distance, this is no needed now. + /// Transport of pq_table. transport_pq_table = float array of size [ndims * 256]. + /// e.g. if pa_table is 2 centroids * 3 dims + /// [ 1, 2, 3, + /// 4, 5, 6] + /// then transport_pq_table would be 3 dims * 2 centroids + /// [ 1, 4, + /// 2, 5, + /// 3, 6] + /// transport_pq_table: Vec, + + /// Map dim offset to chunk index e.g., 8 dims in to 2 chunks + /// then would be [(0,0), (1,0), (2,0), (3,0), (4,1), (5,1), (6,1), (7,1)] + dimoffset_chunk_mapping: HashMap, +} + +impl FixedChunkPQTable { + /// Create the FixedChunkPQTable with dim and chunk numbers and pivot file data (pivot table + cenroids + chunk offsets) + pub fn new( + dim: usize, + num_pq_chunks: usize, + pq_table: Vec, + centroids: Vec, + chunk_offsets: Vec, + ) -> Self { + let mut dimoffset_chunk_mapping = HashMap::new(); + for chunk_index in 0..num_pq_chunks { + for dim_offset in chunk_offsets[chunk_index]..chunk_offsets[chunk_index + 1] { + dimoffset_chunk_mapping.insert(dim_offset, chunk_index); + } + } + + Self { + pq_table, + dim, + num_pq_chunks, + chunk_offsets, + centroids, + dimoffset_chunk_mapping, + } + } + + /// Get chunk number + pub fn get_num_chunks(&self) -> usize { + self.num_pq_chunks + } + + /// Shifting the query according to mean or the whole corpus + pub fn preprocess_query(&self, query_vec: &mut [f32]) { + for (query, ¢roid) in query_vec.iter_mut().zip(self.centroids.iter()) { + *query -= centroid; + } + } + + /// Pre-calculated the distance between query and each centroid by l2 distance + /// * `query_vec` - query vector: 1 * dim + /// * `dist_vec` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids + #[allow(clippy::needless_range_loop)] + pub fn populate_chunk_distances(&self, query_vec: &[f32]) -> Vec { + let mut dist_vec = vec![0.0; self.num_pq_chunks * NUM_PQ_CENTROIDS]; + for centroid_index in 0..NUM_PQ_CENTROIDS { + for chunk_index in 0..self.num_pq_chunks { + for dim_offset in + self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1] + { + let diff: f32 = self.pq_table[self.dim * centroid_index + dim_offset] + - query_vec[dim_offset]; + dist_vec[chunk_index * NUM_PQ_CENTROIDS + centroid_index] += diff * diff; + } + } + } + dist_vec + } + + /// Pre-calculated the distance between query and each centroid by inner product + /// * `query_vec` - query vector: 1 * dim + /// * `dist_vec` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids + /// + /// Reason to allow clippy::needless_range_loop: + /// The inner loop is operating over a range that is different for each iteration of the outer loop. + /// This isn't a scenario where using iter().enumerate() would be easily applicable, + /// because the inner loop isn't iterating directly over the contents of a slice or array. + /// Thus, using indexing might be the most straightforward way to express this logic. + #[allow(clippy::needless_range_loop)] + pub fn populate_chunk_inner_products(&self, query_vec: &[f32]) -> Vec { + let mut dist_vec = vec![0.0; self.num_pq_chunks * NUM_PQ_CENTROIDS]; + for centroid_index in 0..NUM_PQ_CENTROIDS { + for chunk_index in 0..self.num_pq_chunks { + for dim_offset in + self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1] + { + // assumes that we are not shifting the vectors to mean zero, i.e., centroid + // array should be all zeros returning negative to keep the search code + // clean (max inner product vs min distance) + let diff: f32 = self.pq_table[self.dim * centroid_index + dim_offset] + * query_vec[dim_offset]; + dist_vec[chunk_index * NUM_PQ_CENTROIDS + centroid_index] -= diff; + } + } + } + dist_vec + } + + /// Calculate the distance between query and given centroid by l2 distance + /// * `query_vec` - query vector: 1 * dim + /// * `base_vec` - given centroid array: 1 * num_pq_chunks + #[allow(clippy::needless_range_loop)] + pub fn l2_distance(&self, query_vec: &[f32], base_vec: &[u8]) -> f32 { + let mut res_vec: Vec = vec![0.0; self.num_pq_chunks]; + res_vec + .par_iter_mut() + .enumerate() + .for_each(|(chunk_index, chunk_diff)| { + for dim_offset in + self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1] + { + let diff = self.pq_table + [self.dim * base_vec[chunk_index] as usize + dim_offset] + - query_vec[dim_offset]; + *chunk_diff += diff * diff; + } + }); + + let res: f32 = res_vec.iter().sum::(); + + res + } + + /// Calculate the distance between query and given centroid by inner product + /// * `query_vec` - query vector: 1 * dim + /// * `base_vec` - given centroid array: 1 * num_pq_chunks + #[allow(clippy::needless_range_loop)] + pub fn inner_product(&self, query_vec: &[f32], base_vec: &[u8]) -> f32 { + let mut res_vec: Vec = vec![0.0; self.num_pq_chunks]; + res_vec + .par_iter_mut() + .enumerate() + .for_each(|(chunk_index, chunk_diff)| { + for dim_offset in + self.chunk_offsets[chunk_index]..self.chunk_offsets[chunk_index + 1] + { + *chunk_diff += self.pq_table + [self.dim * base_vec[chunk_index] as usize + dim_offset] + * query_vec[dim_offset]; + } + }); + + let res: f32 = res_vec.iter().sum::(); + + // returns negative value to simulate distances (max -> min conversion) + -res + } + + /// Revert vector by adding centroid + /// * `base_vec` - given centroid array: 1 * num_pq_chunks + /// * `out_vec` - reverted vector + pub fn inflate_vector(&self, base_vec: &[u8]) -> ANNResult> { + let mut out_vec: Vec = vec![0.0; self.dim]; + for (dim_offset, value) in out_vec.iter_mut().enumerate() { + let chunk_index = + self.dimoffset_chunk_mapping + .get(&dim_offset) + .ok_or(ANNError::log_pq_error( + "ERROR: dim_offset not found in dimoffset_chunk_mapping".to_string(), + ))?; + *value = self.pq_table[self.dim * base_vec[*chunk_index] as usize + dim_offset] + + self.centroids[dim_offset]; + } + + Ok(out_vec) + } +} + +/// Given a batch input nodes, return a batch of PQ distance +/// * `pq_ids` - batch nodes: n_pts * pq_nchunks +/// * `n_pts` - batch number +/// * `pq_nchunks` - pq chunk number number +/// * `pq_dists` - pre-calculated the distance between query and each centroid: chunk_size * num_centroids +/// * `dists_out` - n_pts * 1 +pub fn pq_dist_lookup( + pq_ids: &[u8], + n_pts: usize, + pq_nchunks: usize, + pq_dists: &[f32], +) -> Vec { + let mut dists_out: Vec = vec![0.0; n_pts]; + unsafe { + _mm_prefetch(dists_out.as_ptr() as *const i8, _MM_HINT_T0); + _mm_prefetch(pq_ids.as_ptr() as *const i8, _MM_HINT_T0); + _mm_prefetch(pq_ids.as_ptr().add(64) as *const i8, _MM_HINT_T0); + _mm_prefetch(pq_ids.as_ptr().add(128) as *const i8, _MM_HINT_T0); + } + for chunk in 0..pq_nchunks { + let chunk_dists = &pq_dists[256 * chunk..]; + if chunk < pq_nchunks - 1 { + unsafe { + _mm_prefetch( + chunk_dists.as_ptr().offset(256 * chunk as isize).add(256) as *const i8, + _MM_HINT_T0, + ); + } + } + dists_out + .par_iter_mut() + .enumerate() + .for_each(|(n_iter, dist)| { + let pq_centerid = pq_ids[pq_nchunks * n_iter + chunk]; + *dist += chunk_dists[pq_centerid as usize]; + }); + } + dists_out +} + +pub fn aggregate_coords(ids: &[u32], all_coords: &[u8], ndims: usize) -> Vec { + let mut out: Vec = vec![0u8; ids.len() * ndims]; + let ndim_u32 = ndims as u32; + out.par_chunks_mut(ndims) + .enumerate() + .for_each(|(index, chunk)| { + let id_compressed_pivot = &all_coords + [(ids[index] * ndim_u32) as usize..(ids[index] * ndim_u32 + ndim_u32) as usize]; + let temp_slice = + unsafe { std::slice::from_raw_parts(id_compressed_pivot.as_ptr(), ndims) }; + chunk.copy_from_slice(temp_slice); + }); + + out +} + +#[cfg(test)] +mod fixed_chunk_pq_table_test { + + use super::*; + use crate::common::{ANNError, ANNResult}; + use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, file_exists, load_bin}; + + const DIM: usize = 128; + + #[test] + fn load_pivot_test() { + let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin"; + let (dim, pq_table, centroids, chunk_offsets) = + load_pq_pivots_bin(pq_pivots_path, &1).unwrap(); + let fixed_chunk_pq_table = + FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets); + + assert_eq!(dim, DIM); + assert_eq!(fixed_chunk_pq_table.pq_table.len(), DIM * NUM_PQ_CENTROIDS); + assert_eq!(fixed_chunk_pq_table.centroids.len(), DIM); + + assert_eq!(fixed_chunk_pq_table.chunk_offsets[0], 0); + assert_eq!(fixed_chunk_pq_table.chunk_offsets[1], DIM); + assert_eq!(fixed_chunk_pq_table.chunk_offsets.len(), 2); + } + + #[test] + fn get_num_chunks_test() { + let num_chunks = 7; + let pa_table = vec![0.0; DIM * NUM_PQ_CENTROIDS]; + let centroids = vec![0.0; DIM]; + let chunk_offsets = vec![0, 7, 9, 11, 22, 34, 78, 127]; + let fixed_chunk_pq_table = + FixedChunkPQTable::new(DIM, num_chunks, pa_table, centroids, chunk_offsets); + let chunk: usize = fixed_chunk_pq_table.get_num_chunks(); + assert_eq!(chunk, num_chunks); + } + + #[test] + fn preprocess_query_test() { + let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin"; + let (dim, pq_table, centroids, chunk_offsets) = + load_pq_pivots_bin(pq_pivots_path, &1).unwrap(); + let fixed_chunk_pq_table = + FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets); + + let mut query_vec: Vec = vec![ + 32.39f32, 78.57f32, 50.32f32, 80.46f32, 6.47f32, 69.76f32, 94.2f32, 83.36f32, 5.8f32, + 68.78f32, 42.32f32, 61.77f32, 90.26f32, 60.41f32, 3.86f32, 61.21f32, 16.6f32, 54.46f32, + 7.29f32, 54.24f32, 92.49f32, 30.18f32, 65.36f32, 99.09f32, 3.8f32, 36.4f32, 86.72f32, + 65.18f32, 29.87f32, 62.21f32, 58.32f32, 43.23f32, 94.3f32, 79.61f32, 39.67f32, + 11.18f32, 48.88f32, 38.19f32, 93.95f32, 10.46f32, 36.7f32, 14.75f32, 81.64f32, + 59.18f32, 99.03f32, 74.23f32, 1.26f32, 82.69f32, 35.7f32, 38.39f32, 46.17f32, 64.75f32, + 7.15f32, 36.55f32, 77.32f32, 18.65f32, 32.8f32, 74.84f32, 18.12f32, 20.19f32, 70.06f32, + 48.37f32, 40.18f32, 45.69f32, 88.3f32, 39.15f32, 60.97f32, 71.29f32, 61.79f32, + 47.23f32, 94.71f32, 58.04f32, 52.4f32, 34.66f32, 59.1f32, 47.11f32, 30.2f32, 58.72f32, + 74.35f32, 83.68f32, 66.8f32, 28.57f32, 29.45f32, 52.02f32, 91.95f32, 92.44f32, + 65.25f32, 38.3f32, 35.6f32, 41.67f32, 91.33f32, 76.81f32, 74.88f32, 33.17f32, 48.36f32, + 41.42f32, 23f32, 8.31f32, 81.69f32, 80.08f32, 50.55f32, 54.46f32, 23.79f32, 43.46f32, + 84.5f32, 10.42f32, 29.51f32, 19.73f32, 46.48f32, 35.01f32, 52.3f32, 66.97f32, 4.8f32, + 74.81f32, 2.82f32, 61.82f32, 25.06f32, 17.3f32, 17.29f32, 63.2f32, 64.1f32, 61.68f32, + 37.42f32, 3.39f32, 97.45f32, 5.32f32, 59.02f32, 35.6f32, + ]; + fixed_chunk_pq_table.preprocess_query(&mut query_vec); + assert_eq!(query_vec[0], 32.39f32 - fixed_chunk_pq_table.centroids[0]); + assert_eq!( + query_vec[127], + 35.6f32 - fixed_chunk_pq_table.centroids[127] + ); + } + + #[test] + fn calculate_distances_tests() { + let pq_pivots_path: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin"; + + let (dim, pq_table, centroids, chunk_offsets) = + load_pq_pivots_bin(pq_pivots_path, &1).unwrap(); + let fixed_chunk_pq_table = + FixedChunkPQTable::new(dim, 1, pq_table, centroids, chunk_offsets); + + let query_vec: Vec = vec![ + 32.39f32, 78.57f32, 50.32f32, 80.46f32, 6.47f32, 69.76f32, 94.2f32, 83.36f32, 5.8f32, + 68.78f32, 42.32f32, 61.77f32, 90.26f32, 60.41f32, 3.86f32, 61.21f32, 16.6f32, 54.46f32, + 7.29f32, 54.24f32, 92.49f32, 30.18f32, 65.36f32, 99.09f32, 3.8f32, 36.4f32, 86.72f32, + 65.18f32, 29.87f32, 62.21f32, 58.32f32, 43.23f32, 94.3f32, 79.61f32, 39.67f32, + 11.18f32, 48.88f32, 38.19f32, 93.95f32, 10.46f32, 36.7f32, 14.75f32, 81.64f32, + 59.18f32, 99.03f32, 74.23f32, 1.26f32, 82.69f32, 35.7f32, 38.39f32, 46.17f32, 64.75f32, + 7.15f32, 36.55f32, 77.32f32, 18.65f32, 32.8f32, 74.84f32, 18.12f32, 20.19f32, 70.06f32, + 48.37f32, 40.18f32, 45.69f32, 88.3f32, 39.15f32, 60.97f32, 71.29f32, 61.79f32, + 47.23f32, 94.71f32, 58.04f32, 52.4f32, 34.66f32, 59.1f32, 47.11f32, 30.2f32, 58.72f32, + 74.35f32, 83.68f32, 66.8f32, 28.57f32, 29.45f32, 52.02f32, 91.95f32, 92.44f32, + 65.25f32, 38.3f32, 35.6f32, 41.67f32, 91.33f32, 76.81f32, 74.88f32, 33.17f32, 48.36f32, + 41.42f32, 23f32, 8.31f32, 81.69f32, 80.08f32, 50.55f32, 54.46f32, 23.79f32, 43.46f32, + 84.5f32, 10.42f32, 29.51f32, 19.73f32, 46.48f32, 35.01f32, 52.3f32, 66.97f32, 4.8f32, + 74.81f32, 2.82f32, 61.82f32, 25.06f32, 17.3f32, 17.29f32, 63.2f32, 64.1f32, 61.68f32, + 37.42f32, 3.39f32, 97.45f32, 5.32f32, 59.02f32, 35.6f32, + ]; + + let dist_vec = fixed_chunk_pq_table.populate_chunk_distances(&query_vec); + assert_eq!(dist_vec.len(), 256); + + // populate_chunk_distances_test + let mut sampled_output = 0.0; + (0..DIM).for_each(|dim_offset| { + let diff = fixed_chunk_pq_table.pq_table[dim_offset] - query_vec[dim_offset]; + sampled_output += diff * diff; + }); + assert_eq!(sampled_output, dist_vec[0]); + + // populate_chunk_inner_products_test + let dist_vec = fixed_chunk_pq_table.populate_chunk_inner_products(&query_vec); + assert_eq!(dist_vec.len(), 256); + + let mut sampled_output = 0.0; + (0..DIM).for_each(|dim_offset| { + sampled_output -= fixed_chunk_pq_table.pq_table[dim_offset] * query_vec[dim_offset]; + }); + assert_eq!(sampled_output, dist_vec[0]); + + // l2_distance_test + let base_vec: Vec = vec![3u8]; + let dist = fixed_chunk_pq_table.l2_distance(&query_vec, &base_vec); + let mut l2_output = 0.0; + (0..DIM).for_each(|dim_offset| { + let diff = fixed_chunk_pq_table.pq_table[3 * DIM + dim_offset] - query_vec[dim_offset]; + l2_output += diff * diff; + }); + assert_eq!(l2_output, dist); + + // inner_product_test + let dist = fixed_chunk_pq_table.inner_product(&query_vec, &base_vec); + let mut l2_output = 0.0; + (0..DIM).for_each(|dim_offset| { + l2_output -= + fixed_chunk_pq_table.pq_table[3 * DIM + dim_offset] * query_vec[dim_offset]; + }); + assert_eq!(l2_output, dist); + + // inflate_vector_test + let inflate_vector = fixed_chunk_pq_table.inflate_vector(&base_vec).unwrap(); + assert_eq!(inflate_vector.len(), DIM); + assert_eq!( + inflate_vector[0], + fixed_chunk_pq_table.pq_table[3 * DIM] + fixed_chunk_pq_table.centroids[0] + ); + assert_eq!( + inflate_vector[1], + fixed_chunk_pq_table.pq_table[3 * DIM + 1] + fixed_chunk_pq_table.centroids[1] + ); + assert_eq!( + inflate_vector[127], + fixed_chunk_pq_table.pq_table[3 * DIM + 127] + fixed_chunk_pq_table.centroids[127] + ); + } + + fn load_pq_pivots_bin( + pq_pivots_path: &str, + num_pq_chunks: &usize, + ) -> ANNResult<(usize, Vec, Vec, Vec)> { + if !file_exists(pq_pivots_path) { + return Err(ANNError::log_pq_error( + "ERROR: PQ k-means pivot file not found.".to_string(), + )); + } + + let (data, offset_num, offset_dim) = load_bin::(pq_pivots_path, 0)?; + let file_offset_data = convert_types_u64_usize(&data, offset_num, offset_dim); + if offset_num != 4 { + let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", pq_pivots_path, offset_num); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, pq_center_num, dim) = load_bin::(pq_pivots_path, file_offset_data[0])?; + let pq_table = data.to_vec(); + if pq_center_num != NUM_PQ_CENTROIDS { + let error_message = format!( + "Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.", + pq_pivots_path, pq_center_num, NUM_PQ_CENTROIDS + ); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, centroid_dim, nc) = load_bin::(pq_pivots_path, file_offset_data[1])?; + let centroids = data.to_vec(); + if centroid_dim != dim || nc != 1 { + let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", pq_pivots_path, centroid_dim, nc, dim); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, chunk_offset_num, nc) = load_bin::(pq_pivots_path, file_offset_data[2])?; + let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_num, nc); + if chunk_offset_num != num_pq_chunks + 1 || nc != 1 { + let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_num, nc, num_pq_chunks + 1); + return Err(ANNError::log_pq_error(error_message)); + } + + Ok((dim, pq_table, centroids, chunk_offsets)) + } +} + +#[cfg(test)] +mod pq_index_prune_query_test { + + use super::*; + + #[test] + fn pq_dist_lookup_test() { + let pq_ids: Vec = vec![1u8, 3u8, 2u8, 2u8]; + let mut pq_dists: Vec = Vec::with_capacity(256 * 2); + for _ in 0..pq_dists.capacity() { + pq_dists.push(rand::random()); + } + + let dists_out = pq_dist_lookup(&pq_ids, 2, 2, &pq_dists); + assert_eq!(dists_out.len(), 2); + assert_eq!(dists_out[0], pq_dists[0 + 1] + pq_dists[256 + 3]); + assert_eq!(dists_out[1], pq_dists[0 + 2] + pq_dists[256 + 2]); + } +} diff --git a/rust/diskann/src/model/pq/mod.rs b/rust/diskann/src/model/pq/mod.rs new file mode 100644 index 000000000..85daaa7c6 --- /dev/null +++ b/rust/diskann/src/model/pq/mod.rs @@ -0,0 +1,9 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod fixed_chunk_pq_table; +pub use fixed_chunk_pq_table::*; + +mod pq_construction; +pub use pq_construction::*; diff --git a/rust/diskann/src/model/pq/pq_construction.rs b/rust/diskann/src/model/pq/pq_construction.rs new file mode 100644 index 000000000..0a7b0784e --- /dev/null +++ b/rust/diskann/src/model/pq/pq_construction.rs @@ -0,0 +1,398 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations)] + +use rayon::prelude::{IndexedParallelIterator, ParallelIterator}; +use rayon::slice::ParallelSliceMut; + +use crate::common::{ANNError, ANNResult}; +use crate::storage::PQStorage; +use crate::utils::{compute_closest_centers, file_exists, k_means_clustering}; + +/// Max size of PQ training set +pub const MAX_PQ_TRAINING_SET_SIZE: f64 = 256_000f64; + +/// Max number of PQ chunks +pub const MAX_PQ_CHUNKS: usize = 512; + +pub const NUM_PQ_CENTROIDS: usize = 256; +/// block size for reading/processing large files and matrices in blocks +const BLOCK_SIZE: usize = 5000000; +const NUM_KMEANS_REPS_PQ: usize = 12; + +/// given training data in train_data of dimensions num_train * dim, generate +/// PQ pivots using k-means algorithm to partition the co-ordinates into +/// num_pq_chunks (if it divides dimension, else rounded) chunks, and runs +/// k-means in each chunk to compute the PQ pivots and stores in bin format in +/// file pq_pivots_path as a s num_centers*dim floating point binary file +/// PQ pivot table layout: {pivot offsets data: METADATA_SIZE}{pivot vector:[dim; num_centroid]}{centroid vector:[dim; 1]}{chunk offsets:[chunk_num+1; 1]} +fn generate_pq_pivots( + train_data: &mut [f32], + num_train: usize, + dim: usize, + num_centers: usize, + num_pq_chunks: usize, + max_k_means_reps: usize, + pq_storage: &mut PQStorage, +) -> ANNResult<()> { + if num_pq_chunks > dim { + return Err(ANNError::log_pq_error( + "Error: number of chunks more than dimension.".to_string(), + )); + } + + if pq_storage.pivot_data_exist() { + let (file_num_centers, file_dim) = pq_storage.read_pivot_metadata()?; + if file_dim == dim && file_num_centers == num_centers { + // PQ pivot file exists. Not generating again. + return Ok(()); + } + } + + // Calculate centroid and center the training data + // If we use L2 distance, there is an option to + // translate all vectors to make them centered and + // then compute PQ. This needs to be set to false + // when using PQ for MIPS as such translations dont + // preserve inner products. + // Now, we're using L2 as default. + let mut centroid: Vec = vec![0.0; dim]; + for dim_index in 0..dim { + for train_data_index in 0..num_train { + centroid[dim_index] += train_data[train_data_index * dim + dim_index]; + } + centroid[dim_index] /= num_train as f32; + } + for dim_index in 0..dim { + for train_data_index in 0..num_train { + train_data[train_data_index * dim + dim_index] -= centroid[dim_index]; + } + } + + // Calculate each chunk's offset + // If we have 8 dimension and 3 chunk then offsets would be [0,3,6,8] + let mut chunk_offsets: Vec = vec![0; num_pq_chunks + 1]; + let mut chunk_offset: usize = 0; + for chunk_index in 0..num_pq_chunks { + chunk_offset += dim / num_pq_chunks; + if chunk_index < (dim % num_pq_chunks) { + chunk_offset += 1; + } + chunk_offsets[chunk_index + 1] = chunk_offset; + } + + let mut full_pivot_data: Vec = vec![0.0; num_centers * dim]; + for chunk_index in 0..num_pq_chunks { + let chunk_size = chunk_offsets[chunk_index + 1] - chunk_offsets[chunk_index]; + + let mut cur_train_data: Vec = vec![0.0; num_train * chunk_size]; + let mut cur_pivot_data: Vec = vec![0.0; num_centers * chunk_size]; + + cur_train_data + .par_chunks_mut(chunk_size) + .enumerate() + .for_each(|(train_data_index, chunk)| { + for (dim_offset, item) in chunk.iter_mut().enumerate() { + *item = train_data + [train_data_index * dim + chunk_offsets[chunk_index] + dim_offset]; + } + }); + + // Run kmeans to get the centroids of this chunk. + let (_closest_docs, _closest_center, _residual) = k_means_clustering( + &cur_train_data, + num_train, + chunk_size, + &mut cur_pivot_data, + num_centers, + max_k_means_reps, + )?; + + // Copy centroids from this chunk table to full table + for center_index in 0..num_centers { + full_pivot_data[center_index * dim + chunk_offsets[chunk_index] + ..center_index * dim + chunk_offsets[chunk_index + 1]] + .copy_from_slice( + &cur_pivot_data[center_index * chunk_size..(center_index + 1) * chunk_size], + ); + } + } + + pq_storage.write_pivot_data( + &full_pivot_data, + ¢roid, + &chunk_offsets, + num_centers, + dim, + )?; + + Ok(()) +} + +/// streams the base file (data_file), and computes the closest centers in each +/// chunk to generate the compressed data_file and stores it in +/// pq_compressed_vectors_path. +/// If the numbber of centers is < 256, it stores as byte vector, else as +/// 4-byte vector in binary format. +/// Compressed PQ table layout: {num_points: usize}{num_chunks: usize}{compressed pq table: [num_points; num_chunks]} +fn generate_pq_data_from_pivots>( + num_centers: usize, + num_pq_chunks: usize, + pq_storage: &mut PQStorage, +) -> ANNResult<()> { + let (num_points, dim) = pq_storage.read_pq_data_metadata()?; + + let full_pivot_data: Vec; + let centroid: Vec; + let chunk_offsets: Vec; + + if !pq_storage.pivot_data_exist() { + return Err(ANNError::log_pq_error( + "ERROR: PQ k-means pivot file not found.".to_string(), + )); + } else { + (full_pivot_data, centroid, chunk_offsets) = + pq_storage.load_pivot_data(&num_pq_chunks, &num_centers, &dim)?; + } + + pq_storage.write_compressed_pivot_metadata(num_points as i32, num_pq_chunks as i32)?; + + let block_size = if num_points <= BLOCK_SIZE { + num_points + } else { + BLOCK_SIZE + }; + let num_blocks = (num_points / block_size) + (num_points % block_size != 0) as usize; + + for block_index in 0..num_blocks { + let start_index: usize = block_index * block_size; + let end_index: usize = std::cmp::min((block_index + 1) * block_size, num_points); + let cur_block_size: usize = end_index - start_index; + + let mut block_compressed_base: Vec = vec![0; cur_block_size * num_pq_chunks]; + + let block_data: Vec = pq_storage.read_pq_block_data(cur_block_size, dim)?; + + let mut adjusted_block_data: Vec = vec![0.0; cur_block_size * dim]; + + for block_data_index in 0..cur_block_size { + for dim_index in 0..dim { + adjusted_block_data[block_data_index * dim + dim_index] = + block_data[block_data_index * dim + dim_index].into() - centroid[dim_index]; + } + } + + for chunk_index in 0..num_pq_chunks { + let cur_chunk_size = chunk_offsets[chunk_index + 1] - chunk_offsets[chunk_index]; + if cur_chunk_size == 0 { + continue; + } + + let mut cur_pivot_data: Vec = vec![0.0; num_centers * cur_chunk_size]; + let mut cur_data: Vec = vec![0.0; cur_block_size * cur_chunk_size]; + let mut closest_center: Vec = vec![0; cur_block_size]; + + // Divide the data into chunks and process each chunk in parallel. + cur_data + .par_chunks_mut(cur_chunk_size) + .enumerate() + .for_each(|(block_data_index, chunk)| { + for (dim_offset, item) in chunk.iter_mut().enumerate() { + *item = adjusted_block_data + [block_data_index * dim + chunk_offsets[chunk_index] + dim_offset]; + } + }); + + cur_pivot_data + .par_chunks_mut(cur_chunk_size) + .enumerate() + .for_each(|(center_index, chunk)| { + for (din_offset, item) in chunk.iter_mut().enumerate() { + *item = full_pivot_data + [center_index * dim + chunk_offsets[chunk_index] + din_offset]; + } + }); + + // Compute the closet centers + compute_closest_centers( + &cur_data, + cur_block_size, + cur_chunk_size, + &cur_pivot_data, + num_centers, + 1, + &mut closest_center, + None, + None, + )?; + + block_compressed_base + .par_chunks_mut(num_pq_chunks) + .enumerate() + .for_each(|(block_data_index, slice)| { + slice[chunk_index] = closest_center[block_data_index] as usize; + }); + } + + _ = pq_storage.write_compressed_pivot_data( + &block_compressed_base, + num_centers, + cur_block_size, + num_pq_chunks, + ); + } + Ok(()) +} + +/// Save the data on a file. +/// # Arguments +/// * `p_val` - choose how many ratio sample data as trained data to get pivot +/// * `num_pq_chunks` - pq chunk number +/// * `codebook_prefix` - predefined pivots file named +/// * `pq_storage` - pq file access +pub fn generate_quantized_data>( + p_val: f64, + num_pq_chunks: usize, + codebook_prefix: &str, + pq_storage: &mut PQStorage, +) -> ANNResult<()> { + // If predefined pivots already exists, skip training. + if !file_exists(codebook_prefix) { + // Instantiates train data with random sample updates train_data_vector + // Training data with train_size samples loaded. + // Each sampled file has train_dim. + let (mut train_data_vector, train_size, train_dim) = + pq_storage.gen_random_slice::(p_val)?; + + generate_pq_pivots( + &mut train_data_vector, + train_size, + train_dim, + NUM_PQ_CENTROIDS, + num_pq_chunks, + NUM_KMEANS_REPS_PQ, + pq_storage, + )?; + } + generate_pq_data_from_pivots::(NUM_PQ_CENTROIDS, num_pq_chunks, pq_storage)?; + Ok(()) +} + +#[cfg(test)] +mod pq_test { + + use std::fs::File; + use std::io::Write; + + use super::*; + use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, load_bin, METADATA_SIZE}; + + #[test] + fn generate_pq_pivots_test() { + let pivot_file_name = "generate_pq_pivots_test.bin"; + let compressed_file_name = "compressed.bin"; + let pq_training_file_name = "tests/data/siftsmall_learn.bin"; + let mut pq_storage = + PQStorage::new(pivot_file_name, compressed_file_name, pq_training_file_name).unwrap(); + let mut train_data: Vec = vec![ + 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, + 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, + 2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, + 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, + ]; + generate_pq_pivots(&mut train_data, 5, 8, 2, 2, 5, &mut pq_storage).unwrap(); + + let (data, nr, nc) = load_bin::(pivot_file_name, 0).unwrap(); + let file_offset_data = convert_types_u64_usize(&data, nr, nc); + assert_eq!(file_offset_data[0], METADATA_SIZE); + assert_eq!(nr, 4); + assert_eq!(nc, 1); + + let (data, nr, nc) = load_bin::(pivot_file_name, file_offset_data[0]).unwrap(); + let full_pivot_data = data.to_vec(); + assert_eq!(full_pivot_data.len(), 16); + assert_eq!(nr, 2); + assert_eq!(nc, 8); + + let (data, nr, nc) = load_bin::(pivot_file_name, file_offset_data[1]).unwrap(); + let centroid = data.to_vec(); + assert_eq!( + centroid[0], + (1.0f32 + 2.0f32 + 2.1f32 + 2.2f32 + 100.0f32) / 5.0f32 + ); + assert_eq!(nr, 8); + assert_eq!(nc, 1); + + let (data, nr, nc) = load_bin::(pivot_file_name, file_offset_data[2]).unwrap(); + let chunk_offsets = convert_types_u32_usize(&data, nr, nc); + assert_eq!(chunk_offsets[0], 0); + assert_eq!(chunk_offsets[1], 4); + assert_eq!(chunk_offsets[2], 8); + assert_eq!(nr, 3); + assert_eq!(nc, 1); + std::fs::remove_file(pivot_file_name).unwrap(); + } + + #[test] + fn generate_pq_data_from_pivots_test() { + let data_file = "generate_pq_data_from_pivots_test_data.bin"; + //npoints=5, dim=8, 5 vectors [1.0;8] [2.0;8] [2.1;8] [2.2;8] [100.0;8] + let mut train_data: Vec = vec![ + 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, + 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, 2.1f32, + 2.1f32, 2.1f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, 2.2f32, + 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, 100.0f32, + ]; + let my_nums_unstructured: &[u8] = unsafe { + std::slice::from_raw_parts(train_data.as_ptr() as *const u8, train_data.len() * 4) + }; + let meta: Vec = vec![5, 8]; + let meta_unstructured: &[u8] = + unsafe { std::slice::from_raw_parts(meta.as_ptr() as *const u8, meta.len() * 4) }; + let mut data_file_writer = File::create(data_file).unwrap(); + data_file_writer + .write_all(meta_unstructured) + .expect("Failed to write sample file"); + data_file_writer + .write_all(my_nums_unstructured) + .expect("Failed to write sample file"); + + let pq_pivots_path = "generate_pq_data_from_pivots_test_pivot.bin"; + let pq_compressed_vectors_path = "generate_pq_data_from_pivots_test.bin"; + let mut pq_storage = + PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, data_file).unwrap(); + generate_pq_pivots(&mut train_data, 5, 8, 2, 2, 5, &mut pq_storage).unwrap(); + generate_pq_data_from_pivots::(2, 2, &mut pq_storage).unwrap(); + let (data, nr, nc) = load_bin::(pq_compressed_vectors_path, 0).unwrap(); + assert_eq!(nr, 5); + assert_eq!(nc, 2); + assert_eq!(data[0], data[2]); + assert_ne!(data[0], data[8]); + + std::fs::remove_file(data_file).unwrap(); + std::fs::remove_file(pq_pivots_path).unwrap(); + std::fs::remove_file(pq_compressed_vectors_path).unwrap(); + } + + #[test] + fn pq_end_to_end_validation_with_codebook_test() { + let data_file = "tests/data/siftsmall_learn.bin"; + let pq_pivots_path = "tests/data/siftsmall_learn.bin_pq_pivots.bin"; + let gound_truth_path = "tests/data/siftsmall_learn.bin_pq_compressed.bin"; + let pq_compressed_vectors_path = "validation.bin"; + let mut pq_storage = + PQStorage::new(pq_pivots_path, pq_compressed_vectors_path, data_file).unwrap(); + generate_quantized_data::(0.5, 1, pq_pivots_path, &mut pq_storage).unwrap(); + + let (data, nr, nc) = load_bin::(pq_compressed_vectors_path, 0).unwrap(); + let (gt_data, gt_nr, gt_nc) = load_bin::(gound_truth_path, 0).unwrap(); + assert_eq!(nr, gt_nr); + assert_eq!(nc, gt_nc); + for i in 0..data.len() { + assert_eq!(data[i], gt_data[i]); + } + std::fs::remove_file(pq_compressed_vectors_path).unwrap(); + } +} diff --git a/rust/diskann/src/model/scratch/concurrent_queue.rs b/rust/diskann/src/model/scratch/concurrent_queue.rs new file mode 100644 index 000000000..8c72bab02 --- /dev/null +++ b/rust/diskann/src/model/scratch/concurrent_queue.rs @@ -0,0 +1,312 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Aligned allocator + +use std::collections::VecDeque; +use std::ops::Deref; +use std::sync::{Arc, Condvar, Mutex, MutexGuard}; +use std::time::Duration; + +use crate::common::{ANNError, ANNResult}; + +#[derive(Debug)] +/// Query scratch data structures +pub struct ConcurrentQueue { + q: Mutex>, + c: Mutex, + push_cv: Condvar, +} + +impl Default for ConcurrentQueue { + fn default() -> Self { + Self::new() + } +} + +impl ConcurrentQueue { + /// Create a concurrent queue + pub fn new() -> Self { + Self { + q: Mutex::new(VecDeque::new()), + c: Mutex::new(false), + push_cv: Condvar::new(), + } + } + + /// Block the current thread until it is able to acquire the mutex + pub fn reserve(&self, size: usize) -> ANNResult<()> { + let mut guard = lock(&self.q)?; + guard.reserve(size); + Ok(()) + } + + /// queue stats + pub fn size(&self) -> ANNResult { + let guard = lock(&self.q)?; + + Ok(guard.len()) + } + + /// empty the queue + pub fn is_empty(&self) -> ANNResult { + Ok(self.size()? == 0) + } + + /// push back + pub fn push(&self, new_val: T) -> ANNResult<()> { + let mut guard = lock(&self.q)?; + self.push_internal(&mut guard, new_val); + self.push_cv.notify_all(); + Ok(()) + } + + /// push back + fn push_internal(&self, guard: &mut MutexGuard>, new_val: T) { + guard.push_back(new_val); + } + + /// insert into queue + pub fn insert(&self, iter: I) -> ANNResult<()> + where + I: IntoIterator, + { + let mut guard = lock(&self.q)?; + for item in iter { + self.push_internal(&mut guard, item); + } + + self.push_cv.notify_all(); + Ok(()) + } + + /// pop front + pub fn pop(&self) -> ANNResult> { + let mut guard = lock(&self.q)?; + Ok(guard.pop_front()) + } + + /// Empty - is this necessary? + pub fn empty_queue(&self) -> ANNResult<()> { + let mut guard = lock(&self.q)?; + while !guard.is_empty() { + let _ = guard.pop_front(); + } + Ok(()) + } + + /// register for push notifications + pub fn wait_for_push_notify(&self, wait_time: Duration) -> ANNResult<()> { + let guard_lock = lock(&self.c)?; + let _ = self + .push_cv + .wait_timeout(guard_lock, wait_time) + .map_err(|err| { + ANNError::log_lock_poison_error(format!( + "ConcurrentQueue Lock is poisoned, err={}", + err + )) + })?; + Ok(()) + } +} + +fn lock(mutex: &Mutex) -> ANNResult> { + let guard = mutex.lock().map_err(|err| { + ANNError::log_lock_poison_error(format!("ConcurrentQueue lock is poisoned, err={}", err)) + })?; + Ok(guard) +} + +/// A thread-safe queue that holds instances of `T`. +/// Each instance is stored in a `Box` to keep the size of the queue node constant. +#[derive(Debug)] +pub struct ArcConcurrentBoxedQueue { + internal_queue: Arc>>, +} + +impl ArcConcurrentBoxedQueue { + /// Create a new `ArcConcurrentBoxedQueue`. + pub fn new() -> Self { + Self { + internal_queue: Arc::new(ConcurrentQueue::new()), + } + } +} + +impl Default for ArcConcurrentBoxedQueue { + fn default() -> Self { + Self::new() + } +} + +impl Clone for ArcConcurrentBoxedQueue { + /// Create a new `ArcConcurrentBoxedQueue` that shares the same internal queue + /// with the existing one. This allows multiple `ArcConcurrentBoxedQueue` to + /// operate on the same underlying queue. + fn clone(&self) -> Self { + Self { + internal_queue: Arc::clone(&self.internal_queue), + } + } +} + +/// Deref to the ConcurrentQueue. +impl Deref for ArcConcurrentBoxedQueue { + type Target = ConcurrentQueue>; + + fn deref(&self) -> &Self::Target { + &self.internal_queue + } +} + +#[cfg(test)] +mod tests { + use crate::model::ConcurrentQueue; + use std::sync::Arc; + use std::thread; + use std::time::Duration; + + #[test] + fn test_push_pop() { + let queue = ConcurrentQueue::::new(); + + queue.push(1).unwrap(); + queue.push(2).unwrap(); + queue.push(3).unwrap(); + + assert_eq!(queue.pop().unwrap(), Some(1)); + assert_eq!(queue.pop().unwrap(), Some(2)); + assert_eq!(queue.pop().unwrap(), Some(3)); + assert_eq!(queue.pop().unwrap(), None); + } + + #[test] + fn test_size_empty() { + let queue = ConcurrentQueue::new(); + + assert_eq!(queue.size().unwrap(), 0); + assert!(queue.is_empty().unwrap()); + + queue.push(1).unwrap(); + queue.push(2).unwrap(); + + assert_eq!(queue.size().unwrap(), 2); + assert!(!queue.is_empty().unwrap()); + + queue.pop().unwrap(); + queue.pop().unwrap(); + + assert_eq!(queue.size().unwrap(), 0); + assert!(queue.is_empty().unwrap()); + } + + #[test] + fn test_insert() { + let queue = ConcurrentQueue::new(); + + let data = vec![1, 2, 3]; + queue.insert(data.into_iter()).unwrap(); + + assert_eq!(queue.pop().unwrap(), Some(1)); + assert_eq!(queue.pop().unwrap(), Some(2)); + assert_eq!(queue.pop().unwrap(), Some(3)); + assert_eq!(queue.pop().unwrap(), None); + } + + #[test] + fn test_notifications() { + let queue = Arc::new(ConcurrentQueue::new()); + let queue_clone = Arc::clone(&queue); + + let producer = thread::spawn(move || { + for i in 0..3 { + thread::sleep(Duration::from_millis(50)); + queue_clone.push(i).unwrap(); + } + }); + + let consumer = thread::spawn(move || { + let mut values = vec![]; + + for _ in 0..3 { + let mut val = -1; + while val == -1 { + queue + .wait_for_push_notify(Duration::from_millis(10)) + .unwrap(); + val = queue.pop().unwrap().unwrap_or(-1); + } + + values.push(val); + } + + values + }); + + producer.join().unwrap(); + let consumer_results = consumer.join().unwrap(); + + assert_eq!(consumer_results, vec![0, 1, 2]); + } + + #[test] + fn test_multithreaded_push_pop() { + let queue = Arc::new(ConcurrentQueue::new()); + let queue_clone = Arc::clone(&queue); + + let producer = thread::spawn(move || { + for i in 0..10 { + queue_clone.push(i).unwrap(); + thread::sleep(Duration::from_millis(50)); + } + }); + + let consumer = thread::spawn(move || { + let mut values = vec![]; + + for _ in 0..10 { + let mut val = -1; + while val == -1 { + val = queue.pop().unwrap().unwrap_or(-1); + thread::sleep(Duration::from_millis(10)); + } + + values.push(val); + } + + values + }); + + producer.join().unwrap(); + let consumer_results = consumer.join().unwrap(); + + assert_eq!(consumer_results, (0..10).collect::>()); + } + + /// This is a single value test. It avoids the unlimited wait until the collectin got empty on the previous test. + /// It will make sure the signal mutex is matching the waiting mutex. + #[test] + fn test_wait_for_push_notify() { + let queue = Arc::new(ConcurrentQueue::::new()); + let queue_clone = Arc::clone(&queue); + + let producer = thread::spawn(move || { + thread::sleep(Duration::from_millis(100)); + queue_clone.push(1).unwrap(); + }); + + let consumer = thread::spawn(move || { + queue + .wait_for_push_notify(Duration::from_millis(200)) + .unwrap(); + assert_eq!(queue.pop().unwrap(), Some(1)); + }); + + producer.join().unwrap(); + consumer.join().unwrap(); + } +} diff --git a/rust/diskann/src/model/scratch/inmem_query_scratch.rs b/rust/diskann/src/model/scratch/inmem_query_scratch.rs new file mode 100644 index 000000000..f0fa432c2 --- /dev/null +++ b/rust/diskann/src/model/scratch/inmem_query_scratch.rs @@ -0,0 +1,186 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Scratch space for in-memory index based search + +use std::cmp::max; +use std::mem; + +use hashbrown::HashSet; + +use crate::common::{ANNError, ANNResult, AlignedBoxWithSlice}; +use crate::model::configuration::index_write_parameters::IndexWriteParameters; +use crate::model::{Neighbor, NeighborPriorityQueue, PQScratch}; + +use super::Scratch; + +/// In-mem index related limits +pub const GRAPH_SLACK_FACTOR: f64 = 1.3_f64; + +/// Max number of points for using bitset +pub const MAX_POINTS_FOR_USING_BITSET: usize = 100000; + +/// TODO: SSD Index related limits +pub const MAX_GRAPH_DEGREE: usize = 512; + +/// TODO: SSD Index related limits +pub const MAX_N_CMPS: usize = 16384; + +/// TODO: SSD Index related limits +pub const SECTOR_LEN: usize = 4096; + +/// TODO: SSD Index related limits +pub const MAX_N_SECTOR_READS: usize = 128; + +/// The alignment required for memory access. This will be multiplied with size of T to get the actual alignment +pub const QUERY_ALIGNMENT_OF_T_SIZE: usize = 16; + +/// Scratch space for in-memory index based search +#[derive(Debug)] +pub struct InMemQueryScratch { + /// Size of the candidate queue + pub candidate_size: u32, + + /// Max degree for each vertex + pub max_degree: u32, + + /// Max occlusion size + pub max_occlusion_size: u32, + + /// Query node + pub query: AlignedBoxWithSlice, + + /// Best candidates, whose size is candidate_queue_size + pub best_candidates: NeighborPriorityQueue, + + /// Occlude factor + pub occlude_factor: Vec, + + /// Visited neighbor id + pub id_scratch: Vec, + + /// The distance between visited neighbor and query node + pub dist_scratch: Vec, + + /// The PQ Scratch, keey it private since this class use the Box to own the memory. Use the function pq_scratch to get its reference + pub pq_scratch: Option>, + + /// Buffers used in process delete, capacity increases as needed + pub expanded_nodes_set: HashSet, + + /// Expanded neighbors + pub expanded_neighbors_vector: Vec, + + /// Occlude list + pub occlude_list_output: Vec, + + /// RobinSet for larger dataset + pub node_visited_robinset: HashSet, +} + +impl InMemQueryScratch { + /// Create InMemQueryScratch instance + pub fn new( + search_candidate_size: u32, + index_write_parameter: &IndexWriteParameters, + init_pq_scratch: bool, + ) -> ANNResult { + let indexing_candidate_size = index_write_parameter.search_list_size; + let max_degree = index_write_parameter.max_degree; + let max_occlusion_size = index_write_parameter.max_occlusion_size; + + if search_candidate_size == 0 || indexing_candidate_size == 0 || max_degree == 0 || N == 0 { + return Err(ANNError::log_index_error(format!( + "In InMemQueryScratch, one of search_candidate_size = {}, indexing_candidate_size = {}, dim = {} or max_degree = {} is zero.", + search_candidate_size, indexing_candidate_size, N, max_degree))); + } + + let query = AlignedBoxWithSlice::new(N, mem::size_of::() * QUERY_ALIGNMENT_OF_T_SIZE)?; + let pq_scratch = if init_pq_scratch { + Some(Box::new(PQScratch::new(MAX_GRAPH_DEGREE, N)?)) + } else { + None + }; + + let occlude_factor = Vec::with_capacity(max_occlusion_size as usize); + + let capacity = (1.5 * GRAPH_SLACK_FACTOR * (max_degree as f64)).ceil() as usize; + let id_scratch = Vec::with_capacity(capacity); + let dist_scratch = Vec::with_capacity(capacity); + + let expanded_nodes_set = HashSet::::new(); + let expanded_neighbors_vector = Vec::::new(); + let occlude_list_output = Vec::::new(); + + let candidate_size = max(search_candidate_size, indexing_candidate_size); + let node_visited_robinset = HashSet::::with_capacity(20 * candidate_size as usize); + let scratch = Self { + candidate_size, + max_degree, + max_occlusion_size, + query, + best_candidates: NeighborPriorityQueue::with_capacity(candidate_size as usize), + occlude_factor, + id_scratch, + dist_scratch, + pq_scratch, + expanded_nodes_set, + expanded_neighbors_vector, + occlude_list_output, + node_visited_robinset, + }; + + Ok(scratch) + } + + /// Resize the scratch with new candidate size + pub fn resize_for_new_candidate_size(&mut self, new_candidate_size: u32) { + if new_candidate_size > self.candidate_size { + let delta = new_candidate_size - self.candidate_size; + self.candidate_size = new_candidate_size; + self.best_candidates.reserve(delta as usize); + self.node_visited_robinset.reserve((20 * delta) as usize); + } + } +} + +impl Scratch for InMemQueryScratch { + fn clear(&mut self) { + self.best_candidates.clear(); + self.occlude_factor.clear(); + + self.node_visited_robinset.clear(); + + self.id_scratch.clear(); + self.dist_scratch.clear(); + + self.expanded_nodes_set.clear(); + self.expanded_neighbors_vector.clear(); + self.occlude_list_output.clear(); + } +} + +#[cfg(test)] +mod inmemory_query_scratch_test { + use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder; + + use super::*; + + #[test] + fn node_visited_robinset_test() { + let index_write_parameter = IndexWriteParametersBuilder::new(10, 10) + .with_max_occlusion_size(5) + .build(); + + let mut scratch = + InMemQueryScratch::::new(100, &index_write_parameter, false).unwrap(); + + assert_eq!(scratch.node_visited_robinset.len(), 0); + + scratch.clear(); + assert_eq!(scratch.node_visited_robinset.len(), 0); + } +} diff --git a/rust/diskann/src/model/scratch/mod.rs b/rust/diskann/src/model/scratch/mod.rs new file mode 100644 index 000000000..cf9ee2900 --- /dev/null +++ b/rust/diskann/src/model/scratch/mod.rs @@ -0,0 +1,28 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod scratch_traits; +pub use scratch_traits::*; + +pub mod concurrent_queue; +pub use concurrent_queue::*; + +pub mod pq_scratch; +pub use pq_scratch::*; + + +pub mod inmem_query_scratch; +pub use inmem_query_scratch::*; + +pub mod scratch_store_manager; +pub use scratch_store_manager::*; + +pub mod ssd_query_scratch; +pub use ssd_query_scratch::*; + +pub mod ssd_thread_data; +pub use ssd_thread_data::*; + +pub mod ssd_io_context; +pub use ssd_io_context::*; diff --git a/rust/diskann/src/model/scratch/pq_scratch.rs b/rust/diskann/src/model/scratch/pq_scratch.rs new file mode 100644 index 000000000..bf9d6c547 --- /dev/null +++ b/rust/diskann/src/model/scratch/pq_scratch.rs @@ -0,0 +1,105 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Aligned allocator + +use std::mem::size_of; + +use crate::common::{ANNResult, AlignedBoxWithSlice}; + +const MAX_PQ_CHUNKS: usize = 512; + +#[derive(Debug)] +/// PQ scratch +pub struct PQScratch { + /// Aligned pq table dist scratch, must be at least [256 * NCHUNKS] + pub aligned_pqtable_dist_scratch: AlignedBoxWithSlice, + /// Aligned dist scratch, must be at least diskann MAX_DEGREE + pub aligned_dist_scratch: AlignedBoxWithSlice, + /// Aligned pq coord scratch, must be at least [N_CHUNKS * MAX_DEGREE] + pub aligned_pq_coord_scratch: AlignedBoxWithSlice, + /// Rotated query + pub rotated_query: AlignedBoxWithSlice, + /// Aligned query float + pub aligned_query_float: AlignedBoxWithSlice, +} + +impl PQScratch { + const ALIGNED_ALLOC_256: usize = 256; + + /// Create a new pq scratch + pub fn new(graph_degree: usize, aligned_dim: usize) -> ANNResult { + let aligned_pq_coord_scratch = + AlignedBoxWithSlice::new(graph_degree * MAX_PQ_CHUNKS, PQScratch::ALIGNED_ALLOC_256)?; + let aligned_pqtable_dist_scratch = + AlignedBoxWithSlice::new(256 * MAX_PQ_CHUNKS, PQScratch::ALIGNED_ALLOC_256)?; + let aligned_dist_scratch = + AlignedBoxWithSlice::new(graph_degree, PQScratch::ALIGNED_ALLOC_256)?; + let aligned_query_float = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::())?; + let rotated_query = AlignedBoxWithSlice::new(aligned_dim, 8 * size_of::())?; + + Ok(Self { + aligned_pqtable_dist_scratch, + aligned_dist_scratch, + aligned_pq_coord_scratch, + rotated_query, + aligned_query_float, + }) + } + + /// Set rotated_query and aligned_query_float values + pub fn set(&mut self, dim: usize, query: &[T], norm: f32) + where + T: Into + Copy, + { + for (d, item) in query.iter().enumerate().take(dim) { + let query_val: f32 = (*item).into(); + if (norm - 1.0).abs() > f32::EPSILON { + self.rotated_query[d] = query_val / norm; + self.aligned_query_float[d] = query_val / norm; + } else { + self.rotated_query[d] = query_val; + self.aligned_query_float[d] = query_val; + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::model::PQScratch; + + #[test] + fn test_pq_scratch() { + let graph_degree = 512; + let aligned_dim = 8; + + let mut pq_scratch: PQScratch = PQScratch::new(graph_degree, aligned_dim).unwrap(); + + // Check alignment + assert_eq!( + (pq_scratch.aligned_pqtable_dist_scratch.as_ptr() as usize) % 256, + 0 + ); + assert_eq!((pq_scratch.aligned_dist_scratch.as_ptr() as usize) % 256, 0); + assert_eq!( + (pq_scratch.aligned_pq_coord_scratch.as_ptr() as usize) % 256, + 0 + ); + assert_eq!((pq_scratch.rotated_query.as_ptr() as usize) % 32, 0); + assert_eq!((pq_scratch.aligned_query_float.as_ptr() as usize) % 32, 0); + + // Test set() method + let query = vec![1u8, 2, 3, 4, 5, 6, 7, 8]; + let norm = 2.0f32; + pq_scratch.set::(query.len(), &query, norm); + + (0..query.len()).for_each(|i| { + assert_eq!(pq_scratch.rotated_query[i], query[i] as f32 / norm); + assert_eq!(pq_scratch.aligned_query_float[i], query[i] as f32 / norm); + }); + } +} diff --git a/rust/diskann/src/model/scratch/scratch_store_manager.rs b/rust/diskann/src/model/scratch/scratch_store_manager.rs new file mode 100644 index 000000000..4e2397f49 --- /dev/null +++ b/rust/diskann/src/model/scratch/scratch_store_manager.rs @@ -0,0 +1,84 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::common::ANNResult; + +use super::ArcConcurrentBoxedQueue; +use super::{scratch_traits::Scratch}; +use std::time::Duration; + +pub struct ScratchStoreManager { + scratch: Option>, + scratch_pool: ArcConcurrentBoxedQueue, +} + +impl ScratchStoreManager { + pub fn new(scratch_pool: ArcConcurrentBoxedQueue, wait_time: Duration) -> ANNResult { + let mut scratch = scratch_pool.pop()?; + while scratch.is_none() { + scratch_pool.wait_for_push_notify(wait_time)?; + scratch = scratch_pool.pop()?; + } + + Ok(ScratchStoreManager { + scratch, + scratch_pool, + }) + } + + pub fn scratch_space(&mut self) -> Option<&mut T> { + self.scratch.as_deref_mut() + } +} + +impl Drop for ScratchStoreManager { + fn drop(&mut self) { + if let Some(mut scratch) = self.scratch.take() { + scratch.clear(); + let _ = self.scratch_pool.push(scratch); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(Debug)] + struct MyScratch { + data: Vec, + } + + impl Scratch for MyScratch { + fn clear(&mut self) { + self.data.clear(); + } + } + + #[test] + fn test_scratch_store_manager() { + let wait_time = Duration::from_millis(100); + + let scratch_pool = ArcConcurrentBoxedQueue::new(); + for i in 1..3 { + scratch_pool.push(Box::new(MyScratch { + data: vec![i, 2 * i, 3 * i], + })).unwrap(); + } + + let mut manager = ScratchStoreManager::new(scratch_pool.clone(), wait_time).unwrap(); + let scratch_space = manager.scratch_space().unwrap(); + + assert_eq!(scratch_space.data, vec![1, 2, 3]); + + // At this point, the ScratchStoreManager will go out of scope, + // causing the Drop implementation to be called, which should + // call the clear method on MyScratch. + drop(manager); + + let current_scratch = scratch_pool.pop().unwrap().unwrap(); + assert_eq!(current_scratch.data, vec![2, 4, 6]); + } +} + diff --git a/rust/diskann/src/model/scratch/scratch_traits.rs b/rust/diskann/src/model/scratch/scratch_traits.rs new file mode 100644 index 000000000..71e4b932d --- /dev/null +++ b/rust/diskann/src/model/scratch/scratch_traits.rs @@ -0,0 +1,8 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub trait Scratch { + fn clear(&mut self); +} + diff --git a/rust/diskann/src/model/scratch/ssd_io_context.rs b/rust/diskann/src/model/scratch/ssd_io_context.rs new file mode 100644 index 000000000..d4dff0cec --- /dev/null +++ b/rust/diskann/src/model/scratch/ssd_io_context.rs @@ -0,0 +1,38 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete. +use crate::common::ANNError; + +use platform::{FileHandle, IOCompletionPort}; + +// The IOContext struct for disk I/O. One for each thread. +pub struct IOContext { + pub status: Status, + pub file_handle: FileHandle, + pub io_completion_port: IOCompletionPort, +} + +impl Default for IOContext { + fn default() -> Self { + IOContext { + status: Status::ReadWait, + file_handle: FileHandle::default(), + io_completion_port: IOCompletionPort::default(), + } + } +} + +impl IOContext { + pub fn new() -> Self { + Self::default() + } +} + +pub enum Status { + ReadWait, + ReadSuccess, + ReadFailed(ANNError), + ProcessComplete, +} diff --git a/rust/diskann/src/model/scratch/ssd_query_scratch.rs b/rust/diskann/src/model/scratch/ssd_query_scratch.rs new file mode 100644 index 000000000..b36669303 --- /dev/null +++ b/rust/diskann/src/model/scratch/ssd_query_scratch.rs @@ -0,0 +1,132 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete. +use std::mem; +use std::vec::Vec; + +use hashbrown::HashSet; + +use crate::{ + common::{ANNResult, AlignedBoxWithSlice}, + model::{Neighbor, NeighborPriorityQueue}, + model::data_store::DiskScratchDataset, +}; + +use super::{PQScratch, Scratch, MAX_GRAPH_DEGREE, QUERY_ALIGNMENT_OF_T_SIZE}; + +// Scratch space for disk index based search. +pub struct SSDQueryScratch +{ + // Disk scratch dataset storing fp vectors with aligned dim (N) + pub scratch_dataset: DiskScratchDataset, + + // The query scratch. + pub query: AlignedBoxWithSlice, + + /// The PQ Scratch. + pub pq_scratch: Option>, + + // The visited set. + pub id_scratch: HashSet, + + /// Best candidates, whose size is candidate_queue_size + pub best_candidates: NeighborPriorityQueue, + + // Full return set. + pub full_return_set: Vec, +} + +// +impl SSDQueryScratch +{ + pub fn new( + visited_reserve: usize, + candidate_queue_size: usize, + init_pq_scratch: bool, + ) -> ANNResult { + let scratch_dataset = DiskScratchDataset::::new()?; + + let query = AlignedBoxWithSlice::::new(N, mem::size_of::() * QUERY_ALIGNMENT_OF_T_SIZE)?; + + let id_scratch = HashSet::::with_capacity(visited_reserve); + let full_return_set = Vec::::with_capacity(visited_reserve); + let best_candidates = NeighborPriorityQueue::with_capacity(candidate_queue_size); + + let pq_scratch = if init_pq_scratch { + Some(Box::new(PQScratch::new(MAX_GRAPH_DEGREE, N)?)) + } else { + None + }; + + Ok(Self { + scratch_dataset, + query, + pq_scratch, + id_scratch, + best_candidates, + full_return_set, + }) + } + + pub fn pq_scratch(&mut self) -> &Option> { + &self.pq_scratch + } +} + +impl Scratch for SSDQueryScratch +{ + fn clear(&mut self) { + self.id_scratch.clear(); + self.best_candidates.clear(); + self.full_return_set.clear(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_new() { + // Arrange + let visited_reserve = 100; + let candidate_queue_size = 10; + let init_pq_scratch = true; + + // Act + let result = + SSDQueryScratch::::new(visited_reserve, candidate_queue_size, init_pq_scratch); + + // Assert + assert!(result.is_ok()); + + let scratch = result.unwrap(); + + // Assert the properties of the scratch instance + assert!(scratch.pq_scratch.is_some()); + assert!(scratch.id_scratch.is_empty()); + assert!(scratch.best_candidates.size() == 0); + assert!(scratch.full_return_set.is_empty()); + } + + #[test] + fn test_clear() { + // Arrange + let mut scratch = SSDQueryScratch::::new(100, 10, true).unwrap(); + + // Add some data to scratch fields + scratch.id_scratch.insert(1); + scratch.best_candidates.insert(Neighbor::new(2, 0.5)); + scratch.full_return_set.push(Neighbor::new(3, 0.8)); + + // Act + scratch.clear(); + + // Assert + assert!(scratch.id_scratch.is_empty()); + assert!(scratch.best_candidates.size() == 0); + assert!(scratch.full_return_set.is_empty()); + } +} diff --git a/rust/diskann/src/model/scratch/ssd_thread_data.rs b/rust/diskann/src/model/scratch/ssd_thread_data.rs new file mode 100644 index 000000000..e37495901 --- /dev/null +++ b/rust/diskann/src/model/scratch/ssd_thread_data.rs @@ -0,0 +1,92 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![allow(dead_code)] // Todo: Remove this when the disk index query code is complete. +use std::sync::Arc; + +use super::{scratch_traits::Scratch, IOContext, SSDQueryScratch}; +use crate::common::ANNResult; + +// The thread data struct for SSD I/O. One for each thread, contains the ScratchSpace and the IOContext. +pub struct SSDThreadData { + pub scratch: SSDQueryScratch, + pub io_context: Option>, +} + +impl SSDThreadData { + pub fn new( + aligned_dim: usize, + visited_reserve: usize, + init_pq_scratch: bool, + ) -> ANNResult { + let scratch = SSDQueryScratch::new(aligned_dim, visited_reserve, init_pq_scratch)?; + Ok(SSDThreadData { + scratch, + io_context: None, + }) + } + + pub fn clear(&mut self) { + self.scratch.clear(); + } +} + +#[cfg(test)] +mod tests { + use crate::model::Neighbor; + + use super::*; + + #[test] + fn test_new() { + // Arrange + let aligned_dim = 10; + let visited_reserve = 100; + let init_pq_scratch = true; + + // Act + let result = SSDThreadData::::new(aligned_dim, visited_reserve, init_pq_scratch); + + // Assert + assert!(result.is_ok()); + + let thread_data = result.unwrap(); + + // Assert the properties of the thread data instance + assert!(thread_data.io_context.is_none()); + + let scratch = &thread_data.scratch; + // Assert the properties of the scratch instance + assert!(scratch.pq_scratch.is_some()); + assert!(scratch.id_scratch.is_empty()); + assert!(scratch.best_candidates.size() == 0); + assert!(scratch.full_return_set.is_empty()); + } + + #[test] + fn test_clear() { + // Arrange + let mut thread_data = SSDThreadData::::new(10, 100, true).unwrap(); + + // Add some data to scratch fields + thread_data.scratch.id_scratch.insert(1); + thread_data + .scratch + .best_candidates + .insert(Neighbor::new(2, 0.5)); + thread_data + .scratch + .full_return_set + .push(Neighbor::new(3, 0.8)); + + // Act + thread_data.clear(); + + // Assert + assert!(thread_data.scratch.id_scratch.is_empty()); + assert!(thread_data.scratch.best_candidates.size() == 0); + assert!(thread_data.scratch.full_return_set.is_empty()); + } +} + diff --git a/rust/diskann/src/model/vertex/dimension.rs b/rust/diskann/src/model/vertex/dimension.rs new file mode 100644 index 000000000..32670a8db --- /dev/null +++ b/rust/diskann/src/model/vertex/dimension.rs @@ -0,0 +1,22 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Vertex dimension + +/// 32 vertex dimension +pub const DIM_32: usize = 32; + +/// 64 vertex dimension +pub const DIM_64: usize = 64; + +/// 104 vertex dimension +pub const DIM_104: usize = 104; + +/// 128 vertex dimension +pub const DIM_128: usize = 128; + +/// 256 vertex dimension +pub const DIM_256: usize = 256; diff --git a/rust/diskann/src/model/vertex/mod.rs b/rust/diskann/src/model/vertex/mod.rs new file mode 100644 index 000000000..224d476dc --- /dev/null +++ b/rust/diskann/src/model/vertex/mod.rs @@ -0,0 +1,10 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod vertex; +pub use vertex::Vertex; + +mod dimension; +pub use dimension::*; diff --git a/rust/diskann/src/model/vertex/vertex.rs b/rust/diskann/src/model/vertex/vertex.rs new file mode 100644 index 000000000..55369748e --- /dev/null +++ b/rust/diskann/src/model/vertex/vertex.rs @@ -0,0 +1,68 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Vertex + +use std::array::TryFromSliceError; + +use vector::{FullPrecisionDistance, Metric}; + +/// Vertex with data type T and dimension N +#[derive(Debug)] +pub struct Vertex<'a, T, const N: usize> +where + [T; N]: FullPrecisionDistance, +{ + /// Vertex value + val: &'a [T; N], + + /// Vertex Id + id: u32, +} + +impl<'a, T, const N: usize> Vertex<'a, T, N> +where + [T; N]: FullPrecisionDistance, +{ + /// Create the vertex with data + pub fn new(val: &'a [T; N], id: u32) -> Self { + Self { + val, + id, + } + } + + /// Compare the vertex with another. + #[inline(always)] + pub fn compare(&self, other: &Vertex<'a, T, N>, metric: Metric) -> f32 { + <[T; N]>::distance_compare(self.val, other.val, metric) + } + + /// Get the vector associated with the vertex. + #[inline] + pub fn vector(&self) -> &[T; N] { + self.val + } + + /// Get the vertex id. + #[inline] + pub fn vertex_id(&self) -> u32 { + self.id + } +} + +impl<'a, T, const N: usize> TryFrom<(&'a [T], u32)> for Vertex<'a, T, N> +where + [T; N]: FullPrecisionDistance, +{ + type Error = TryFromSliceError; + + fn try_from((mem_slice, id): (&'a [T], u32)) -> Result { + let array: &[T; N] = mem_slice.try_into()?; + Ok(Vertex::new(array, id)) + } +} + diff --git a/rust/diskann/src/model/windows_aligned_file_reader/mod.rs b/rust/diskann/src/model/windows_aligned_file_reader/mod.rs new file mode 100644 index 000000000..0e63df0a6 --- /dev/null +++ b/rust/diskann/src/model/windows_aligned_file_reader/mod.rs @@ -0,0 +1,7 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[allow(clippy::module_inception)] +mod windows_aligned_file_reader; +pub use windows_aligned_file_reader::*; diff --git a/rust/diskann/src/model/windows_aligned_file_reader/windows_aligned_file_reader.rs b/rust/diskann/src/model/windows_aligned_file_reader/windows_aligned_file_reader.rs new file mode 100644 index 000000000..1cc3dc032 --- /dev/null +++ b/rust/diskann/src/model/windows_aligned_file_reader/windows_aligned_file_reader.rs @@ -0,0 +1,414 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::sync::Arc; +use std::time::Duration; +use std::{ptr, thread}; + +use crossbeam::sync::ShardedLock; +use hashbrown::HashMap; +use once_cell::sync::Lazy; + +use platform::file_handle::{AccessMode, ShareMode}; +use platform::{ + file_handle::FileHandle, + file_io::{get_queued_completion_status, read_file_to_slice}, + io_completion_port::IOCompletionPort, +}; + +use winapi::{ + shared::{basetsd::ULONG_PTR, minwindef::DWORD}, + um::minwinbase::OVERLAPPED, +}; + +use crate::common::{ANNError, ANNResult}; +use crate::model::IOContext; + +pub const MAX_IO_CONCURRENCY: usize = 128; // To do: explore the optimal value for this. The current value is taken from C++ code. +pub const FILE_ATTRIBUTE_READONLY: DWORD = 0x00000001; +pub const IO_COMPLETION_TIMEOUT: DWORD = u32::MAX; // Infinite timeout. +pub const DISK_IO_ALIGNMENT: usize = 512; +pub const ASYNC_IO_COMPLETION_CHECK_INTERVAL: Duration = Duration::from_micros(5); + +/// Aligned read struct for disk IO, it takes the ownership of the AlignedBoxedSlice and returns the AlignedBoxWithSlice data immutably. +pub struct AlignedRead<'a, T> { + /// where to read from + /// offset needs to be aligned with DISK_IO_ALIGNMENT + offset: u64, + + /// where to read into + /// aligned_buf and its len need to be aligned with DISK_IO_ALIGNMENT + aligned_buf: &'a mut [T], +} + +impl<'a, T> AlignedRead<'a, T> { + pub fn new(offset: u64, aligned_buf: &'a mut [T]) -> ANNResult { + Self::assert_is_aligned(offset as usize)?; + Self::assert_is_aligned(std::mem::size_of_val(aligned_buf))?; + + Ok(Self { + offset, + aligned_buf, + }) + } + + fn assert_is_aligned(val: usize) -> ANNResult<()> { + match val % DISK_IO_ALIGNMENT { + 0 => Ok(()), + _ => Err(ANNError::log_disk_io_request_alignment_error(format!( + "The offset or length of AlignedRead request is not {} bytes aligned", + DISK_IO_ALIGNMENT + ))), + } + } + + pub fn aligned_buf(&self) -> &[T] { + self.aligned_buf + } +} + +pub struct WindowsAlignedFileReader { + file_name: String, + + // ctx_map is the mapping from thread id to io context. It is hashmap behind a sharded lock to allow concurrent access from multiple threads. + // ShardedLock: shardedlock provides an implementation of a reader-writer lock that offers concurrent read access to the shared data while allowing exclusive write access. + // It achieves better scalability by dividing the shared data into multiple shards, and each with its own internal lock. + // Multiple threads can read from different shards simultaneously, reducing contention. + // https://docs.rs/crossbeam/0.8.2/crossbeam/sync/struct.ShardedLock.html + // Comparing to RwLock, ShardedLock provides higher concurrency for read operations and is suitable for read heavy workloads. + // The value of the hashmap is an Arc to allow immutable access to IOContext with automatic reference counting. + ctx_map: Lazy>>>, +} + +impl WindowsAlignedFileReader { + pub fn new(fname: &str) -> ANNResult { + let reader: WindowsAlignedFileReader = WindowsAlignedFileReader { + file_name: fname.to_string(), + ctx_map: Lazy::new(|| ShardedLock::new(HashMap::new())), + }; + + reader.register_thread()?; + Ok(reader) + } + + // Register the io context for a thread if it hasn't been registered. + pub fn register_thread(&self) -> ANNResult<()> { + let mut ctx_map = self.ctx_map.write().map_err(|_| { + ANNError::log_lock_poison_error("unable to acquire read lock on ctx_map".to_string()) + })?; + + let id = thread::current().id(); + if ctx_map.contains_key(&id) { + println!( + "Warning:: Duplicate registration for thread_id : {:?}. Directly call get_ctx to get the thread context data.", + id); + + return Ok(()); + } + + let mut ctx = IOContext::new(); + + match unsafe { FileHandle::new(&self.file_name, AccessMode::Read, ShareMode::Read) } { + Ok(file_handle) => ctx.file_handle = file_handle, + Err(err) => { + return Err(ANNError::log_io_error(err)); + } + } + + // Create a io completion port for the file handle, later it will be used to get the completion status. + match IOCompletionPort::new(&ctx.file_handle, None, 0, 0) { + Ok(io_completion_port) => ctx.io_completion_port = io_completion_port, + Err(err) => { + return Err(ANNError::log_io_error(err)); + } + } + + ctx_map.insert(id, Arc::new(ctx)); + + Ok(()) + } + + // Get the reference counted io context for the current thread. + pub fn get_ctx(&self) -> ANNResult> { + let ctx_map = self.ctx_map.read().map_err(|_| { + ANNError::log_lock_poison_error("unable to acquire read lock on ctx_map".to_string()) + })?; + + let id = thread::current().id(); + match ctx_map.get(&id) { + Some(ctx) => Ok(Arc::clone(ctx)), + None => Err(ANNError::log_index_error(format!( + "unable to find IOContext for thread_id {:?}", + id + ))), + } + } + + // Read the data from the file by sending concurrent io requests in batches. + pub fn read(&self, read_requests: &mut [AlignedRead], ctx: &IOContext) -> ANNResult<()> { + let n_requests = read_requests.len(); + let n_batches = (n_requests + MAX_IO_CONCURRENCY - 1) / MAX_IO_CONCURRENCY; + + let mut overlapped_in_out = + vec![unsafe { std::mem::zeroed::() }; MAX_IO_CONCURRENCY]; + + for batch_idx in 0..n_batches { + let batch_start = MAX_IO_CONCURRENCY * batch_idx; + let batch_size = std::cmp::min(n_requests - batch_start, MAX_IO_CONCURRENCY); + + for j in 0..batch_size { + let req = &mut read_requests[batch_start + j]; + let os = &mut overlapped_in_out[j]; + + match unsafe { + read_file_to_slice(&ctx.file_handle, req.aligned_buf, os, req.offset) + } { + Ok(_) => {} + Err(error) => { + return Err(ANNError::IOError { err: (error) }); + } + } + } + + let mut n_read: DWORD = 0; + let mut n_complete: u64 = 0; + let mut completion_key: ULONG_PTR = 0; + let mut lp_os: *mut OVERLAPPED = ptr::null_mut(); + while n_complete < batch_size as u64 { + match unsafe { + get_queued_completion_status( + &ctx.io_completion_port, + &mut n_read, + &mut completion_key, + &mut lp_os, + IO_COMPLETION_TIMEOUT, + ) + } { + // An IO request completed. + Ok(true) => n_complete += 1, + // No IO request completed, continue to wait. + Ok(false) => { + thread::sleep(ASYNC_IO_COMPLETION_CHECK_INTERVAL); + } + // An error ocurred. + Err(error) => return Err(ANNError::IOError { err: (error) }), + } + } + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{fs::File, io::BufReader}; + + use bincode::deserialize_from; + use serde::{Deserialize, Serialize}; + + use crate::{common::AlignedBoxWithSlice, model::SECTOR_LEN}; + + use super::*; + pub const TEST_INDEX_PATH: &str = + "./tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_alligned_reader_test.index"; + pub const TRUTH_NODE_DATA_PATH: &str = + "./tests/data/disk_index_node_data_aligned_reader_truth.bin"; + + #[derive(Debug, Serialize, Deserialize)] + struct NodeData { + num_neighbors: u32, + coordinates: Vec, + neighbors: Vec, + } + + impl PartialEq for NodeData { + fn eq(&self, other: &Self) -> bool { + self.num_neighbors == other.num_neighbors + && self.coordinates == other.coordinates + && self.neighbors == other.neighbors + } + } + + #[test] + fn test_new_aligned_file_reader() { + // Replace "test_file_path" with actual file path + let result = WindowsAlignedFileReader::new(TEST_INDEX_PATH); + assert!(result.is_ok()); + + let reader = result.unwrap(); + assert_eq!(reader.file_name, TEST_INDEX_PATH); + } + + #[test] + fn test_read() { + let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap(); + let ctx = reader.get_ctx().unwrap(); + + let read_length = 512; // adjust according to your logic + let num_read = 10; + let mut aligned_mem = AlignedBoxWithSlice::::new(read_length * num_read, 512).unwrap(); + + // create and add AlignedReads to the vector + let mut mem_slices = aligned_mem + .split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length) + .unwrap(); + + let mut aligned_reads: Vec> = mem_slices + .iter_mut() + .enumerate() + .map(|(i, slice)| { + let offset = (i * read_length) as u64; + AlignedRead::new(offset, slice).unwrap() + }) + .collect(); + + let result = reader.read(&mut aligned_reads, &ctx); + assert!(result.is_ok()); + } + + #[test] + fn test_read_disk_index_by_sector() { + let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap(); + let ctx = reader.get_ctx().unwrap(); + + let read_length = SECTOR_LEN; // adjust according to your logic + let num_sector = 10; + let mut aligned_mem = + AlignedBoxWithSlice::::new(read_length * num_sector, 512).unwrap(); + + // Each slice will be used as the buffer for a read request of a sector. + let mut mem_slices = aligned_mem + .split_into_nonoverlapping_mut_slices(0..aligned_mem.len(), read_length) + .unwrap(); + + let mut aligned_reads: Vec> = mem_slices + .iter_mut() + .enumerate() + .map(|(sector_id, slice)| { + let offset = (sector_id * read_length) as u64; + AlignedRead::new(offset, slice).unwrap() + }) + .collect(); + + let result = reader.read(&mut aligned_reads, &ctx); + assert!(result.is_ok()); + + aligned_reads.iter().for_each(|read| { + assert_eq!(read.aligned_buf.len(), SECTOR_LEN); + }); + + let disk_layout_meta = reconstruct_disk_meta(aligned_reads[0].aligned_buf); + assert!(disk_layout_meta.len() > 9); + + let dims = disk_layout_meta[1]; + let num_pts = disk_layout_meta[0]; + let max_node_len = disk_layout_meta[3]; + let max_num_nodes_per_sector = disk_layout_meta[4]; + + assert!(max_node_len * max_num_nodes_per_sector < SECTOR_LEN as u64); + + let num_nbrs_start = (dims as usize) * std::mem::size_of::(); + let nbrs_buf_start = num_nbrs_start + std::mem::size_of::(); + + let mut node_data_array = Vec::with_capacity(max_num_nodes_per_sector as usize * 9); + + // Only validate the first 9 sectors with graph nodes. + (1..9).for_each(|sector_id| { + let sector_data = &mem_slices[sector_id]; + for node_data in sector_data.chunks_exact(max_node_len as usize) { + // Extract coordinates data from the start of the node_data + let coordinates_end = (dims as usize) * std::mem::size_of::(); + let coordinates = node_data[0..coordinates_end] + .chunks_exact(std::mem::size_of::()) + .map(|chunk| f32::from_le_bytes(chunk.try_into().unwrap())) + .collect(); + + // Extract number of neighbors from the node_data + let neighbors_num = u32::from_le_bytes( + node_data[num_nbrs_start..nbrs_buf_start] + .try_into() + .unwrap(), + ); + + let nbors_buf_end = + nbrs_buf_start + (neighbors_num as usize) * std::mem::size_of::(); + + // Extract neighbors from the node data. + let mut neighbors = Vec::new(); + for nbors_data in node_data[nbrs_buf_start..nbors_buf_end] + .chunks_exact(std::mem::size_of::()) + { + let nbors_id = u32::from_le_bytes(nbors_data.try_into().unwrap()); + assert!(nbors_id < num_pts as u32); + neighbors.push(nbors_id); + } + + // Create NodeData struct and push it to the node_data_array + node_data_array.push(NodeData { + num_neighbors: neighbors_num, + coordinates, + neighbors, + }); + } + }); + + // Compare that each node read from the disk index are expected. + let node_data_truth_file = File::open(TRUTH_NODE_DATA_PATH).unwrap(); + let reader = BufReader::new(node_data_truth_file); + + let node_data_vec: Vec = deserialize_from(reader).unwrap(); + for (node_from_node_data_file, node_from_disk_index) in + node_data_vec.iter().zip(node_data_array.iter()) + { + // Verify that the NodeData from the file is equal to the NodeData in node_data_array + assert_eq!(node_from_node_data_file, node_from_disk_index); + } + } + + #[test] + fn test_read_fail_invalid_file() { + let reader = WindowsAlignedFileReader::new("/invalid_path"); + assert!(reader.is_err()); + } + + #[test] + fn test_read_no_requests() { + let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap(); + let ctx = reader.get_ctx().unwrap(); + + let mut read_requests = Vec::>::new(); + let result = reader.read(&mut read_requests, &ctx); + assert!(result.is_ok()); + } + + #[test] + fn test_get_ctx() { + let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap(); + let result = reader.get_ctx(); + assert!(result.is_ok()); + } + + #[test] + fn test_register_thread() { + let reader = WindowsAlignedFileReader::new(TEST_INDEX_PATH).unwrap(); + let result = reader.register_thread(); + assert!(result.is_ok()); + } + + fn reconstruct_disk_meta(buffer: &[u8]) -> Vec { + let size_of_u64 = std::mem::size_of::(); + + let num_values = buffer.len() / size_of_u64; + let mut disk_layout_meta = Vec::with_capacity(num_values); + let meta_data = &buffer[8..]; + + for chunk in meta_data.chunks_exact(size_of_u64) { + let value = u64::from_le_bytes(chunk.try_into().unwrap()); + disk_layout_meta.push(value); + } + + disk_layout_meta + } +} diff --git a/rust/diskann/src/storage/disk_graph_storage.rs b/rust/diskann/src/storage/disk_graph_storage.rs new file mode 100644 index 000000000..448175212 --- /dev/null +++ b/rust/diskann/src/storage/disk_graph_storage.rs @@ -0,0 +1,37 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_docs)] + +//! Disk graph storage + +use std::sync::Arc; + +use crate::{model::{WindowsAlignedFileReader, IOContext, AlignedRead}, common::ANNResult}; + +/// Graph storage for disk index +/// One thread has one storage instance +pub struct DiskGraphStorage { + /// Disk graph reader + disk_graph_reader: Arc, + + /// IOContext of current thread + ctx: Arc, +} + +impl DiskGraphStorage { + /// Create a new DiskGraphStorage instance + pub fn new(disk_graph_reader: Arc) -> ANNResult { + let ctx = disk_graph_reader.get_ctx()?; + Ok(Self { + disk_graph_reader, + ctx, + }) + } + + /// Read disk graph data + pub fn read(&self, read_requests: &mut [AlignedRead]) -> ANNResult<()> { + self.disk_graph_reader.read(read_requests, &self.ctx) + } +} diff --git a/rust/diskann/src/storage/disk_index_storage.rs b/rust/diskann/src/storage/disk_index_storage.rs new file mode 100644 index 000000000..0c558084d --- /dev/null +++ b/rust/diskann/src/storage/disk_index_storage.rs @@ -0,0 +1,363 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use byteorder::{ByteOrder, LittleEndian, ReadBytesExt}; +use std::fs::File; +use std::io::Read; +use std::marker::PhantomData; +use std::{fs, mem}; + +use crate::common::{ANNError, ANNResult}; +use crate::model::NUM_PQ_CENTROIDS; +use crate::storage::PQStorage; +use crate::utils::{convert_types_u32_usize, convert_types_u64_usize, load_bin, save_bin_u64}; +use crate::utils::{ + file_exists, gen_sample_data, get_file_size, round_up, CachedReader, CachedWriter, +}; + +const SECTOR_LEN: usize = 4096; + +/// Todo: Remove the allow(dead_code) when the disk search code is complete +#[allow(dead_code)] +pub struct PQPivotData { + dim: usize, + pq_table: Vec, + centroids: Vec, + chunk_offsets: Vec, +} + +pub struct DiskIndexStorage { + /// Dataset file + dataset_file: String, + + /// Index file path prefix + index_path_prefix: String, + + // TODO: Only a placeholder for T, will be removed later + _marker: PhantomData, + + pq_storage: PQStorage, +} + +impl DiskIndexStorage { + /// Create DiskIndexStorage instance + pub fn new(dataset_file: String, index_path_prefix: String) -> ANNResult { + let pq_storage: PQStorage = PQStorage::new( + &(index_path_prefix.clone() + ".bin_pq_pivots.bin"), + &(index_path_prefix.clone() + ".bin_pq_compressed.bin"), + &dataset_file, + )?; + + Ok(DiskIndexStorage { + dataset_file, + index_path_prefix, + _marker: PhantomData, + pq_storage, + }) + } + + pub fn get_pq_storage(&mut self) -> &mut PQStorage { + &mut self.pq_storage + } + + pub fn dataset_file(&self) -> &String { + &self.dataset_file + } + + pub fn index_path_prefix(&self) -> &String { + &self.index_path_prefix + } + + /// Create disk layout + /// Sector #1: disk_layout_meta + /// Sector #n: num_nodes_per_sector nodes + /// Each node's layout: {full precision vector:[T; DIM]}{num_nbrs: u32}{neighbors: [u32; num_nbrs]} + /// # Arguments + /// * `dataset_file` - dataset file containing full precision vectors + /// * `mem_index_file` - in-memory index graph file + /// * `disk_layout_file` - output disk layout file + pub fn create_disk_layout(&self) -> ANNResult<()> { + let mem_index_file = self.mem_index_file(); + let disk_layout_file = self.disk_index_file(); + + // amount to read or write in one shot + let read_blk_size = 64 * 1024 * 1024; + let write_blk_size = read_blk_size; + let mut dataset_reader = CachedReader::new(self.dataset_file.as_str(), read_blk_size)?; + + let num_pts = dataset_reader.read_u32()? as u64; + let dims = dataset_reader.read_u32()? as u64; + + // Create cached reader + writer + let actual_file_size = get_file_size(mem_index_file.as_str())?; + println!("Vamana index file size={}", actual_file_size); + + let mut vamana_reader = File::open(mem_index_file)?; + let mut diskann_writer = CachedWriter::new(disk_layout_file.as_str(), write_blk_size)?; + + let index_file_size = vamana_reader.read_u64::()?; + if index_file_size != actual_file_size { + println!( + "Vamana Index file size does not match expected size per meta-data. file size from file: {}, actual file size: {}", + index_file_size, actual_file_size + ); + } + + let max_degree = vamana_reader.read_u32::()?; + let medoid = vamana_reader.read_u32::()?; + let vamana_frozen_num = vamana_reader.read_u64::()?; + + let mut vamana_frozen_loc = 0; + if vamana_frozen_num == 1 { + vamana_frozen_loc = medoid; + } + + let max_node_len = ((max_degree as u64 + 1) * (mem::size_of::() as u64)) + + (dims * (mem::size_of::() as u64)); + let num_nodes_per_sector = (SECTOR_LEN as u64) / max_node_len; + + println!("medoid: {}B", medoid); + println!("max_node_len: {}B", max_node_len); + println!("num_nodes_per_sector: {}B", num_nodes_per_sector); + + // SECTOR_LEN buffer for each sector + let mut sector_buf = vec![0u8; SECTOR_LEN]; + let mut node_buf = vec![0u8; max_node_len as usize]; + + let num_nbrs_start = (dims as usize) * mem::size_of::(); + let nbrs_buf_start = num_nbrs_start + mem::size_of::(); + + // number of sectors (1 for meta data) + let num_sectors = round_up(num_pts, num_nodes_per_sector) / num_nodes_per_sector; + let disk_index_file_size = (num_sectors + 1) * (SECTOR_LEN as u64); + + let disk_layout_meta = vec![ + num_pts, + dims, + medoid as u64, + max_node_len, + num_nodes_per_sector, + vamana_frozen_num, + vamana_frozen_loc as u64, + // append_reorder_data + // We are not supporting this. Temporarily write it into the layout so that + // we can leverage C++ query driver to test the disk index + false as u64, + disk_index_file_size, + ]; + + diskann_writer.write(§or_buf)?; + + let mut cur_node_coords = vec![0u8; (dims as usize) * mem::size_of::()]; + let mut cur_node_id = 0u64; + + for sector in 0..num_sectors { + if sector % 100_000 == 0 { + println!("Sector #{} written", sector); + } + sector_buf.fill(0); + + for sector_node_id in 0..num_nodes_per_sector { + if cur_node_id >= num_pts { + break; + } + + node_buf.fill(0); + + // read cur node's num_nbrs + let num_nbrs = vamana_reader.read_u32::()?; + + // sanity checks on num_nbrs + debug_assert!(num_nbrs > 0); + debug_assert!(num_nbrs <= max_degree); + + // write coords of node first + dataset_reader.read(&mut cur_node_coords)?; + node_buf[..cur_node_coords.len()].copy_from_slice(&cur_node_coords); + + // write num_nbrs + LittleEndian::write_u32( + &mut node_buf[num_nbrs_start..(num_nbrs_start + mem::size_of::())], + num_nbrs, + ); + + // write neighbors + let nbrs_buf = &mut node_buf[nbrs_buf_start + ..(nbrs_buf_start + (num_nbrs as usize) * mem::size_of::())]; + vamana_reader.read_exact(nbrs_buf)?; + + // get offset into sector_buf + let sector_node_buf_start = (sector_node_id * max_node_len) as usize; + let sector_node_buf = &mut sector_buf + [sector_node_buf_start..(sector_node_buf_start + max_node_len as usize)]; + sector_node_buf.copy_from_slice(&node_buf[..(max_node_len as usize)]); + + cur_node_id += 1; + } + + // flush sector to disk + diskann_writer.write(§or_buf)?; + } + + diskann_writer.flush()?; + save_bin_u64( + disk_layout_file.as_str(), + &disk_layout_meta, + disk_layout_meta.len(), + 1, + 0, + )?; + + Ok(()) + } + + pub fn index_build_cleanup(&self) -> ANNResult<()> { + fs::remove_file(self.mem_index_file())?; + Ok(()) + } + + pub fn gen_query_warmup_data(&self, sampling_rate: f64) -> ANNResult<()> { + gen_sample_data::( + &self.dataset_file, + &self.warmup_query_prefix(), + sampling_rate, + )?; + Ok(()) + } + + /// Load pre-trained pivot table + pub fn load_pq_pivots_bin( + &self, + num_pq_chunks: &usize, + ) -> ANNResult { + let pq_pivots_path = &self.pq_pivot_file(); + if !file_exists(pq_pivots_path) { + return Err(ANNError::log_pq_error( + "ERROR: PQ k-means pivot file not found.".to_string(), + )); + } + + let (data, offset_num, offset_dim) = load_bin::(pq_pivots_path, 0)?; + let file_offset_data = convert_types_u64_usize(&data, offset_num, offset_dim); + if offset_num != 4 { + let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", pq_pivots_path, offset_num); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, pivot_num, dim) = load_bin::(pq_pivots_path, file_offset_data[0])?; + let pq_table = data.to_vec(); + if pivot_num != NUM_PQ_CENTROIDS { + let error_message = format!( + "Error reading pq_pivots file {}. file_num_centers = {}, but expecting {} centers.", + pq_pivots_path, pivot_num, NUM_PQ_CENTROIDS + ); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, centroid_dim, nc) = load_bin::(pq_pivots_path, file_offset_data[1])?; + let centroids = data.to_vec(); + if centroid_dim != dim || nc != 1 { + let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", pq_pivots_path, centroid_dim, nc, dim); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, chunk_offset_num, nc) = load_bin::(pq_pivots_path, file_offset_data[2])?; + let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_num, nc); + if chunk_offset_num != num_pq_chunks + 1 || nc != 1 { + let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_num, nc, num_pq_chunks + 1); + return Err(ANNError::log_pq_error(error_message)); + } + + Ok(PQPivotData { + dim, + pq_table, + centroids, + chunk_offsets + }) + } + + fn mem_index_file(&self) -> String { + self.index_path_prefix.clone() + "_mem.index" + } + + fn disk_index_file(&self) -> String { + self.index_path_prefix.clone() + "_disk.index" + } + + fn warmup_query_prefix(&self) -> String { + self.index_path_prefix.clone() + "_sample" + } + + pub fn pq_pivot_file(&self) -> String { + self.index_path_prefix.clone() + ".bin_pq_pivots.bin" + } + + pub fn compressed_pq_pivot_file(&self) -> String { + self.index_path_prefix.clone() + ".bin_pq_compressed.bin" + } +} + +#[cfg(test)] +mod disk_index_storage_test { + use std::fs; + + use crate::test_utils::get_test_file_path; + + use super::*; + + const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin"; + const DISK_INDEX_PATH_PREFIX: &str = "tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2"; + const TRUTH_DISK_LAYOUT: &str = + "tests/data/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index"; + + #[test] + fn create_disk_layout_test() { + let storage = DiskIndexStorage::::new( + get_test_file_path(TEST_DATA_FILE), + get_test_file_path(DISK_INDEX_PATH_PREFIX), + ).unwrap(); + storage.create_disk_layout().unwrap(); + + let disk_layout_file = storage.disk_index_file(); + let rust_disk_layout = fs::read(disk_layout_file.as_str()).unwrap(); + let truth_disk_layout = fs::read(get_test_file_path(TRUTH_DISK_LAYOUT).as_str()).unwrap(); + + assert!(rust_disk_layout == truth_disk_layout); + + fs::remove_file(disk_layout_file.as_str()).expect("Failed to delete file"); + } + + #[test] + fn load_pivot_test() { + let dim: usize = 128; + let num_pq_chunk: usize = 1; + let pivot_file_prefix: &str = "tests/data/siftsmall_learn"; + let storage = DiskIndexStorage::::new( + get_test_file_path(TEST_DATA_FILE), + pivot_file_prefix.to_string(), + ).unwrap(); + + let pq_pivot_data = + storage.load_pq_pivots_bin(&num_pq_chunk).unwrap(); + + assert_eq!(pq_pivot_data.pq_table.len(), NUM_PQ_CENTROIDS * dim); + assert_eq!(pq_pivot_data.centroids.len(), dim); + + assert_eq!(pq_pivot_data.chunk_offsets[0], 0); + assert_eq!(pq_pivot_data.chunk_offsets[1], dim); + assert_eq!(pq_pivot_data.chunk_offsets.len(), num_pq_chunk + 1); + } + + #[test] + #[should_panic(expected = "ERROR: PQ k-means pivot file not found.")] + fn load_pivot_file_not_exist_test() { + let num_pq_chunk: usize = 1; + let pivot_file_prefix: &str = "tests/data/siftsmall_learn_file_not_exist"; + let storage = DiskIndexStorage::::new( + get_test_file_path(TEST_DATA_FILE), + pivot_file_prefix.to_string(), + ).unwrap(); + let _ = storage.load_pq_pivots_bin(&num_pq_chunk).unwrap(); + } +} diff --git a/rust/diskann/src/storage/mod.rs b/rust/diskann/src/storage/mod.rs new file mode 100644 index 000000000..03c5b8e82 --- /dev/null +++ b/rust/diskann/src/storage/mod.rs @@ -0,0 +1,12 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +mod disk_index_storage; +pub use disk_index_storage::*; + +mod disk_graph_storage; +pub use disk_graph_storage::*; + +mod pq_storage; +pub use pq_storage::*; diff --git a/rust/diskann/src/storage/pq_storage.rs b/rust/diskann/src/storage/pq_storage.rs new file mode 100644 index 000000000..b1d3fa05a --- /dev/null +++ b/rust/diskann/src/storage/pq_storage.rs @@ -0,0 +1,367 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use byteorder::{LittleEndian, ReadBytesExt}; +use rand::distributions::{Distribution, Uniform}; +use std::fs::File; +use std::io::{Read, Seek, SeekFrom, Write}; +use std::mem; + +use crate::common::{ANNError, ANNResult}; +use crate::utils::CachedReader; +use crate::utils::{ + convert_types_u32_usize, convert_types_u64_usize, convert_types_usize_u32, + convert_types_usize_u64, convert_types_usize_u8, save_bin_f32, save_bin_u32, save_bin_u64, +}; +use crate::utils::{file_exists, load_bin, open_file_to_write, METADATA_SIZE}; + +#[derive(Debug)] +pub struct PQStorage { + /// Pivot table path + pivot_file: String, + + /// Compressed pivot path + compressed_pivot_file: String, + + /// Data used to construct PQ table and PQ compressed table + pq_data_file: String, + + /// PQ data reader + pq_data_file_reader: File, +} + +impl PQStorage { + pub fn new( + pivot_file: &str, + compressed_pivot_file: &str, + pq_data_file: &str, + ) -> std::io::Result { + let pq_data_file_reader = File::open(pq_data_file)?; + Ok(Self { + pivot_file: pivot_file.to_string(), + compressed_pivot_file: compressed_pivot_file.to_string(), + pq_data_file: pq_data_file.to_string(), + pq_data_file_reader, + }) + } + + pub fn write_compressed_pivot_metadata(&self, npts: i32, pq_chunk: i32) -> std::io::Result<()> { + let mut writer = open_file_to_write(&self.compressed_pivot_file)?; + writer.write_all(&npts.to_le_bytes())?; + writer.write_all(&pq_chunk.to_le_bytes())?; + Ok(()) + } + + pub fn write_compressed_pivot_data( + &self, + compressed_base: &[usize], + num_centers: usize, + block_size: usize, + num_pq_chunks: usize, + ) -> std::io::Result<()> { + let mut writer = open_file_to_write(&self.compressed_pivot_file)?; + writer.seek(SeekFrom::Start((std::mem::size_of::() * 2) as u64))?; + if num_centers > 256 { + writer.write_all(unsafe { + std::slice::from_raw_parts( + compressed_base.as_ptr() as *const u8, + block_size * num_pq_chunks * std::mem::size_of::(), + ) + })?; + } else { + let compressed_base_u8 = + convert_types_usize_u8(compressed_base, block_size, num_pq_chunks); + writer.write_all(&compressed_base_u8)?; + } + Ok(()) + } + + pub fn write_pivot_data( + &self, + full_pivot_data: &[f32], + centroid: &[f32], + chunk_offsets: &[usize], + num_centers: usize, + dim: usize, + ) -> std::io::Result<()> { + let mut cumul_bytes: Vec = vec![0; 4]; + cumul_bytes[0] = METADATA_SIZE; + cumul_bytes[1] = cumul_bytes[0] + + save_bin_f32( + &self.pivot_file, + full_pivot_data, + num_centers, + dim, + cumul_bytes[0], + )?; + cumul_bytes[2] = + cumul_bytes[1] + save_bin_f32(&self.pivot_file, centroid, dim, 1, cumul_bytes[1])?; + + // Because the writer only can write u32, u64 but not usize, so we need to convert the type first. + let chunk_offsets_u64 = convert_types_usize_u32(chunk_offsets, chunk_offsets.len(), 1); + cumul_bytes[3] = cumul_bytes[2] + + save_bin_u32( + &self.pivot_file, + &chunk_offsets_u64, + chunk_offsets.len(), + 1, + cumul_bytes[2], + )?; + + let cumul_bytes_u64 = convert_types_usize_u64(&cumul_bytes, 4, 1); + save_bin_u64(&self.pivot_file, &cumul_bytes_u64, cumul_bytes.len(), 1, 0)?; + + Ok(()) + } + + pub fn pivot_data_exist(&self) -> bool { + file_exists(&self.pivot_file) + } + + pub fn read_pivot_metadata(&self) -> std::io::Result<(usize, usize)> { + let (_, file_num_centers, file_dim) = load_bin::(&self.pivot_file, METADATA_SIZE)?; + Ok((file_num_centers, file_dim)) + } + + pub fn load_pivot_data( + &self, + num_pq_chunks: &usize, + num_centers: &usize, + dim: &usize, + ) -> ANNResult<(Vec, Vec, Vec)> { + // Load file offset data. File saved as offset data(4*1) -> pivot data(centroid num*dim) -> centroid of dim data(dim*1) -> chunk offset data(chunksize+1*1) + // Because we only can write u64 rather than usize, so the file stored as u64 type. Need to convert to usize when use. + let (data, offset_num, nc) = load_bin::(&self.pivot_file, 0)?; + let file_offset_data = convert_types_u64_usize(&data, offset_num, nc); + if offset_num != 4 { + let error_message = format!("Error reading pq_pivots file {}. Offsets don't contain correct metadata, # offsets = {}, but expecting 4.", &self.pivot_file, offset_num); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, pivot_num, pivot_dim) = load_bin::(&self.pivot_file, file_offset_data[0])?; + let full_pivot_data = data; + if pivot_num != *num_centers || pivot_dim != *dim { + let error_message = format!("Error reading pq_pivots file {}. file_num_centers = {}, file_dim = {} but expecting {} centers in {} dimensions.", &self.pivot_file, pivot_num, pivot_dim, num_centers, dim); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, centroid_dim, nc) = load_bin::(&self.pivot_file, file_offset_data[1])?; + let centroid = data; + if centroid_dim != *dim || nc != 1 { + let error_message = format!("Error reading pq_pivots file {}. file_dim = {}, file_cols = {} but expecting {} entries in 1 dimension.", &self.pivot_file, centroid_dim, nc, dim); + return Err(ANNError::log_pq_error(error_message)); + } + + let (data, chunk_offset_number, nc) = + load_bin::(&self.pivot_file, file_offset_data[2])?; + let chunk_offsets = convert_types_u32_usize(&data, chunk_offset_number, nc); + if chunk_offset_number != *num_pq_chunks + 1 || nc != 1 { + let error_message = format!("Error reading pq_pivots file at chunk offsets; file has nr={}, nc={} but expecting nr={} and nc=1.", chunk_offset_number, nc, num_pq_chunks + 1); + return Err(ANNError::log_pq_error(error_message)); + } + Ok((full_pivot_data, centroid, chunk_offsets)) + } + + pub fn read_pq_data_metadata(&mut self) -> std::io::Result<(usize, usize)> { + let npts_i32 = self.pq_data_file_reader.read_i32::()?; + let dim_i32 = self.pq_data_file_reader.read_i32::()?; + let num_points = npts_i32 as usize; + let dim = dim_i32 as usize; + Ok((num_points, dim)) + } + + pub fn read_pq_block_data( + &mut self, + cur_block_size: usize, + dim: usize, + ) -> std::io::Result> { + let mut buf = vec![0u8; cur_block_size * dim * std::mem::size_of::()]; + self.pq_data_file_reader.read_exact(&mut buf)?; + + let ptr = buf.as_ptr() as *const T; + let block_data = unsafe { std::slice::from_raw_parts(ptr, cur_block_size * dim) }; + Ok(block_data.to_vec()) + } + + /// streams data from the file, and samples each vector with probability p_val + /// and returns a matrix of size slice_size* ndims as floating point type. + /// the slice_size and ndims are set inside the function. + /// # Arguments + /// * `file_name` - filename where the data is + /// * `p_val` - possibility to sample data + /// * `sampled_vectors` - sampled vector chose by p_val possibility + /// * `slice_size` - how many sampled data return + /// * `dim` - each sample data dimension + pub fn gen_random_slice>( + &self, + mut p_val: f64, + ) -> ANNResult<(Vec, usize, usize)> { + let read_blk_size = 64 * 1024 * 1024; + let mut reader = CachedReader::new(&self.pq_data_file, read_blk_size)?; + + let npts = reader.read_u32()? as usize; + let dim = reader.read_u32()? as usize; + let mut sampled_vectors: Vec = Vec::new(); + let mut slice_size = 0; + p_val = if p_val < 1f64 { p_val } else { 1f64 }; + + let mut generator = rand::thread_rng(); + let distribution = Uniform::from(0.0..1.0); + + for _ in 0..npts { + let mut cur_vector_bytes = vec![0u8; dim * mem::size_of::()]; + reader.read(&mut cur_vector_bytes)?; + let random_value = distribution.sample(&mut generator); + if random_value < p_val { + let ptr = cur_vector_bytes.as_ptr() as *const T; + let cur_vector_t = unsafe { std::slice::from_raw_parts(ptr, dim) }; + sampled_vectors.extend(cur_vector_t.iter().map(|&t| t.into())); + slice_size += 1; + } + } + + Ok((sampled_vectors, slice_size, dim)) + } +} + +#[cfg(test)] +mod pq_storage_tests { + use rand::Rng; + + use super::*; + use crate::utils::gen_random_slice; + + const DATA_FILE: &str = "tests/data/siftsmall_learn.bin"; + const PQ_PIVOT_PATH: &str = "tests/data/siftsmall_learn.bin_pq_pivots.bin"; + const PQ_COMPRESSED_PATH: &str = "tests/data/empty_pq_compressed.bin"; + + #[test] + fn new_test() { + let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE); + assert!(result.is_ok()); + } + + #[test] + fn write_compressed_pivot_metadata_test() { + let compress_pivot_path = "write_compressed_pivot_metadata_test.bin"; + let result = PQStorage::new(PQ_PIVOT_PATH, compress_pivot_path, DATA_FILE).unwrap(); + + _ = result.write_compressed_pivot_metadata(100, 20); + let mut result_reader = File::open(compress_pivot_path).unwrap(); + let npts_i32 = result_reader.read_i32::().unwrap(); + let dim_i32 = result_reader.read_i32::().unwrap(); + + assert_eq!(npts_i32, 100); + assert_eq!(dim_i32, 20); + + std::fs::remove_file(compress_pivot_path).unwrap(); + } + + #[test] + fn write_compressed_pivot_data_test() { + let compress_pivot_path = "write_compressed_pivot_data_test.bin"; + let result = PQStorage::new(PQ_PIVOT_PATH, compress_pivot_path, DATA_FILE).unwrap(); + + let mut rng = rand::thread_rng(); + + let num_centers = 256; + let block_size = 4; + let num_pq_chunks = 2; + let compressed_base: Vec = (0..block_size * num_pq_chunks) + .map(|_| rng.gen_range(0..num_centers)) + .collect(); + _ = result.write_compressed_pivot_data( + &compressed_base, + num_centers, + block_size, + num_pq_chunks, + ); + + let mut result_reader = File::open(compress_pivot_path).unwrap(); + _ = result_reader.read_i32::().unwrap(); + _ = result_reader.read_i32::().unwrap(); + let mut buf = vec![0u8; block_size * num_pq_chunks * std::mem::size_of::()]; + result_reader.read_exact(&mut buf).unwrap(); + + let ptr = buf.as_ptr() as *const u8; + let block_data = unsafe { std::slice::from_raw_parts(ptr, block_size * num_pq_chunks) }; + + for index in 0..block_data.len() { + assert_eq!(compressed_base[index], block_data[index] as usize); + } + std::fs::remove_file(compress_pivot_path).unwrap(); + } + + #[test] + fn pivot_data_exist_test() { + let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap(); + assert!(result.pivot_data_exist()); + + let pivot_path = "not_exist_pivot_path.bin"; + let result = PQStorage::new(pivot_path, PQ_COMPRESSED_PATH, DATA_FILE).unwrap(); + assert!(!result.pivot_data_exist()); + } + + #[test] + fn read_pivot_metadata_test() { + let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap(); + let (npt, dim) = result.read_pivot_metadata().unwrap(); + + assert_eq!(npt, 256); + assert_eq!(dim, 128); + } + + #[test] + fn load_pivot_data_test() { + let result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap(); + let (pq_pivot_data, centroids, chunk_offsets) = + result.load_pivot_data(&1, &256, &128).unwrap(); + + assert_eq!(pq_pivot_data.len(), 256 * 128); + assert_eq!(centroids.len(), 128); + assert_eq!(chunk_offsets.len(), 2); + } + + #[test] + fn read_pq_data_metadata_test() { + let mut result = PQStorage::new(PQ_PIVOT_PATH, PQ_COMPRESSED_PATH, DATA_FILE).unwrap(); + let (npt, dim) = result.read_pq_data_metadata().unwrap(); + + assert_eq!(npt, 25000); + assert_eq!(dim, 128); + } + + #[test] + fn gen_random_slice_test() { + let file_name = "gen_random_slice_test.bin"; + //npoints=2, dim=8 + let data: [u8; 72] = [ + 2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, + 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, + 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, + 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41, + 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41, + ]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let (sampled_vectors, slice_size, ndims) = + gen_random_slice::(file_name, 1f64).unwrap(); + let mut start = 8; + (0..sampled_vectors.len()).for_each(|i| { + assert_eq!(sampled_vectors[i].to_le_bytes(), data[start..start + 4]); + start += 4; + }); + assert_eq!(sampled_vectors.len(), 16); + assert_eq!(slice_size, 2); + assert_eq!(ndims, 8); + + let (sampled_vectors, slice_size, ndims) = + gen_random_slice::(file_name, 0f64).unwrap(); + assert_eq!(sampled_vectors.len(), 0); + assert_eq!(slice_size, 0); + assert_eq!(ndims, 8); + + std::fs::remove_file(file_name).expect("Failed to delete file"); + } +} diff --git a/rust/diskann/src/test_utils/inmem_index_initialization.rs b/rust/diskann/src/test_utils/inmem_index_initialization.rs new file mode 100644 index 000000000..db3b58179 --- /dev/null +++ b/rust/diskann/src/test_utils/inmem_index_initialization.rs @@ -0,0 +1,74 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use vector::Metric; + +use crate::index::InmemIndex; +use crate::model::configuration::index_write_parameters::IndexWriteParametersBuilder; +use crate::model::{IndexConfiguration}; +use crate::model::vertex::DIM_128; +use crate::utils::{file_exists, load_metadata_from_file}; + +use super::get_test_file_path; + +// f32, 128 DIM and 256 points source data +const TEST_DATA_FILE: &str = "tests/data/siftsmall_learn_256pts.fbin"; +const NUM_POINTS_TO_LOAD: usize = 256; + +pub fn create_index_with_test_data() -> InmemIndex { + let index_write_parameters = IndexWriteParametersBuilder::new(50, 4).with_alpha(1.2).build(); + let config = IndexConfiguration::new( + Metric::L2, + 128, + 128, + 256, + false, + 0, + false, + 0, + 1.0f32, + index_write_parameters); + let mut index: InmemIndex = InmemIndex::new(config).unwrap(); + + build_test_index(&mut index, get_test_file_path(TEST_DATA_FILE).as_str(), NUM_POINTS_TO_LOAD); + + index.start = index.dataset.calculate_medoid_point_id().unwrap(); + + index +} + +fn build_test_index(index: &mut InmemIndex, filename: &str, num_points_to_load: usize) { + if !file_exists(filename) { + panic!("ERROR: Data file {} does not exist.", filename); + } + + let (file_num_points, file_dim) = load_metadata_from_file(filename).unwrap(); + if file_num_points > index.configuration.max_points { + panic!( + "ERROR: Driver requests loading {} points and file has {} points, + but index can support only {} points as specified in configuration.", + num_points_to_load, file_num_points, index.configuration.max_points + ); + } + + if num_points_to_load > file_num_points { + panic!( + "ERROR: Driver requests loading {} points and file has only {} points.", + num_points_to_load, file_num_points + ); + } + + if file_dim != index.configuration.dim { + panic!( + "ERROR: Driver requests loading {} dimension, but file has {} dimension.", + index.configuration.dim, file_dim + ); + } + + index.dataset.build_from_file(filename, num_points_to_load).unwrap(); + + println!("Using only first {} from file.", num_points_to_load); + + index.num_active_pts = num_points_to_load; +} diff --git a/rust/diskann/src/test_utils/mod.rs b/rust/diskann/src/test_utils/mod.rs new file mode 100644 index 000000000..fc8de5f30 --- /dev/null +++ b/rust/diskann/src/test_utils/mod.rs @@ -0,0 +1,11 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod inmem_index_initialization; + +/// test files should be placed under tests folder +pub fn get_test_file_path(relative_path: &str) -> String { + format!("{}/{}", env!("CARGO_MANIFEST_DIR"), relative_path) +} + diff --git a/rust/diskann/src/utils/bit_vec_extension.rs b/rust/diskann/src/utils/bit_vec_extension.rs new file mode 100644 index 000000000..9571a726e --- /dev/null +++ b/rust/diskann/src/utils/bit_vec_extension.rs @@ -0,0 +1,45 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::cmp::Ordering; + +use bit_vec::BitVec; + +pub trait BitVecExtension { + fn resize(&mut self, new_len: usize, value: bool); +} + +impl BitVecExtension for BitVec { + fn resize(&mut self, new_len: usize, value: bool) { + let old_len = self.len(); + match new_len.cmp(&old_len) { + Ordering::Less => self.truncate(new_len), + Ordering::Greater => self.grow(new_len - old_len, value), + Ordering::Equal => {} + } + } +} + +#[cfg(test)] +mod bit_vec_extension_test { + use super::*; + + #[test] + fn resize_test() { + let mut bitset = BitVec::new(); + + bitset.resize(10, false); + assert_eq!(bitset.len(), 10); + assert!(bitset.none()); + + bitset.resize(11, true); + assert_eq!(bitset.len(), 11); + assert!(bitset[10]); + + bitset.resize(5, false); + assert_eq!(bitset.len(), 5); + assert!(bitset.none()); + } +} + diff --git a/rust/diskann/src/utils/cached_reader.rs b/rust/diskann/src/utils/cached_reader.rs new file mode 100644 index 000000000..1a21f1a77 --- /dev/null +++ b/rust/diskann/src/utils/cached_reader.rs @@ -0,0 +1,160 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::fs::File; +use std::io::{Seek, Read}; + +use crate::common::{ANNResult, ANNError}; + +/// Sequential cached reads +pub struct CachedReader { + /// File reader + reader: File, + + /// # bytes to cache in one shot read + cache_size: u64, + + /// Underlying buf for cache + cache_buf: Vec, + + /// Offset into cache_buf for cur_pos + cur_off: u64, + + /// File size + fsize: u64, +} + +impl CachedReader { + pub fn new(filename: &str, cache_size: u64) -> std::io::Result { + let mut reader = File::open(filename)?; + let metadata = reader.metadata()?; + let fsize = metadata.len(); + + let cache_size = cache_size.min(fsize); + let mut cache_buf = vec![0; cache_size as usize]; + reader.read_exact(&mut cache_buf)?; + println!("Opened: {}, size: {}, cache_size: {}", filename, fsize, cache_size); + + Ok(Self { + reader, + cache_size, + cache_buf, + cur_off: 0, + fsize, + }) + } + + pub fn get_file_size(&self) -> u64 { + self.fsize + } + + pub fn read(&mut self, read_buf: &mut [u8]) -> ANNResult<()> { + let n_bytes = read_buf.len() as u64; + if n_bytes <= (self.cache_size - self.cur_off) { + // case 1: cache contains all data + read_buf.copy_from_slice(&self.cache_buf[(self.cur_off as usize)..(self.cur_off as usize + n_bytes as usize)]); + self.cur_off += n_bytes; + } else { + // case 2: cache contains some data + let cached_bytes = self.cache_size - self.cur_off; + if n_bytes - cached_bytes > self.fsize - self.reader.stream_position()? { + return Err(ANNError::log_index_error(format!( + "Reading beyond end of file, n_bytes: {} cached_bytes: {} fsize: {} current pos: {}", + n_bytes, cached_bytes, self.fsize, self.reader.stream_position()?)) + ); + } + + read_buf[..cached_bytes as usize].copy_from_slice(&self.cache_buf[self.cur_off as usize..]); + // go to disk and fetch more data + self.reader.read_exact(&mut read_buf[cached_bytes as usize..])?; + // reset cur off + self.cur_off = self.cache_size; + + let size_left = self.fsize - self.reader.stream_position()?; + if size_left >= self.cache_size { + self.reader.read_exact(&mut self.cache_buf)?; + self.cur_off = 0; + } + // note that if size_left < cache_size, then cur_off = cache_size, + // so subsequent reads will all be directly from file + } + Ok(()) + } + + pub fn read_u32(&mut self) -> ANNResult { + let mut bytes = [0u8; 4]; + self.read(&mut bytes)?; + Ok(u32::from_le_bytes(bytes)) + } +} + +#[cfg(test)] +mod cached_reader_test { + use std::fs; + + use super::*; + + #[test] + fn cached_reader_works() { + let file_name = "cached_reader_works_test.bin"; + //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8] + let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3, + 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let mut reader = CachedReader::new(file_name, 8).unwrap(); + assert_eq!(reader.get_file_size(), 72); + assert_eq!(reader.cache_size, 8); + + let mut all_from_cache_buf = vec![0; 4]; + reader.read(all_from_cache_buf.as_mut_slice()).unwrap(); + assert_eq!(all_from_cache_buf, [2, 0, 1, 2]); + assert_eq!(reader.cur_off, 4); + + let mut partial_from_cache_buf = vec![0; 6]; + reader.read(partial_from_cache_buf.as_mut_slice()).unwrap(); + assert_eq!(partial_from_cache_buf, [8, 0, 1, 3, 0x00, 0x01]); + assert_eq!(reader.cur_off, 0); + + let mut over_cache_size_buf = vec![0; 60]; + reader.read(over_cache_size_buf.as_mut_slice()).unwrap(); + assert_eq!( + over_cache_size_buf, + [0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11] + ); + + let mut remaining_less_than_cache_size_buf = vec![0; 2]; + reader.read(remaining_less_than_cache_size_buf.as_mut_slice()).unwrap(); + assert_eq!(remaining_less_than_cache_size_buf, [0x80, 0x41]); + assert_eq!(reader.cur_off, reader.cache_size); + + fs::remove_file(file_name).expect("Failed to delete file"); + } + + #[test] + #[should_panic(expected = "n_bytes: 73 cached_bytes: 8 fsize: 72 current pos: 8")] + fn failed_for_reading_beyond_end_of_file() { + let file_name = "failed_for_reading_beyond_end_of_file_test.bin"; + //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8] + let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3, + 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let mut reader = CachedReader::new(file_name, 8).unwrap(); + fs::remove_file(file_name).expect("Failed to delete file"); + + let mut over_size_buf = vec![0; 73]; + reader.read(over_size_buf.as_mut_slice()).unwrap(); + } +} + diff --git a/rust/diskann/src/utils/cached_writer.rs b/rust/diskann/src/utils/cached_writer.rs new file mode 100644 index 000000000..d3929bef2 --- /dev/null +++ b/rust/diskann/src/utils/cached_writer.rs @@ -0,0 +1,142 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::io::{Write, Seek, SeekFrom}; +use std::fs::{OpenOptions, File}; +use std::path::Path; + +pub struct CachedWriter { + /// File writer + writer: File, + + /// # bytes to cache for one shot write + cache_size: u64, + + /// Underlying buf for cache + cache_buf: Vec, + + /// Offset into cache_buf for cur_pos + cur_off: u64, + + /// File size + fsize: u64, +} + +impl CachedWriter { + pub fn new(filename: &str, cache_size: u64) -> std::io::Result { + let writer = OpenOptions::new() + .write(true) + .create(true) + .open(Path::new(filename))?; + + if cache_size == 0 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "Cache size must be greater than 0")); + } + + println!("Opened: {}, cache_size: {}", filename, cache_size); + Ok(Self { + writer, + cache_size, + cache_buf: vec![0; cache_size as usize], + cur_off: 0, + fsize: 0, + }) + } + + pub fn flush(&mut self) -> std::io::Result<()> { + // dump any remaining data in memory + if self.cur_off > 0 { + self.flush_cache()?; + } + + self.writer.flush()?; + println!("Finished writing {}B", self.fsize); + Ok(()) + } + + pub fn get_file_size(&self) -> u64 { + self.fsize + } + + /// Writes n_bytes from write_buf to the underlying cache + pub fn write(&mut self, write_buf: &[u8]) -> std::io::Result<()> { + let n_bytes = write_buf.len() as u64; + if n_bytes <= (self.cache_size - self.cur_off) { + // case 1: cache can take all data + self.cache_buf[(self.cur_off as usize)..((self.cur_off + n_bytes) as usize)].copy_from_slice(&write_buf[..n_bytes as usize]); + self.cur_off += n_bytes; + } else { + // case 2: cache cant take all data + // go to disk and write existing cache data + self.writer.write_all(&self.cache_buf[..self.cur_off as usize])?; + self.fsize += self.cur_off; + // write the new data to disk + self.writer.write_all(write_buf)?; + self.fsize += n_bytes; + // clear cache data and reset cur_off + self.cache_buf.fill(0); + self.cur_off = 0; + } + Ok(()) + } + + pub fn reset(&mut self) -> std::io::Result<()> { + self.flush_cache()?; + self.writer.seek(SeekFrom::Start(0))?; + Ok(()) + } + + fn flush_cache(&mut self) -> std::io::Result<()> { + self.writer.write_all(&self.cache_buf[..self.cur_off as usize])?; + self.fsize += self.cur_off; + self.cache_buf.fill(0); + self.cur_off = 0; + Ok(()) + } +} + +impl Drop for CachedWriter { + fn drop(&mut self) { + let _ = self.flush(); + } +} + +#[cfg(test)] +mod cached_writer_test { + use std::fs; + + use super::*; + + #[test] + fn cached_writer_works() { + let file_name = "cached_writer_works_test.bin"; + //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8] + let data: [u8; 72] = [2, 0, 1, 2, 8, 0, 1, 3, + 0x00, 0x01, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x11, 0x80, 0x41]; + + let mut writer = CachedWriter::new(file_name, 8).unwrap(); + assert_eq!(writer.get_file_size(), 0); + assert_eq!(writer.cache_size, 8); + assert_eq!(writer.get_file_size(), 0); + + let cache_all_buf = &data[0..4]; + writer.write(cache_all_buf).unwrap(); + assert_eq!(&writer.cache_buf[..4], cache_all_buf); + assert_eq!(&writer.cache_buf[4..], vec![0; 4]); + assert_eq!(writer.cur_off, 4); + assert_eq!(writer.get_file_size(), 0); + + let write_all_buf = &data[4..10]; + writer.write(write_all_buf).unwrap(); + assert_eq!(writer.cache_buf, vec![0; 8]); + assert_eq!(writer.cur_off, 0); + assert_eq!(writer.get_file_size(), 10); + + fs::remove_file(file_name).expect("Failed to delete file"); + } +} + diff --git a/rust/diskann/src/utils/file_util.rs b/rust/diskann/src/utils/file_util.rs new file mode 100644 index 000000000..f187d0128 --- /dev/null +++ b/rust/diskann/src/utils/file_util.rs @@ -0,0 +1,377 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! File operations + +use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; +use std::{mem, io}; +use std::fs::{self, File, OpenOptions}; +use std::io::{Read, BufReader, Write, Seek, SeekFrom}; +use std::path::Path; + +use crate::model::data_store::DatasetDto; + +/// Read metadata of data file. +pub fn load_metadata_from_file(file_name: &str) -> std::io::Result<(usize, usize)> { + let file = File::open(file_name)?; + let mut reader = BufReader::new(file); + + let npoints = reader.read_i32::()? as usize; + let ndims = reader.read_i32::()? as usize; + + Ok((npoints, ndims)) +} + +/// Read the deleted vertex ids from file. +pub fn load_ids_to_delete_from_file(file_name: &str) -> std::io::Result<(usize, Vec)> { + // The first 4 bytes are the number of vector ids. + // The rest of the file are the vector ids in the format of usize. + // The vector ids are sorted in ascending order. + let mut file = File::open(file_name)?; + let num_ids = file.read_u32::()? as usize; + + let mut ids = Vec::with_capacity(num_ids); + for _ in 0..num_ids { + let id = file.read_u32::()?; + ids.push(id); + } + + Ok((num_ids, ids)) +} + +/// Copy data from file +/// # Arguments +/// * `bin_file` - filename where the data is +/// * `data` - destination dataset dto to which the data is copied +/// * `pts_offset` - offset of points. data will be loaded after this point in dataset +/// * `npts` - number of points read from bin_file +/// * `dim` - point dimension read from bin_file +/// * `rounded_dim` - rounded dimension (padding zero if it's > dim) +/// # Return +/// * `npts` - number of points read from bin_file +/// * `dim` - point dimension read from bin_file +pub fn copy_aligned_data_from_file( + bin_file: &str, + dataset_dto: DatasetDto, + pts_offset: usize, +) -> std::io::Result<(usize, usize)> { + let mut reader = File::open(bin_file)?; + + let npts = reader.read_i32::()? as usize; + let dim = reader.read_i32::()? as usize; + let rounded_dim = dataset_dto.rounded_dim; + let offset = pts_offset * rounded_dim; + + for i in 0..npts { + let data_slice = &mut dataset_dto.data[offset + i * rounded_dim..offset + i * rounded_dim + dim]; + let mut buf = vec![0u8; dim * mem::size_of::()]; + reader.read_exact(&mut buf)?; + + let ptr = buf.as_ptr() as *const T; + let temp_slice = unsafe { std::slice::from_raw_parts(ptr, dim) }; + data_slice.copy_from_slice(temp_slice); + + (i * rounded_dim + dim..i * rounded_dim + rounded_dim).for_each(|j| { + dataset_dto.data[j] = T::default(); + }); + } + + Ok((npts, dim)) +} + +/// Open a file to write +/// # Arguments +/// * `writer` - mutable File reference +/// * `file_name` - file name +#[inline] +pub fn open_file_to_write(file_name: &str) -> std::io::Result { + OpenOptions::new() + .write(true) + .create(true) + .open(Path::new(file_name)) +} + +/// Delete a file +/// # Arguments +/// * `file_name` - file name +pub fn delete_file(file_name: &str) -> std::io::Result<()> { + if file_exists(file_name) { + fs::remove_file(file_name)?; + } + + Ok(()) +} + +/// Check whether file exists or not +pub fn file_exists(filename: &str) -> bool { + std::path::Path::new(filename).exists() +} + +/// Save data to file +/// # Arguments +/// * `filename` - filename where the data is +/// * `data` - information data +/// * `npts` - number of points +/// * `ndims` - point dimension +/// * `aligned_dim` - aligned dimension +/// * `offset` - data offset in file +pub fn save_data_in_base_dimensions( + filename: &str, + data: &mut [T], + npts: usize, + ndims: usize, + aligned_dim: usize, + offset: usize, +) -> std::io::Result { + let mut writer = open_file_to_write(filename)?; + let npts_i32 = npts as i32; + let ndims_i32 = ndims as i32; + let bytes_written = 2 * std::mem::size_of::() + npts * ndims * (std::mem::size_of::()); + + writer.seek(std::io::SeekFrom::Start(offset as u64))?; + writer.write_all(&npts_i32.to_le_bytes())?; + writer.write_all(&ndims_i32.to_le_bytes())?; + let data_ptr = data.as_ptr() as *const u8; + for i in 0..npts { + let middle_offset = i * aligned_dim * std::mem::size_of::(); + let middle_slice = unsafe { std::slice::from_raw_parts(data_ptr.add(middle_offset), ndims * std::mem::size_of::()) }; + writer.write_all(middle_slice)?; + } + writer.flush()?; + Ok(bytes_written) +} + +/// Read data file +/// # Arguments +/// * `bin_file` - filename where the data is +/// * `file_offset` - data offset in file +/// * `data` - information data +/// * `npts` - number of points +/// * `ndims` - point dimension +pub fn load_bin( + bin_file: &str, + file_offset: usize) -> std::io::Result<(Vec, usize, usize)> +{ + let mut reader = File::open(bin_file)?; + reader.seek(std::io::SeekFrom::Start(file_offset as u64))?; + let npts = reader.read_i32::()? as usize; + let dim = reader.read_i32::()? as usize; + + let size = npts * dim * std::mem::size_of::(); + let mut buf = vec![0u8; size]; + reader.read_exact(&mut buf)?; + + let ptr = buf.as_ptr() as *const T; + let data = unsafe { std::slice::from_raw_parts(ptr, npts * dim)}; + + Ok((data.to_vec(), npts, dim)) +} + +/// Get file size +pub fn get_file_size(filename: &str) -> io::Result { + let reader = File::open(filename)?; + let metadata = reader.metadata()?; + Ok(metadata.len()) +} + +macro_rules! save_bin { + ($name:ident, $t:ty, $write_func:ident) => { + /// Write data into file + pub fn $name(filename: &str, data: &[$t], num_pts: usize, dims: usize, offset: usize) -> std::io::Result { + let mut writer = open_file_to_write(filename)?; + + println!("Writing bin: {}", filename); + writer.seek(SeekFrom::Start(offset as u64))?; + let num_pts_i32 = num_pts as i32; + let dims_i32 = dims as i32; + let bytes_written = num_pts * dims * mem::size_of::<$t>() + 2 * mem::size_of::(); + + writer.write_i32::(num_pts_i32)?; + writer.write_i32::(dims_i32)?; + println!("bin: #pts = {}, #dims = {}, size = {}B", num_pts, dims, bytes_written); + + for item in data.iter() { + writer.$write_func::(*item)?; + } + + writer.flush()?; + + println!("Finished writing bin."); + Ok(bytes_written) + } + }; +} + +save_bin!(save_bin_f32, f32, write_f32); +save_bin!(save_bin_u64, u64, write_u64); +save_bin!(save_bin_u32, u32, write_u32); + +#[cfg(test)] +mod file_util_test { + use crate::model::data_store::InmemDataset; + use std::fs; + use super::*; + + pub const DIM_8: usize = 8; + + #[test] + fn load_metadata_test() { + let file_name = "test_load_metadata_test.bin"; + let data = [200, 0, 0, 0, 128, 0, 0, 0]; // 200 and 128 in little endian bytes + std::fs::write(file_name, data).expect("Failed to write sample file"); + match load_metadata_from_file(file_name) { + Ok((npoints, ndims)) => { + assert!(npoints == 200); + assert!(ndims == 128); + }, + Err(_e) => {}, + } + fs::remove_file(file_name).expect("Failed to delete file"); + } + + #[test] + fn load_data_test() { + let file_name = "test_load_data_test.bin"; + //npoints=2, dim=8, 2 vectors [1.0;8] [2.0;8] + let data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0, + 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let mut dataset = InmemDataset::::new(2, 1f32).unwrap(); + + match copy_aligned_data_from_file(file_name, dataset.into_dto(), 0) { + Ok((num_points, dim)) => { + fs::remove_file(file_name).expect("Failed to delete file"); + assert!(num_points == 2); + assert!(dim == 8); + assert!(dataset.data.len() == 16); + + let first_vertex = dataset.get_vertex(0).unwrap(); + let second_vertex = dataset.get_vertex(1).unwrap(); + + assert!(*first_vertex.vector() == [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]); + assert!(*second_vertex.vector() == [9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0]); + }, + Err(e) => { + fs::remove_file(file_name).expect("Failed to delete file"); + panic!("{}", e) + }, + } + } + + #[test] + fn open_file_to_write_test() { + let file_name = "test_open_file_to_write_test.bin"; + let mut writer = File::create(file_name).unwrap(); + let data = [200, 0, 0, 0, 128, 0, 0, 0]; + writer.write(&data).expect("Failed to write sample file"); + + let _ = open_file_to_write(file_name); + + fs::remove_file(file_name).expect("Failed to delete file"); + } + + #[test] + fn delete_file_test() { + let file_name = "test_delete_file_test.bin"; + let mut file = File::create(file_name).unwrap(); + writeln!(file, "test delete file").unwrap(); + + let result = delete_file(file_name); + + assert!(result.is_ok()); + assert!(fs::metadata(file_name).is_err()); + } + + #[test] + fn save_data_in_base_dimensions_test() { + //npoints=2, dim=8 + let mut data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0, + 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41]; + let num_points = 2; + let dim = DIM_8; + let data_file = "save_data_in_base_dimensions_test.data"; + match save_data_in_base_dimensions(data_file, &mut data, num_points, dim, DIM_8, 0) { + Ok(num) => { + assert!(file_exists(data_file)); + assert_eq!(num, 2 * std::mem::size_of::() + num_points * dim * std::mem::size_of::()); + fs::remove_file(data_file).expect("Failed to delete file"); + }, + Err(e) => { + fs::remove_file(data_file).expect("Failed to delete file"); + panic!("{}", e) + } + } + } + + #[test] + fn save_bin_test() { + let filename = "save_bin_test"; + let data = vec![0u64, 1u64, 2u64]; + let num_pts = data.len(); + let dims = 1; + let bytes_written = save_bin_u64(filename, &data, num_pts, dims, 0).unwrap(); + assert_eq!(bytes_written, 32); + + let mut file = File::open(filename).unwrap(); + let mut buffer = vec![]; + + let npts_read = file.read_i32::().unwrap() as usize; + let dims_read = file.read_i32::().unwrap() as usize; + + file.read_to_end(&mut buffer).unwrap(); + let data_read: Vec = buffer + .chunks_exact(8) + .map(|b| u64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]])) + .collect(); + + std::fs::remove_file(filename).unwrap(); + + assert_eq!(num_pts, npts_read); + assert_eq!(dims, dims_read); + assert_eq!(data, data_read); + } + + #[test] + fn load_bin_test() { + let file_name = "load_bin_test"; + let data = vec![0u64, 1u64, 2u64]; + let num_pts = data.len(); + let dims = 1; + let bytes_written = save_bin_u64(file_name, &data, num_pts, dims, 0).unwrap(); + assert_eq!(bytes_written, 32); + + let (load_data, load_num_pts, load_dims) = load_bin::(file_name, 0).unwrap(); + assert_eq!(load_num_pts, num_pts); + assert_eq!(load_dims, dims); + assert_eq!(load_data, data); + std::fs::remove_file(file_name).unwrap(); + } + + #[test] + fn load_bin_offset_test() { + let offset:usize = 32; + let file_name = "load_bin_offset_test"; + let data = vec![0u64, 1u64, 2u64]; + let num_pts = data.len(); + let dims = 1; + let bytes_written = save_bin_u64(file_name, &data, num_pts, dims, offset).unwrap(); + assert_eq!(bytes_written, 32); + + let (load_data, load_num_pts, load_dims) = load_bin::(file_name, offset).unwrap(); + assert_eq!(load_num_pts, num_pts); + assert_eq!(load_dims, dims); + assert_eq!(load_data, data); + std::fs::remove_file(file_name).unwrap(); + } +} + diff --git a/rust/diskann/src/utils/hashset_u32.rs b/rust/diskann/src/utils/hashset_u32.rs new file mode 100644 index 000000000..15db687d6 --- /dev/null +++ b/rust/diskann/src/utils/hashset_u32.rs @@ -0,0 +1,46 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use hashbrown::HashSet; +use std::{hash::BuildHasherDefault, ops::{Deref, DerefMut}}; +use fxhash::FxHasher; + +lazy_static::lazy_static! { + /// Singleton hasher. + static ref HASHER: BuildHasherDefault = { + BuildHasherDefault::::default() + }; +} + +pub struct HashSetForU32 { + hashset: HashSet::>, +} + +impl HashSetForU32 { + pub fn with_capacity(capacity: usize) -> HashSetForU32 { + let hashset = HashSet::>::with_capacity_and_hasher(capacity, HASHER.clone()); + HashSetForU32 { + hashset + } + } +} + +impl Deref for HashSetForU32 { + type Target = HashSet::>; + + fn deref(&self) -> &Self::Target { + &self.hashset + } +} + +impl DerefMut for HashSetForU32 { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.hashset + } +} + diff --git a/rust/diskann/src/utils/kmeans.rs b/rust/diskann/src/utils/kmeans.rs new file mode 100644 index 000000000..d1edffad7 --- /dev/null +++ b/rust/diskann/src/utils/kmeans.rs @@ -0,0 +1,430 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Aligned allocator + +use rand::{distributions::Uniform, prelude::Distribution, thread_rng}; +use rayon::prelude::*; +use std::cmp::min; + +use crate::common::ANNResult; +use crate::utils::math_util::{calc_distance, compute_closest_centers, compute_vecs_l2sq}; + +/// Run Lloyds one iteration +/// Given data in row-major num_points * dim, and centers in row-major +/// num_centers * dim and squared lengths of ata points, output the closest +/// center to each data point, update centers, and also return inverted index. +/// If closest_centers == NULL, will allocate memory and return. +/// Similarly, if closest_docs == NULL, will allocate memory and return. +#[allow(clippy::too_many_arguments)] +fn lloyds_iter( + data: &[f32], + num_points: usize, + dim: usize, + centers: &mut [f32], + num_centers: usize, + docs_l2sq: &[f32], + mut closest_docs: &mut Vec>, + closest_center: &mut [u32], +) -> ANNResult { + let compute_residual = true; + + closest_docs.iter_mut().for_each(|doc| doc.clear()); + + compute_closest_centers( + data, + num_points, + dim, + centers, + num_centers, + 1, + closest_center, + Some(&mut closest_docs), + Some(docs_l2sq), + )?; + + centers.fill(0.0); + + centers + .par_chunks_mut(dim) + .enumerate() + .for_each(|(c, center)| { + let mut cluster_sum = vec![0.0; dim]; + for &doc_index in &closest_docs[c] { + let current = &data[doc_index * dim..(doc_index + 1) * dim]; + for (j, current_val) in current.iter().enumerate() { + cluster_sum[j] += *current_val as f64; + } + } + if !closest_docs[c].is_empty() { + for (i, sum_val) in cluster_sum.iter().enumerate() { + center[i] = (*sum_val / closest_docs[c].len() as f64) as f32; + } + } + }); + + let mut residual = 0.0; + if compute_residual { + let buf_pad: usize = 32; + let chunk_size: usize = 2 * 8192; + let nchunks = + num_points / chunk_size + (if num_points % chunk_size == 0 { 0 } else { 1 } as usize); + + let mut residuals: Vec = vec![0.0; nchunks * buf_pad]; + + residuals + .par_iter_mut() + .enumerate() + .for_each(|(chunk, res)| { + for d in (chunk * chunk_size)..min(num_points, (chunk + 1) * chunk_size) { + *res += calc_distance( + &data[d * dim..(d + 1) * dim], + ¢ers[closest_center[d] as usize * dim..], + dim, + ); + } + }); + + for chunk in 0..nchunks { + residual += residuals[chunk * buf_pad]; + } + } + + Ok(residual) +} + +/// Run Lloyds until max_reps or stopping criterion +/// If you pass NULL for closest_docs and closest_center, it will NOT return +/// the results, else it will assume appropriate allocation as closest_docs = +/// new vec [num_centers], and closest_center = new size_t[num_points] +/// Final centers are output in centers as row-major num_centers * dim. +fn run_lloyds( + data: &[f32], + num_points: usize, + dim: usize, + centers: &mut [f32], + num_centers: usize, + max_reps: usize, +) -> ANNResult<(Vec>, Vec, f32)> { + let mut residual = f32::MAX; + + let mut closest_docs = vec![Vec::new(); num_centers]; + let mut closest_center = vec![0; num_points]; + + let mut docs_l2sq = vec![0.0; num_points]; + compute_vecs_l2sq(&mut docs_l2sq, data, num_points, dim); + + let mut old_residual; + + for i in 0..max_reps { + old_residual = residual; + + residual = lloyds_iter( + data, + num_points, + dim, + centers, + num_centers, + &docs_l2sq, + &mut closest_docs, + &mut closest_center, + )?; + + if (i != 0 && (old_residual - residual) / residual < 0.00001) || (residual < f32::EPSILON) { + println!( + "Residuals unchanged: {} becomes {}. Early termination.", + old_residual, residual + ); + break; + } + } + + Ok((closest_docs, closest_center, residual)) +} + +/// Assume memory allocated for pivot_data as new float[num_centers * dim] +/// and select randomly num_centers points as pivots +fn selecting_pivots( + data: &[f32], + num_points: usize, + dim: usize, + pivot_data: &mut [f32], + num_centers: usize, +) { + let mut picked = Vec::new(); + let mut rng = thread_rng(); + let distribution = Uniform::from(0..num_points); + + for j in 0..num_centers { + let mut tmp_pivot = distribution.sample(&mut rng); + while picked.contains(&tmp_pivot) { + tmp_pivot = distribution.sample(&mut rng); + } + picked.push(tmp_pivot); + let data_offset = tmp_pivot * dim; + let pivot_offset = j * dim; + pivot_data[pivot_offset..pivot_offset + dim] + .copy_from_slice(&data[data_offset..data_offset + dim]); + } +} + +/// Select pivots in k-means++ algorithm +/// Points that are farther away from the already chosen centroids +/// have a higher probability of being selected as the next centroid. +/// The k-means++ algorithm helps avoid poor initial centroid +/// placement that can result in suboptimal clustering. +fn k_meanspp_selecting_pivots( + data: &[f32], + num_points: usize, + dim: usize, + pivot_data: &mut [f32], + num_centers: usize, +) { + if num_points > (1 << 23) { + println!("ERROR: n_pts {} currently not supported for k-means++, maximum is 8388608. Falling back to random pivot selection.", num_points); + selecting_pivots(data, num_points, dim, pivot_data, num_centers); + return; + } + + let mut picked: Vec = Vec::new(); + let mut rng = thread_rng(); + let real_distribution = Uniform::from(0.0..1.0); + let int_distribution = Uniform::from(0..num_points); + + let init_id = int_distribution.sample(&mut rng); + let mut num_picked = 1; + + picked.push(init_id); + let init_data_offset = init_id * dim; + pivot_data[0..dim].copy_from_slice(&data[init_data_offset..init_data_offset + dim]); + + let mut dist = vec![0.0; num_points]; + + dist.par_iter_mut().enumerate().for_each(|(i, dist_i)| { + *dist_i = calc_distance( + &data[i * dim..(i + 1) * dim], + &data[init_id * dim..(init_id + 1) * dim], + dim, + ); + }); + + let mut dart_val: f64; + let mut tmp_pivot = 0; + let mut sum_flag = false; + + while num_picked < num_centers { + dart_val = real_distribution.sample(&mut rng); + + let mut sum: f64 = 0.0; + for item in dist.iter().take(num_points) { + sum += *item as f64; + } + if sum == 0.0 { + sum_flag = true; + } + + dart_val *= sum; + + let mut prefix_sum: f64 = 0.0; + for (i, pivot) in dist.iter().enumerate().take(num_points) { + tmp_pivot = i; + if dart_val >= prefix_sum && dart_val < (prefix_sum + *pivot as f64) { + break; + } + + prefix_sum += *pivot as f64; + } + + if picked.contains(&tmp_pivot) && !sum_flag { + continue; + } + + picked.push(tmp_pivot); + let pivot_offset = num_picked * dim; + let data_offset = tmp_pivot * dim; + pivot_data[pivot_offset..pivot_offset + dim] + .copy_from_slice(&data[data_offset..data_offset + dim]); + + dist.par_iter_mut().enumerate().for_each(|(i, dist_i)| { + *dist_i = (*dist_i).min(calc_distance( + &data[i * dim..(i + 1) * dim], + &data[tmp_pivot * dim..(tmp_pivot + 1) * dim], + dim, + )); + }); + + num_picked += 1; + } +} + +/// k-means algorithm interface +pub fn k_means_clustering( + data: &[f32], + num_points: usize, + dim: usize, + centers: &mut [f32], + num_centers: usize, + max_reps: usize, +) -> ANNResult<(Vec>, Vec, f32)> { + k_meanspp_selecting_pivots(data, num_points, dim, centers, num_centers); + let (closest_docs, closest_center, residual) = + run_lloyds(data, num_points, dim, centers, num_centers, max_reps)?; + Ok((closest_docs, closest_center, residual)) +} + +#[cfg(test)] +mod kmeans_test { + use super::*; + use approx::assert_relative_eq; + use rand::Rng; + + #[test] + fn lloyds_iter_test() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + + let data: Vec = (1..=num_points * dim).map(|x| x as f32).collect(); + let mut centers = [1.0, 2.0, 7.0, 8.0, 19.0, 20.0]; + + let mut closest_docs: Vec> = vec![vec![]; num_centers]; + let mut closest_center: Vec = vec![0; num_points]; + let docs_l2sq: Vec = data + .chunks(dim) + .map(|chunk| chunk.iter().map(|val| val.powi(2)).sum()) + .collect(); + + let residual = lloyds_iter( + &data, + num_points, + dim, + &mut centers, + num_centers, + &docs_l2sq, + &mut closest_docs, + &mut closest_center, + ) + .unwrap(); + + let expected_centers: [f32; 6] = [2.0, 3.0, 9.0, 10.0, 17.0, 18.0]; + let expected_closest_docs: Vec> = + vec![vec![0, 1], vec![2, 3, 4, 5, 6], vec![7, 8, 9]]; + let expected_closest_center: [u32; 10] = [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]; + let expected_residual: f32 = 100.0; + + // sort data for assert + centers.sort_by(|a, b| a.partial_cmp(b).unwrap()); + for inner_vec in &mut closest_docs { + inner_vec.sort(); + } + closest_center.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + assert_eq!(centers, expected_centers); + assert_eq!(closest_docs, expected_closest_docs); + assert_eq!(closest_center, expected_closest_center); + assert_relative_eq!(residual, expected_residual, epsilon = 1.0e-6_f32); + } + + #[test] + fn run_lloyds_test() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + let max_reps = 5; + + let data: Vec = (1..=num_points * dim).map(|x| x as f32).collect(); + let mut centers = [1.0, 2.0, 7.0, 8.0, 19.0, 20.0]; + + let (mut closest_docs, mut closest_center, residual) = + run_lloyds(&data, num_points, dim, &mut centers, num_centers, max_reps).unwrap(); + + let expected_centers: [f32; 6] = [3.0, 4.0, 10.0, 11.0, 17.0, 18.0]; + let expected_closest_docs: Vec> = + vec![vec![0, 1, 2], vec![3, 4, 5, 6], vec![7, 8, 9]]; + let expected_closest_center: [u32; 10] = [0, 0, 0, 1, 1, 1, 1, 2, 2, 2]; + let expected_residual: f32 = 72.0; + + // sort data for assert + centers.sort_by(|a, b| a.partial_cmp(b).unwrap()); + for inner_vec in &mut closest_docs { + inner_vec.sort(); + } + closest_center.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + assert_eq!(centers, expected_centers); + assert_eq!(closest_docs, expected_closest_docs); + assert_eq!(closest_center, expected_closest_center); + assert_relative_eq!(residual, expected_residual, epsilon = 1.0e-6_f32); + } + + #[test] + fn selecting_pivots_test() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + + // Generate some random data points + let mut rng = rand::thread_rng(); + let data: Vec = (0..num_points * dim).map(|_| rng.gen()).collect(); + + let mut pivot_data = vec![0.0; num_centers * dim]; + + selecting_pivots(&data, num_points, dim, &mut pivot_data, num_centers); + + // Verify that each pivot point corresponds to a point in the data + for i in 0..num_centers { + let pivot_offset = i * dim; + let pivot = &pivot_data[pivot_offset..(pivot_offset + dim)]; + + // Make sure the pivot is found in the data + let mut found = false; + for j in 0..num_points { + let data_offset = j * dim; + let point = &data[data_offset..(data_offset + dim)]; + + if pivot == point { + found = true; + break; + } + } + assert!(found, "Pivot not found in data"); + } + } + + #[test] + fn k_meanspp_selecting_pivots_test() { + let dim = 2; + let num_points = 10; + let num_centers = 3; + + // Generate some random data points + let mut rng = rand::thread_rng(); + let data: Vec = (0..num_points * dim).map(|_| rng.gen()).collect(); + + let mut pivot_data = vec![0.0; num_centers * dim]; + + k_meanspp_selecting_pivots(&data, num_points, dim, &mut pivot_data, num_centers); + + // Verify that each pivot point corresponds to a point in the data + for i in 0..num_centers { + let pivot_offset = i * dim; + let pivot = &pivot_data[pivot_offset..pivot_offset + dim]; + + // Make sure the pivot is found in the data + let mut found = false; + for j in 0..num_points { + let data_offset = j * dim; + let point = &data[data_offset..data_offset + dim]; + + if pivot == point { + found = true; + break; + } + } + assert!(found, "Pivot not found in data"); + } + } +} diff --git a/rust/diskann/src/utils/math_util.rs b/rust/diskann/src/utils/math_util.rs new file mode 100644 index 000000000..ef30c76ff --- /dev/null +++ b/rust/diskann/src/utils/math_util.rs @@ -0,0 +1,481 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Aligned allocator + +extern crate cblas; +extern crate openblas_src; + +use cblas::{sgemm, snrm2, Layout, Transpose}; +use rayon::prelude::*; +use std::{ + cmp::{min, Ordering}, + collections::BinaryHeap, + sync::{Arc, Mutex}, +}; + +use crate::common::{ANNError, ANNResult}; + +struct PivotContainer { + piv_id: usize, + piv_dist: f32, +} + +impl PartialOrd for PivotContainer { + fn partial_cmp(&self, other: &Self) -> Option { + other.piv_dist.partial_cmp(&self.piv_dist) + } +} + +impl Ord for PivotContainer { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + // Treat NaN as less than all other values. + // piv_dist should never be NaN. + self.partial_cmp(other).unwrap_or(Ordering::Less) + } +} + +impl PartialEq for PivotContainer { + fn eq(&self, other: &Self) -> bool { + self.piv_dist == other.piv_dist + } +} + +impl Eq for PivotContainer {} + +/// Calculate the Euclidean distance between two vectors +pub fn calc_distance(vec_1: &[f32], vec_2: &[f32], dim: usize) -> f32 { + let mut dist = 0.0; + for j in 0..dim { + let diff = vec_1[j] - vec_2[j]; + dist += diff * diff; + } + dist +} + +/// Compute L2-squared norms of data stored in row-major num_points * dim, +/// need to be pre-allocated +pub fn compute_vecs_l2sq(vecs_l2sq: &mut [f32], data: &[f32], num_points: usize, dim: usize) { + assert_eq!(vecs_l2sq.len(), num_points); + + vecs_l2sq + .par_iter_mut() + .enumerate() + .for_each(|(n_iter, vec_l2sq)| { + let slice = &data[n_iter * dim..(n_iter + 1) * dim]; + let norm = unsafe { snrm2(dim as i32, slice, 1) }; + *vec_l2sq = norm * norm; + }); +} + +/// Calculate k closest centers to data of num_points * dim (row-major) +/// Centers is num_centers * dim (row-major) +/// data_l2sq has pre-computed squared norms of data +/// centers_l2sq has pre-computed squared norms of centers +/// Pre-allocated center_index will contain id of nearest center +/// Pre-allocated dist_matrix should be num_points * num_centers and contain squared distances +/// Default value of k is 1 +/// Ideally used only by compute_closest_centers +#[allow(clippy::too_many_arguments)] +pub fn compute_closest_centers_in_block( + data: &[f32], + num_points: usize, + dim: usize, + centers: &[f32], + num_centers: usize, + docs_l2sq: &[f32], + centers_l2sq: &[f32], + center_index: &mut [u32], + dist_matrix: &mut [f32], + k: usize, +) -> ANNResult<()> { + if k > num_centers { + return Err(ANNError::log_index_error(format!( + "ERROR: k ({}) > num_centers({})", + k, num_centers + ))); + } + + let ones_a: Vec = vec![1.0; num_centers]; + let ones_b: Vec = vec![1.0; num_points]; + + unsafe { + sgemm( + Layout::RowMajor, + Transpose::None, + Transpose::Ordinary, + num_points as i32, + num_centers as i32, + 1, + 1.0, + docs_l2sq, + 1, + &ones_a, + 1, + 0.0, + dist_matrix, + num_centers as i32, + ); + } + + unsafe { + sgemm( + Layout::RowMajor, + Transpose::None, + Transpose::Ordinary, + num_points as i32, + num_centers as i32, + 1, + 1.0, + &ones_b, + 1, + centers_l2sq, + 1, + 1.0, + dist_matrix, + num_centers as i32, + ); + } + + unsafe { + sgemm( + Layout::RowMajor, + Transpose::None, + Transpose::Ordinary, + num_points as i32, + num_centers as i32, + dim as i32, + -2.0, + data, + dim as i32, + centers, + dim as i32, + 1.0, + dist_matrix, + num_centers as i32, + ); + } + + if k == 1 { + center_index + .par_iter_mut() + .enumerate() + .for_each(|(i, center_idx)| { + let mut min = f32::MAX; + let current = &dist_matrix[i * num_centers..(i + 1) * num_centers]; + let mut min_idx = 0; + for (j, &distance) in current.iter().enumerate() { + if distance < min { + min = distance; + min_idx = j; + } + } + *center_idx = min_idx as u32; + }); + } else { + center_index + .par_chunks_mut(k) + .enumerate() + .for_each(|(i, center_chunk)| { + let current = &dist_matrix[i * num_centers..(i + 1) * num_centers]; + let mut top_k_queue = BinaryHeap::new(); + for (j, &distance) in current.iter().enumerate() { + let this_piv = PivotContainer { + piv_id: j, + piv_dist: distance, + }; + if top_k_queue.len() < k { + top_k_queue.push(this_piv); + } else { + // Safe unwrap, top_k_queue is not empty + #[allow(clippy::unwrap_used)] + let mut top = top_k_queue.peek_mut().unwrap(); + if this_piv.piv_dist < top.piv_dist { + *top = this_piv; + } + } + } + for (_j, center_idx) in center_chunk.iter_mut().enumerate() { + if let Some(this_piv) = top_k_queue.pop() { + *center_idx = this_piv.piv_id as u32; + } else { + break; + } + } + }); + } + + Ok(()) +} + +/// Given data in num_points * new_dim row major +/// Pivots stored in full_pivot_data as num_centers * new_dim row major +/// Calculate the k closest pivot for each point and store it in vector +/// closest_centers_ivf (row major, num_points*k) (which needs to be allocated +/// outside) Additionally, if inverted index is not null (and pre-allocated), +/// it will return inverted index for each center, assuming each of the inverted +/// indices is an empty vector. Additionally, if pts_norms_squared is not null, +/// then it will assume that point norms are pre-computed and use those values +#[allow(clippy::too_many_arguments)] +pub fn compute_closest_centers( + data: &[f32], + num_points: usize, + dim: usize, + pivot_data: &[f32], + num_centers: usize, + k: usize, + closest_centers_ivf: &mut [u32], + mut inverted_index: Option<&mut Vec>>, + pts_norms_squared: Option<&[f32]>, +) -> ANNResult<()> { + if k > num_centers { + return Err(ANNError::log_index_error(format!( + "ERROR: k ({}) > num_centers({})", + k, num_centers + ))); + } + + let _is_norm_given_for_pts = pts_norms_squared.is_some(); + + let mut pivs_norms_squared = vec![0.0; num_centers]; + + let mut pts_norms_squared = if let Some(pts_norms) = pts_norms_squared { + pts_norms.to_vec() + } else { + let mut norms_squared = vec![0.0; num_points]; + compute_vecs_l2sq(&mut norms_squared, data, num_points, dim); + norms_squared + }; + + compute_vecs_l2sq(&mut pivs_norms_squared, pivot_data, num_centers, dim); + + let par_block_size = num_points; + let n_blocks = if num_points % par_block_size == 0 { + num_points / par_block_size + } else { + num_points / par_block_size + 1 + }; + + let mut closest_centers = vec![0u32; par_block_size * k]; + let mut distance_matrix = vec![0.0; num_centers * par_block_size]; + + for cur_blk in 0..n_blocks { + let data_cur_blk = &data[cur_blk * par_block_size * dim..]; + let num_pts_blk = min(par_block_size, num_points - cur_blk * par_block_size); + let pts_norms_blk = &mut pts_norms_squared[cur_blk * par_block_size..]; + + compute_closest_centers_in_block( + data_cur_blk, + num_pts_blk, + dim, + pivot_data, + num_centers, + pts_norms_blk, + &pivs_norms_squared, + &mut closest_centers, + &mut distance_matrix, + k, + )?; + + closest_centers_ivf.clone_from_slice(&closest_centers); + + if let Some(inverted_index_inner) = inverted_index.as_mut() { + let inverted_index_arc = Arc::new(Mutex::new(inverted_index_inner)); + + (0..num_points) + .into_par_iter() + .try_for_each(|j| -> ANNResult<()> { + let this_center_id = closest_centers[j] as usize; + let mut guard = inverted_index_arc.lock().map_err(|err| { + ANNError::log_index_error(format!( + "PoisonError: Lock poisoned when acquiring inverted_index_arc, err={}", + err + )) + })?; + guard[this_center_id].push(j); + + Ok(()) + })?; + } + } + + Ok(()) +} + +/// If to_subtract is true, will subtract nearest center from each row. +/// Else will add. +/// Output will be in data_load itself. +/// Nearest centers need to be provided in closest_centers. +pub fn process_residuals( + data_load: &mut [f32], + num_points: usize, + dim: usize, + cur_pivot_data: &[f32], + num_centers: usize, + closest_centers: &[u32], + to_subtract: bool, +) { + println!( + "Processing residuals of {} points in {} dimensions using {} centers", + num_points, dim, num_centers + ); + + data_load + .par_chunks_mut(dim) + .enumerate() + .for_each(|(n_iter, chunk)| { + let cur_pivot_index = closest_centers[n_iter] as usize * dim; + for d_iter in 0..dim { + if to_subtract { + chunk[d_iter] -= cur_pivot_data[cur_pivot_index + d_iter]; + } else { + chunk[d_iter] += cur_pivot_data[cur_pivot_index + d_iter]; + } + } + }); +} + +#[cfg(test)] +mod math_util_test { + use super::*; + use approx::assert_abs_diff_eq; + + #[test] + fn calc_distance_test() { + let vec1 = vec![1.0, 2.0, 3.0]; + let vec2 = vec![4.0, 5.0, 6.0]; + let dim = vec1.len(); + + let dist = calc_distance(&vec1, &vec2, dim); + + let expected = 27.0; + + assert_eq!(dist, expected); + } + + #[test] + fn compute_vecs_l2sq_test() { + let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; + let num_points = 2; + let dim = 3; + let mut vecs_l2sq = vec![0.0; num_points]; + + compute_vecs_l2sq(&mut vecs_l2sq, &data, num_points, dim); + + let expected = vec![14.0, 77.0]; + + assert_eq!(vecs_l2sq.len(), num_points); + assert_abs_diff_eq!(vecs_l2sq[0], expected[0], epsilon = 1e-6); + assert_abs_diff_eq!(vecs_l2sq[1], expected[1], epsilon = 1e-6); + } + + #[test] + fn compute_closest_centers_in_block_test() { + let num_points = 10; + let dim = 5; + let num_centers = 3; + let data = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, + 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, + 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, + ]; + let centers = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 31.0, 32.0, 33.0, 34.0, 35.0, + ]; + let mut docs_l2sq = vec![0.0; num_points]; + compute_vecs_l2sq(&mut docs_l2sq, &data, num_points, dim); + let mut centers_l2sq = vec![0.0; num_centers]; + compute_vecs_l2sq(&mut centers_l2sq, ¢ers, num_centers, dim); + let mut center_index = vec![0; num_points]; + let mut dist_matrix = vec![0.0; num_points * num_centers]; + let k = 1; + + compute_closest_centers_in_block( + &data, + num_points, + dim, + ¢ers, + num_centers, + &docs_l2sq, + ¢ers_l2sq, + &mut center_index, + &mut dist_matrix, + k, + ) + .unwrap(); + + assert_eq!(center_index.len(), num_points); + let expected_center_index = vec![0, 0, 0, 1, 1, 1, 2, 2, 2, 2]; + assert_abs_diff_eq!(*center_index, expected_center_index); + + assert_eq!(dist_matrix.len(), num_points * num_centers); + let expected_dist_matrix = vec![ + 0.0, 2000.0, 4500.0, 125.0, 1125.0, 3125.0, 500.0, 500.0, 2000.0, 1125.0, 125.0, + 1125.0, 2000.0, 0.0, 500.0, 3125.0, 125.0, 125.0, 4500.0, 500.0, 0.0, 6125.0, 1125.0, + 125.0, 8000.0, 2000.0, 500.0, 10125.0, 3125.0, 1125.0, + ]; + assert_abs_diff_eq!(*dist_matrix, expected_dist_matrix, epsilon = 1e-2); + } + + #[test] + fn test_compute_closest_centers() { + let num_points = 4; + let dim = 3; + let num_centers = 2; + let mut data = vec![ + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ]; + let pivot_data = vec![1.0, 2.0, 3.0, 10.0, 11.0, 12.0]; + let k = 1; + + let mut closest_centers_ivf = vec![0u32; num_points * k]; + let mut inverted_index: Vec> = vec![vec![], vec![]]; + + compute_closest_centers( + &data, + num_points, + dim, + &pivot_data, + num_centers, + k, + &mut closest_centers_ivf, + Some(&mut inverted_index), + None, + ) + .unwrap(); + + assert_eq!(closest_centers_ivf, vec![0, 0, 1, 1]); + + for vec in inverted_index.iter_mut() { + vec.sort_unstable(); + } + assert_eq!(inverted_index, vec![vec![0, 1], vec![2, 3]]); + } + + #[test] + fn process_residuals_test() { + let mut data_load = vec![1.0, 2.0, 3.0, 4.0]; + let num_points = 2; + let dim = 2; + let cur_pivot_data = vec![0.5, 1.5, 2.5, 3.5]; + let num_centers = 2; + let closest_centers = vec![0, 1]; + let to_subtract = true; + + process_residuals( + &mut data_load, + num_points, + dim, + &cur_pivot_data, + num_centers, + &closest_centers, + to_subtract, + ); + + assert_eq!(data_load, vec![0.5, 0.5, 0.5, 0.5]); + } +} diff --git a/rust/diskann/src/utils/mod.rs b/rust/diskann/src/utils/mod.rs new file mode 100644 index 000000000..df174f8f0 --- /dev/null +++ b/rust/diskann/src/utils/mod.rs @@ -0,0 +1,34 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +pub mod file_util; +pub use file_util::*; + +#[allow(clippy::module_inception)] +pub mod utils; +pub use utils::*; + +pub mod bit_vec_extension; +pub use bit_vec_extension::*; + +pub mod rayon_util; +pub use rayon_util::*; + +pub mod timer; +pub use timer::*; + +pub mod cached_reader; +pub use cached_reader::*; + +pub mod cached_writer; +pub use cached_writer::*; + +pub mod partition; +pub use partition::*; + +pub mod math_util; +pub use math_util::*; + +pub mod kmeans; +pub use kmeans::*; diff --git a/rust/diskann/src/utils/partition.rs b/rust/diskann/src/utils/partition.rs new file mode 100644 index 000000000..dbe686226 --- /dev/null +++ b/rust/diskann/src/utils/partition.rs @@ -0,0 +1,151 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::mem; +use std::{fs::File, path::Path}; +use std::io::{Write, Seek, SeekFrom}; +use rand::distributions::{Distribution, Uniform}; + +use crate::common::ANNResult; + +use super::CachedReader; + +/// streams data from the file, and samples each vector with probability p_val +/// and returns a matrix of size slice_size* ndims as floating point type. +/// the slice_size and ndims are set inside the function. +/// # Arguments +/// * `file_name` - filename where the data is +/// * `p_val` - possibility to sample data +/// * `sampled_vectors` - sampled vector chose by p_val possibility +/// * `slice_size` - how many sampled data return +/// * `dim` - each sample data dimension +pub fn gen_random_slice>(data_file: &str, mut p_val: f64) -> ANNResult<(Vec, usize, usize)> { + let read_blk_size = 64 * 1024 * 1024; + let mut reader = CachedReader::new(data_file, read_blk_size)?; + + let npts = reader.read_u32()? as usize; + let dim = reader.read_u32()? as usize; + let mut sampled_vectors: Vec = Vec::new(); + let mut slice_size = 0; + p_val = if p_val < 1f64 { p_val } else { 1f64 }; + + let mut generator = rand::thread_rng(); + let distribution = Uniform::from(0.0..1.0); + + for _ in 0..npts { + let mut cur_vector_bytes = vec![0u8; dim * mem::size_of::()]; + reader.read(&mut cur_vector_bytes)?; + let random_value = distribution.sample(&mut generator); + if random_value < p_val { + let ptr = cur_vector_bytes.as_ptr() as *const T; + let cur_vector_t = unsafe { std::slice::from_raw_parts(ptr, dim) }; + sampled_vectors.extend(cur_vector_t.iter().map(|&t| t.into())); + slice_size += 1; + } + } + + Ok((sampled_vectors, slice_size, dim)) +} + +/// Generate random sample data and write into output_file +pub fn gen_sample_data(data_file: &str, output_file: &str, sampling_rate: f64) -> ANNResult<()> { + let read_blk_size = 64 * 1024 * 1024; + let mut reader = CachedReader::new(data_file, read_blk_size)?; + + let sample_data_path = format!("{}_data.bin", output_file); + let sample_ids_path = format!("{}_ids.bin", output_file); + let mut sample_data_writer = File::create(Path::new(&sample_data_path))?; + let mut sample_id_writer = File::create(Path::new(&sample_ids_path))?; + + let mut num_sampled_pts = 0u32; + let one_const = 1u32; + let mut generator = rand::thread_rng(); + let distribution = Uniform::from(0.0..1.0); + + let npts_u32 = reader.read_u32()?; + let dim_u32 = reader.read_u32()?; + let dim = dim_u32 as usize; + sample_data_writer.write_all(&num_sampled_pts.to_le_bytes())?; + sample_data_writer.write_all(&dim_u32.to_le_bytes())?; + sample_id_writer.write_all(&num_sampled_pts.to_le_bytes())?; + sample_id_writer.write_all(&one_const.to_le_bytes())?; + + for id in 0..npts_u32 { + let mut cur_row_bytes = vec![0u8; dim * mem::size_of::()]; + reader.read(&mut cur_row_bytes)?; + let random_value = distribution.sample(&mut generator); + if random_value < sampling_rate { + sample_data_writer.write_all(&cur_row_bytes)?; + sample_id_writer.write_all(&id.to_le_bytes())?; + num_sampled_pts += 1; + } + } + + sample_data_writer.seek(SeekFrom::Start(0))?; + sample_data_writer.write_all(&num_sampled_pts.to_le_bytes())?; + sample_id_writer.seek(SeekFrom::Start(0))?; + sample_id_writer.write_all(&num_sampled_pts.to_le_bytes())?; + println!("Wrote {} points to sample file: {}", num_sampled_pts, sample_data_path); + + Ok(()) +} + +#[cfg(test)] +mod partition_test { + use std::{fs, io::Read}; + use byteorder::{ReadBytesExt, LittleEndian}; + + use crate::utils::file_exists; + + use super::*; + + #[test] + fn gen_sample_data_test() { + let file_name = "gen_sample_data_test.bin"; + //npoints=2, dim=8 + let data: [u8; 72] = [2, 0, 0, 0, 8, 0, 0, 0, + 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, + 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40, 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, + 0x00, 0x00, 0x10, 0x41, 0x00, 0x00, 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, + 0x00, 0x00, 0x50, 0x41, 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41]; + std::fs::write(file_name, data).expect("Failed to write sample file"); + + let sample_file_prefix = file_name.to_string() + "_sample"; + gen_sample_data::(file_name, sample_file_prefix.as_str(), 1f64).unwrap(); + + let sample_data_path = format!("{}_data.bin", sample_file_prefix); + let sample_ids_path = format!("{}_ids.bin", sample_file_prefix); + assert!(file_exists(sample_data_path.as_str())); + assert!(file_exists(sample_ids_path.as_str())); + + let mut data_file_reader = File::open(sample_data_path.as_str()).unwrap(); + let mut ids_file_reader = File::open(sample_ids_path.as_str()).unwrap(); + + let mut num_sampled_pts = data_file_reader.read_u32::().unwrap(); + assert_eq!(num_sampled_pts, 2); + num_sampled_pts = ids_file_reader.read_u32::().unwrap(); + assert_eq!(num_sampled_pts, 2); + + let dim = data_file_reader.read_u32::().unwrap() as usize; + assert_eq!(dim, 8); + assert_eq!(ids_file_reader.read_u32::().unwrap(), 1); + + let mut start = 8; + for i in 0..num_sampled_pts { + let mut data_bytes = vec![0u8; dim * 4]; + data_file_reader.read_exact(&mut data_bytes).unwrap(); + assert_eq!(data_bytes, data[start..start + dim * 4]); + + let id = ids_file_reader.read_u32::().unwrap(); + assert_eq!(id, i); + + start += dim * 4; + } + + fs::remove_file(file_name).expect("Failed to delete file"); + fs::remove_file(sample_data_path.as_str()).expect("Failed to delete file"); + fs::remove_file(sample_ids_path.as_str()).expect("Failed to delete file"); + } +} + diff --git a/rust/diskann/src/utils/rayon_util.rs b/rust/diskann/src/utils/rayon_util.rs new file mode 100644 index 000000000..f8174ee59 --- /dev/null +++ b/rust/diskann/src/utils/rayon_util.rs @@ -0,0 +1,33 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::ops::Range; +use rayon::prelude::{IntoParallelIterator, ParallelIterator}; + +use crate::common::ANNResult; + +/// based on thread_num, execute the task in parallel using Rayon or serial +#[inline] +pub fn execute_with_rayon(range: Range, num_threads: u32, f: F) -> ANNResult<()> +where F: Fn(usize) -> ANNResult<()> + Sync + Send + Copy +{ + if num_threads == 1 { + for i in range { + f(i)?; + } + Ok(()) + } else { + range.into_par_iter().try_for_each(f) + } +} + +/// set the thread count of Rayon, otherwise it will use threads as many as logical cores. +#[inline] +pub fn set_rayon_num_threads(num_threads: u32) { + std::env::set_var( + "RAYON_NUM_THREADS", + num_threads.to_string(), + ); +} + diff --git a/rust/diskann/src/utils/timer.rs b/rust/diskann/src/utils/timer.rs new file mode 100644 index 000000000..2f4b38ba7 --- /dev/null +++ b/rust/diskann/src/utils/timer.rs @@ -0,0 +1,101 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use platform::*; +use std::time::{Duration, Instant}; + +#[derive(Clone)] +pub struct Timer { + check_point: Instant, + pid: Option, + cycles: Option, +} + +impl Default for Timer { + fn default() -> Self { + Self::new() + } +} + +impl Timer { + pub fn new() -> Timer { + let pid = get_process_handle(); + let cycles = get_process_cycle_time(pid); + Timer { + check_point: Instant::now(), + pid, + cycles, + } + } + + pub fn reset(&mut self) { + self.check_point = Instant::now(); + self.cycles = get_process_cycle_time(self.pid); + } + + pub fn elapsed(&self) -> Duration { + Instant::now().duration_since(self.check_point) + } + + pub fn elapsed_seconds(&self) -> f64 { + self.elapsed().as_secs_f64() + } + + pub fn elapsed_gcycles(&self) -> f32 { + let cur_cycles = get_process_cycle_time(self.pid); + if let (Some(cur_cycles), Some(cycles)) = (cur_cycles, self.cycles) { + let spent_cycles = + ((cur_cycles - cycles) as f64 * 1.0f64) / (1024 * 1024 * 1024) as f64; + return spent_cycles as f32; + } + + 0.0 + } + + pub fn elapsed_seconds_for_step(&self, step: &str) -> String { + format!( + "Time for {}: {:.3} seconds, {:.3}B cycles", + step, + self.elapsed_seconds(), + self.elapsed_gcycles() + ) + } +} + +#[cfg(test)] +mod timer_tests { + use super::*; + use std::{thread, time}; + + #[test] + fn test_new() { + let timer = Timer::new(); + assert!(timer.check_point.elapsed().as_secs() < 1); + if cfg!(windows) { + assert!(timer.pid.is_some()); + assert!(timer.cycles.is_some()); + } + else { + assert!(timer.pid.is_none()); + assert!(timer.cycles.is_none()); + } + } + + #[test] + fn test_reset() { + let mut timer = Timer::new(); + thread::sleep(time::Duration::from_millis(100)); + timer.reset(); + assert!(timer.check_point.elapsed().as_millis() < 10); + } + + #[test] + fn test_elapsed() { + let timer = Timer::new(); + thread::sleep(time::Duration::from_millis(100)); + assert!(timer.elapsed().as_millis() > 100); + assert!(timer.elapsed_seconds() > 0.1); + } +} + diff --git a/rust/diskann/src/utils/utils.rs b/rust/diskann/src/utils/utils.rs new file mode 100644 index 000000000..2e80676af --- /dev/null +++ b/rust/diskann/src/utils/utils.rs @@ -0,0 +1,154 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::sync::Mutex; +use num_traits::Num; + +/// Non recursive mutex +pub type NonRecursiveMutex = Mutex<()>; + +/// Round up X to the nearest multiple of Y +#[inline] +pub fn round_up(x: T, y: T) -> T +where T : Num + Copy +{ + div_round_up(x, y) * y +} + +/// Rounded-up division +#[inline] +pub fn div_round_up(x: T, y: T) -> T +where T : Num + Copy +{ + (x / y) + if x % y != T::zero() {T::one()} else {T::zero()} +} + +/// Round down X to the nearest multiple of Y +#[inline] +pub fn round_down(x: T, y: T) -> T +where T : Num + Copy +{ + (x / y) * y +} + +/// Is aligned +#[inline] +pub fn is_aligned(x: T, y: T) -> bool +where T : Num + Copy +{ + x % y == T::zero() +} + +#[inline] +pub fn is_512_aligned(x: u64) -> bool { + is_aligned(x, 512) +} + +#[inline] +pub fn is_4096_aligned(x: u64) -> bool { + is_aligned(x, 4096) +} + +/// all metadata of individual sub-component files is written in first 4KB for unified files +pub const METADATA_SIZE: usize = 4096; + +pub const BUFFER_SIZE_FOR_CACHED_IO: usize = 1024 * 1048576; + +pub const PBSTR: &str = "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||"; + +pub const PBWIDTH: usize = 60; + +macro_rules! convert_types { + ($name:ident, $intput_type:ty, $output_type:ty) => { + /// Write data into file + pub fn $name(srcmat: &[$intput_type], npts: usize, dim: usize) -> Vec<$output_type> { + let mut destmat: Vec<$output_type> = Vec::new(); + for i in 0..npts { + for j in 0..dim { + destmat.push(srcmat[i * dim + j] as $output_type); + } + } + destmat + } + }; +} +convert_types!(convert_types_usize_u8, usize, u8); +convert_types!(convert_types_usize_u32, usize, u32); +convert_types!(convert_types_usize_u64, usize, u64); +convert_types!(convert_types_u64_usize, u64, usize); +convert_types!(convert_types_u32_usize, u32, usize); + +#[cfg(test)] +mod file_util_test { + use super::*; + use std::any::type_name; + + #[test] + fn round_up_test() { + assert_eq!(round_up(252, 8), 256); + assert_eq!(round_up(256, 8), 256); + } + + #[test] + fn div_round_up_test() { + assert_eq!(div_round_up(252, 8), 32); + assert_eq!(div_round_up(256, 8), 32); + } + + #[test] + fn round_down_test() { + assert_eq!(round_down(252, 8), 248); + assert_eq!(round_down(256, 8), 256); + } + + #[test] + fn is_aligned_test() { + assert!(!is_aligned(252, 8)); + assert!(is_aligned(256, 8)); + } + + #[test] + fn is_512_aligned_test() { + assert!(!is_512_aligned(520)); + assert!(is_512_aligned(512)); + } + + #[test] + fn is_4096_aligned_test() { + assert!(!is_4096_aligned(4090)); + assert!(is_4096_aligned(4096)); + } + + #[test] + fn convert_types_test() { + let data = vec![0u64, 1u64, 2u64]; + let output = convert_types_u64_usize(&data, 3, 1); + assert_eq!(output.len(), 3); + assert_eq!(type_of(output[0]), "usize"); + assert_eq!(output[0], 0usize); + + let data = vec![0usize, 1usize, 2usize]; + let output = convert_types_usize_u8(&data, 3, 1); + assert_eq!(output.len(), 3); + assert_eq!(type_of(output[0]), "u8"); + assert_eq!(output[0], 0u8); + + let data = vec![0usize, 1usize, 2usize]; + let output = convert_types_usize_u64(&data, 3, 1); + assert_eq!(output.len(), 3); + assert_eq!(type_of(output[0]), "u64"); + assert_eq!(output[0], 0u64); + + let data = vec![0u32, 1u32, 2u32]; + let output = convert_types_u32_usize(&data, 3, 1); + assert_eq!(output.len(), 3); + assert_eq!(type_of(output[0]), "usize"); + assert_eq!(output[0],0usize); + } + + fn type_of(_: T) -> &'static str { + type_name::() + } +} + diff --git a/rust/diskann/tests/data/delete_set_50pts.bin b/rust/diskann/tests/data/delete_set_50pts.bin new file mode 100644 index 000000000..8d520e7c7 Binary files /dev/null and b/rust/diskann/tests/data/delete_set_50pts.bin differ diff --git a/rust/diskann/tests/data/disk_index_node_data_aligned_reader_truth.bin b/rust/diskann/tests/data/disk_index_node_data_aligned_reader_truth.bin new file mode 100644 index 000000000..737a1a34d Binary files /dev/null and b/rust/diskann/tests/data/disk_index_node_data_aligned_reader_truth.bin differ diff --git a/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_alligned_reader_test.index b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_alligned_reader_test.index new file mode 100644 index 000000000..55fcbb58d Binary files /dev/null and b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_alligned_reader_test.index differ diff --git a/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index new file mode 100644 index 000000000..88a86b7da Binary files /dev/null and b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index differ diff --git a/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_mem.index b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_mem.index new file mode 100644 index 000000000..974535776 Binary files /dev/null and b/rust/diskann/tests/data/disk_index_siftsmall_learn_256pts_R4_L50_A1.2_mem.index differ diff --git a/rust/diskann/tests/data/siftsmall_learn.bin b/rust/diskann/tests/data/siftsmall_learn.bin new file mode 100644 index 000000000..e08c7af7a Binary files /dev/null and b/rust/diskann/tests/data/siftsmall_learn.bin differ diff --git a/rust/diskann/tests/data/siftsmall_learn.bin_pq_compressed.bin b/rust/diskann/tests/data/siftsmall_learn.bin_pq_compressed.bin new file mode 100644 index 000000000..5f1ddab29 Binary files /dev/null and b/rust/diskann/tests/data/siftsmall_learn.bin_pq_compressed.bin differ diff --git a/rust/diskann/tests/data/siftsmall_learn.bin_pq_pivots.bin b/rust/diskann/tests/data/siftsmall_learn.bin_pq_pivots.bin new file mode 100644 index 000000000..e84f8d8a9 Binary files /dev/null and b/rust/diskann/tests/data/siftsmall_learn.bin_pq_pivots.bin differ diff --git a/rust/diskann/tests/data/siftsmall_learn_256pts.fbin b/rust/diskann/tests/data/siftsmall_learn_256pts.fbin new file mode 100644 index 000000000..357a9db87 Binary files /dev/null and b/rust/diskann/tests/data/siftsmall_learn_256pts.fbin differ diff --git a/rust/diskann/tests/data/siftsmall_learn_256pts_2.fbin b/rust/diskann/tests/data/siftsmall_learn_256pts_2.fbin new file mode 100644 index 000000000..9528e4bd9 Binary files /dev/null and b/rust/diskann/tests/data/siftsmall_learn_256pts_2.fbin differ diff --git a/rust/diskann/tests/data/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index b/rust/diskann/tests/data/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index new file mode 100644 index 000000000..55fcbb58d Binary files /dev/null and b/rust/diskann/tests/data/truth_disk_index_siftsmall_learn_256pts_R4_L50_A1.2_disk.index differ diff --git a/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_R4_L50_A1.2 b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_R4_L50_A1.2 new file mode 100644 index 000000000..9c803c3fa Binary files /dev/null and b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_R4_L50_A1.2 differ diff --git a/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_saturated_R4_L50_A1.2 b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_saturated_R4_L50_A1.2 new file mode 100644 index 000000000..a9dac1013 Binary files /dev/null and b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_1+2_saturated_R4_L50_A1.2 differ diff --git a/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2 b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2 new file mode 100644 index 000000000..817009044 Binary files /dev/null and b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2 differ diff --git a/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2.data b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2.data new file mode 100644 index 000000000..357a9db87 Binary files /dev/null and b/rust/diskann/tests/data/truth_index_siftsmall_learn_256pts_R4_L50_A1.2.data differ diff --git a/rust/logger/Cargo.toml b/rust/logger/Cargo.toml new file mode 100644 index 000000000..e750d9530 --- /dev/null +++ b/rust/logger/Cargo.toml @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "logger" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +lazy_static = "1.4.0" +log="0.4.17" +once_cell = "1.17.1" +prost = "0.11.9" +prost-types = "0.11.9" +thiserror = "1.0.40" +win_etw_macros="0.1.8" +win_etw_provider="0.1.8" + +[build-dependencies] +prost-build = "0.11.9" + +[[example]] +name="trace_example" +path= "src/examples/trace_example.rs" + +[target."cfg(target_os=\"windows\")".build-dependencies.vcpkg] +version = "0.2" + diff --git a/rust/logger/build.rs b/rust/logger/build.rs new file mode 100644 index 000000000..76058f768 --- /dev/null +++ b/rust/logger/build.rs @@ -0,0 +1,33 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::env; + +extern crate prost_build; + +fn main() { + let protopkg = vcpkg::find_package("protobuf").unwrap(); + let protobuf_path = protopkg.link_paths[0].parent().unwrap(); + + let protobuf_bin_path = protobuf_path + .join("tools") + .join("protobuf") + .join("protoc.exe") + .to_str() + .unwrap() + .to_string(); + env::set_var("PROTOC", protobuf_bin_path); + + let protobuf_inc_path = protobuf_path + .join("include") + .join("google") + .join("protobuf") + .to_str() + .unwrap() + .to_string(); + env::set_var("PROTOC_INCLUDE", protobuf_inc_path); + + prost_build::compile_protos(&["src/indexlog.proto"], &["src/"]).unwrap(); +} + diff --git a/rust/logger/src/error_logger.rs b/rust/logger/src/error_logger.rs new file mode 100644 index 000000000..50069b477 --- /dev/null +++ b/rust/logger/src/error_logger.rs @@ -0,0 +1,29 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::log_error::LogError; +use crate::logger::indexlog::{ErrorLog, Log, LogLevel}; +use crate::message_handler::send_log; + +pub fn log_error(error_message: String) -> Result<(), LogError> { + let mut log = Log::default(); + let error_log = ErrorLog { + log_level: LogLevel::Error as i32, + error_message, + }; + log.error_log = Some(error_log); + + send_log(log) +} + +#[cfg(test)] +mod error_logger_test { + use super::*; + + #[test] + fn log_error_works() { + log_error(String::from("Error")).unwrap(); + } +} + diff --git a/rust/logger/src/examples/trace_example.rs b/rust/logger/src/examples/trace_example.rs new file mode 100644 index 000000000..7933a5699 --- /dev/null +++ b/rust/logger/src/examples/trace_example.rs @@ -0,0 +1,30 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use log::{debug, info, log_enabled, warn, Level}; +use logger::trace_logger::TraceLogger; + +// cargo run --example trace_example + +fn main() { + static LOGGER: TraceLogger = TraceLogger {}; + log::set_logger(&LOGGER) + .map(|()| log::set_max_level(log::LevelFilter::Trace)) + .unwrap(); + + info!("Rust logging n = {}", 42); + warn!("This is too much fun!"); + debug!("Maybe we can make this code work"); + + let error_is_enabled = log_enabled!(Level::Error); + let warn_is_enabled = log_enabled!(Level::Warn); + let info_is_enabled = log_enabled!(Level::Info); + let debug_is_enabled = log_enabled!(Level::Debug); + let trace_is_enabled = log_enabled!(Level::Trace); + println!( + "is_enabled? error: {:5?}, warn: {:5?}, info: {:5?}, debug: {:5?}, trace: {:5?}", + error_is_enabled, warn_is_enabled, info_is_enabled, debug_is_enabled, trace_is_enabled, + ); +} + diff --git a/rust/logger/src/indexlog.proto b/rust/logger/src/indexlog.proto new file mode 100644 index 000000000..68310ae41 --- /dev/null +++ b/rust/logger/src/indexlog.proto @@ -0,0 +1,50 @@ +syntax = "proto3"; + +package diskann_logger; + +message Log { + IndexConstructionLog IndexConstructionLog = 1; + DiskIndexConstructionLog DiskIndexConstructionLog = 2; + ErrorLog ErrorLog = 3; + TraceLog TraceLog = 100; +} + +enum LogLevel { + UNSPECIFIED = 0; + Error = 1; + Warn = 2; + Info = 3; + Debug = 4; + Trace = 5; +} + +message IndexConstructionLog { + float PercentageComplete = 1; + float TimeSpentInSeconds = 2; + float GCyclesSpent = 3; + LogLevel LogLevel = 4; +} + +message DiskIndexConstructionLog { + DiskIndexConstructionCheckpoint checkpoint = 1; + float TimeSpentInSeconds = 2; + float GCyclesSpent = 3; + LogLevel LogLevel = 4; +} + +enum DiskIndexConstructionCheckpoint { + None = 0; + PqConstruction = 1; + InmemIndexBuild = 2; + DiskLayout = 3; +} + +message TraceLog { + string LogLine = 1; + LogLevel LogLevel = 2; +} + +message ErrorLog { + string ErrorMessage = 1; + LogLevel LogLevel = 2; +} \ No newline at end of file diff --git a/rust/logger/src/lib.rs b/rust/logger/src/lib.rs new file mode 100644 index 000000000..6cfe2d589 --- /dev/null +++ b/rust/logger/src/lib.rs @@ -0,0 +1,19 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![cfg_attr( + not(test), + warn(clippy::panic, clippy::unwrap_used, clippy::expect_used) +)] + +pub mod logger { + pub mod indexlog { + include!(concat!(env!("OUT_DIR"), "/diskann_logger.rs")); + } +} + +pub mod error_logger; +pub mod log_error; +pub mod message_handler; +pub mod trace_logger; diff --git a/rust/logger/src/log_error.rs b/rust/logger/src/log_error.rs new file mode 100644 index 000000000..149d094a2 --- /dev/null +++ b/rust/logger/src/log_error.rs @@ -0,0 +1,27 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::sync::mpsc::SendError; + +use crate::logger::indexlog::Log; + +#[derive(thiserror::Error, Debug, Clone)] +pub enum LogError { + /// Sender failed to send message to the channel + #[error("IOError: {err}")] + SendError { + #[from] + err: SendError, + }, + + /// PoisonError which can be returned whenever a lock is acquired + /// Both Mutexes and RwLocks are poisoned whenever a thread fails while the lock is held + #[error("LockPoisonError: {err}")] + LockPoisonError { err: String }, + + /// Failed to create EtwPublisher + #[error("EtwProviderError: {err:?}")] + ETWProviderError { err: win_etw_provider::Error }, +} + diff --git a/rust/logger/src/message_handler.rs b/rust/logger/src/message_handler.rs new file mode 100644 index 000000000..37f352a28 --- /dev/null +++ b/rust/logger/src/message_handler.rs @@ -0,0 +1,167 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::log_error::LogError; +use crate::logger::indexlog::DiskIndexConstructionCheckpoint; +use crate::logger::indexlog::Log; +use crate::logger::indexlog::LogLevel; + +use std::sync::mpsc::{self, Sender}; +use std::sync::Mutex; +use std::thread; + +use win_etw_macros::trace_logging_provider; + +trait MessagePublisher { + fn publish(&self, log_level: LogLevel, message: &str); +} + +// ETW provider - the GUID specified here is that of the default provider for Geneva Metric Extensions +// We are just using it as a placeholder until we have a version of OpenTelemetry exporter for Rust +#[trace_logging_provider(guid = "edc24920-e004-40f6-a8e1-0e6e48f39d84")] +trait EtwTraceProvider { + fn write(msg: &str); +} + +struct EtwPublisher { + provider: EtwTraceProvider, + publish_to_stdout: bool, +} + +impl EtwPublisher { + pub fn new() -> Result { + let provider = EtwTraceProvider::new(); + Ok(EtwPublisher { + provider, + publish_to_stdout: true, + }) + } +} + +fn log_level_to_etw(level: LogLevel) -> win_etw_provider::Level { + match level { + LogLevel::Error => win_etw_provider::Level::ERROR, + LogLevel::Warn => win_etw_provider::Level::WARN, + LogLevel::Info => win_etw_provider::Level::INFO, + LogLevel::Debug => win_etw_provider::Level::VERBOSE, + LogLevel::Trace => win_etw_provider::Level(6), + LogLevel::Unspecified => win_etw_provider::Level(6), + } +} + +fn i32_to_log_level(value: i32) -> LogLevel { + match value { + 0 => LogLevel::Unspecified, + 1 => LogLevel::Error, + 2 => LogLevel::Warn, + 3 => LogLevel::Info, + 4 => LogLevel::Debug, + 5 => LogLevel::Trace, + _ => LogLevel::Unspecified, + } +} + +impl MessagePublisher for EtwPublisher { + fn publish(&self, log_level: LogLevel, message: &str) { + let options = win_etw_provider::EventOptions { + level: Some(log_level_to_etw(log_level)), + ..Default::default() + }; + self.provider.write(Some(&options), message); + + if self.publish_to_stdout { + println!("{}", message); + } + } +} + +struct MessageProcessor { + sender: Mutex>, +} + +impl MessageProcessor { + pub fn start_processing() -> Self { + let (sender, receiver) = mpsc::channel::(); + thread::spawn(move || -> Result<(), LogError> { + for message in receiver { + // Process the received message + if let Some(indexlog) = message.index_construction_log { + let str = format!( + "Time for {}% of index build completed: {:.3} seconds, {:.3}B cycles", + indexlog.percentage_complete, + indexlog.time_spent_in_seconds, + indexlog.g_cycles_spent + ); + publish(i32_to_log_level(indexlog.log_level), &str)?; + } + + if let Some(disk_index_log) = message.disk_index_construction_log { + let str = format!( + "Time for disk index build [Checkpoint: {:?}] completed: {:.3} seconds, {:.3}B cycles", + DiskIndexConstructionCheckpoint::from_i32(disk_index_log.checkpoint).unwrap_or(DiskIndexConstructionCheckpoint::None), + disk_index_log.time_spent_in_seconds, + disk_index_log.g_cycles_spent + ); + publish(i32_to_log_level(disk_index_log.log_level), &str)?; + } + + if let Some(tracelog) = message.trace_log { + let str = format!("{}:{}", tracelog.log_level, tracelog.log_line); + publish(i32_to_log_level(tracelog.log_level), &str)?; + } + + if let Some(err) = message.error_log { + publish(i32_to_log_level(err.log_level), &err.error_message)?; + } + } + + Ok(()) + }); + + let sender = Mutex::new(sender); + MessageProcessor { sender } + } + + /// Log the message. + fn log(&self, message: Log) -> Result<(), LogError> { + Ok(self + .sender + .lock() + .map_err(|err| LogError::LockPoisonError { + err: err.to_string(), + })? + .send(message)?) + } +} + +lazy_static::lazy_static! { + /// Singleton logger. + static ref PROCESSOR: MessageProcessor = { + + MessageProcessor::start_processing() + }; +} + +lazy_static::lazy_static! { + /// Singleton publisher. + static ref PUBLISHER: Result = { + EtwPublisher::new() + }; +} + +/// Send a message to the logging system. +pub fn send_log(message: Log) -> Result<(), LogError> { + PROCESSOR.log(message) +} + +fn publish(log_level: LogLevel, message: &str) -> Result<(), LogError> { + match *PUBLISHER { + Ok(ref etw_publisher) => { + etw_publisher.publish(log_level, message); + Ok(()) + } + Err(ref err) => Err(LogError::ETWProviderError { err: err.clone() }), + } +} + diff --git a/rust/logger/src/trace_logger.rs b/rust/logger/src/trace_logger.rs new file mode 100644 index 000000000..96ef38611 --- /dev/null +++ b/rust/logger/src/trace_logger.rs @@ -0,0 +1,41 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::logger::indexlog::{Log, TraceLog}; +use crate::message_handler::send_log; + +use log; + +pub struct TraceLogger {} + +fn level_to_i32(value: log::Level) -> i32 { + match value { + log::Level::Error => 1, + log::Level::Warn => 2, + log::Level::Info => 3, + log::Level::Debug => 4, + log::Level::Trace => 5, + } +} + +impl log::Log for TraceLogger { + fn enabled(&self, metadata: &log::Metadata) -> bool { + metadata.level() <= log::max_level() + } + + fn log(&self, record: &log::Record) { + let message = record.args().to_string(); + let metadata = record.metadata(); + let mut log = Log::default(); + let trace_log = TraceLog { + log_line: message, + log_level: level_to_i32(metadata.level()), + }; + log.trace_log = Some(trace_log); + let _ = send_log(log); + } + + fn flush(&self) {} +} + diff --git a/rust/platform/Cargo.toml b/rust/platform/Cargo.toml new file mode 100644 index 000000000..057f9e852 --- /dev/null +++ b/rust/platform/Cargo.toml @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "platform" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +log="0.4.18" +winapi = { version = "0.3.9", features = ["errhandlingapi", "fileapi", "ioapiset", "handleapi", "winnt", "minwindef", "basetsd", "winerror", "winbase"] } + diff --git a/rust/platform/src/file_handle.rs b/rust/platform/src/file_handle.rs new file mode 100644 index 000000000..23da8796a --- /dev/null +++ b/rust/platform/src/file_handle.rs @@ -0,0 +1,212 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::ffi::CString; +use std::{io, ptr}; + +use winapi::um::fileapi::OPEN_EXISTING; +use winapi::um::winbase::{FILE_FLAG_NO_BUFFERING, FILE_FLAG_OVERLAPPED, FILE_FLAG_RANDOM_ACCESS}; +use winapi::um::winnt::{FILE_SHARE_DELETE, FILE_SHARE_READ, FILE_SHARE_WRITE, GENERIC_READ, GENERIC_WRITE}; + +use winapi::{ + shared::minwindef::DWORD, + um::{ + errhandlingapi::GetLastError, + fileapi::CreateFileA, + handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + winnt::HANDLE, + }, +}; + +pub const FILE_ATTRIBUTE_READONLY: DWORD = 0x00000001; + +/// `AccessMode` determines how a file can be accessed. +/// These modes are used when creating or opening a file to decide what operations are allowed +/// to be performed on the file. +/// +/// # Variants +/// +/// - `Read`: The file is opened in read-only mode. +/// +/// - `Write`: The file is opened in write-only mode. +/// +/// - `ReadWrite`: The file is opened for both reading and writing. +pub enum AccessMode { + Read, + Write, + ReadWrite, +} + +/// `ShareMode` determines how a file can be shared. +/// +/// These modes are used when creating or opening a file to decide what operations other +/// opening instances of the file can perform on it. +/// # Variants +/// - `None`: Prevents other processes from opening a file if they request delete, +/// read, or write access. +/// +/// - `Read`: Allows subsequent open operations on the same file to request read access. +/// +/// - `Write`: Allows subsequent open operations on the same file file to request write access. +/// +/// - `Delete`: Allows subsequent open operations on the same file file to request delete access. +pub enum ShareMode { + None, + Read, + Write, + Delete, +} + +/// # Windows File Handle Wrapper +/// +/// Introduces a Rust-friendly wrapper around the native Windows `HANDLE` object, `FileHandle`. +/// `FileHandle` provides safe creation and automatic cleanup of Windows file handles, leveraging Rust's ownership model. + +/// `FileHandle` struct that wraps a native Windows `HANDLE` object +#[cfg(target_os = "windows")] +pub struct FileHandle { + handle: HANDLE, +} + +impl FileHandle { + /// Creates a new `FileHandle` by opening an existing file with the given access and shared mode. + /// + /// This function is marked unsafe because it creates a raw pointer to the filename and try to create + /// a Windows `HANDLE` object without checking if you have sufficient permissions. + /// + /// # Safety + /// + /// Ensure that the file specified by `file_name` is valid and the calling process has + /// sufficient permissions to perform the specified `access_mode` and `share_mode` operations. + /// + /// # Parameters + /// + /// - `file_name`: The name of the file. + /// - `access_mode`: The access mode to be used for the file. + /// - `share_mode`: The share mode to be used for the file + /// + /// # Errors + /// This function will return an error if the `file_name` is invalid or if the file cannot + /// be opened with the specified `access_mode` and `share_mode`. + pub unsafe fn new( + file_name: &str, + access_mode: AccessMode, + share_mode: ShareMode, + ) -> io::Result { + let file_name_c = CString::new(file_name).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidData, + format!("Invalid file name. {}", file_name), + ) + })?; + + let dw_desired_access = match access_mode { + AccessMode::Read => GENERIC_READ, + AccessMode::Write => GENERIC_WRITE, + AccessMode::ReadWrite => GENERIC_READ | GENERIC_WRITE, + }; + + let dw_share_mode = match share_mode { + ShareMode::None => 0, + ShareMode::Read => FILE_SHARE_READ, + ShareMode::Write => FILE_SHARE_WRITE, + ShareMode::Delete => FILE_SHARE_DELETE, + }; + + let dw_flags_and_attributes = FILE_ATTRIBUTE_READONLY + | FILE_FLAG_NO_BUFFERING + | FILE_FLAG_OVERLAPPED + | FILE_FLAG_RANDOM_ACCESS; + + let handle = unsafe { + CreateFileA( + file_name_c.as_ptr(), + dw_desired_access, + dw_share_mode, + ptr::null_mut(), + OPEN_EXISTING, + dw_flags_and_attributes, + ptr::null_mut(), + ) + }; + + if handle == INVALID_HANDLE_VALUE { + let error_code = unsafe { GetLastError() }; + Err(io::Error::from_raw_os_error(error_code as i32)) + } else { + Ok(Self { handle }) + } + } + + pub fn raw_handle(&self) -> HANDLE { + self.handle + } +} + +impl Drop for FileHandle { + /// Automatically closes the `FileHandle` when it goes out of scope. + /// Any errors in closing the handle are logged, as `Drop` does not support returning `Result`. + fn drop(&mut self) { + let result = unsafe { CloseHandle(self.handle) }; + if result == 0 { + let error_code = unsafe { GetLastError() }; + let error = io::Error::from_raw_os_error(error_code as i32); + + // Only log the error if dropping the handle fails, since Rust's Drop trait does not support returning Result types from the drop method, + // and panicking in the drop method is considered bad practice + log::warn!("Error when dropping IOCompletionPort: {:?}", error); + } + } +} + +/// Returns a `FileHandle` with an `INVALID_HANDLE_VALUE`. +impl Default for FileHandle { + fn default() -> Self { + Self { + handle: INVALID_HANDLE_VALUE, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::fs::File; + use std::path::Path; + + #[test] + fn test_create_file() { + // Create a dummy file + let dummy_file_path = "dummy_file.txt"; + { + let _file = File::create(dummy_file_path).expect("Failed to create dummy file."); + } + + let path = Path::new(dummy_file_path); + { + let file_handle = unsafe { + FileHandle::new(path.to_str().unwrap(), AccessMode::Read, ShareMode::Read) + }; + + // Check that the file handle is valid + assert!(file_handle.is_ok()); + } + + // Try to delete the file. If the handle was correctly dropped, this should succeed. + match std::fs::remove_file(dummy_file_path) { + Ok(()) => (), // File was deleted successfully, which means the handle was closed. + Err(e) => panic!("Failed to delete file: {}", e), // Failed to delete the file, likely because the handle is still open. + } + } + + #[test] + fn test_file_not_found() { + let path = Path::new("non_existent_file.txt"); + let file_handle = + unsafe { FileHandle::new(path.to_str().unwrap(), AccessMode::Read, ShareMode::Read) }; + + // Check that opening a non-existent file returns an error + assert!(file_handle.is_err()); + } +} diff --git a/rust/platform/src/file_io.rs b/rust/platform/src/file_io.rs new file mode 100644 index 000000000..e5de24773 --- /dev/null +++ b/rust/platform/src/file_io.rs @@ -0,0 +1,154 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +/// The module provides unsafe wrappers around two Windows API functions: `ReadFile` and `GetQueuedCompletionStatus`. +/// +/// These wrappers aim to simplify and abstract the use of these functions, providing easier error handling and a safer interface. +/// They return standard Rust `io::Result` types for convenience and consistency with the rest of the Rust standard library. +use std::io; +use std::ptr; + +use winapi::{ + ctypes::c_void, + shared::{ + basetsd::ULONG_PTR, + minwindef::{DWORD, FALSE}, + winerror::{ERROR_IO_PENDING, WAIT_TIMEOUT}, + }, + um::{ + errhandlingapi::GetLastError, fileapi::ReadFile, ioapiset::GetQueuedCompletionStatus, + minwinbase::OVERLAPPED, + }, +}; + +use crate::FileHandle; +use crate::IOCompletionPort; + +/// Asynchronously queue a read request from a file into a buffer slice. +/// +/// Wraps the unsafe Windows API function `ReadFile`, making it safe to call only when the overlapped buffer +/// remains valid and unchanged anywhere else during the entire async operation. +/// +/// Returns a boolean indicating whether the read operation completed synchronously or is pending. +/// +/// # Safety +/// +/// This function is marked as `unsafe` because it uses raw pointers and requires the caller to ensure +/// that the buffer slice and the overlapped buffer stay valid during the whole async operation. +pub unsafe fn read_file_to_slice( + file_handle: &FileHandle, + buffer_slice: &mut [T], + overlapped: *mut OVERLAPPED, + offset: u64, +) -> io::Result { + let num_bytes = std::mem::size_of_val(buffer_slice); + unsafe { + ptr::write(overlapped, std::mem::zeroed()); + (*overlapped).u.s_mut().Offset = offset as u32; + (*overlapped).u.s_mut().OffsetHigh = (offset >> 32) as u32; + } + + let result = unsafe { + ReadFile( + file_handle.raw_handle(), + buffer_slice.as_mut_ptr() as *mut c_void, + num_bytes as DWORD, + ptr::null_mut(), + overlapped, + ) + }; + + match result { + FALSE => { + let error = unsafe { GetLastError() }; + if error != ERROR_IO_PENDING { + Err(io::Error::from_raw_os_error(error as i32)) + } else { + Ok(false) + } + } + _ => Ok(true), + } +} + +/// Retrieves the results of an asynchronous I/O operation on an I/O completion port. +/// +/// Wraps the unsafe Windows API function `GetQueuedCompletionStatus`, making it safe to call only when the overlapped buffer +/// remains valid and unchanged anywhere else during the entire async operation. +/// +/// Returns a boolean indicating whether an I/O operation completed synchronously or is still pending. +/// +/// # Safety +/// +/// This function is marked as `unsafe` because it uses raw pointers and requires the caller to ensure +/// that the overlapped buffer stays valid during the whole async operation. +pub unsafe fn get_queued_completion_status( + completion_port: &IOCompletionPort, + lp_number_of_bytes: &mut DWORD, + lp_completion_key: &mut ULONG_PTR, + lp_overlapped: *mut *mut OVERLAPPED, + dw_milliseconds: DWORD, +) -> io::Result { + let result = unsafe { + GetQueuedCompletionStatus( + completion_port.raw_handle(), + lp_number_of_bytes, + lp_completion_key, + lp_overlapped, + dw_milliseconds, + ) + }; + + match result { + 0 => { + let error = unsafe { GetLastError() }; + if error == WAIT_TIMEOUT { + Ok(false) + } else { + Err(io::Error::from_raw_os_error(error as i32)) + } + } + _ => Ok(true), + } +} + +#[cfg(test)] +mod tests { + use crate::file_handle::{AccessMode, ShareMode}; + + use super::*; + use std::fs::File; + use std::io::Write; + use std::path::Path; + + #[test] + fn test_read_file_to_slice() { + // Create a temporary file and write some data into it + let path = Path::new("temp.txt"); + { + let mut file = File::create(path).unwrap(); + file.write_all(b"Hello, world!").unwrap(); + } + + let mut buffer: [u8; 512] = [0; 512]; + let mut overlapped = unsafe { std::mem::zeroed::() }; + { + let file_handle = unsafe { + FileHandle::new(path.to_str().unwrap(), AccessMode::Read, ShareMode::Read) + } + .unwrap(); + + // Call the function under test + let result = + unsafe { read_file_to_slice(&file_handle, &mut buffer, &mut overlapped, 0) }; + + assert!(result.is_ok()); + let result_str = std::str::from_utf8(&buffer[.."Hello, world!".len()]).unwrap(); + assert_eq!(result_str, "Hello, world!"); + } + + // Clean up + std::fs::remove_file("temp.txt").unwrap(); + } +} diff --git a/rust/platform/src/io_completion_port.rs b/rust/platform/src/io_completion_port.rs new file mode 100644 index 000000000..5bb332281 --- /dev/null +++ b/rust/platform/src/io_completion_port.rs @@ -0,0 +1,142 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::io; + +use winapi::{ + ctypes::c_void, + shared::{basetsd::ULONG_PTR, minwindef::DWORD}, + um::{ + errhandlingapi::GetLastError, + handleapi::{CloseHandle, INVALID_HANDLE_VALUE}, + ioapiset::CreateIoCompletionPort, + winnt::HANDLE, + }, +}; + +use crate::FileHandle; + +/// This module provides a safe and idiomatic Rust interface over the IOCompletionPort handle and associated Windows API functions. +/// This struct represents an I/O completion port, which is an object used in asynchronous I/O operations on Windows. +pub struct IOCompletionPort { + io_completion_port: HANDLE, +} + +impl IOCompletionPort { + /// Create a new IOCompletionPort. + /// This function wraps the Windows CreateIoCompletionPort function, providing error handling and automatic resource management. + /// + /// # Arguments + /// + /// * `file_handle` - A reference to a FileHandle to associate with the IOCompletionPort. + /// * `existing_completion_port` - An optional reference to an existing IOCompletionPort. If provided, the new IOCompletionPort will be associated with it. + /// * `completion_key` - The completion key associated with the file handle. + /// * `number_of_concurrent_threads` - The maximum number of threads that the operating system can allow to concurrently process I/O completion packets for the I/O completion port. + /// + /// # Return + /// + /// Returns a Result with the new IOCompletionPort if successful, or an io::Error if the function fails. + pub fn new( + file_handle: &FileHandle, + existing_completion_port: Option<&IOCompletionPort>, + completion_key: ULONG_PTR, + number_of_concurrent_threads: DWORD, + ) -> io::Result { + let io_completion_port = unsafe { + CreateIoCompletionPort( + file_handle.raw_handle(), + existing_completion_port + .map_or(std::ptr::null_mut::(), |io_completion_port| { + io_completion_port.raw_handle() + }), + completion_key, + number_of_concurrent_threads, + ) + }; + + if io_completion_port == INVALID_HANDLE_VALUE { + let error_code = unsafe { GetLastError() }; + return Err(io::Error::from_raw_os_error(error_code as i32)); + } + + Ok(IOCompletionPort { io_completion_port }) + } + + pub fn raw_handle(&self) -> HANDLE { + self.io_completion_port + } +} + +impl Drop for IOCompletionPort { + /// Drop method for IOCompletionPort. + /// This wraps the Windows CloseHandle function, providing automatic resource cleanup when the IOCompletionPort is dropped. + /// If an error occurs while dropping, it is logged and the drop continues. This is because panicking in Drop can cause unwinding issues. + fn drop(&mut self) { + let result = unsafe { CloseHandle(self.io_completion_port) }; + if result == 0 { + let error_code = unsafe { GetLastError() }; + let error = io::Error::from_raw_os_error(error_code as i32); + + // Only log the error if dropping the handle fails, since Rust's Drop trait does not support returning Result types from the drop method, + // and panicking in the drop method is considered bad practice + log::warn!("Error when dropping IOCompletionPort: {:?}", error); + } + } +} + +impl Default for IOCompletionPort { + /// Create a default IOCompletionPort, whose handle is set to INVALID_HANDLE_VALUE. + /// Returns a new IOCompletionPort with handle set to INVALID_HANDLE_VALUE. + fn default() -> Self { + Self { + io_completion_port: INVALID_HANDLE_VALUE, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::file_handle::{AccessMode, ShareMode}; + + #[test] + fn create_io_completion_port() { + let file_name = "../diskann/tests/data/delete_set_50pts.bin"; + let file_handle = unsafe { FileHandle::new(file_name, AccessMode::Read, ShareMode::Read) } + .expect("Failed to create file handle."); + + let io_completion_port = IOCompletionPort::new(&file_handle, None, 0, 0); + + assert!( + io_completion_port.is_ok(), + "Failed to create IOCompletionPort." + ); + } + + #[test] + fn drop_io_completion_port() { + let file_name = "../diskann/tests/data/delete_set_50pts.bin"; + let file_handle = unsafe { FileHandle::new(file_name, AccessMode::Read, ShareMode::Read) } + .expect("Failed to create file handle."); + + let io_completion_port = IOCompletionPort::new(&file_handle, None, 0, 0) + .expect("Failed to create IOCompletionPort."); + + // After this line, io_completion_port goes out of scope and its Drop trait will be called. + let _ = io_completion_port; + // We have no easy way to test that the Drop trait works correctly, but if it doesn't, + // a resource leak or other problem may become apparent in later tests or in real use of the code. + } + + #[test] + fn default_io_completion_port() { + let io_completion_port = IOCompletionPort::default(); + assert_eq!( + io_completion_port.raw_handle(), + INVALID_HANDLE_VALUE, + "Default IOCompletionPort did not have INVALID_HANDLE_VALUE." + ); + } +} + diff --git a/rust/platform/src/lib.rs b/rust/platform/src/lib.rs new file mode 100644 index 000000000..e28257078 --- /dev/null +++ b/rust/platform/src/lib.rs @@ -0,0 +1,20 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![cfg_attr( + not(test), + warn(clippy::panic, clippy::unwrap_used, clippy::expect_used) +)] + +pub mod perf; +pub use perf::{get_process_cycle_time, get_process_handle}; + +pub mod file_io; +pub use file_io::{get_queued_completion_status, read_file_to_slice}; + +pub mod file_handle; +pub use file_handle::FileHandle; + +pub mod io_completion_port; +pub use io_completion_port::IOCompletionPort; diff --git a/rust/platform/src/perf.rs b/rust/platform/src/perf.rs new file mode 100644 index 000000000..1ea146f9a --- /dev/null +++ b/rust/platform/src/perf.rs @@ -0,0 +1,50 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[cfg(target_os = "windows")] +#[link(name = "kernel32")] +extern "system" { + fn OpenProcess(dwDesiredAccess: u32, bInheritHandle: bool, dwProcessId: u32) -> usize; + fn QueryProcessCycleTime(hProcess: usize, lpCycleTime: *mut u64) -> bool; + fn GetCurrentProcessId() -> u32; +} + +/// Get current process handle. +pub fn get_process_handle() -> Option { + if cfg!(windows) { + const PROCESS_QUERY_INFORMATION: u32 = 0x0400; + const PROCESS_VM_READ: u32 = 0x0010; + + unsafe { + let current_process_id = GetCurrentProcessId(); + let handle = OpenProcess( + PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, + false, + current_process_id, + ); + if handle == 0 { + None + } else { + Some(handle) + } + } + } else { + None + } +} + +pub fn get_process_cycle_time(process_handle: Option) -> Option { + let mut cycle_time: u64 = 0; + if cfg!(windows) { + if let Some(handle) = process_handle { + let result = unsafe { QueryProcessCycleTime(handle, &mut cycle_time as *mut u64) }; + if result { + return Some(cycle_time); + } + } + } + + None +} + diff --git a/rust/project.code-workspace b/rust/project.code-workspace new file mode 100644 index 000000000..29bed0024 --- /dev/null +++ b/rust/project.code-workspace @@ -0,0 +1,58 @@ +{ + "folders": [ + { + "path": "." + } + ], + "settings": { + "search.exclude": { + "target": true, + }, + "files.exclude": { + "target": true, + }, + "rust-analyzer.linkedProjects": [ + ".\\vector\\Cargo.toml", + ".\\vector\\Cargo.toml", + ".\\vector\\Cargo.toml", + ".\\diskann\\Cargo.toml" + ], + "[rust]": { + "editor.defaultFormatter": "rust-lang.rust-analyzer", + "editor.formatOnSave": true, + } + }, + "launch": { + "version": "0.2.0", + "configurations": [ + { + "name": "Build memory index", + "type": "cppvsdbg", + "request": "launch", + "program": "${workspaceRoot}\\target\\debug\\build_memory_index.exe", + "args": [ + "--data_type", + "float", + "--dist_fn", + "l2", + "--data_path", + ".\\base1m.fbin", + "--index_path_prefix", + ".\\rust_index_sift_base_R32_L50_A1.2_T1", + "-R", + "64", + "-L", + "100", + "--alpha", + "1.2", + "-T", + "1" + ], + "stopAtEntry": false, + "cwd": "c:\\data", + "environment": [], + "externalConsole": true + }, + ] + } +} \ No newline at end of file diff --git a/rust/readme.md b/rust/readme.md new file mode 100644 index 000000000..a6c5a1bd4 --- /dev/null +++ b/rust/readme.md @@ -0,0 +1,25 @@ + +# readme + +run commands under disnann_rust directory. + +build: +``` +cargo build // Debug + +cargo build -r // Release +``` + + +run: +``` +cargo run // Debug + +cargo run -r // Release +``` + + +test: +``` +cargo test +``` diff --git a/rust/rust-toolchain.toml b/rust/rust-toolchain.toml new file mode 100644 index 000000000..183a72c9c --- /dev/null +++ b/rust/rust-toolchain.toml @@ -0,0 +1,4 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[toolchain] +channel = "stable" diff --git a/rust/vector/Cargo.toml b/rust/vector/Cargo.toml new file mode 100644 index 000000000..709a2905c --- /dev/null +++ b/rust/vector/Cargo.toml @@ -0,0 +1,24 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "vector" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +half = "2.2.1" +thiserror = "1.0.40" +bytemuck = "1.7.0" + +[build-dependencies] +cc = "1.0.79" + +[dev-dependencies] +base64 = "0.21.2" +bincode = "1.3.3" +serde = "1.0.163" +approx = "0.5.1" +rand = "0.8.5" + diff --git a/rust/vector/build.rs b/rust/vector/build.rs new file mode 100644 index 000000000..2d36c213c --- /dev/null +++ b/rust/vector/build.rs @@ -0,0 +1,29 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +fn main() { + println!("cargo:rerun-if-changed=distance.c"); + if cfg!(target_os = "macos") { + std::env::set_var("CFLAGS", "-mavx2 -mfma -Wno-error -MP -O2 -D NDEBUG -D MKL_ILP64 -D USE_AVX2 -D USE_ACCELERATED_PQ -D NOMINMAX -D _TARGET_ARM_APPLE_DARWIN"); + + cc::Build::new() + .file("distance.c") + .warnings_into_errors(true) + .debug(false) + .target("x86_64-apple-darwin") + .compile("nativefunctions.lib"); + } else { + std::env::set_var("CFLAGS", "/permissive- /MP /ifcOutput /GS- /W3 /Gy /Zi /Gm- /O2 /Ob2 /Zc:inline /fp:fast /D NDEBUG /D MKL_ILP64 /D USE_AVX2 /D USE_ACCELERATED_PQ /D NOMINMAX /fp:except- /errorReport:prompt /WX /openmp:experimental /Zc:forScope /GR /arch:AVX2 /Gd /Oy /Oi /MD /std:c++14 /FC /EHsc /nologo /Ot"); + // std::env::set_var("CFLAGS", "/permissive- /MP /ifcOutput /GS- /W3 /Gy /Zi /Gm- /Obd /Zc:inline /fp:fast /D DEBUG /D MKL_ILP64 /D USE_AVX2 /D USE_ACCELERATED_PQ /D NOMINMAX /fp:except- /errorReport:prompt /WX /openmp:experimental /Zc:forScope /GR /arch:AVX512 /Gd /Oy /Oi /MD /std:c++14 /FC /EHsc /nologo /Ot"); + + cc::Build::new() + .file("distance.c") + .warnings_into_errors(true) + .debug(false) + .compile("nativefunctions"); + + println!("cargo:rustc-link-arg=nativefunctions.lib"); + } +} + diff --git a/rust/vector/distance.c b/rust/vector/distance.c new file mode 100644 index 000000000..ee5333a53 --- /dev/null +++ b/rust/vector/distance.c @@ -0,0 +1,35 @@ +#include +#include + +inline __m256i load_128bit_to_256bit(const __m128i *ptr) +{ + __m128i value128 = _mm_loadu_si128(ptr); + __m256i value256 = _mm256_castsi128_si256(value128); + return _mm256_inserti128_si256(value256, _mm_setzero_si128(), 1); +} + +float distance_compare_avx512f_f16(const unsigned char *vec1, const unsigned char *vec2, size_t size) +{ + __m512 sum_squared_diff = _mm512_setzero_ps(); + + for (int i = 0; i < size / 16; i += 1) + { + __m512 v1 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(vec1 + i * 2 * 16))); + __m512 v2 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(vec2 + i * 2 * 16))); + + __m512 diff = _mm512_sub_ps(v1, v2); + sum_squared_diff = _mm512_fmadd_ps(diff, diff, sum_squared_diff); + } + + size_t i = (size / 16) * 16; + + if (i != size) + { + __m512 va = _mm512_cvtph_ps(load_128bit_to_256bit((const __m128i *)(vec1 + i * 2))); + __m512 vb = _mm512_cvtph_ps(load_128bit_to_256bit((const __m128i *)(vec2 + i * 2))); + __m512 diff512 = _mm512_sub_ps(va, vb); + sum_squared_diff = _mm512_fmadd_ps(diff512, diff512, sum_squared_diff); + } + + return _mm512_reduce_add_ps(sum_squared_diff); +} diff --git a/rust/vector/src/distance.rs b/rust/vector/src/distance.rs new file mode 100644 index 000000000..8ca6cb250 --- /dev/null +++ b/rust/vector/src/distance.rs @@ -0,0 +1,442 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use crate::l2_float_distance::{distance_l2_vector_f16, distance_l2_vector_f32}; +use crate::{Half, Metric}; + +/// Distance contract for full-precision vertex +pub trait FullPrecisionDistance { + /// Get the distance between vertex a and vertex b + fn distance_compare(a: &[T; N], b: &[T; N], vec_type: Metric) -> f32; +} + +// reason = "Not supported Metric type Metric::Cosine" +#[allow(clippy::panic)] +impl FullPrecisionDistance for [f32; N] { + /// Calculate distance between two f32 Vertex + #[inline(always)] + fn distance_compare(a: &[f32; N], b: &[f32; N], metric: Metric) -> f32 { + match metric { + Metric::L2 => distance_l2_vector_f32::(a, b), + _ => panic!("Not supported Metric type {:?}", metric), + } + } +} + +// reason = "Not supported Metric type Metric::Cosine" +#[allow(clippy::panic)] +impl FullPrecisionDistance for [Half; N] { + fn distance_compare(a: &[Half; N], b: &[Half; N], metric: Metric) -> f32 { + match metric { + Metric::L2 => distance_l2_vector_f16::(a, b), + _ => panic!("Not supported Metric type {:?}", metric), + } + } +} + +// reason = "Not yet supported Vector i8" +#[allow(clippy::panic)] +impl FullPrecisionDistance for [i8; N] { + fn distance_compare(_a: &[i8; N], _b: &[i8; N], _metric: Metric) -> f32 { + panic!("Not supported VectorType i8") + } +} + +// reason = "Not yet supported Vector u8" +#[allow(clippy::panic)] +impl FullPrecisionDistance for [u8; N] { + fn distance_compare(_a: &[u8; N], _b: &[u8; N], _metric: Metric) -> f32 { + panic!("Not supported VectorType u8") + } +} + +#[cfg(test)] +mod distance_test { + use super::*; + + #[repr(C, align(32))] + pub struct F32Slice112([f32; 112]); + + #[repr(C, align(32))] + pub struct F16Slice112([Half; 112]); + + fn get_turing_test_data() -> (F32Slice112, F32Slice112) { + let a_slice: [f32; 112] = [ + 0.13961786, + -0.031577103, + -0.09567415, + 0.06695563, + -0.1588727, + 0.089852564, + -0.019837005, + 0.07497972, + 0.010418192, + -0.054594643, + 0.08613386, + -0.05103466, + 0.16568437, + -0.02703799, + 0.00728657, + -0.15313251, + 0.16462992, + -0.030570814, + 0.11635703, + 0.23938893, + 0.018022912, + -0.12646551, + 0.018048918, + -0.035986554, + 0.031986624, + -0.015286017, + 0.010117953, + -0.032691937, + 0.12163067, + -0.04746277, + 0.010213069, + -0.043672588, + -0.099362016, + 0.06599016, + -0.19397286, + -0.13285528, + -0.22040887, + 0.017690737, + -0.104262285, + -0.0044555613, + -0.07383778, + -0.108652934, + 0.13399786, + 0.054912474, + 0.20181285, + 0.1795591, + -0.05425621, + -0.10765217, + 0.1405377, + -0.14101997, + -0.12017701, + 0.011565498, + 0.06952187, + 0.060136646, + 0.0023214167, + 0.04204699, + 0.048470616, + 0.17398086, + 0.024218207, + -0.15626553, + -0.11291045, + -0.09688122, + 0.14393932, + -0.14713104, + -0.108876854, + 0.035279203, + -0.05440188, + 0.017205412, + 0.011413814, + 0.04009471, + 0.11070237, + -0.058998976, + 0.07260045, + -0.057893746, + -0.0036240944, + -0.0064988653, + -0.13842176, + -0.023219328, + 0.0035885905, + -0.0719257, + -0.21335067, + 0.11415403, + -0.0059823603, + 0.12091869, + 0.08136634, + -0.10769281, + 0.024518685, + 0.0009200326, + -0.11628049, + 0.07448965, + 0.13736208, + -0.04144517, + -0.16426727, + -0.06380103, + -0.21386267, + 0.022373492, + -0.05874115, + 0.017314062, + -0.040344074, + 0.01059176, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ]; + let b_slice: [f32; 112] = [ + -0.07209058, + -0.17755842, + -0.030627966, + 0.163028, + -0.2233766, + 0.057412963, + 0.0076995124, + -0.017121306, + -0.015759075, + -0.026947778, + -0.010282468, + -0.23968373, + -0.021486737, + -0.09903155, + 0.09361805, + 0.0042711576, + -0.08695552, + -0.042165346, + 0.064218745, + -0.06707651, + 0.07846054, + 0.12235762, + -0.060716823, + 0.18496591, + -0.13023394, + 0.022469055, + 0.056764495, + 0.07168404, + -0.08856144, + -0.15343173, + 0.099879816, + -0.033529017, + 0.0795304, + -0.009242254, + -0.10254546, + 0.13086525, + -0.101518914, + -0.1031299, + -0.056826904, + 0.033196196, + 0.044143833, + -0.049787212, + -0.018148342, + -0.11172959, + -0.06776237, + -0.09185828, + -0.24171598, + 0.05080982, + -0.0727684, + 0.045031235, + -0.11363879, + -0.063389264, + 0.105850354, + -0.19847773, + 0.08828623, + -0.087071925, + 0.033512704, + 0.16118294, + 0.14111553, + 0.020884402, + -0.088860825, + 0.018745849, + 0.047522716, + -0.03665169, + 0.15726231, + -0.09930561, + 0.057844743, + -0.10532736, + -0.091297254, + 0.067029804, + 0.04153976, + 0.06393326, + 0.054578528, + 0.0038539872, + 0.1023088, + -0.10653885, + -0.108500294, + -0.046606563, + 0.020439683, + -0.120957725, + -0.13334097, + -0.13425854, + -0.20481694, + 0.07009538, + 0.08660361, + -0.0096641015, + 0.095316306, + -0.002898167, + -0.19680002, + 0.08466311, + 0.04812689, + -0.028978813, + 0.04780206, + -0.2001506, + -0.036866356, + -0.023720587, + 0.10731964, + 0.05517358, + -0.09580819, + 0.14595725, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ]; + + (F32Slice112(a_slice), F32Slice112(b_slice)) + } + + fn get_turing_test_data_f16() -> (F16Slice112, F16Slice112) { + let (a_slice, b_slice) = get_turing_test_data(); + let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x)); + let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x)); + + ( + F16Slice112(a_data.collect::>().try_into().unwrap()), + F16Slice112(b_data.collect::>().try_into().unwrap()), + ) + } + + use crate::test_util::*; + use approx::assert_abs_diff_eq; + + #[test] + fn test_dist_l2_float_turing() { + // two vectors are allocated in the contiguous heap memory + let (a_slice, b_slice) = get_turing_test_data(); + let distance = <[f32; 112] as FullPrecisionDistance>::distance_compare( + &a_slice.0, + &b_slice.0, + Metric::L2, + ); + + assert_abs_diff_eq!( + distance, + no_vector_compare_f32(&a_slice.0, &b_slice.0), + epsilon = 1e-6 + ); + } + + #[test] + fn test_dist_l2_f16_turing() { + // two vectors are allocated in the contiguous heap memory + let (a_slice, b_slice) = get_turing_test_data_f16(); + let distance = <[Half; 112] as FullPrecisionDistance>::distance_compare( + &a_slice.0, + &b_slice.0, + Metric::L2, + ); + + // Note the variance between the full 32 bit precision and the 16 bit precision + assert_eq!(distance, no_vector_compare_f16(&a_slice.0, &b_slice.0)); + } + + #[test] + fn distance_test() { + #[repr(C, align(32))] + struct Vector32ByteAligned { + v: [f32; 512], + } + + // two vectors are allocated in the contiguous heap memory + let two_vec = Box::new(Vector32ByteAligned { + v: [ + 69.02492, 78.84786, 63.125072, 90.90581, 79.2592, 70.81731, 3.0829668, 33.33287, + 20.777142, 30.147898, 23.681915, 42.553043, 12.602162, 7.3808074, 19.157589, + 65.6791, 76.44677, 76.89124, 86.40756, 84.70118, 87.86142, 16.126896, 5.1277637, + 95.11038, 83.946945, 22.735607, 11.548555, 59.51482, 24.84603, 15.573776, 78.27185, + 71.13179, 38.574017, 80.0228, 13.175261, 62.887978, 15.205181, 18.89392, 96.13162, + 87.55455, 34.179806, 62.920044, 4.9305916, 54.349373, 21.731495, 14.982187, + 40.262867, 20.15214, 36.61963, 72.450806, 55.565, 95.5375, 93.73356, 95.36308, + 66.30762, 58.0397, 18.951357, 67.11702, 43.043316, 30.65622, 99.85361, 2.5889993, + 27.844774, 39.72441, 46.463238, 71.303764, 90.45308, 36.390602, 63.344395, + 26.427078, 35.99528, 82.35505, 32.529175, 23.165905, 74.73179, 9.856939, 59.38126, + 35.714924, 79.81213, 46.704124, 24.47884, 36.01743, 0.46678782, 29.528152, + 1.8980742, 24.68853, 75.58984, 98.72279, 68.62601, 11.890173, 49.49361, 55.45572, + 72.71067, 34.107483, 51.357758, 76.400635, 81.32725, 66.45081, 17.848074, + 62.398876, 94.20444, 2.10886, 17.416393, 64.88253, 29.000723, 62.434315, 53.907238, + 70.51412, 78.70744, 55.181683, 64.45116, 23.419212, 53.68544, 43.506958, 46.89598, + 35.905994, 64.51397, 91.95555, 20.322979, 74.80128, 97.548744, 58.312725, 78.81985, + 31.911612, 14.445949, 49.85094, 70.87396, 40.06766, 7.129991, 78.48008, 75.21636, + 93.623604, 95.95479, 29.571129, 22.721554, 26.73875, 52.075504, 56.783104, + 94.65493, 61.778534, 85.72401, 85.369514, 29.922367, 41.410553, 94.12884, + 80.276855, 55.604828, 54.70947, 74.07216, 44.61955, 31.38113, 68.48596, 34.56782, + 14.424729, 48.204506, 9.675444, 32.01946, 92.32695, 36.292683, 78.31955, 98.05327, + 14.343918, 46.017002, 95.90888, 82.63626, 16.873539, 3.698051, 7.8042626, + 64.194405, 96.71023, 67.93692, 21.618402, 51.92182, 22.834194, 61.56986, 19.749891, + 55.31206, 38.29552, 67.57593, 67.145836, 38.92673, 94.95708, 72.38746, 90.70901, + 69.43995, 9.394085, 31.646872, 88.20112, 9.134722, 99.98214, 5.423498, 41.51995, + 76.94409, 77.373276, 3.2966614, 9.611201, 57.231106, 30.747868, 76.10228, 91.98308, + 70.893585, 0.9067178, 43.96515, 16.321218, 27.734184, 83.271835, 88.23312, + 87.16445, 5.556643, 15.627432, 58.547127, 93.6459, 40.539192, 49.124157, 91.13276, + 57.485855, 8.827019, 4.9690843, 46.511234, 53.91469, 97.71925, 20.135271, + 23.353004, 70.92099, 93.38748, 87.520134, 51.684677, 29.89813, 9.110392, 65.809204, + 34.16554, 93.398605, 84.58669, 96.409645, 9.876037, 94.767784, 99.21523, 1.9330144, + 94.92429, 75.12728, 17.218828, 97.89164, 35.476578, 77.629456, 69.573746, + 40.200542, 42.117836, 5.861628, 75.45282, 82.73633, 0.98086596, 77.24894, + 11.248695, 61.070026, 52.692616, 80.5449, 80.76036, 29.270136, 67.60252, 48.782394, + 95.18851, 83.47162, 52.068756, 46.66002, 90.12216, 15.515327, 33.694042, 96.963036, + 73.49627, 62.805485, 44.715607, 59.98627, 3.8921833, 37.565327, 29.69184, + 39.429665, 83.46899, 44.286453, 21.54851, 56.096413, 18.169249, 5.214751, + 14.691341, 99.779335, 26.32643, 67.69903, 36.41243, 67.27333, 12.157213, 96.18984, + 2.438283, 78.14289, 0.14715195, 98.769, 53.649532, 21.615898, 39.657497, 95.45616, + 18.578386, 71.47976, 22.348118, 17.85519, 6.3717127, 62.176777, 22.033644, + 23.178005, 79.44858, 89.70233, 37.21273, 71.86182, 21.284317, 52.908623, 30.095518, + 63.64478, 77.55823, 80.04871, 15.133011, 30.439043, 70.16561, 4.4014096, 89.28944, + 26.29093, 46.827854, 11.764729, 61.887516, 47.774887, 57.19503, 59.444664, + 28.592825, 98.70386, 1.2497544, 82.28431, 46.76423, 83.746124, 53.032673, 86.53457, + 99.42168, 90.184, 92.27852, 9.059965, 71.75723, 70.45299, 10.924053, 68.329704, + 77.27232, 6.677854, 75.63629, 57.370533, 17.09031, 10.554659, 99.56178, 37.53221, + 72.311104, 75.7565, 65.2042, 36.096478, 64.69502, 38.88497, 64.33723, 84.87812, + 66.84958, 8.508932, 79.134, 83.431015, 66.72124, 61.801838, 64.30524, 37.194263, + 77.94725, 89.705185, 23.643505, 19.505919, 48.40264, 43.01083, 21.171177, + 18.717121, 10.805857, 69.66983, 77.85261, 57.323063, 3.28964, 38.758026, 5.349946, + 7.46572, 57.485138, 30.822384, 33.9411, 95.53746, 65.57723, 42.1077, 28.591347, + 11.917269, 5.031073, 31.835615, 19.34116, 85.71027, 87.4516, 1.3798475, 70.70583, + 51.988052, 45.217144, 14.308596, 54.557167, 86.18323, 79.13666, 76.866745, + 46.010685, 79.739235, 44.667603, 39.36416, 72.605896, 73.83187, 13.137412, + 6.7911267, 63.952374, 10.082436, 86.00318, 99.760376, 92.84948, 63.786434, + 3.4429908, 18.244314, 75.65299, 14.964747, 70.126366, 80.89449, 91.266655, + 96.58798, 46.439327, 38.253975, 87.31036, 21.093178, 37.19671, 58.28973, 9.75231, + 12.350321, 25.75115, 87.65073, 53.610504, 36.850048, 18.66356, 94.48941, 83.71898, + 44.49315, 44.186737, 19.360733, 84.365974, 46.76272, 44.924366, 50.279808, + 54.868866, 91.33004, 18.683397, 75.13282, 15.070831, 47.04839, 53.780903, + 26.911152, 74.65651, 57.659935, 25.604189, 37.235474, 65.39667, 53.952206, + 40.37131, 59.173275, 96.00756, 54.591274, 10.787476, 69.51549, 31.970142, + 25.408005, 55.972492, 85.01888, 97.48981, 91.006134, 28.98619, 97.151276, + 34.388496, 47.498177, 11.985874, 64.73775, 33.877014, 13.370312, 34.79146, + 86.19321, 15.019405, 94.07832, 93.50433, 60.168625, 50.95409, 38.27827, 47.458614, + 32.83715, 69.54998, 69.0361, 84.1418, 34.270298, 74.23852, 70.707466, 78.59845, + 9.651399, 24.186779, 58.255756, 53.72362, 92.46477, 97.75528, 20.257462, 30.122698, + 50.41517, 28.156603, 42.644154, + ], + }); + + let distance = compare::(256, Metric::L2, &two_vec.v); + + assert_eq!(distance, 429141.2); + } + + fn compare(dim: usize, metric: Metric, v: &[f32]) -> f32 + where + for<'a> [T; N]: FullPrecisionDistance, + { + let a_ptr = v.as_ptr(); + let b_ptr = unsafe { a_ptr.add(dim) }; + + let a_ref = + <&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(a_ptr, dim) }).unwrap(); + let b_ref = + <&[f32; N]>::try_from(unsafe { std::slice::from_raw_parts(b_ptr, dim) }).unwrap(); + + <[f32; N]>::distance_compare(a_ref, b_ref, metric) + } +} diff --git a/rust/vector/src/distance_test.rs b/rust/vector/src/distance_test.rs new file mode 100644 index 000000000..0def0264a --- /dev/null +++ b/rust/vector/src/distance_test.rs @@ -0,0 +1,152 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[cfg(test)] +mod e2e_test { + + #[repr(C, align(32))] + pub struct F32Slice104([f32; 104]); + + #[repr(C, align(32))] + pub struct F16Slice104([Half; 104]); + + use approx::assert_abs_diff_eq; + + use crate::half::Half; + use crate::l2_float_distance::{distance_l2_vector_f16, distance_l2_vector_f32}; + + fn no_vector_compare_f32(a: &[f32], b: &[f32]) -> f32 { + let mut sum = 0.0; + for i in 0..a.len() { + let a_f32 = a[i]; + let b_f32 = b[i]; + let diff = a_f32 - b_f32; + sum += diff * diff; + } + sum + } + + fn no_vector_compare(a: &[Half], b: &[Half]) -> f32 { + let mut sum = 0.0; + for i in 0..a.len() { + let a_f32 = a[i].to_f32(); + let b_f32 = b[i].to_f32(); + let diff = a_f32 - b_f32; + sum += diff * diff; + } + sum + } + + #[test] + fn avx2_matches_novector() { + for i in 1..3 { + let (f1, f2) = get_test_data(0, i); + + let distance_f32x8 = distance_l2_vector_f32::<104>(&f1.0, &f2.0); + let distance = no_vector_compare_f32(&f1.0, &f2.0); + + assert_abs_diff_eq!(distance, distance_f32x8, epsilon = 1e-6); + } + } + + #[test] + fn avx2_matches_novector_random() { + let (f1, f2) = get_test_data_random(); + + let distance_f32x8 = distance_l2_vector_f32::<104>(&f1.0, &f2.0); + let distance = no_vector_compare_f32(&f1.0, &f2.0); + + assert_abs_diff_eq!(distance, distance_f32x8, epsilon = 1e-4); + } + + #[test] + fn avx_f16_matches_novector() { + for i in 1..3 { + let (f1, f2) = get_test_data_f16(0, i); + let _a_slice = f1.0.map(|x| x.to_f32().to_string()).join(", "); + let _b_slice = f2.0.map(|x| x.to_f32().to_string()).join(", "); + + let expected = no_vector_compare(f1.0[0..].as_ref(), f2.0[0..].as_ref()); + let distance_f16x8 = distance_l2_vector_f16::<104>(&f1.0, &f2.0); + + assert_abs_diff_eq!(distance_f16x8, expected, epsilon = 1e-4); + } + } + + #[test] + fn avx_f16_matches_novector_random() { + let (f1, f2) = get_test_data_f16_random(); + + let expected = no_vector_compare(f1.0[0..].as_ref(), f2.0[0..].as_ref()); + let distance_f16x8 = distance_l2_vector_f16::<104>(&f1.0, &f2.0); + + assert_abs_diff_eq!(distance_f16x8, expected, epsilon = 1e-4); + } + + fn get_test_data_f16(i1: usize, i2: usize) -> (F16Slice104, F16Slice104) { + let (a_slice, b_slice) = get_test_data(i1, i2); + let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x)); + let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x)); + + ( + F16Slice104(a_data.collect::>().try_into().unwrap()), + F16Slice104(b_data.collect::>().try_into().unwrap()), + ) + } + + fn get_test_data(i1: usize, i2: usize) -> (F32Slice104, F32Slice104) { + use base64::{engine::general_purpose, Engine as _}; + + let b64 = general_purpose::STANDARD.decode(TEST_DATA).unwrap(); + + let decoded: Vec> = bincode::deserialize(&b64).unwrap(); + debug_assert!(decoded.len() > i1); + debug_assert!(decoded.len() > i2); + + let mut f1 = F32Slice104([0.0; 104]); + let v1 = &decoded[i1]; + debug_assert!(v1.len() == 104); + f1.0.copy_from_slice(v1); + + let mut f2 = F32Slice104([0.0; 104]); + let v2 = &decoded[i2]; + debug_assert!(v2.len() == 104); + f2.0.copy_from_slice(v2); + + (f1, f2) + } + + fn get_test_data_f16_random() -> (F16Slice104, F16Slice104) { + let (a_slice, b_slice) = get_test_data_random(); + let a_data = a_slice.0.iter().map(|x| Half::from_f32(*x)); + let b_data = b_slice.0.iter().map(|x| Half::from_f32(*x)); + + ( + F16Slice104(a_data.collect::>().try_into().unwrap()), + F16Slice104(b_data.collect::>().try_into().unwrap()), + ) + } + + fn get_test_data_random() -> (F32Slice104, F32Slice104) { + use rand::Rng; + + let mut rng = rand::thread_rng(); + let mut f1 = F32Slice104([0.0; 104]); + + for i in 0..104 { + f1.0[i] = rng.gen_range(-1.0..1.0); + } + + let mut f2 = F32Slice104([0.0; 104]); + + for i in 0..104 { + f2.0[i] = rng.gen_range(-1.0..1.0); + } + + (f1, f2) + } + + const TEST_DATA: &str = "BQAAAAAAAABoAAAAAAAAAPz3Dj7+VgG9z/DDvQkgiT2GryK+nwS4PTeBorz4jpk9ELEqPKKeX73zZrA9uAlRvSqpKT7Gft28LsTuO8XOHL6/lCg+pW/6vJhM7j1fInU+yaSTPC2AAb5T25M8o2YTvWgEAz00cnq8xcUlPPvnBb2AGfk9UmhCvbdUJzwH4jK9UH7Lvdklhz3SoEa+NwsIvt2yYb4q7JA8d4fVvfX/kbtDOJe9boXevbw2CT7n62A9B6hOPlfeNz7CO169vnjcvR3pDz6KZxC+XR/2vTd9PTx7YY492FF2PekiGDt3OSw9IIlGPQooMj5DZcY8EgQgvpg9572paca91GQTPoWpFr7U+t697YAQPYHUXr1d8ow8AQE7PFo6JD3tt+I96ahxvYuvlD3+IW29N4Jtu2/01Ltvvg2+dja+vI8uazvITZO9mXhavpfJ6T2tB8S7OKT3PWWjpj0Mjty9advIPFgucTp3JO69CI6YPaWoDD5pwim9rjUovh2qgr3R/lq+nUi3PI+acL041o081D8lvRCJLTwAAAAAAAAAAAAAAAAAAAAAaAAAAAAAAAA6pJO94NE1voDn+rzQ8CY+1rxkvtspaz0xTPw7+0GMvC0ZgbyWwdy8zHcovKdvdb70BLC8DtHKvdK6vz0R9Ys7vBWyvZK1LL0ehYM9aV+JveuvoD2ilvo9NLJ4vbRnPT4MXAW+BhG4POOBaD0Vz5I9s1+1vTUdHb7Kjcw9uVUJvdbgoj3TbBe8WwPSvYoBBj4m6c+9xTXTvVTDaL28+Ac9KtA0Pa3tS73Vq5S8fNLkvf/Gir0yILy9ZYR3vvUdUD2ZB5W9rHI4PXS76L070oG9EsjYPb89S75pz7Q9xFKyvZ5ECT0kDSU+l4AQPsQVqzyq/LW95ZCZPC6nQj0VIBa9XwkhPr1gy72c7mw937XXvQ76ur3sRok9mCUqPXHvgj28jV89LZN8O0eH0T0KMdq9ZzXevYbmPr0fcac8r7j3vYmKCL4Sewm+iLtRviuOjz08XbE9LlYevDI1wz0s7z278oVJvtpjrT20IEU9+mTtvBjMQz1H9Ey+LQEXva1Rwrxmyts9sf1hPRY3xL3RdRU+AAAAAAAAAAAAAAAAAAAAAGgAAAAAAAAARqSTvbYJpLx1x869cW67PeeJhb7/cBu9m0eFPQO3oL0I+L49YQDavTYSez3SmTg96hBGPuh4oL2x2ow6WdCUO6XUSz4xcU88GReAvVfekj0Ph3Y9z43hvBzT5z1I2my9UVy3vAj8jL08Gtm9CfJcPRihTr1+8Yu9TiP+PNrJa77Dfa09IhpEPesJNr0XzFU8yye3PZKFyz3uzJ09FLRUvYq3l73X4X07DDUzvq9VXjwWtg8+JrzYPcFCkr0jDCg9T9zlvZbZjz4Y8pM89xo8PgAcfbvYSnY8XoFKvO05/L36yzE8J+5yPqfe5r2AZFq8ULRDvnkTgrw+S7q9qGYLvQDZYL1T8d09bFikvZw3+jsYLdO8H3GVveHBYT4gnsE8ZBIJPpzOEj7OSDC+ZYu+vFc1Erzko4M9GqLtPBHH5TwpeRs+miC4PBHH5Tw9Z9k9VUsUPjnppj0oC5C9mcqDvY7y1rxdvZU8PdFAPov9lz0bOmq94kdyPBBokTxtOj89fu4avSsazj1P7iE+x8YkPAAAAAAAAAAAAAAAAAAAAABoAAAAAAAAAHEruT3mgKM8JnEvvAsfHL63906+ifhgvldl1r14OeO9waUyuw3yUzx+PDW9UbDhPQP4Lb4KRRk+Oky2vaLfaT30mrA9YMeZPfzPMz4h42M+XfCHva4AGr6MOSM+iBOzvdsaE7xFxgI+gJGXvVMzE75kHY+8oAWNvVqNK7yOx589fU3lvVVPg730Cwk+DKkEPWYtxjqQ2MK9H0T+vTnGQj2yq5w8L49BvrEJrzyB4Yo9AXV7PYGCLr3MxsG9oWM7PTyu8TzEOhW+dyWrvUTxHD2nL+c9+VKFPcthhLsc0PM8FdyPPeLj/z1WAHS8ZvW2PGg4Cb5u3IU9g4CovSHW+L2CWoG++nZnPAi2ST3HmUC9P5rJuxQbU765lwU+7FLBPUPTfL0uGgk+yKy2PYwXaT1I4I+9AU6VPQ5QaDx9mdE8Qg8zPfGCUjzD/io9rr+BvTNDqT0MFNi9mHatvS1iJD0nVrK78WmIPE0QsL3PAQq9cMRgPWXmmr3yTcw9UcXrPccwa76+cBq+5iVOvUg9c70AAAAAAAAAAAAAAAAAAAAAaAAAAAAAAAB/K7k9hCsnPUJXJr2Wg4a9MEtXve33Sj0VJZ89pciEvWLqwLzUgyu8ADTGPAVenL2UZ/c96YtMved+Wr3LUro9H8a7vGTSA77C5n69Lf3pPQj4KD5cFKq9fZ0uvvYQCT7b23G9XGMCPrGuy736Z9A9kZzFPSuCSD7/9/07Y4/6POxLir3/JBS9qFKMvkSzjryPgVY+ugq8PC9yhbsXaiq+O6WfPcvFK7vZXAy+goAQvXpHHj5jwPI87eokvrySET5QoOm8h8ixOhXzKb5s8+A9sjcJPjiLAz598yQ9yCYSPq6eGz4rvjE82lvGvWuIOLx23zK9hHg8vTWOv70/Tse81fA6Pr2wNz34Eza+2Uj3PZ3trr0aXAI9PCkKPiybe721P9U9QkNLO927jT3LpRA+mpJUvUeU6rwC/Qa+lr4Cvgrpnj1pQ/i9TxhSvJqYr72RS6y8aQLTPQzPiz3vSRY94NfrPJl6LL2adjO8iYfPuhRzZz2f7R8+iVskPcUeXr12ZiI+nd3xvIYv8bwqYlg+AAAAAAAAAAAAAAAAAAAAAA=="; +} + diff --git a/rust/vector/src/half.rs b/rust/vector/src/half.rs new file mode 100644 index 000000000..87d7df6a1 --- /dev/null +++ b/rust/vector/src/half.rs @@ -0,0 +1,82 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use bytemuck::{Pod, Zeroable}; +use half::f16; +use std::convert::AsRef; +use std::fmt; + +// Define the Half type as a new type over f16. +// the memory layout of the Half struct will be the same as the memory layout of the f16 type itself. +// The Half struct serves as a simple wrapper around the f16 type and does not introduce any additional memory overhead. +// Test function: +// use half::f16; +// pub struct Half(f16); +// fn main() { +// let size_of_half = std::mem::size_of::(); +// let alignment_of_half = std::mem::align_of::(); +// println!("Size of Half: {} bytes", size_of_half); +// println!("Alignment of Half: {} bytes", alignment_of_half); +// } +// Output: +// Size of Half: 2 bytes +// Alignment of Half: 2 bytes +pub struct Half(f16); + +unsafe impl Pod for Half {} +unsafe impl Zeroable for Half {} + +// Implement From for Half +impl From for f32 { + fn from(val: Half) -> Self { + val.0.to_f32() + } +} + +// Implement AsRef for Half so that it can be used in distance_compare. +impl AsRef for Half { + fn as_ref(&self) -> &f16 { + &self.0 + } +} + +// Implement From for Half. +impl Half { + pub fn from_f32(value: f32) -> Self { + Self(f16::from_f32(value)) + } +} + +// Implement Default for Half. +impl Default for Half { + fn default() -> Self { + Self(f16::from_f32(Default::default())) + } +} + +// Implement Clone for Half. +impl Clone for Half { + fn clone(&self) -> Self { + Half(self.0) + } +} + +// Implement PartialEq for Half. +impl fmt::Debug for Half { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Half({:?})", self.0) + } +} + +impl Copy for Half {} + +impl Half { + pub fn to_f32(&self) -> f32 { + self.0.to_f32() + } +} + +unsafe impl Send for Half {} +unsafe impl Sync for Half {} + diff --git a/rust/vector/src/l2_float_distance.rs b/rust/vector/src/l2_float_distance.rs new file mode 100644 index 000000000..b818899bf --- /dev/null +++ b/rust/vector/src/l2_float_distance.rs @@ -0,0 +1,78 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] + +//! Distance calculation for L2 Metric + +#[cfg(not(target_feature = "avx2"))] +compile_error!("Library must be compiled with -C target-feature=+avx2"); + +use std::arch::x86_64::*; + +use crate::Half; + +/// Calculate the distance by vector arithmetic +#[inline(never)] +pub fn distance_l2_vector_f16(a: &[Half; N], b: &[Half; N]) -> f32 { + debug_assert_eq!(N % 8, 0); + + // make sure the addresses are bytes aligned + debug_assert_eq!(a.as_ptr().align_offset(32), 0); + debug_assert_eq!(b.as_ptr().align_offset(32), 0); + + unsafe { + let mut sum = _mm256_setzero_ps(); + let a_ptr = a.as_ptr() as *const __m128i; + let b_ptr = b.as_ptr() as *const __m128i; + + // Iterate over the elements in steps of 8 + for i in (0..N).step_by(8) { + let a_vec = _mm256_cvtph_ps(_mm_load_si128(a_ptr.add(i / 8))); + let b_vec = _mm256_cvtph_ps(_mm_load_si128(b_ptr.add(i / 8))); + + let diff = _mm256_sub_ps(a_vec, b_vec); + sum = _mm256_fmadd_ps(diff, diff, sum); + } + + let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(sum, 1), _mm256_castps256_ps128(sum)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + _mm_cvtss_f32(x32) + } +} + +/// Calculate the distance by vector arithmetic +#[inline(never)] +pub fn distance_l2_vector_f32(a: &[f32; N], b: &[f32; N]) -> f32 { + debug_assert_eq!(N % 8, 0); + + // make sure the addresses are bytes aligned + debug_assert_eq!(a.as_ptr().align_offset(32), 0); + debug_assert_eq!(b.as_ptr().align_offset(32), 0); + + unsafe { + let mut sum = _mm256_setzero_ps(); + + // Iterate over the elements in steps of 8 + for i in (0..N).step_by(8) { + let a_vec = _mm256_load_ps(&a[i]); + let b_vec = _mm256_load_ps(&b[i]); + let diff = _mm256_sub_ps(a_vec, b_vec); + sum = _mm256_fmadd_ps(diff, diff, sum); + } + + let x128: __m128 = _mm_add_ps(_mm256_extractf128_ps(sum, 1), _mm256_castps256_ps128(sum)); + /* ( -, -, x1+x3+x5+x7, x0+x2+x4+x6 ) */ + let x64: __m128 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); + /* ( -, -, -, x0+x1+x2+x3+x4+x5+x6+x7 ) */ + let x32: __m128 = _mm_add_ss(x64, _mm_shuffle_ps(x64, x64, 0x55)); + /* Conversion to float is a no-op on x86-64 */ + _mm_cvtss_f32(x32) + } +} + diff --git a/rust/vector/src/lib.rs b/rust/vector/src/lib.rs new file mode 100644 index 000000000..d221070b5 --- /dev/null +++ b/rust/vector/src/lib.rs @@ -0,0 +1,26 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![cfg_attr( + not(test), + warn(clippy::panic, clippy::unwrap_used, clippy::expect_used) +)] + +// #![feature(stdsimd)] +// mod f32x16; +// Uncomment above 2 to experiment with f32x16 +mod distance; +mod half; +mod l2_float_distance; +mod metric; +mod utils; + +pub use crate::half::Half; +pub use distance::FullPrecisionDistance; +pub use metric::Metric; +pub use utils::prefetch_vector; + +#[cfg(test)] +mod distance_test; +mod test_util; diff --git a/rust/vector/src/metric.rs b/rust/vector/src/metric.rs new file mode 100644 index 000000000..c60ef291b --- /dev/null +++ b/rust/vector/src/metric.rs @@ -0,0 +1,36 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#![warn(missing_debug_implementations, missing_docs)] +use std::str::FromStr; + +/// Distance metric +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum Metric { + /// Squared Euclidean (L2-Squared) + L2, + + /// Cosine similarity + /// TODO: T should be float for Cosine distance + Cosine, +} + +#[derive(thiserror::Error, Debug)] +pub enum ParseMetricError { + #[error("Invalid format for Metric: {0}")] + InvalidFormat(String), +} + +impl FromStr for Metric { + type Err = ParseMetricError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "l2" => Ok(Metric::L2), + "cosine" => Ok(Metric::Cosine), + _ => Err(ParseMetricError::InvalidFormat(String::from(s))), + } + } +} + diff --git a/rust/vector/src/test_util.rs b/rust/vector/src/test_util.rs new file mode 100644 index 000000000..7cfc92985 --- /dev/null +++ b/rust/vector/src/test_util.rs @@ -0,0 +1,29 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +#[cfg(test)] +use crate::Half; + +#[cfg(test)] +pub fn no_vector_compare_f16(a: &[Half], b: &[Half]) -> f32 { + let mut sum = 0.0; + debug_assert_eq!(a.len(), b.len()); + + for i in 0..a.len() { + sum += (a[i].to_f32() - b[i].to_f32()).powi(2); + } + sum +} + +#[cfg(test)] +pub fn no_vector_compare_f32(a: &[f32], b: &[f32]) -> f32 { + let mut sum = 0.0; + debug_assert_eq!(a.len(), b.len()); + + for i in 0..a.len() { + sum += (a[i] - b[i]).powi(2); + } + sum +} + diff --git a/rust/vector/src/utils.rs b/rust/vector/src/utils.rs new file mode 100644 index 000000000..a61c99aad --- /dev/null +++ b/rust/vector/src/utils.rs @@ -0,0 +1,21 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0}; + +/// Prefetch the given vector in chunks of 64 bytes, which is a cache line size +/// NOTE: good efficiency when total_vec_size is integral multiple of 64 +#[inline] +pub fn prefetch_vector(vec: &[T]) { + let vec_ptr = vec.as_ptr() as *const i8; + let vecsize = std::mem::size_of_val(vec); + let max_prefetch_size = (vecsize / 64) * 64; + + for d in (0..max_prefetch_size).step_by(64) { + unsafe { + _mm_prefetch(vec_ptr.add(d), _MM_HINT_T0); + } + } +} + diff --git a/rust/vector_base64/Cargo.toml b/rust/vector_base64/Cargo.toml new file mode 100644 index 000000000..6f50ad96e --- /dev/null +++ b/rust/vector_base64/Cargo.toml @@ -0,0 +1,15 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT license. +[package] +name = "vector_base64" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +base64 = "0.21.2" +bincode = "1.3.3" +half = "2.2.1" +serde = "1.0.163" + diff --git a/rust/vector_base64/src/main.rs b/rust/vector_base64/src/main.rs new file mode 100644 index 000000000..2867436a9 --- /dev/null +++ b/rust/vector_base64/src/main.rs @@ -0,0 +1,82 @@ +/* + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT license. + */ +use std::fs::File; +use std::io::{self, BufReader, Read}; +use std::{env, vec}; + +fn main() -> io::Result<()> { + // Retrieve command-line arguments + let args: Vec = env::args().collect(); + + // Check if the correct number of arguments is provided + if args.len() != 4 { + print_usage(); + return Ok(()); + } + + // Retrieve the input and output file paths from the arguments + let input_file_path = &args[1]; + let item_count: usize = args[2].parse::().unwrap(); + let return_dimension: usize = args[3].parse::().unwrap(); + + // Open the input file for reading + let mut input_file = BufReader::new(File::open(input_file_path)?); + + // Read the first 8 bytes as metadata + let mut metadata = [0; 8]; + input_file.read_exact(&mut metadata)?; + + // Extract the number of points and dimension from the metadata + let _ = i32::from_le_bytes(metadata[..4].try_into().unwrap()); + let mut dimension: usize = (i32::from_le_bytes(metadata[4..].try_into().unwrap())) as usize; + if return_dimension < dimension { + dimension = return_dimension; + } + + let mut float_array = Vec::>::with_capacity(item_count); + + // Process each data point + for _ in 0..item_count { + // Read one data point from the input file + let mut buffer = vec![0; dimension * std::mem::size_of::()]; + match input_file.read_exact(&mut buffer) { + Ok(()) => { + let mut float_data = buffer + .chunks_exact(4) + .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect::>(); + + let mut i = return_dimension; + while i > dimension { + float_data.push(0.0); + i -= 1; + } + + float_array.push(float_data); + } + Err(err) => { + println!("Error: {}", err); + break; + } + } + } + + use base64::{engine::general_purpose, Engine as _}; + + let encoded: Vec = bincode::serialize(&float_array).unwrap(); + let b64 = general_purpose::STANDARD.encode(encoded); + println!("Float {}", b64); + + Ok(()) +} + +/// Prints the usage information +fn print_usage() { + println!("Usage: program_name input_file "); + println!( + "Itemcount is the number of items to convert. Expand to dimension if provided is smaller" + ); +} +