From a60f076bd3af74d1877c6f81df80eaf68e1f8ff7 Mon Sep 17 00:00:00 2001 From: usamoi Date: Tue, 20 Feb 2024 17:32:38 +0800 Subject: [PATCH] refactor: rework vector abstraction Signed-off-by: usamoi --- .cargo/config.toml | 3 - Cargo.lock | 75 +- Cargo.toml | 13 +- crates/base/Cargo.toml | 14 +- crates/base/src/distance.rs | 9 + crates/base/src/global/mod.rs | 93 +++ .../src/global/svecf32.rs} | 84 ++- crates/base/src/global/svecf32_cos.rs | 102 +++ crates/base/src/global/svecf32_dot.rs | 102 +++ crates/base/src/global/svecf32_l2.rs | 98 +++ .../f16.rs => base/src/global/vecf16.rs} | 4 +- .../src/global/vecf16_cos.rs} | 85 +-- .../src/global/vecf16_dot.rs} | 89 +-- .../src/global/vecf16_l2.rs} | 85 +-- .../f32.rs => base/src/global/vecf32.rs} | 3 +- .../src/global/vecf32_cos.rs} | 84 +-- .../src/global/vecf32_dot.rs} | 88 +-- .../src/global/vecf32_l2.rs} | 86 +-- crates/base/src/index.rs | 390 ++++++++++ crates/base/src/lib.rs | 6 +- crates/base/src/scalar/f16.rs | 29 +- crates/base/src/scalar/f32.rs | 30 +- crates/base/src/scalar/mod.rs | 25 +- crates/base/src/search.rs | 44 ++ crates/base/src/sys.rs | 44 -- crates/base/src/vector/mod.rs | 57 +- crates/base/src/vector/sparse_f32.rs | 64 -- crates/base/src/vector/svecf32.rs | 183 +++++ crates/base/src/vector/vecf16.rs | 99 +++ crates/base/src/vector/vecf32.rs | 99 +++ crates/base/src/worker.rs | 43 ++ crates/c/Cargo.toml | 3 + crates/c/tests/x86_64.rs | 12 +- crates/detect/Cargo.toml | 3 + crates/interprocess-atomic-wait/Cargo.toml | 5 +- crates/interprocess-atomic-wait/src/lib.rs | 3 + crates/memfd/Cargo.toml | 5 +- crates/send_fd/Cargo.toml | 5 +- crates/service/Cargo.toml | 14 +- .../algorithms/clustering/elkan_k_means.rs | 29 +- crates/service/src/algorithms/flat.rs | 11 +- crates/service/src/algorithms/hnsw.rs | 23 +- .../service/src/algorithms/ivf/ivf_naive.rs | 36 +- crates/service/src/algorithms/ivf/ivf_pq.rs | 40 +- crates/service/src/algorithms/ivf/mod.rs | 30 +- .../src/algorithms/quantization/mod.rs | 57 +- .../src/algorithms/quantization/product.rs | 131 ++-- .../src/algorithms/quantization/scalar.rs | 33 +- .../src/algorithms/quantization/trivial.rs | 17 +- crates/service/src/algorithms/raw.rs | 12 +- crates/service/src/index/indexing/flat.rs | 25 +- crates/service/src/index/indexing/hnsw.rs | 42 +- crates/service/src/index/indexing/ivf.rs | 56 +- crates/service/src/index/indexing/mod.rs | 73 +- crates/service/src/index/mod.rs | 71 +- .../service/src/index/optimizing/indexing.rs | 3 +- crates/service/src/index/optimizing/mod.rs | 49 -- crates/service/src/index/segments/growing.rs | 12 +- crates/service/src/index/segments/mod.rs | 42 -- crates/service/src/index/segments/sealed.rs | 6 +- crates/service/src/instance/metadata.rs | 2 +- crates/service/src/instance/mod.rs | 332 +++++---- crates/service/src/lib.rs | 19 +- crates/service/src/prelude.rs | 33 + crates/service/src/prelude/global/mod.rs | 163 ----- .../src/prelude/global/sparse_f32_cos.rs | 104 --- .../src/prelude/global/sparse_f32_dot.rs | 104 --- .../src/prelude/global/sparse_f32_l2.rs | 100 --- crates/service/src/prelude/mod.rs | 13 - crates/service/src/prelude/storage/dense.rs | 65 -- crates/service/src/prelude/storage/mod.rs | 24 - crates/service/src/storage/dense.rs | 107 +++ crates/service/src/storage/mod.rs | 60 ++ .../src/{prelude => }/storage/sparse.rs | 23 +- crates/service/src/worker/mod.rs | 35 +- src/bgworker/mod.rs | 4 +- src/bgworker/normal.rs | 14 +- src/datatype/binary_svecf32.rs | 59 ++ src/datatype/binary_vecf16.rs | 45 ++ src/datatype/binary_vecf32.rs | 45 ++ src/datatype/{casts_f32.rs => casts.rs} | 45 +- src/datatype/functions.rs | 45 ++ src/datatype/memory_svecf32.rs | 181 +++++ src/datatype/memory_vecf16.rs | 162 +++++ src/datatype/memory_vecf32.rs | 158 ++++ src/datatype/mod.rs | 18 +- src/datatype/operators_svecf32.rs | 127 ++-- src/datatype/operators_vecf16.rs | 49 +- src/datatype/operators_vecf32.rs | 49 +- src/datatype/subscript_svecf32.rs | 202 ++++++ src/datatype/subscript_vecf16.rs | 185 +++++ src/datatype/subscript_vecf32.rs | 185 +++++ src/datatype/svecf32.rs | 673 ------------------ src/datatype/text_svecf32.rs | 96 +++ src/datatype/text_vecf16.rs | 41 ++ src/datatype/text_vecf32.rs | 41 ++ src/datatype/typmod.rs | 4 +- src/datatype/vecf16.rs | 516 -------------- src/datatype/vecf32.rs | 516 -------------- src/embedding/mod.rs | 6 +- src/gucs/executing.rs | 2 +- src/index/am.rs | 1 - src/index/am_build.rs | 1 - src/index/am_scan.rs | 3 +- src/index/am_setup.rs | 35 +- src/index/am_update.rs | 3 +- src/index/hook_transaction.rs | 1 - src/index/utils.rs | 29 +- src/index/views.rs | 2 - src/ipc/mod.rs | 10 +- src/lib.rs | 6 +- src/prelude/error.rs | 23 +- src/prelude/mod.rs | 8 + src/prelude/sys.rs | 2 +- src/sql/finalize.sql | 386 +++++----- 115 files changed, 4176 insertions(+), 4064 deletions(-) create mode 100644 crates/base/src/distance.rs create mode 100644 crates/base/src/global/mod.rs rename crates/{service/src/prelude/global/sparse_f32.rs => base/src/global/svecf32.rs} (57%) create mode 100644 crates/base/src/global/svecf32_cos.rs create mode 100644 crates/base/src/global/svecf32_dot.rs create mode 100644 crates/base/src/global/svecf32_l2.rs rename crates/{service/src/prelude/global/f16.rs => base/src/global/vecf16.rs} (98%) rename crates/{service/src/prelude/global/f16_cos.rs => base/src/global/vecf16_cos.rs} (78%) rename crates/{service/src/prelude/global/f16_dot.rs => base/src/global/vecf16_dot.rs} (71%) rename crates/{service/src/prelude/global/f16_l2.rs => base/src/global/vecf16_l2.rs} (72%) rename crates/{service/src/prelude/global/f32.rs => base/src/global/vecf32.rs} (97%) rename crates/{service/src/prelude/global/f32_cos.rs => base/src/global/vecf32_cos.rs} (77%) rename crates/{service/src/prelude/global/f32_dot.rs => base/src/global/vecf32_dot.rs} (70%) rename crates/{service/src/prelude/global/f32_l2.rs => base/src/global/vecf32_l2.rs} (70%) create mode 100644 crates/base/src/index.rs delete mode 100644 crates/base/src/sys.rs delete mode 100644 crates/base/src/vector/sparse_f32.rs create mode 100644 crates/base/src/vector/svecf32.rs create mode 100644 crates/base/src/vector/vecf16.rs create mode 100644 crates/base/src/vector/vecf32.rs create mode 100644 crates/base/src/worker.rs create mode 100644 crates/service/src/prelude.rs delete mode 100644 crates/service/src/prelude/global/mod.rs delete mode 100644 crates/service/src/prelude/global/sparse_f32_cos.rs delete mode 100644 crates/service/src/prelude/global/sparse_f32_dot.rs delete mode 100644 crates/service/src/prelude/global/sparse_f32_l2.rs delete mode 100644 crates/service/src/prelude/mod.rs delete mode 100644 crates/service/src/prelude/storage/dense.rs delete mode 100644 crates/service/src/prelude/storage/mod.rs create mode 100644 crates/service/src/storage/dense.rs create mode 100644 crates/service/src/storage/mod.rs rename crates/service/src/{prelude => }/storage/sparse.rs (79%) create mode 100644 src/datatype/binary_svecf32.rs create mode 100644 src/datatype/binary_vecf16.rs create mode 100644 src/datatype/binary_vecf32.rs rename src/datatype/{casts_f32.rs => casts.rs} (58%) create mode 100644 src/datatype/functions.rs create mode 100644 src/datatype/memory_svecf32.rs create mode 100644 src/datatype/memory_vecf16.rs create mode 100644 src/datatype/memory_vecf32.rs create mode 100644 src/datatype/subscript_svecf32.rs create mode 100644 src/datatype/subscript_vecf16.rs create mode 100644 src/datatype/subscript_vecf32.rs delete mode 100644 src/datatype/svecf32.rs create mode 100644 src/datatype/text_svecf32.rs create mode 100644 src/datatype/text_vecf16.rs create mode 100644 src/datatype/text_vecf32.rs delete mode 100644 src/datatype/vecf16.rs delete mode 100644 src/datatype/vecf32.rs diff --git a/.cargo/config.toml b/.cargo/config.toml index c4db64902..13c456b5d 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -1,6 +1,3 @@ -[build] -rustdocflags = ["--document-private-items"] - [target.'cfg(target_os="macos")'] # Postgres symbols won't be available until runtime rustflags = ["-Clink-arg=-Wl,-undefined,dynamic_lookup"] diff --git a/Cargo.lock b/Cargo.lock index 8b546b074..d7dfe3da3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -319,7 +319,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -372,27 +372,15 @@ dependencies = [ name = "base" version = "0.0.0" dependencies = [ - "arc-swap", - "bincode", "bytemuck", - "byteorder", "c", - "crc32fast", - "crossbeam", - "dashmap", "detect", "half 2.3.1", "libc", - "log", - "memmap2", "multiversion", "num-traits", - "parking_lot", "rand", - "rayon", - "rustix", "serde", - "serde_json", "thiserror", "uuid", "validator", @@ -441,7 +429,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -510,9 +498,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.0" +version = "3.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d32a994c2b3ca201d9b263612a374263f05e7adde37c4707f693dcd375076d1f" +checksum = "c764d619ca78fccbf3069b37bd7af92577f044bb15236036662d79b6559f25b7" [[package]] name = "bytemuck" @@ -531,7 +519,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -568,12 +556,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.83" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" -dependencies = [ - "libc", -] +checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730" [[package]] name = "cexpr" @@ -640,7 +625,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -929,7 +914,7 @@ checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -1152,7 +1137,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -2599,9 +2584,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] @@ -2618,20 +2603,20 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] name = "serde_json" -version = "1.0.113" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ "itoa", "ryu", @@ -2678,16 +2663,11 @@ dependencies = [ "bincode", "bytemuck", "byteorder", - "c", "crc32fast", "crossbeam", "dashmap", - "detect", - "half 2.3.1", - "libc", "log", "memmap2", - "multiversion", "num-traits", "parking_lot", "rand", @@ -2846,9 +2826,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.49" +version = "2.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915aea9e586f80826ee59f8453c1101f9d1c4b3964cd2460185ee8e299ada496" +checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb" dependencies = [ "proc-macro2", "quote", @@ -2949,7 +2929,7 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3002,7 +2982,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3167,9 +3147,9 @@ checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" dependencies = [ "tinyvec", ] @@ -3279,7 +3259,6 @@ dependencies = [ "detect", "embedding", "env_logger", - "half 2.3.1", "httpmock", "interprocess_atomic_wait", "libc", @@ -3368,7 +3347,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", "wasm-bindgen-shared", ] @@ -3402,7 +3381,7 @@ checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3604,9 +3583,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.6.1" +version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d90f4e0f530c4c69f62b80d839e9ef3855edc9cba471a160c4d692deed62b401" +checksum = "7a4191c47f15cc3ec71fcb4913cb83d58def65dd3787610213c649283b5ce178" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 0f3f96bfe..de85e0490 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,10 @@ edition.workspace = true [lib] crate-type = ["cdylib", "lib"] +[[bin]] +name = "pgrx_embed_vectors" +path = "./src/bin/pgrx_embed.rs" + [features] default = ["pg15"] pg14 = ["pgrx/pg14", "pgrx-tests/pg14"] @@ -18,7 +22,6 @@ arrayvec.workspace = true bincode.workspace = true bytemuck.workspace = true byteorder.workspace = true -half.workspace = true libc.workspace = true log.workspace = true memmap2.workspace = true @@ -50,9 +53,6 @@ pgrx = { git = "https://github.com/tensorchord/pgrx.git", branch = "v0.11.3-patc pgrx-tests = { git = "https://github.com/tensorchord/pgrx.git", branch = "v0.11.3-patch" } [lints] -clippy.needless_range_loop = "allow" -clippy.len_zero = "allow" -clippy.too_many_arguments = "allow" rust.unsafe_op_in_unsafe_fn = "deny" rust.unused_lifetimes = "warn" rust.unused_qualifications = "warn" @@ -90,6 +90,11 @@ thiserror = "~1.0" uuid = { version = "1.7.0", features = ["v4", "serde"] } validator = { version = "~0.16", features = ["derive"] } +[workspace.lints] +rust.unsafe_op_in_unsafe_fn = "forbid" +rust.unused_lifetimes = "warn" +rust.unused_qualifications = "warn" + [profile.dev] panic = "unwind" diff --git a/crates/base/Cargo.toml b/crates/base/Cargo.toml index 79c2f8d59..6fc38d840 100644 --- a/crates/base/Cargo.toml +++ b/crates/base/Cargo.toml @@ -4,35 +4,25 @@ version.workspace = true edition.workspace = true [dependencies] -bincode.workspace = true bytemuck.workspace = true -byteorder.workspace = true half.workspace = true libc.workspace = true -log.workspace = true -memmap2.workspace = true num-traits.workspace = true rand.workspace = true -rustix.workspace = true serde.workspace = true -serde_json.workspace = true thiserror.workspace = true uuid.workspace = true validator.workspace = true c = { path = "../c" } detect = { path = "../detect" } -crc32fast = "1.4.0" -crossbeam = "0.8.4" -dashmap = "5.5.3" -parking_lot = "0.12.1" -rayon = "1.8.1" -arc-swap = "1.6.0" multiversion = "0.7.3" [lints] clippy.derivable_impls = "allow" clippy.len_without_is_empty = "allow" +clippy.len_zero = "allow" clippy.needless_range_loop = "allow" +clippy.nonminimal_bool = "allow" clippy.too_many_arguments = "allow" rust.internal_features = "allow" rust.unsafe_op_in_unsafe_fn = "forbid" diff --git a/crates/base/src/distance.rs b/crates/base/src/distance.rs new file mode 100644 index 000000000..c9129bb86 --- /dev/null +++ b/crates/base/src/distance.rs @@ -0,0 +1,9 @@ +use serde::{Deserialize, Serialize}; + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum DistanceKind { + L2, + Cos, + Dot, +} diff --git a/crates/base/src/global/mod.rs b/crates/base/src/global/mod.rs new file mode 100644 index 000000000..00fa2939e --- /dev/null +++ b/crates/base/src/global/mod.rs @@ -0,0 +1,93 @@ +mod svecf32; +mod svecf32_cos; +mod svecf32_dot; +mod svecf32_l2; +mod vecf16; +mod vecf16_cos; +mod vecf16_dot; +mod vecf16_l2; +mod vecf32; +mod vecf32_cos; +mod vecf32_dot; +mod vecf32_l2; + +pub use svecf32_cos::SVecf32Cos; +pub use svecf32_dot::SVecf32Dot; +pub use svecf32_l2::SVecf32L2; +pub use vecf16_cos::Vecf16Cos; +pub use vecf16_dot::Vecf16Dot; +pub use vecf16_l2::Vecf16L2; +pub use vecf32_cos::Vecf32Cos; +pub use vecf32_dot::Vecf32Dot; +pub use vecf32_l2::Vecf32L2; + +use crate::distance::*; +use crate::scalar::*; +use crate::vector::*; + +pub trait GlobalElkanKMeans: Global { + fn elkan_k_means_normalize(vector: &mut [Scalar]); + fn elkan_k_means_normalize2(vector: &mut Self::VectorOwned); + fn elkan_k_means_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32; + fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar]) -> F32; +} + +pub trait GlobalScalarQuantization: Global { + fn scalar_quantization_distance( + dims: u16, + max: &[Scalar], + min: &[Scalar], + lhs: Borrowed<'_, Self>, + rhs: &[u8], + ) -> F32; + fn scalar_quantization_distance2( + dims: u16, + max: &[Scalar], + min: &[Scalar], + lhs: &[u8], + rhs: &[u8], + ) -> F32; +} + +pub trait GlobalProductQuantization: Global { + type ProductQuantizationL2: Global + + GlobalElkanKMeans + + GlobalProductQuantization; + fn product_quantization_distance( + dims: u16, + ratio: u16, + centroids: &[Scalar], + lhs: Borrowed<'_, Self>, + rhs: &[u8], + ) -> F32; + fn product_quantization_distance2( + dims: u16, + ratio: u16, + centroids: &[Scalar], + lhs: &[u8], + rhs: &[u8], + ) -> F32; + fn product_quantization_distance_with_delta( + dims: u16, + ratio: u16, + centroids: &[Scalar], + lhs: Borrowed<'_, Self>, + rhs: &[u8], + delta: &[Scalar], + ) -> F32; + fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32; + fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32; +} + +pub trait Global: Copy + 'static { + type VectorOwned: VectorOwned; + + const VECTOR_KIND: VectorKind; + const DISTANCE_KIND: DistanceKind; + + fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32; +} + +pub type Owned = ::VectorOwned; +pub type Borrowed<'a, T> = <::VectorOwned as VectorOwned>::Borrowed<'a>; +pub type Scalar = <::VectorOwned as VectorOwned>::Scalar; diff --git a/crates/service/src/prelude/global/sparse_f32.rs b/crates/base/src/global/svecf32.rs similarity index 57% rename from crates/service/src/prelude/global/sparse_f32.rs rename to crates/base/src/global/svecf32.rs index 0490b8b29..ee15fa420 100644 --- a/crates/service/src/prelude/global/sparse_f32.rs +++ b/crates/base/src/global/svecf32.rs @@ -1,4 +1,7 @@ -use crate::prelude::*; +use super::SVecf32Owned; +use crate::scalar::*; +use crate::vector::*; +use num_traits::{Float, Zero}; #[inline(always)] #[multiversion::multiversion(targets( @@ -7,19 +10,19 @@ use crate::prelude::*; "x86_64/x86-64-v2", "aarch64+neon" ))] -pub fn cosine<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { +pub fn cosine<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { let mut lhs_pos = 0; let mut rhs_pos = 0; - let size1 = lhs.length() as usize; - let size2 = rhs.length() as usize; + let size1 = lhs.len() as usize; + let size2 = rhs.len() as usize; let mut xy = F32::zero(); let mut x2 = F32::zero(); let mut y2 = F32::zero(); while lhs_pos < size1 && rhs_pos < size2 { - let lhs_index = lhs.indexes[lhs_pos]; - let rhs_index = rhs.indexes[rhs_pos]; - let lhs_value = lhs.values[lhs_pos]; - let rhs_value = rhs.values[rhs_pos]; + let lhs_index = lhs.indexes()[lhs_pos]; + let rhs_index = rhs.indexes()[rhs_pos]; + let lhs_value = lhs.values()[lhs_pos]; + let rhs_value = rhs.values()[rhs_pos]; xy += F32((lhs_index == rhs_index) as u32 as f32) * lhs_value * rhs_value; x2 += F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value * lhs_value; y2 += F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value * rhs_value; @@ -27,10 +30,10 @@ pub fn cosine<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { rhs_pos += (lhs_index >= rhs_index) as usize; } for i in lhs_pos..size1 { - x2 += lhs.values[i] * lhs.values[i]; + x2 += lhs.values()[i] * lhs.values()[i]; } for i in rhs_pos..size2 { - y2 += rhs.values[i] * rhs.values[i]; + y2 += rhs.values()[i] * rhs.values()[i]; } xy / (x2 * y2).sqrt() } @@ -42,17 +45,17 @@ pub fn cosine<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { "x86_64/x86-64-v2", "aarch64+neon" ))] -pub fn dot<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { +pub fn dot<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { let mut lhs_pos = 0; let mut rhs_pos = 0; - let size1 = lhs.length() as usize; - let size2 = rhs.length() as usize; + let size1 = lhs.len() as usize; + let size2 = rhs.len() as usize; let mut xy = F32::zero(); while lhs_pos < size1 && rhs_pos < size2 { - let lhs_index = lhs.indexes[lhs_pos]; - let rhs_index = rhs.indexes[rhs_pos]; - let lhs_value = lhs.values[lhs_pos]; - let rhs_value = rhs.values[rhs_pos]; + let lhs_index = lhs.indexes()[lhs_pos]; + let rhs_index = rhs.indexes()[rhs_pos]; + let lhs_value = lhs.values()[lhs_pos]; + let rhs_value = rhs.values()[rhs_pos]; xy += F32((lhs_index == rhs_index) as u32 as f32) * lhs_value * rhs_value; lhs_pos += (lhs_index <= rhs_index) as usize; rhs_pos += (lhs_index >= rhs_index) as usize; @@ -67,10 +70,10 @@ pub fn dot<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { "x86_64/x86-64-v2", "aarch64+neon" ))] -pub fn dot_2<'a>(lhs: SparseF32Ref<'a>, rhs: &[F32]) -> F32 { +pub fn dot_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 { let mut xy = F32::zero(); - for i in 0..lhs.indexes.len() { - xy += lhs.values[i] * rhs[lhs.indexes[i] as usize]; + for i in 0..lhs.len() as usize { + xy += lhs.values()[i] * rhs[lhs.indexes()[i] as usize]; } xy } @@ -82,17 +85,17 @@ pub fn dot_2<'a>(lhs: SparseF32Ref<'a>, rhs: &[F32]) -> F32 { "x86_64/x86-64-v2", "aarch64+neon" ))] -pub fn sl2<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { +pub fn sl2<'a>(lhs: SVecf32Borrowed<'a>, rhs: SVecf32Borrowed<'a>) -> F32 { let mut lhs_pos = 0; let mut rhs_pos = 0; - let size1 = lhs.length() as usize; - let size2 = rhs.length() as usize; + let size1 = lhs.len() as usize; + let size2 = rhs.len() as usize; let mut d2 = F32::zero(); while lhs_pos < size1 && rhs_pos < size2 { - let lhs_index = lhs.indexes[lhs_pos]; - let rhs_index = rhs.indexes[rhs_pos]; - let lhs_value = lhs.values[lhs_pos]; - let rhs_value = rhs.values[rhs_pos]; + let lhs_index = lhs.indexes()[lhs_pos]; + let rhs_index = rhs.indexes()[rhs_pos]; + let lhs_value = lhs.values()[lhs_pos]; + let rhs_value = rhs.values()[rhs_pos]; let d = F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value - F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value; d2 += d * d; @@ -100,10 +103,10 @@ pub fn sl2<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { rhs_pos += (lhs_index >= rhs_index) as usize; } for i in lhs_pos..size1 { - d2 += lhs.values[i] * lhs.values[i]; + d2 += lhs.values()[i] * lhs.values()[i]; } for i in rhs_pos..size2 { - d2 += rhs.values[i] * rhs.values[i]; + d2 += rhs.values()[i] * rhs.values()[i]; } d2 } @@ -115,13 +118,14 @@ pub fn sl2<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { "x86_64/x86-64-v2", "aarch64+neon" ))] -pub fn sl2_2<'a>(lhs: SparseF32Ref<'a>, rhs: &[F32]) -> F32 { +pub fn sl2_2<'a>(lhs: SVecf32Borrowed<'a>, rhs: &[F32]) -> F32 { let mut d2 = F32::zero(); let mut lhs_pos: u16 = 0; let mut rhs_pos: u16 = 0; - while lhs_pos < lhs.length() { - let index_eq = lhs.indexes[lhs_pos as usize] == rhs_pos; - let d = F32(index_eq as u32 as f32) * lhs.values[lhs_pos as usize] - rhs[rhs_pos as usize]; + while lhs_pos < lhs.len() { + let index_eq = lhs.indexes()[lhs_pos as usize] == rhs_pos; + let d = + F32(index_eq as u32 as f32) * lhs.values()[lhs_pos as usize] - rhs[rhs_pos as usize]; d2 += d * d; lhs_pos += index_eq as u16; rhs_pos += 1; @@ -139,9 +143,9 @@ pub fn sl2_2<'a>(lhs: SparseF32Ref<'a>, rhs: &[F32]) -> F32 { "x86_64/x86-64-v2", "aarch64+neon" ))] -pub fn length<'a>(vector: SparseF32Ref<'a>) -> F32 { +pub fn length<'a>(vector: SVecf32Borrowed<'a>) -> F32 { let mut dot = F32::zero(); - for &i in vector.values { + for &i in vector.values() { dot += i * i; } dot.sqrt() @@ -154,9 +158,13 @@ pub fn length<'a>(vector: SparseF32Ref<'a>) -> F32 { "x86_64/x86-64-v2", "aarch64+neon" ))] -pub fn l2_normalize(vector: &mut SparseF32) { - let l = length(SparseF32Ref::from(vector as &SparseF32)); - for i in vector.values.iter_mut() { +pub fn l2_normalize(vector: &mut SVecf32Owned) { + let l = length(vector.for_borrow()); + let dims = vector.dims(); + let indexes = vector.indexes().to_vec(); + let mut values = vector.values().to_vec(); + for i in values.iter_mut() { *i /= l; } + *vector = SVecf32Owned::new(dims, indexes, values); } diff --git a/crates/base/src/global/svecf32_cos.rs b/crates/base/src/global/svecf32_cos.rs new file mode 100644 index 000000000..09d0cf05d --- /dev/null +++ b/crates/base/src/global/svecf32_cos.rs @@ -0,0 +1,102 @@ +use super::*; +use crate::distance::*; +use crate::scalar::*; +use crate::vector::*; +use num_traits::Float; + +#[derive(Debug, Clone, Copy)] +pub enum SVecf32Cos {} + +impl Global for SVecf32Cos { + type VectorOwned = SVecf32Owned; + + const VECTOR_KIND: VectorKind = VectorKind::SVecf32; + const DISTANCE_KIND: DistanceKind = DistanceKind::Cos; + + fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { + F32(1.0) - super::svecf32::cosine(lhs, rhs) + } +} + +impl GlobalElkanKMeans for SVecf32Cos { + fn elkan_k_means_normalize(vector: &mut [Scalar]) { + super::vecf32::l2_normalize(vector) + } + + fn elkan_k_means_normalize2(vector: &mut SVecf32Owned) { + super::svecf32::l2_normalize(vector) + } + + fn elkan_k_means_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::dot(lhs, rhs).acos() + } + + fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar]) -> F32 { + super::svecf32::dot_2(lhs, rhs).acos() + } +} + +impl GlobalScalarQuantization for SVecf32Cos { + fn scalar_quantization_distance( + _dims: u16, + _max: &[F32], + _min: &[F32], + _lhs: Borrowed<'_, Self>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn scalar_quantization_distance2( + _dims: u16, + _max: &[Scalar], + _min: &[Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } +} + +impl GlobalProductQuantization for SVecf32Cos { + type ProductQuantizationL2 = SVecf32L2; + + fn product_quantization_distance( + _dims: u16, + _ratio: u16, + _centroids: &[Scalar], + _lhs: Borrowed<'_, Self>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance2( + _dims: u16, + _ratio: u16, + _centroids: &[Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance_with_delta( + _dims: u16, + _ratio: u16, + _centroids: &[Scalar], + _lhs: Borrowed<'_, Self>, + _rhs: &[u8], + _delta: &[Scalar], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::sl2(lhs, rhs) + } + + fn product_quantization_dense_distance(_: &[Scalar], _: &[Scalar]) -> F32 { + unimplemented!() + } +} diff --git a/crates/base/src/global/svecf32_dot.rs b/crates/base/src/global/svecf32_dot.rs new file mode 100644 index 000000000..4aca703d1 --- /dev/null +++ b/crates/base/src/global/svecf32_dot.rs @@ -0,0 +1,102 @@ +use super::*; +use crate::distance::*; +use crate::scalar::*; +use crate::vector::*; +use num_traits::Float; + +#[derive(Debug, Clone, Copy)] +pub enum SVecf32Dot {} + +impl Global for SVecf32Dot { + type VectorOwned = SVecf32Owned; + + const VECTOR_KIND: VectorKind = VectorKind::SVecf32; + const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; + + fn distance(lhs: Borrowed<'_, Self>, rhs: Borrowed<'_, Self>) -> F32 { + super::svecf32::dot(lhs, rhs) * (-1.0) + } +} + +impl GlobalElkanKMeans for SVecf32Dot { + fn elkan_k_means_normalize(vector: &mut [Scalar]) { + super::vecf32::l2_normalize(vector) + } + + fn elkan_k_means_normalize2(vector: &mut SVecf32Owned) { + super::svecf32::l2_normalize(vector) + } + + fn elkan_k_means_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::dot(lhs, rhs).acos() + } + + fn elkan_k_means_distance2(lhs: Borrowed<'_, Self>, rhs: &[Scalar]) -> F32 { + super::svecf32::dot_2(lhs, rhs).acos() + } +} + +impl GlobalScalarQuantization for SVecf32Dot { + fn scalar_quantization_distance( + _dims: u16, + _max: &[Scalar], + _min: &[Scalar], + _lhs: Borrowed<'_, Self>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn scalar_quantization_distance2( + _dims: u16, + _max: &[Scalar], + _min: &[Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } +} + +impl GlobalProductQuantization for SVecf32Dot { + type ProductQuantizationL2 = SVecf32L2; + + fn product_quantization_distance( + _dims: u16, + _ratio: u16, + _centroids: &[Scalar], + _lhs: Borrowed<'_, Self>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance2( + _dims: u16, + _ratio: u16, + _centroids: &[Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance_with_delta( + _dims: u16, + _ratio: u16, + _centroids: &[Scalar], + _lhs: Borrowed<'_, Self>, + _rhs: &[u8], + _delta: &[Scalar], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::sl2(lhs, rhs) + } + + fn product_quantization_dense_distance(_: &[Scalar], _: &[Scalar]) -> F32 { + unimplemented!() + } +} diff --git a/crates/base/src/global/svecf32_l2.rs b/crates/base/src/global/svecf32_l2.rs new file mode 100644 index 000000000..90b9a2bb8 --- /dev/null +++ b/crates/base/src/global/svecf32_l2.rs @@ -0,0 +1,98 @@ +use super::*; +use crate::distance::*; +use crate::scalar::*; +use crate::vector::*; +use num_traits::Float; + +#[derive(Debug, Clone, Copy)] +pub enum SVecf32L2 {} + +impl Global for SVecf32L2 { + type VectorOwned = SVecf32Owned; + + const VECTOR_KIND: VectorKind = VectorKind::SVecf32; + const DISTANCE_KIND: DistanceKind = DistanceKind::L2; + + fn distance(lhs: SVecf32Borrowed<'_>, rhs: SVecf32Borrowed<'_>) -> F32 { + super::svecf32::sl2(lhs, rhs) + } +} + +impl GlobalElkanKMeans for SVecf32L2 { + fn elkan_k_means_normalize(_: &mut [Scalar]) {} + + fn elkan_k_means_normalize2(_: &mut SVecf32Owned) {} + + fn elkan_k_means_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::sl2(lhs, rhs).sqrt() + } + + fn elkan_k_means_distance2(lhs: SVecf32Borrowed<'_>, rhs: &[Scalar]) -> F32 { + super::svecf32::sl2_2(lhs, rhs).sqrt() + } +} + +impl GlobalScalarQuantization for SVecf32L2 { + fn scalar_quantization_distance( + _dims: u16, + _max: &[Scalar], + _min: &[Scalar], + _lhs: SVecf32Borrowed<'_>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn scalar_quantization_distance2( + _dims: u16, + _max: &[Scalar], + _min: &[Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } +} + +impl GlobalProductQuantization for SVecf32L2 { + type ProductQuantizationL2 = SVecf32L2; + + fn product_quantization_distance( + _dims: u16, + _ratio: u16, + _centroids: &[Scalar], + _lhs: SVecf32Borrowed<'_>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance2( + _dims: u16, + _ratio: u16, + _centroids: &[Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance_with_delta( + _dims: u16, + _ratio: u16, + _centroids: &[Scalar], + _lhs: SVecf32Borrowed<'_>, + _rhs: &[u8], + _delta: &[Scalar], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::sl2(lhs, rhs) + } + + fn product_quantization_dense_distance(_: &[Scalar], _: &[Scalar]) -> F32 { + unimplemented!() + } +} diff --git a/crates/service/src/prelude/global/f16.rs b/crates/base/src/global/vecf16.rs similarity index 98% rename from crates/service/src/prelude/global/f16.rs rename to crates/base/src/global/vecf16.rs index be5c560d2..f225bcc2f 100644 --- a/crates/service/src/prelude/global/f16.rs +++ b/crates/base/src/global/vecf16.rs @@ -1,5 +1,5 @@ -use crate::prelude::*; -use base::scalar::FloatCast; +use crate::scalar::*; +use num_traits::{Float, Zero}; pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { #[inline(always)] diff --git a/crates/service/src/prelude/global/f16_cos.rs b/crates/base/src/global/vecf16_cos.rs similarity index 78% rename from crates/service/src/prelude/global/f16_cos.rs rename to crates/base/src/global/vecf16_cos.rs index 64c648d03..b6b2ed2f6 100644 --- a/crates/service/src/prelude/global/f16_cos.rs +++ b/crates/base/src/global/vecf16_cos.rs @@ -1,69 +1,56 @@ -use crate::prelude::*; -use base::scalar::FloatCast; -use std::borrow::Cow; +use super::*; +use crate::distance::*; +use crate::scalar::*; +use crate::vector::*; +use num_traits::{Float, Zero}; #[derive(Debug, Clone, Copy)] -pub enum F16Cos {} +pub enum Vecf16Cos {} -impl G for F16Cos { - type Scalar = F16; - type Storage = DenseMmap; - type L2 = F16L2; - type VectorOwned = Vec; - type VectorRef<'a> = &'a [F16]; +impl Global for Vecf16Cos { + type VectorOwned = Vecf16Owned; - const DISTANCE: Distance = Distance::Cos; - const KIND: Kind = Kind::F16; + const VECTOR_KIND: VectorKind = VectorKind::Vecf16; + const DISTANCE_KIND: DistanceKind = DistanceKind::Cos; - fn owned_to_ref(vector: &Vec) -> &[F16] { - vector - } - - fn ref_to_owned(vector: &[F16]) -> Vec { - vector.to_vec() - } - - fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F16]> { - Cow::Borrowed(vector) - } - - fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { - F32(1.0) - super::f16::cosine(lhs, rhs) - } - - fn distance2(lhs: &[F16], rhs: &[F16]) -> F32 { - F32(1.0) - super::f16::cosine(lhs, rhs) + fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 { + F32(1.0) - super::vecf16::cosine(lhs.slice(), rhs.slice()) } +} +impl GlobalElkanKMeans for Vecf16Cos { fn elkan_k_means_normalize(vector: &mut [F16]) { - super::f16::l2_normalize(vector) + super::vecf16::l2_normalize(vector) } - fn elkan_k_means_normalize2(vector: &mut Vec) { - super::f16::l2_normalize(vector) + fn elkan_k_means_normalize2(vector: &mut Vecf16Owned) { + super::vecf16::l2_normalize(vector.slice_mut()) } fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::dot(lhs, rhs).acos() + super::vecf16::dot(lhs, rhs).acos() } - fn elkan_k_means_distance2(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::dot(lhs, rhs).acos() + fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 { + super::vecf16::dot(lhs.slice(), rhs).acos() } +} +impl GlobalScalarQuantization for Vecf16Cos { #[multiversion::multiversion(targets( "x86_64/x86-64-v4", "x86_64/x86-64-v3", "x86_64/x86-64-v2", "aarch64+neon" ))] - fn scalar_quantization_distance( + fn scalar_quantization_distance<'a>( dims: u16, max: &[F16], min: &[F16], - lhs: &[F16], + lhs: Vecf16Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let mut xy = F32::zero(); let mut x2 = F32::zero(); let mut y2 = F32::zero(); @@ -102,6 +89,10 @@ impl G for F16Cos { } F32(1.0) - xy / (x2 * y2).sqrt() } +} + +impl GlobalProductQuantization for Vecf16Cos { + type ProductQuantizationL2 = Vecf16L2; #[multiversion::multiversion(targets( "x86_64/x86-64-v4", @@ -109,13 +100,14 @@ impl G for F16Cos { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance( + fn product_quantization_distance<'a>( dims: u16, ratio: u16, centroids: &[F16], - lhs: &[F16], + lhs: Vecf16Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut xy = F32::zero(); let mut x2 = F32::zero(); @@ -170,14 +162,15 @@ impl G for F16Cos { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance_with_delta( + fn product_quantization_distance_with_delta<'a>( dims: u16, ratio: u16, centroids: &[F16], - lhs: &[F16], + lhs: Vecf16Borrowed<'a>, rhs: &[u8], delta: &[F16], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut xy = F32::zero(); let mut x2 = F32::zero(); @@ -195,6 +188,14 @@ impl G for F16Cos { } F32(1.0) - xy / (x2 * y2).sqrt() } + + fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf16::sl2(lhs, rhs) + } + + fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + F32(1.0) - super::vecf16::cosine(lhs, rhs) + } } #[inline(always)] diff --git a/crates/service/src/prelude/global/f16_dot.rs b/crates/base/src/global/vecf16_dot.rs similarity index 71% rename from crates/service/src/prelude/global/f16_dot.rs rename to crates/base/src/global/vecf16_dot.rs index 0d210a421..affa5573e 100644 --- a/crates/service/src/prelude/global/f16_dot.rs +++ b/crates/base/src/global/vecf16_dot.rs @@ -1,69 +1,56 @@ -use crate::prelude::*; -use base::scalar::FloatCast; -use std::borrow::Cow; +use super::*; +use crate::distance::*; +use crate::scalar::*; +use crate::vector::*; +use num_traits::{Float, Zero}; #[derive(Debug, Clone, Copy)] -pub enum F16Dot {} +pub enum Vecf16Dot {} -impl G for F16Dot { - type Scalar = F16; - type Storage = DenseMmap; - type L2 = F16L2; - type VectorOwned = Vec; - type VectorRef<'a> = &'a [F16]; +impl Global for Vecf16Dot { + type VectorOwned = Vecf16Owned; - const DISTANCE: Distance = Distance::Dot; - const KIND: Kind = Kind::F16; + const VECTOR_KIND: VectorKind = VectorKind::Vecf16; + const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; - fn owned_to_ref(vector: &Vec) -> &[F16] { - vector - } - - fn ref_to_owned(vector: &[F16]) -> Vec { - vector.to_vec() - } - - fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F16]> { - Cow::Borrowed(vector) - } - - fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::dot(lhs, rhs) * (-1.0) - } - - fn distance2(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::dot(lhs, rhs) * (-1.0) + fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 { + super::vecf16::dot(lhs.slice(), rhs.slice()) * (-1.0) } +} +impl GlobalElkanKMeans for Vecf16Dot { fn elkan_k_means_normalize(vector: &mut [F16]) { - super::f16::l2_normalize(vector) + super::vecf16::l2_normalize(vector) } - fn elkan_k_means_normalize2(vector: &mut Vec) { - super::f16::l2_normalize(vector) + fn elkan_k_means_normalize2(vector: &mut Vecf16Owned) { + super::vecf16::l2_normalize(vector.slice_mut()) } fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::dot(lhs, rhs).acos() + super::vecf16::dot(lhs, rhs).acos() } - fn elkan_k_means_distance2(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::dot(lhs, rhs).acos() + fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 { + super::vecf16::dot(lhs.slice(), rhs).acos() } +} +impl GlobalScalarQuantization for Vecf16Dot { #[multiversion::multiversion(targets( "x86_64/x86-64-v4", "x86_64/x86-64-v3", "x86_64/x86-64-v2", "aarch64+neon" ))] - fn scalar_quantization_distance( + fn scalar_quantization_distance<'a>( dims: u16, max: &[F16], min: &[F16], - lhs: &[F16], + lhs: Vecf16Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let mut xy = F32::zero(); for i in 0..dims as usize { let _x = lhs[i].to_f(); @@ -94,6 +81,10 @@ impl G for F16Dot { } xy * (-1.0) } +} + +impl GlobalProductQuantization for Vecf16Dot { + type ProductQuantizationL2 = Vecf16L2; #[multiversion::multiversion(targets( "x86_64/x86-64-v4", @@ -101,13 +92,14 @@ impl G for F16Dot { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance( + fn product_quantization_distance<'a>( dims: u16, ratio: u16, centroids: &[F16], - lhs: &[F16], + lhs: Vecf16Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut xy = F32::zero(); for i in 0..width { @@ -115,7 +107,7 @@ impl G for F16Dot { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = super::f16::dot(lhs, rhs); + let _xy = super::vecf16::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -142,7 +134,7 @@ impl G for F16Dot { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = super::f16::dot(lhs, rhs); + let _xy = super::vecf16::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -154,14 +146,15 @@ impl G for F16Dot { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance_with_delta( + fn product_quantization_distance_with_delta<'a>( dims: u16, ratio: u16, centroids: &[F16], - lhs: &[F16], + lhs: Vecf16Borrowed<'a>, rhs: &[u8], delta: &[F16], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut xy = F32::zero(); for i in 0..width { @@ -175,6 +168,14 @@ impl G for F16Dot { } xy * (-1.0) } + + fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf16::sl2(lhs, rhs) + } + + fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf16::dot(lhs, rhs) * (-1.0) + } } #[inline(always)] diff --git a/crates/service/src/prelude/global/f16_l2.rs b/crates/base/src/global/vecf16_l2.rs similarity index 72% rename from crates/service/src/prelude/global/f16_l2.rs rename to crates/base/src/global/vecf16_l2.rs index 14e7bddbb..bb4fba973 100644 --- a/crates/service/src/prelude/global/f16_l2.rs +++ b/crates/base/src/global/vecf16_l2.rs @@ -1,65 +1,52 @@ -use crate::prelude::*; -use base::scalar::FloatCast; -use std::borrow::Cow; +use super::*; +use crate::distance::*; +use crate::scalar::*; +use crate::vector::*; +use num_traits::{Float, Zero}; #[derive(Debug, Clone, Copy)] -pub enum F16L2 {} +pub enum Vecf16L2 {} -impl G for F16L2 { - type Scalar = F16; - type Storage = DenseMmap; - type L2 = F16L2; - type VectorOwned = Vec; - type VectorRef<'a> = &'a [F16]; +impl Global for Vecf16L2 { + type VectorOwned = Vecf16Owned; - const DISTANCE: Distance = Distance::L2; - const KIND: Kind = Kind::F16; + const VECTOR_KIND: VectorKind = VectorKind::Vecf16; + const DISTANCE_KIND: DistanceKind = DistanceKind::L2; - fn owned_to_ref(vector: &Vec) -> &[F16] { - vector - } - - fn ref_to_owned(vector: &[F16]) -> Vec { - vector.to_vec() - } - - fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F16]> { - Cow::Borrowed(vector) - } - - fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::sl2(lhs, rhs) - } - - fn distance2(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::sl2(lhs, rhs) + fn distance(lhs: Vecf16Borrowed<'_>, rhs: Vecf16Borrowed<'_>) -> F32 { + super::vecf16::sl2(lhs.slice(), rhs.slice()) } +} +impl GlobalElkanKMeans for Vecf16L2 { fn elkan_k_means_normalize(_: &mut [F16]) {} - fn elkan_k_means_normalize2(_: &mut Vec) {} + fn elkan_k_means_normalize2(_: &mut Vecf16Owned) {} fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::sl2(lhs, rhs).sqrt() + super::vecf16::sl2(lhs, rhs).sqrt() } - fn elkan_k_means_distance2(lhs: &[F16], rhs: &[F16]) -> F32 { - super::f16::sl2(lhs, rhs).sqrt() + fn elkan_k_means_distance2(lhs: Vecf16Borrowed<'_>, rhs: &[F16]) -> F32 { + super::vecf16::sl2(lhs.slice(), rhs).sqrt() } +} +impl GlobalScalarQuantization for Vecf16L2 { #[multiversion::multiversion(targets( "x86_64/x86-64-v4", "x86_64/x86-64-v3", "x86_64/x86-64-v2", "aarch64+neon" ))] - fn scalar_quantization_distance( + fn scalar_quantization_distance<'a>( dims: u16, max: &[F16], min: &[F16], - lhs: &[F16], + lhs: Vecf16Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let mut result = F32::zero(); for i in 0..dims as usize { let _x = lhs[i].to_f(); @@ -90,6 +77,10 @@ impl G for F16L2 { } result } +} + +impl GlobalProductQuantization for Vecf16L2 { + type ProductQuantizationL2 = Vecf16L2; #[multiversion::multiversion(targets( "x86_64/x86-64-v4", @@ -97,13 +88,14 @@ impl G for F16L2 { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance( + fn product_quantization_distance<'a>( dims: u16, ratio: u16, centroids: &[F16], - lhs: &[F16], + lhs: Vecf16Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut result = F32::zero(); for i in 0..width { @@ -111,7 +103,7 @@ impl G for F16L2 { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += super::f16::sl2(lhs, rhs); + result += super::vecf16::sl2(lhs, rhs); } result } @@ -137,7 +129,7 @@ impl G for F16L2 { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += super::f16::sl2(lhs, rhs); + result += super::vecf16::sl2(lhs, rhs); } result } @@ -148,14 +140,15 @@ impl G for F16L2 { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance_with_delta( + fn product_quantization_distance_with_delta<'a>( dims: u16, ratio: u16, centroids: &[F16], - lhs: &[F16], + lhs: Vecf16Borrowed<'a>, rhs: &[u8], delta: &[F16], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut result = F32::zero(); for i in 0..width { @@ -168,6 +161,14 @@ impl G for F16L2 { } result } + + fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf16::sl2(lhs, rhs) + } + + fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf16::sl2(lhs, rhs) + } } #[inline(always)] diff --git a/crates/service/src/prelude/global/f32.rs b/crates/base/src/global/vecf32.rs similarity index 97% rename from crates/service/src/prelude/global/f32.rs rename to crates/base/src/global/vecf32.rs index 962c50f39..5e873643b 100644 --- a/crates/service/src/prelude/global/f32.rs +++ b/crates/base/src/global/vecf32.rs @@ -1,4 +1,5 @@ -use crate::prelude::*; +use crate::scalar::*; +use num_traits::{Float, Zero}; #[inline(always)] #[multiversion::multiversion(targets( diff --git a/crates/service/src/prelude/global/f32_cos.rs b/crates/base/src/global/vecf32_cos.rs similarity index 77% rename from crates/service/src/prelude/global/f32_cos.rs rename to crates/base/src/global/vecf32_cos.rs index c9e75f92a..2c4799dac 100644 --- a/crates/service/src/prelude/global/f32_cos.rs +++ b/crates/base/src/global/vecf32_cos.rs @@ -1,68 +1,56 @@ -use crate::prelude::*; -use std::borrow::Cow; +use super::*; +use crate::distance::*; +use crate::scalar::*; +use crate::vector::*; +use num_traits::{Float, Zero}; #[derive(Debug, Clone, Copy)] -pub enum F32Cos {} +pub enum Vecf32Cos {} -impl G for F32Cos { - type Scalar = F32; - type Storage = DenseMmap; - type L2 = F32L2; - type VectorOwned = Vec; - type VectorRef<'a> = &'a [F32]; +impl Global for Vecf32Cos { + type VectorOwned = Vecf32Owned; - const DISTANCE: Distance = Distance::Cos; - const KIND: Kind = Kind::F32; + const VECTOR_KIND: VectorKind = VectorKind::Vecf32; + const DISTANCE_KIND: DistanceKind = DistanceKind::Cos; - fn owned_to_ref(vector: &Vec) -> &[F32] { - vector - } - - fn ref_to_owned(vector: &[F32]) -> Vec { - vector.to_vec() - } - - fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { - Cow::Borrowed(vector) - } - - fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { - F32(1.0) - super::f32::cosine(lhs, rhs) - } - - fn distance2(lhs: &[F32], rhs: &[F32]) -> F32 { - F32(1.0) - super::f32::cosine(lhs, rhs) + fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 { + F32(1.0) - super::vecf32::cosine(lhs.slice(), rhs.slice()) } +} +impl GlobalElkanKMeans for Vecf32Cos { fn elkan_k_means_normalize(vector: &mut [F32]) { - super::f32::l2_normalize(vector) + super::vecf32::l2_normalize(vector) } - fn elkan_k_means_normalize2(vector: &mut Vec) { - super::f32::l2_normalize(vector) + fn elkan_k_means_normalize2(vector: &mut Vecf32Owned) { + super::vecf32::l2_normalize(vector.slice_mut()) } fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32::dot(lhs, rhs).acos() + super::vecf32::dot(lhs, rhs).acos() } - fn elkan_k_means_distance2(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32::dot(lhs, rhs).acos() + fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[F32]) -> F32 { + super::vecf32::dot(lhs.slice(), rhs).acos() } +} +impl GlobalScalarQuantization for Vecf32Cos { #[multiversion::multiversion(targets( "x86_64/x86-64-v4", "x86_64/x86-64-v3", "x86_64/x86-64-v2", "aarch64+neon" ))] - fn scalar_quantization_distance( + fn scalar_quantization_distance<'a>( dims: u16, max: &[F32], min: &[F32], - lhs: &[F32], + lhs: Vecf32Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let mut xy = F32::zero(); let mut x2 = F32::zero(); let mut y2 = F32::zero(); @@ -101,6 +89,10 @@ impl G for F32Cos { } F32(1.0) - xy / (x2 * y2).sqrt() } +} + +impl GlobalProductQuantization for Vecf32Cos { + type ProductQuantizationL2 = Vecf32L2; #[multiversion::multiversion(targets( "x86_64/x86-64-v4", @@ -108,13 +100,14 @@ impl G for F32Cos { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance( + fn product_quantization_distance<'a>( dims: u16, ratio: u16, centroids: &[F32], - lhs: &[F32], + lhs: Vecf32Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut xy = F32::zero(); let mut x2 = F32::zero(); @@ -169,14 +162,15 @@ impl G for F32Cos { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance_with_delta( + fn product_quantization_distance_with_delta<'a>( dims: u16, ratio: u16, centroids: &[F32], - lhs: &[F32], + lhs: Vecf32Borrowed<'a>, rhs: &[u8], delta: &[F32], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut xy = F32::zero(); let mut x2 = F32::zero(); @@ -194,6 +188,14 @@ impl G for F32Cos { } F32(1.0) - xy / (x2 * y2).sqrt() } + + fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::sl2(lhs, rhs) + } + + fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + F32(1.0) - super::vecf32::cosine(lhs, rhs) + } } #[inline(always)] diff --git a/crates/service/src/prelude/global/f32_dot.rs b/crates/base/src/global/vecf32_dot.rs similarity index 70% rename from crates/service/src/prelude/global/f32_dot.rs rename to crates/base/src/global/vecf32_dot.rs index d4f58632b..7d82d9348 100644 --- a/crates/service/src/prelude/global/f32_dot.rs +++ b/crates/base/src/global/vecf32_dot.rs @@ -1,68 +1,56 @@ -use crate::prelude::*; -use std::borrow::Cow; +use super::*; +use crate::distance::*; +use crate::scalar::*; +use crate::vector::*; +use num_traits::{Float, Zero}; #[derive(Debug, Clone, Copy)] -pub enum F32Dot {} +pub enum Vecf32Dot {} -impl G for F32Dot { - type Scalar = F32; - type Storage = DenseMmap; - type L2 = F32L2; - type VectorOwned = Vec; - type VectorRef<'a> = &'a [F32]; +impl Global for Vecf32Dot { + type VectorOwned = Vecf32Owned; - const DISTANCE: Distance = Distance::Dot; - const KIND: Kind = Kind::F32; + const DISTANCE_KIND: DistanceKind = DistanceKind::Dot; + const VECTOR_KIND: VectorKind = VectorKind::Vecf32; - fn owned_to_ref(vector: &Vec) -> &[F32] { - vector - } - - fn ref_to_owned(vector: &[F32]) -> Vec { - vector.to_vec() - } - - fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { - Cow::Borrowed(vector) - } - - fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32::dot(lhs, rhs) * (-1.0) - } - - fn distance2(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32::dot(lhs, rhs) * (-1.0) + fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 { + super::vecf32::dot(lhs.slice(), rhs.slice()) * (-1.0) } +} +impl GlobalElkanKMeans for Vecf32Dot { fn elkan_k_means_normalize(vector: &mut [F32]) { - super::f32::l2_normalize(vector) + super::vecf32::l2_normalize(vector) } - fn elkan_k_means_normalize2(vector: &mut Vec) { - super::f32::l2_normalize(vector) + fn elkan_k_means_normalize2(vector: &mut Vecf32Owned) { + super::vecf32::l2_normalize(vector.slice_mut()) } fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32::dot(lhs, rhs).acos() + super::vecf32::dot(lhs, rhs).acos() } - fn elkan_k_means_distance2(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32::dot(lhs, rhs).acos() + fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[F32]) -> F32 { + super::vecf32::dot(lhs.slice(), rhs).acos() } +} +impl GlobalScalarQuantization for Vecf32Dot { #[multiversion::multiversion(targets( "x86_64/x86-64-v4", "x86_64/x86-64-v3", "x86_64/x86-64-v2", "aarch64+neon" ))] - fn scalar_quantization_distance( + fn scalar_quantization_distance<'a>( dims: u16, max: &[F32], min: &[F32], - lhs: &[F32], + lhs: Vecf32Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let mut xy = F32::zero(); for i in 0..dims as usize { let _x = lhs[i]; @@ -93,6 +81,10 @@ impl G for F32Dot { } xy * (-1.0) } +} + +impl GlobalProductQuantization for Vecf32Dot { + type ProductQuantizationL2 = Vecf32L2; #[multiversion::multiversion(targets( "x86_64/x86-64-v4", @@ -100,13 +92,14 @@ impl G for F32Dot { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance( + fn product_quantization_distance<'a>( dims: u16, ratio: u16, centroids: &[F32], - lhs: &[F32], + lhs: Vecf32Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut xy = F32::zero(); for i in 0..width { @@ -114,7 +107,7 @@ impl G for F32Dot { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = super::f32::dot(lhs, rhs); + let _xy = super::vecf32::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -141,7 +134,7 @@ impl G for F32Dot { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = super::f32::dot(lhs, rhs); + let _xy = super::vecf32::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -153,14 +146,15 @@ impl G for F32Dot { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance_with_delta( + fn product_quantization_distance_with_delta<'a>( dims: u16, ratio: u16, centroids: &[F32], - lhs: &[F32], + lhs: Vecf32Borrowed<'a>, rhs: &[u8], delta: &[F32], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut xy = F32::zero(); for i in 0..width { @@ -174,6 +168,14 @@ impl G for F32Dot { } xy * (-1.0) } + + fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::sl2(lhs, rhs) + } + + fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::dot(lhs, rhs) * (-1.0) + } } #[inline(always)] diff --git a/crates/service/src/prelude/global/f32_l2.rs b/crates/base/src/global/vecf32_l2.rs similarity index 70% rename from crates/service/src/prelude/global/f32_l2.rs rename to crates/base/src/global/vecf32_l2.rs index 815d67303..37ba6ab7b 100644 --- a/crates/service/src/prelude/global/f32_l2.rs +++ b/crates/base/src/global/vecf32_l2.rs @@ -1,64 +1,52 @@ -use crate::prelude::*; -use std::borrow::Cow; +use super::*; +use crate::distance::*; +use crate::scalar::*; +use crate::vector::*; +use num_traits::{Float, Zero}; #[derive(Debug, Clone, Copy)] -pub enum F32L2 {} +pub enum Vecf32L2 {} -impl G for F32L2 { - type Scalar = F32; - type Storage = DenseMmap; - type L2 = F32L2; - type VectorOwned = Vec; - type VectorRef<'a> = &'a [F32]; +impl Global for Vecf32L2 { + type VectorOwned = Vecf32Owned; - const DISTANCE: Distance = Distance::L2; - const KIND: Kind = Kind::F32; + const VECTOR_KIND: VectorKind = VectorKind::Vecf32; + const DISTANCE_KIND: DistanceKind = DistanceKind::L2; - fn owned_to_ref(vector: &Vec) -> &[F32] { - vector - } - - fn ref_to_owned(vector: &[F32]) -> Vec { - vector.to_vec() - } - - fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { - Cow::Borrowed(vector) - } - - fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32::sl2(lhs, rhs) - } - - fn distance2(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32::sl2(lhs, rhs) + fn distance(lhs: Vecf32Borrowed<'_>, rhs: Vecf32Borrowed<'_>) -> F32 { + super::vecf32::sl2(lhs.slice(), rhs.slice()) } +} +impl GlobalElkanKMeans for Vecf32L2 { fn elkan_k_means_normalize(_: &mut [F32]) {} - fn elkan_k_means_normalize2(_: &mut Vec) {} + fn elkan_k_means_normalize2(_: &mut Vecf32Owned) {} - fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32::sl2(lhs, rhs).sqrt() + fn elkan_k_means_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::sl2(lhs, rhs).sqrt() } - fn elkan_k_means_distance2(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32::sl2(lhs, rhs).sqrt() + fn elkan_k_means_distance2(lhs: Vecf32Borrowed<'_>, rhs: &[Scalar]) -> F32 { + super::vecf32::sl2(lhs.slice(), rhs).sqrt() } +} +impl GlobalScalarQuantization for Vecf32L2 { #[multiversion::multiversion(targets( "x86_64/x86-64-v4", "x86_64/x86-64-v3", "x86_64/x86-64-v2", "aarch64+neon" ))] - fn scalar_quantization_distance( + fn scalar_quantization_distance<'a>( dims: u16, max: &[F32], min: &[F32], - lhs: &[F32], + lhs: Vecf32Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let mut result = F32::zero(); for i in 0..dims as usize { let _x = lhs[i]; @@ -89,6 +77,10 @@ impl G for F32L2 { } result } +} + +impl GlobalProductQuantization for Vecf32L2 { + type ProductQuantizationL2 = Vecf32L2; #[multiversion::multiversion(targets( "x86_64/x86-64-v4", @@ -96,13 +88,14 @@ impl G for F32L2 { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance( + fn product_quantization_distance<'a>( dims: u16, ratio: u16, centroids: &[F32], - lhs: &[F32], + lhs: Vecf32Borrowed<'a>, rhs: &[u8], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut result = F32::zero(); for i in 0..width { @@ -110,7 +103,7 @@ impl G for F32L2 { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += super::f32::sl2(lhs, rhs); + result += super::vecf32::sl2(lhs, rhs); } result } @@ -136,7 +129,7 @@ impl G for F32L2 { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += super::f32::sl2(lhs, rhs); + result += super::vecf32::sl2(lhs, rhs); } result } @@ -147,14 +140,15 @@ impl G for F32L2 { "x86_64/x86-64-v2", "aarch64+neon" ))] - fn product_quantization_distance_with_delta( + fn product_quantization_distance_with_delta<'a>( dims: u16, ratio: u16, centroids: &[F32], - lhs: &[F32], + lhs: Vecf32Borrowed<'a>, rhs: &[u8], delta: &[F32], ) -> F32 { + let lhs = lhs.slice(); let width = dims.div_ceil(ratio); let mut result = F32::zero(); for i in 0..width { @@ -167,6 +161,14 @@ impl G for F32L2 { } result } + + fn product_quantization_l2_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::sl2(lhs, rhs) + } + + fn product_quantization_dense_distance(lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + super::vecf32::sl2(lhs, rhs) + } } #[inline(always)] diff --git a/crates/base/src/index.rs b/crates/base/src/index.rs new file mode 100644 index 000000000..c7a8a3825 --- /dev/null +++ b/crates/base/src/index.rs @@ -0,0 +1,390 @@ +use crate::distance::*; +use crate::vector::*; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; +use validator::{Validate, ValidationError}; + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +#[validate(schema(function = "IndexOptions::validate_index_options"))] +pub struct IndexOptions { + #[validate] + pub vector: VectorOptions, + #[validate] + pub segment: SegmentsOptions, + #[validate] + pub optimizing: OptimizingOptions, + #[validate] + pub indexing: IndexingOptions, +} + +impl IndexOptions { + fn validate_index_options(options: &IndexOptions) -> Result<(), ValidationError> { + if options.vector.v != VectorKind::SVecf32 { + return Ok(()); + } + let is_trivial = match &options.indexing { + IndexingOptions::Flat(x) => matches!(x.quantization, QuantizationOptions::Trivial(_)), + IndexingOptions::Ivf(x) => matches!(x.quantization, QuantizationOptions::Trivial(_)), + IndexingOptions::Hnsw(x) => matches!(x.quantization, QuantizationOptions::Trivial(_)), + }; + if !is_trivial { + return Err(ValidationError::new( + "Quantization is not supported for svector.", + )); + } + Ok(()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct VectorOptions { + #[validate(range(min = 1, max = 65535))] + #[serde(rename = "dimensions")] + pub dims: u16, + #[serde(rename = "distance")] + pub d: DistanceKind, + #[serde(rename = "vector")] + pub v: VectorKind, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +#[validate(schema(function = "Self::validate_0"))] +pub struct SegmentsOptions { + #[serde(default = "SegmentsOptions::default_max_growing_segment_size")] + #[validate(range(min = 1, max = 4_000_000_000))] + pub max_growing_segment_size: u32, + #[serde(default = "SegmentsOptions::default_max_sealed_segment_size")] + #[validate(range(min = 1, max = 4_000_000_000))] + pub max_sealed_segment_size: u32, +} + +impl SegmentsOptions { + fn default_max_growing_segment_size() -> u32 { + 20_000 + } + fn default_max_sealed_segment_size() -> u32 { + 1_000_000 + } + // max_growing_segment_size <= max_sealed_segment_size + fn validate_0(&self) -> Result<(), ValidationError> { + if self.max_growing_segment_size > self.max_sealed_segment_size { + return Err(ValidationError::new( + "`max_growing_segment_size` must be less than or equal to `max_sealed_segment_size`", + )); + } + Ok(()) + } +} + +impl Default for SegmentsOptions { + fn default() -> Self { + Self { + max_growing_segment_size: Self::default_max_growing_segment_size(), + max_sealed_segment_size: Self::default_max_sealed_segment_size(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct OptimizingOptions { + #[serde(default = "OptimizingOptions::default_sealing_secs")] + #[validate(range(min = 1, max = 60))] + pub sealing_secs: u64, + #[serde(default = "OptimizingOptions::default_sealing_size")] + #[validate(range(min = 1, max = 4_000_000_000))] + pub sealing_size: u32, + #[serde(default = "OptimizingOptions::default_delete_threshold")] + #[validate(range(min = 0.01, max = 1.00))] + pub delete_threshold: f64, + #[serde(default = "OptimizingOptions::default_optimizing_threads")] + #[validate(range(min = 1, max = 65535))] + pub optimizing_threads: usize, +} + +impl OptimizingOptions { + fn default_sealing_secs() -> u64 { + 60 + } + fn default_sealing_size() -> u32 { + 1 + } + fn default_delete_threshold() -> f64 { + 0.2 + } + fn default_optimizing_threads() -> usize { + match std::thread::available_parallelism() { + Ok(threads) => (threads.get() as f64).sqrt() as _, + Err(_) => 1, + } + } +} + +impl Default for OptimizingOptions { + fn default() -> Self { + Self { + sealing_secs: Self::default_sealing_secs(), + sealing_size: Self::default_sealing_size(), + delete_threshold: Self::default_delete_threshold(), + optimizing_threads: Self::default_optimizing_threads(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +#[serde(rename_all = "snake_case")] +pub enum IndexingOptions { + Flat(FlatIndexingOptions), + Ivf(IvfIndexingOptions), + Hnsw(HnswIndexingOptions), +} + +impl IndexingOptions { + pub fn unwrap_flat(self) -> FlatIndexingOptions { + let IndexingOptions::Flat(x) = self else { + unreachable!() + }; + x + } + pub fn unwrap_ivf(self) -> IvfIndexingOptions { + let IndexingOptions::Ivf(x) = self else { + unreachable!() + }; + x + } + pub fn unwrap_hnsw(self) -> HnswIndexingOptions { + let IndexingOptions::Hnsw(x) = self else { + unreachable!() + }; + x + } +} + +impl Default for IndexingOptions { + fn default() -> Self { + Self::Hnsw(Default::default()) + } +} + +impl Validate for IndexingOptions { + fn validate(&self) -> Result<(), validator::ValidationErrors> { + match self { + Self::Flat(x) => x.validate(), + Self::Ivf(x) => x.validate(), + Self::Hnsw(x) => x.validate(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct FlatIndexingOptions { + #[serde(default)] + #[validate] + pub quantization: QuantizationOptions, +} + +impl Default for FlatIndexingOptions { + fn default() -> Self { + Self { + quantization: QuantizationOptions::default(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct IvfIndexingOptions { + #[serde(default = "IvfIndexingOptions::default_least_iterations")] + #[validate(range(min = 1, max = 1_000_000))] + pub least_iterations: u32, + #[serde(default = "IvfIndexingOptions::default_iterations")] + #[validate(range(min = 1, max = 1_000_000))] + pub iterations: u32, + #[serde(default = "IvfIndexingOptions::default_nlist")] + #[validate(range(min = 1, max = 1_000_000))] + pub nlist: u32, + #[serde(default = "IvfIndexingOptions::default_nsample")] + #[validate(range(min = 1, max = 1_000_000))] + pub nsample: u32, + #[serde(default)] + #[validate] + pub quantization: QuantizationOptions, +} + +impl IvfIndexingOptions { + fn default_least_iterations() -> u32 { + 16 + } + fn default_iterations() -> u32 { + 500 + } + fn default_nlist() -> u32 { + 1000 + } + fn default_nsample() -> u32 { + 65536 + } +} + +impl Default for IvfIndexingOptions { + fn default() -> Self { + Self { + least_iterations: Self::default_least_iterations(), + iterations: Self::default_iterations(), + nlist: Self::default_nlist(), + nsample: Self::default_nsample(), + quantization: Default::default(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct HnswIndexingOptions { + #[serde(default = "HnswIndexingOptions::default_m")] + #[validate(range(min = 4, max = 128))] + pub m: u32, + #[serde(default = "HnswIndexingOptions::default_ef_construction")] + #[validate(range(min = 10, max = 2000))] + pub ef_construction: usize, + #[serde(default)] + #[validate] + pub quantization: QuantizationOptions, +} + +impl HnswIndexingOptions { + fn default_m() -> u32 { + 12 + } + fn default_ef_construction() -> usize { + 300 + } +} + +impl Default for HnswIndexingOptions { + fn default() -> Self { + Self { + m: Self::default_m(), + ef_construction: Self::default_ef_construction(), + quantization: Default::default(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +#[serde(rename_all = "snake_case")] +pub enum QuantizationOptions { + Trivial(TrivialQuantizationOptions), + Scalar(ScalarQuantizationOptions), + Product(ProductQuantizationOptions), +} + +impl Validate for QuantizationOptions { + fn validate(&self) -> Result<(), validator::ValidationErrors> { + match self { + Self::Trivial(x) => x.validate(), + Self::Scalar(x) => x.validate(), + Self::Product(x) => x.validate(), + } + } +} + +impl Default for QuantizationOptions { + fn default() -> Self { + Self::Trivial(Default::default()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct TrivialQuantizationOptions {} + +impl Default for TrivialQuantizationOptions { + fn default() -> Self { + Self {} + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct ScalarQuantizationOptions {} + +impl Default for ScalarQuantizationOptions { + fn default() -> Self { + Self {} + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +#[serde(deny_unknown_fields)] +pub struct ProductQuantizationOptions { + #[serde(default = "ProductQuantizationOptions::default_sample")] + #[validate(range(min = 1, max = 1_000_000))] + pub sample: u32, + #[serde(default)] + pub ratio: ProductQuantizationOptionsRatio, +} + +impl ProductQuantizationOptions { + fn default_sample() -> u32 { + 65535 + } +} + +impl Default for ProductQuantizationOptions { + fn default() -> Self { + Self { + sample: Self::default_sample(), + ratio: Default::default(), + } + } +} + +#[repr(u16)] +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +#[serde(rename_all = "snake_case")] +pub enum ProductQuantizationOptionsRatio { + X4 = 1, + X8 = 2, + X16 = 4, + X32 = 8, + X64 = 16, +} + +impl Default for ProductQuantizationOptionsRatio { + fn default() -> Self { + Self::X4 + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, Validate)] +pub struct SearchOptions { + pub prefilter_enable: bool, + #[validate(range(min = 1, max = 65535))] + pub hnsw_ef_search: usize, + #[validate(range(min = 1, max = 1_000_000))] + pub ivf_nprobe: u32, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct IndexStat { + pub indexing: bool, + pub segments: Vec, + pub options: IndexOptions, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct SegmentStat { + pub id: Uuid, + #[serde(rename = "type")] + pub typ: String, + pub length: usize, + pub size: u64, +} diff --git a/crates/base/src/lib.rs b/crates/base/src/lib.rs index 653153e13..c1831e696 100644 --- a/crates/base/src/lib.rs +++ b/crates/base/src/lib.rs @@ -1,7 +1,11 @@ #![feature(core_intrinsics)] +#![feature(avx512_target_feature)] +pub mod distance; pub mod error; +pub mod global; +pub mod index; pub mod scalar; pub mod search; -pub mod sys; pub mod vector; +pub mod worker; diff --git a/crates/base/src/scalar/f16.rs b/crates/base/src/scalar/f16.rs index da5735b37..005211c10 100644 --- a/crates/base/src/scalar/f16.rs +++ b/crates/base/src/scalar/f16.rs @@ -1,4 +1,5 @@ -use super::FloatCast; +use super::ScalarLike; +use crate::scalar::F32; use half::f16; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; @@ -531,16 +532,6 @@ impl FromStr for F16 { } } -impl FloatCast for F16 { - fn from_f32(x: f32) -> Self { - Self(f16::from_f32(x)) - } - - fn to_f32(self) -> f32 { - f16::to_f32(self.0) - } -} - impl From for F16 { fn from(value: f16) -> Self { Self(value) @@ -651,3 +642,19 @@ mod intrinsics { lhs % rhs } } + +impl ScalarLike for F16 { + fn from_f32(x: f32) -> Self { + Self(f16::from_f32(x)) + } + + fn to_f32(self) -> f32 { + f16::to_f32(self.0) + } + fn from_f(x: F32) -> Self { + Self::from_f32(x.0) + } + fn to_f(self) -> F32 { + F32(Self::to_f32(self)) + } +} diff --git a/crates/base/src/scalar/f32.rs b/crates/base/src/scalar/f32.rs index c6e431bcb..84f624e3c 100644 --- a/crates/base/src/scalar/f32.rs +++ b/crates/base/src/scalar/f32.rs @@ -1,4 +1,4 @@ -use super::FloatCast; +use super::ScalarLike; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::fmt::{Debug, Display}; @@ -530,16 +530,6 @@ impl FromStr for F32 { } } -impl FloatCast for F32 { - fn from_f32(x: f32) -> Self { - Self(x) - } - - fn to_f32(self) -> f32 { - self.0 - } -} - impl From for F32 { fn from(value: f32) -> Self { Self(value) @@ -630,3 +620,21 @@ impl RemAssign for F32 { unsafe { self.0 = std::intrinsics::frem_fast(self.0, rhs) } } } + +impl ScalarLike for F32 { + fn from_f32(x: f32) -> Self { + Self(x) + } + + fn to_f32(self) -> f32 { + self.0 + } + + fn from_f(x: F32) -> Self { + Self::from_f32(x.0) + } + + fn to_f(self) -> F32 { + F32(Self::to_f32(self)) + } +} diff --git a/crates/base/src/scalar/mod.rs b/crates/base/src/scalar/mod.rs index 8e30d33c7..1d09f11c1 100644 --- a/crates/base/src/scalar/mod.rs +++ b/crates/base/src/scalar/mod.rs @@ -4,13 +4,24 @@ mod f32; pub use f16::F16; pub use f32::F32; -pub trait FloatCast: Sized { +pub trait ScalarLike: + Copy + + Send + + Sync + + std::fmt::Debug + + std::fmt::Display + + serde::Serialize + + for<'a> serde::Deserialize<'a> + + Ord + + bytemuck::Zeroable + + bytemuck::Pod + + num_traits::Float + + num_traits::Zero + + num_traits::NumOps + + num_traits::NumAssignOps +{ fn from_f32(x: f32) -> Self; fn to_f32(self) -> f32; - fn from_f(x: F32) -> Self { - Self::from_f32(x.0) - } - fn to_f(self) -> F32 { - F32(Self::to_f32(self)) - } + fn from_f(x: F32) -> Self; + fn to_f(self) -> F32; } diff --git a/crates/base/src/search.rs b/crates/base/src/search.rs index c5e946bd5..04c79f580 100644 --- a/crates/base/src/search.rs +++ b/crates/base/src/search.rs @@ -1,4 +1,6 @@ use crate::scalar::F32; +use serde::{Deserialize, Serialize}; +use std::{fmt::Display, num::ParseIntError, str::FromStr}; pub type Payload = u64; @@ -11,3 +13,45 @@ pub struct Element { pub distance: F32, pub payload: Payload, } + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub struct Handle { + pub newtype: u32, +} + +impl Handle { + pub fn as_u32(self) -> u32 { + self.newtype + } +} + +impl Display for Handle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.as_u32()) + } +} + +impl FromStr for Handle { + type Err = ParseIntError; + + fn from_str(s: &str) -> Result { + Ok(Handle { + newtype: u32::from_str(s)?, + }) + } +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub struct Pointer { + pub newtype: u64, +} + +impl Pointer { + pub fn from_u48(value: u64) -> Self { + assert!(value < (1u64 << 48)); + Self { newtype: value } + } + pub fn as_u48(self) -> u64 { + self.newtype + } +} diff --git a/crates/base/src/sys.rs b/crates/base/src/sys.rs deleted file mode 100644 index 5229bfd55..000000000 --- a/crates/base/src/sys.rs +++ /dev/null @@ -1,44 +0,0 @@ -use serde::{Deserialize, Serialize}; -use std::{fmt::Display, num::ParseIntError, str::FromStr}; - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub struct Handle { - pub newtype: u32, -} - -impl Handle { - pub fn as_u32(self) -> u32 { - self.newtype - } -} - -impl Display for Handle { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.as_u32()) - } -} - -impl FromStr for Handle { - type Err = ParseIntError; - - fn from_str(s: &str) -> Result { - Ok(Handle { - newtype: u32::from_str(s)?, - }) - } -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] -pub struct Pointer { - pub newtype: u64, -} - -impl Pointer { - pub fn from_u48(value: u64) -> Self { - assert!(value < (1u64 << 48)); - Self { newtype: value } - } - pub fn as_u48(self) -> u64 { - self.newtype - } -} diff --git a/crates/base/src/vector/mod.rs b/crates/base/src/vector/mod.rs index 8f6117772..20925d3e7 100644 --- a/crates/base/src/vector/mod.rs +++ b/crates/base/src/vector/mod.rs @@ -1,19 +1,54 @@ -mod sparse_f32; +mod svecf32; +mod vecf16; +mod vecf32; -pub use sparse_f32::{SparseF32, SparseF32Ref}; +pub use svecf32::{SVecf32Borrowed, SVecf32Owned}; +pub use vecf16::{Vecf16Borrowed, Vecf16Owned}; +pub use vecf32::{Vecf32Borrowed, Vecf32Owned}; + +use crate::scalar::ScalarLike; +use serde::{Deserialize, Serialize}; + +#[repr(u8)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] +pub enum VectorKind { + Vecf32, + Vecf16, + SVecf32, +} + +pub trait VectorOwned: Clone + Serialize + for<'a> Deserialize<'a> + 'static { + type Scalar: ScalarLike; + type Borrowed<'a>: VectorBorrowed; + + fn for_borrow(&self) -> Self::Borrowed<'_>; + + fn dims(&self) -> u16; + + fn to_vec(&self) -> Vec; +} + +pub trait VectorBorrowed: Copy { + type Scalar: ScalarLike; + type Owned: VectorOwned; + + fn for_own(&self) -> Self::Owned; -pub trait Vector { fn dims(&self) -> u16; + + fn to_vec(&self) -> Vec; } -impl Vector for Vec { - fn dims(&self) -> u16 { - self.len().try_into().unwrap() - } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum OwnedVector { + Vecf32(Vecf32Owned), + Vecf16(Vecf16Owned), + SVecF32(SVecf32Owned), } -impl<'a, T> Vector for &'a [T] { - fn dims(&self) -> u16 { - self.len().try_into().unwrap() - } +#[derive(Debug, Clone)] +pub enum BorrowedVector<'a> { + Vecf32(Vecf32Borrowed<'a>), + Vecf16(Vecf16Borrowed<'a>), + SVecF32(SVecf32Borrowed<'a>), } diff --git a/crates/base/src/vector/sparse_f32.rs b/crates/base/src/vector/sparse_f32.rs deleted file mode 100644 index d52903205..000000000 --- a/crates/base/src/vector/sparse_f32.rs +++ /dev/null @@ -1,64 +0,0 @@ -use super::Vector; -use crate::scalar::F32; -use num_traits::Zero; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct SparseF32 { - pub dims: u16, - pub indexes: Vec, - pub values: Vec, -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct SparseF32Ref<'a> { - pub dims: u16, - pub indexes: &'a [u16], - pub values: &'a [F32], -} - -impl<'a> From> for SparseF32 { - fn from(value: SparseF32Ref<'a>) -> Self { - Self { - dims: value.dims, - indexes: value.indexes.to_vec(), - values: value.values.to_vec(), - } - } -} - -impl<'a> From<&'a SparseF32> for SparseF32Ref<'a> { - fn from(value: &'a SparseF32) -> Self { - Self { - dims: value.dims, - indexes: &value.indexes, - values: &value.values, - } - } -} - -impl Vector for SparseF32 { - fn dims(&self) -> u16 { - self.dims - } -} - -impl<'a> Vector for SparseF32Ref<'a> { - fn dims(&self) -> u16 { - self.dims - } -} - -impl<'a> SparseF32Ref<'a> { - pub fn to_dense(&self) -> Vec { - let mut dense = vec![F32::zero(); self.dims as usize]; - for (&index, &value) in self.indexes.iter().zip(self.values.iter()) { - dense[index as usize] = value; - } - dense - } - - pub fn length(&self) -> u16 { - self.indexes.len().try_into().unwrap() - } -} diff --git a/crates/base/src/vector/svecf32.rs b/crates/base/src/vector/svecf32.rs new file mode 100644 index 000000000..39f72e631 --- /dev/null +++ b/crates/base/src/vector/svecf32.rs @@ -0,0 +1,183 @@ +use super::{VectorBorrowed, VectorOwned}; +use crate::scalar::F32; +use num_traits::Zero; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SVecf32Owned { + dims: u16, + indexes: Vec, + values: Vec, +} + +impl SVecf32Owned { + #[inline(always)] + pub fn new(dims: u16, indexes: Vec, values: Vec) -> Self { + Self::new_checked(dims, indexes, values).unwrap() + } + #[inline(always)] + pub fn new_checked(dims: u16, indexes: Vec, values: Vec) -> Option { + if dims == 0 { + return None; + } + if indexes.len() != values.len() { + return None; + } + let len = indexes.len(); + for i in 1..len { + if !(indexes[i - 1] < indexes[i]) { + return None; + } + } + if len != 0 && !(indexes[len - 1] < dims) { + return None; + } + for i in 0..len { + if values[i].is_zero() { + return None; + } + } + unsafe { Some(Self::new_unchecked(dims, indexes, values)) } + } + /// # Safety + /// + /// * `dims` must be in `1..=65535`. + /// * `indexes.len()` must be equal to `values.len()`. + /// * `indexes` must be a strictly increasing sequence and the last in the sequence must be less than `dims`. + /// * A floating number in `values` must not be positive zero or negative zero. + #[inline(always)] + pub unsafe fn new_unchecked(dims: u16, indexes: Vec, values: Vec) -> Self { + Self { + dims, + indexes, + values, + } + } + #[inline(always)] + pub fn indexes(&self) -> &[u16] { + &self.indexes + } + #[inline(always)] + pub fn values(&self) -> &[F32] { + &self.values + } +} + +impl VectorOwned for SVecf32Owned { + type Scalar = F32; + type Borrowed<'a> = SVecf32Borrowed<'a>; + + #[inline(always)] + fn dims(&self) -> u16 { + self.dims + } + + fn for_borrow(&self) -> SVecf32Borrowed<'_> { + SVecf32Borrowed { + dims: self.dims, + indexes: &self.indexes, + values: &self.values, + } + } + + fn to_vec(&self) -> Vec { + let mut dense = vec![F32::zero(); self.dims as usize]; + for (&index, &value) in self.indexes.iter().zip(self.values.iter()) { + dense[index as usize] = value; + } + dense + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SVecf32Borrowed<'a> { + dims: u16, + indexes: &'a [u16], + values: &'a [F32], +} + +impl<'a> SVecf32Borrowed<'a> { + #[inline(always)] + pub fn new(dims: u16, indexes: &'a [u16], values: &'a [F32]) -> Self { + Self::new_checked(dims, indexes, values).unwrap() + } + #[inline(always)] + pub fn new_checked(dims: u16, indexes: &'a [u16], values: &'a [F32]) -> Option { + if dims == 0 { + return None; + } + if indexes.len() != values.len() { + return None; + } + let len = indexes.len(); + for i in 1..len { + if !(indexes[i - 1] < indexes[i]) { + return None; + } + } + if len != 0 && !(indexes[len - 1] < dims) { + return None; + } + for i in 0..len { + if values[i].is_zero() { + return None; + } + } + unsafe { Some(Self::new_unchecked(dims, indexes, values)) } + } + /// # Safety + /// + /// * `dims` must be in `1..=65535`. + /// * `indexes.len()` must be equal to `values.len()`. + /// * `indexes` must be a strictly increasing sequence and the last in the sequence must be less than `dims`. + /// * A floating number in `values` must not be positive zero or negative zero. + #[inline(always)] + pub unsafe fn new_unchecked(dims: u16, indexes: &'a [u16], values: &'a [F32]) -> Self { + Self { + dims, + indexes, + values, + } + } + #[inline(always)] + pub fn indexes(&self) -> &[u16] { + self.indexes + } + #[inline(always)] + pub fn values(&self) -> &[F32] { + self.values + } +} + +impl<'a> VectorBorrowed for SVecf32Borrowed<'a> { + type Scalar = F32; + type Owned = SVecf32Owned; + + #[inline(always)] + fn dims(&self) -> u16 { + self.dims + } + + fn for_own(&self) -> SVecf32Owned { + SVecf32Owned { + dims: self.dims, + indexes: self.indexes.to_vec(), + values: self.values.to_vec(), + } + } + + fn to_vec(&self) -> Vec { + let mut dense = vec![F32::zero(); self.dims as usize]; + for (&index, &value) in self.indexes.iter().zip(self.values.iter()) { + dense[index as usize] = value; + } + dense + } +} + +impl<'a> SVecf32Borrowed<'a> { + #[inline(always)] + pub fn len(&self) -> u16 { + self.indexes.len().try_into().unwrap() + } +} diff --git a/crates/base/src/vector/vecf16.rs b/crates/base/src/vector/vecf16.rs new file mode 100644 index 000000000..9a52268c6 --- /dev/null +++ b/crates/base/src/vector/vecf16.rs @@ -0,0 +1,99 @@ +use super::{VectorBorrowed, VectorOwned}; +use crate::scalar::F16; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[repr(transparent)] +pub struct Vecf16Owned(Vec); + +impl Vecf16Owned { + #[inline(always)] + pub fn new(slice: Vec) -> Self { + Self::new_checked(slice).unwrap() + } + #[inline(always)] + pub fn new_checked(slice: Vec) -> Option { + if !(1 <= slice.len() && slice.len() <= 65535) { + return None; + } + Some(unsafe { Self::new_unchecked(slice) }) + } + /// # Safety + /// + /// * `slice.len()` must not be zero. + #[inline(always)] + pub unsafe fn new_unchecked(slice: Vec) -> Self { + Self(slice) + } + #[inline(always)] + pub fn slice(&self) -> &[F16] { + self.0.as_slice() + } + #[inline(always)] + pub fn slice_mut(&mut self) -> &mut [F16] { + self.0.as_mut_slice() + } +} + +impl VectorOwned for Vecf16Owned { + type Scalar = F16; + type Borrowed<'a> = Vecf16Borrowed<'a>; + + fn dims(&self) -> u16 { + self.0.len() as u16 + } + + fn for_borrow(&self) -> Vecf16Borrowed<'_> { + Vecf16Borrowed(self.0.as_slice()) + } + + fn to_vec(&self) -> Vec { + self.0.clone() + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(transparent)] +pub struct Vecf16Borrowed<'a>(&'a [F16]); + +impl<'a> Vecf16Borrowed<'a> { + #[inline(always)] + pub fn new(slice: &'a [F16]) -> Self { + Self::new_checked(slice).unwrap() + } + #[inline(always)] + pub fn new_checked(slice: &'a [F16]) -> Option { + if !(1 <= slice.len() && slice.len() <= 65535) { + return None; + } + Some(unsafe { Self::new_unchecked(slice) }) + } + /// # Safety + /// + /// * `slice.len()` must not be zero. + #[inline(always)] + pub unsafe fn new_unchecked(slice: &'a [F16]) -> Self { + Self(slice) + } + #[inline(always)] + pub fn slice(&self) -> &[F16] { + self.0 + } +} + +impl<'a> VectorBorrowed for Vecf16Borrowed<'a> { + type Scalar = F16; + type Owned = Vecf16Owned; + + fn dims(&self) -> u16 { + self.0.len() as u16 + } + + fn for_own(&self) -> Vecf16Owned { + Vecf16Owned(self.0.to_vec()) + } + + fn to_vec(&self) -> Vec { + self.0.to_vec() + } +} diff --git a/crates/base/src/vector/vecf32.rs b/crates/base/src/vector/vecf32.rs new file mode 100644 index 000000000..58c546c30 --- /dev/null +++ b/crates/base/src/vector/vecf32.rs @@ -0,0 +1,99 @@ +use super::{VectorBorrowed, VectorOwned}; +use crate::scalar::F32; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[repr(transparent)] +pub struct Vecf32Owned(Vec); + +impl Vecf32Owned { + #[inline(always)] + pub fn new(slice: Vec) -> Self { + Self::new_checked(slice).unwrap() + } + #[inline(always)] + pub fn new_checked(slice: Vec) -> Option { + if !(1 <= slice.len() && slice.len() <= 65535) { + return None; + } + Some(unsafe { Self::new_unchecked(slice) }) + } + /// # Safety + /// + /// * `slice.len()` must not be zero. + #[inline(always)] + pub unsafe fn new_unchecked(slice: Vec) -> Self { + Self(slice) + } + #[inline(always)] + pub fn slice(&self) -> &[F32] { + self.0.as_slice() + } + #[inline(always)] + pub fn slice_mut(&mut self) -> &mut [F32] { + self.0.as_mut_slice() + } +} + +impl VectorOwned for Vecf32Owned { + type Scalar = F32; + type Borrowed<'a> = Vecf32Borrowed<'a>; + + fn dims(&self) -> u16 { + self.0.len() as u16 + } + + fn for_borrow(&self) -> Vecf32Borrowed<'_> { + Vecf32Borrowed(self.0.as_slice()) + } + + fn to_vec(&self) -> Vec { + self.0.clone() + } +} + +#[derive(Debug, Clone, Copy)] +#[repr(transparent)] +pub struct Vecf32Borrowed<'a>(&'a [F32]); + +impl<'a> Vecf32Borrowed<'a> { + #[inline(always)] + pub fn new(slice: &'a [F32]) -> Self { + Self::new_checked(slice).unwrap() + } + #[inline(always)] + pub fn new_checked(slice: &'a [F32]) -> Option { + if !(1 <= slice.len() && slice.len() <= 65535) { + return None; + } + Some(unsafe { Self::new_unchecked(slice) }) + } + /// # Safety + /// + /// * `slice.len()` must not be zero. + #[inline(always)] + pub unsafe fn new_unchecked(slice: &'a [F32]) -> Self { + Self(slice) + } + #[inline(always)] + pub fn slice(&self) -> &[F32] { + self.0 + } +} + +impl<'a> VectorBorrowed for Vecf32Borrowed<'a> { + type Scalar = F32; + type Owned = Vecf32Owned; + + fn dims(&self) -> u16 { + self.0.len() as u16 + } + + fn for_own(&self) -> Vecf32Owned { + Vecf32Owned(self.0.to_vec()) + } + + fn to_vec(&self) -> Vec { + self.0.to_vec() + } +} diff --git a/crates/base/src/worker.rs b/crates/base/src/worker.rs new file mode 100644 index 000000000..124df5804 --- /dev/null +++ b/crates/base/src/worker.rs @@ -0,0 +1,43 @@ +use crate::error::*; +use crate::index::*; +use crate::search::*; +use crate::vector::*; + +pub trait WorkerOperations { + fn create(&self, handle: Handle, options: IndexOptions) -> Result<(), CreateError>; + fn drop(&self, handle: Handle) -> Result<(), DropError>; + fn flush(&self, handle: Handle) -> Result<(), FlushError>; + fn insert( + &self, + handle: Handle, + vector: OwnedVector, + pointer: Pointer, + ) -> Result<(), InsertError>; + fn delete(&self, handle: Handle, pointer: Pointer) -> Result<(), DeleteError>; + fn view_basic(&self, handle: Handle) -> Result; + fn view_vbase(&self, handle: Handle) -> Result; + fn view_list(&self, handle: Handle) -> Result; + fn stat(&self, handle: Handle) -> Result; +} + +pub trait ViewBasicOperations { + fn basic<'a, F: Fn(Pointer) -> bool + Clone + 'a>( + &'a self, + vector: &'a OwnedVector, + opts: &'a SearchOptions, + filter: F, + ) -> Result + 'a>, BasicError>; +} + +pub trait ViewVbaseOperations { + fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>( + &'a self, + vector: &'a OwnedVector, + opts: &'a SearchOptions, + filter: F, + ) -> Result + 'a>, VbaseError>; +} + +pub trait ViewListOperations { + fn list(&self) -> Result + '_>, ListError>; +} diff --git a/crates/c/Cargo.toml b/crates/c/Cargo.toml index 1c8631903..fcef459aa 100644 --- a/crates/c/Cargo.toml +++ b/crates/c/Cargo.toml @@ -10,3 +10,6 @@ detect = { path = "../detect" } [build-dependencies] cc = "1.0" + +[lints] +workspace = true diff --git a/crates/c/tests/x86_64.rs b/crates/c/tests/x86_64.rs index 99e2a7fb5..3363dec58 100644 --- a/crates/c/tests/x86_64.rs +++ b/crates/c/tests/x86_64.rs @@ -10,8 +10,8 @@ fn test_v_f16_cosine() { let mut xx = 0.0f32; let mut yy = 0.0f32; for i in 0..n { - let x = a.add(i).cast::().read().to_f32(); - let y = b.add(i).cast::().read().to_f32(); + let x = unsafe { a.add(i).cast::().read() }.to_f32(); + let y = unsafe { b.add(i).cast::().read() }.to_f32(); xy += x * y; xx += x * x; yy += y * y; @@ -53,8 +53,8 @@ fn test_v_f16_dot() { unsafe fn v_f16_dot(a: *const u16, b: *const u16, n: usize) -> f32 { let mut xy = 0.0f32; for i in 0..n { - let x = a.add(i).cast::().read().to_f32(); - let y = b.add(i).cast::().read().to_f32(); + let x = unsafe { a.add(i).cast::().read() }.to_f32(); + let y = unsafe { b.add(i).cast::().read() }.to_f32(); xy += x * y; } xy @@ -94,8 +94,8 @@ fn test_v_f16_sl2() { unsafe fn v_f16_sl2(a: *const u16, b: *const u16, n: usize) -> f32 { let mut dd = 0.0f32; for i in 0..n { - let x = a.add(i).cast::().read().to_f32(); - let y = b.add(i).cast::().read().to_f32(); + let x = unsafe { a.add(i).cast::().read() }.to_f32(); + let y = unsafe { b.add(i).cast::().read() }.to_f32(); let d = x - y; dd += d * d; } diff --git a/crates/detect/Cargo.toml b/crates/detect/Cargo.toml index 1bc7a99f0..e2cd91c15 100644 --- a/crates/detect/Cargo.toml +++ b/crates/detect/Cargo.toml @@ -6,3 +6,6 @@ edition.workspace = true [dependencies] rustix.workspace = true std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" } + +[lints] +workspace = true diff --git a/crates/interprocess-atomic-wait/Cargo.toml b/crates/interprocess-atomic-wait/Cargo.toml index 2f36d9edb..f478614b1 100644 --- a/crates/interprocess-atomic-wait/Cargo.toml +++ b/crates/interprocess-atomic-wait/Cargo.toml @@ -10,7 +10,4 @@ libc.workspace = true ulock-sys = "0.1.0" [lints] -rust.internal_features = "allow" -rust.unsafe_op_in_unsafe_fn = "forbid" -rust.unused_lifetimes = "warn" -rust.unused_qualifications = "warn" +workspace = true diff --git a/crates/interprocess-atomic-wait/src/lib.rs b/crates/interprocess-atomic-wait/src/lib.rs index 324bd8a70..cffae866e 100644 --- a/crates/interprocess-atomic-wait/src/lib.rs +++ b/crates/interprocess-atomic-wait/src/lib.rs @@ -89,3 +89,6 @@ pub fn wake(futex: &AtomicU32) { ); }; } + +#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "freebsd")))] +compile_error!("Target is not supported."); diff --git a/crates/memfd/Cargo.toml b/crates/memfd/Cargo.toml index 2bf02ee22..2975803ae 100644 --- a/crates/memfd/Cargo.toml +++ b/crates/memfd/Cargo.toml @@ -9,7 +9,4 @@ rustix.workspace = true detect = { path = "../detect" } [lints] -rust.internal_features = "allow" -rust.unsafe_op_in_unsafe_fn = "forbid" -rust.unused_lifetimes = "warn" -rust.unused_qualifications = "warn" +workspace = true diff --git a/crates/send_fd/Cargo.toml b/crates/send_fd/Cargo.toml index fc50260e0..375ed2fe0 100644 --- a/crates/send_fd/Cargo.toml +++ b/crates/send_fd/Cargo.toml @@ -9,7 +9,4 @@ log.workspace = true rustix.workspace = true [lints] -rust.internal_features = "allow" -rust.unsafe_op_in_unsafe_fn = "forbid" -rust.unused_lifetimes = "warn" -rust.unused_qualifications = "warn" +workspace = true diff --git a/crates/service/Cargo.toml b/crates/service/Cargo.toml index 8cd06be37..4bd885cb4 100644 --- a/crates/service/Cargo.toml +++ b/crates/service/Cargo.toml @@ -7,8 +7,6 @@ edition.workspace = true bincode.workspace = true bytemuck.workspace = true byteorder.workspace = true -half.workspace = true -libc.workspace = true log.workspace = true memmap2.workspace = true num-traits.workspace = true @@ -20,22 +18,12 @@ thiserror.workspace = true uuid.workspace = true validator.workspace = true base = { path = "../base" } -c = { path = "../c" } -detect = { path = "../detect" } crc32fast = "1.4.0" crossbeam = "0.8.4" dashmap = "5.5.3" parking_lot = "0.12.1" rayon = "1.8.1" arc-swap = "1.6.0" -multiversion = "0.7.3" [lints] -clippy.derivable_impls = "allow" -clippy.len_without_is_empty = "allow" -clippy.needless_range_loop = "allow" -clippy.too_many_arguments = "allow" -rust.internal_features = "allow" -rust.unsafe_op_in_unsafe_fn = "forbid" -rust.unused_lifetimes = "warn" -rust.unused_qualifications = "warn" +workspace = true diff --git a/crates/service/src/algorithms/clustering/elkan_k_means.rs b/crates/service/src/algorithms/clustering/elkan_k_means.rs index 9dc7fd97b..3cb420fe5 100644 --- a/crates/service/src/algorithms/clustering/elkan_k_means.rs +++ b/crates/service/src/algorithms/clustering/elkan_k_means.rs @@ -1,27 +1,26 @@ use crate::prelude::*; use crate::utils::vec2::Vec2; -use base::scalar::FloatCast; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; use rayon::slice::ParallelSliceMut; use std::ops::{Index, IndexMut}; -pub struct ElkanKMeans { +pub struct ElkanKMeans { dims: u16, c: usize, - pub centroids: Vec2, + pub centroids: Vec2>, lowerbound: Square, upperbound: Vec, assign: Vec, rand: StdRng, - samples: Vec2, + samples: Vec2>, } const DELTA: f32 = 1.0 / 1024.0; -impl ElkanKMeans { - pub fn new(c: usize, samples: Vec2) -> Self { +impl ElkanKMeans { + pub fn new(c: usize, samples: Vec2>) -> Self { let n = samples.len(); let dims = samples.dims(); @@ -105,14 +104,14 @@ impl ElkanKMeans { centroids[i].copy_from_slice(&samples[*index]); } else { let rand_centroids: Vec<_> = (0..dims) - .map(|_| S::Scalar::from_f32(rand.gen_range(0.0..1.0f32))) + .map(|_| Scalar::::from_f32(rand.gen_range(0.0..1.0f32))) .collect(); centroids[i].copy_from_slice(rand_centroids.as_slice()); } } for i in n..c { let rand_centroids: Vec<_> = (0..dims) - .map(|_| S::Scalar::from_f32(rand.gen_range(0.0..1.0f32))) + .map(|_| Scalar::::from_f32(rand.gen_range(0.0..1.0f32))) .collect(); centroids[i].copy_from_slice(rand_centroids.as_slice()); } @@ -204,7 +203,7 @@ impl ElkanKMeans { // Step 4, 7 let old = std::mem::replace(centroids, Vec2::new(dims, c)); let mut count = vec![F32::zero(); c]; - centroids.fill(S::Scalar::zero()); + centroids.fill(Scalar::::zero()); for i in 0..n { for j in 0..dims as usize { centroids[self.assign[i]][j] += samples[i][j]; @@ -216,7 +215,7 @@ impl ElkanKMeans { continue; } for dim in 0..dims as usize { - centroids[i][dim] /= S::Scalar::from_f32(count[i].into()); + centroids[i][dim] /= Scalar::::from_f32(count[i].into()); } } for i in 0..c { @@ -235,11 +234,11 @@ impl ElkanKMeans { centroids.copy_within(o, i); for dim in 0..dims as usize { if dim % 2 == 0 { - centroids[i][dim] *= S::Scalar::from_f32(1.0 + DELTA); - centroids[o][dim] *= S::Scalar::from_f32(1.0 - DELTA); + centroids[i][dim] *= Scalar::::from_f32(1.0 + DELTA); + centroids[o][dim] *= Scalar::::from_f32(1.0 - DELTA); } else { - centroids[i][dim] *= S::Scalar::from_f32(1.0 - DELTA); - centroids[o][dim] *= S::Scalar::from_f32(1.0 + DELTA); + centroids[i][dim] *= Scalar::::from_f32(1.0 - DELTA); + centroids[o][dim] *= Scalar::::from_f32(1.0 + DELTA); } } count[i] = count[o] / 2.0; @@ -267,7 +266,7 @@ impl ElkanKMeans { change == 0 } - pub fn finish(self) -> Vec2 { + pub fn finish(self) -> Vec2> { self.centroids } } diff --git a/crates/service/src/algorithms/flat.rs b/crates/service/src/algorithms/flat.rs index 7ee1f1a5b..fdd99d0cc 100644 --- a/crates/service/src/algorithms/flat.rs +++ b/crates/service/src/algorithms/flat.rs @@ -2,7 +2,6 @@ use super::quantization::Quantization; use super::raw::Raw; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; -use crate::index::{IndexOptions, SearchOptions}; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use std::cmp::Reverse; @@ -36,7 +35,7 @@ impl Flat { pub fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, _opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -45,7 +44,7 @@ impl Flat { pub fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, _opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -56,7 +55,7 @@ impl Flat { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { self.mmap.raw.vector(i) } @@ -125,7 +124,7 @@ pub fn open(path: &Path, options: IndexOptions) -> FlatMmap { pub fn basic( mmap: &FlatMmap, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, mut filter: impl Filter, ) -> BinaryHeap> { let mut result = BinaryHeap::new(); @@ -141,7 +140,7 @@ pub fn basic( pub fn vbase<'a, S: G>( mmap: &'a FlatMmap, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, mut filter: impl Filter + 'a, ) -> (Vec, Box + 'a>) { let mut result = Vec::new(); diff --git a/crates/service/src/algorithms/hnsw.rs b/crates/service/src/algorithms/hnsw.rs index cf2ced090..b3084bb94 100644 --- a/crates/service/src/algorithms/hnsw.rs +++ b/crates/service/src/algorithms/hnsw.rs @@ -1,10 +1,7 @@ use super::quantization::Quantization; use super::raw::Raw; -use crate::index::indexing::hnsw::HnswIndexingOptions; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; -use crate::index::IndexOptions; -use crate::index::SearchOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::element_heap::ElementHeap; @@ -44,7 +41,7 @@ impl Hnsw { pub fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -53,7 +50,7 @@ impl Hnsw { pub fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -64,7 +61,7 @@ impl Hnsw { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { self.mmap.raw.vector(i) } @@ -171,7 +168,7 @@ pub fn make( graph: &HnswRamGraph, levels: RangeInclusive, u: u32, - target: S::VectorRef<'_>, + target: Borrowed<'_, S>, ) -> u32 { let mut u = u; let mut u_dis = quantization.distance(target, u); @@ -196,7 +193,7 @@ pub fn make( quantization: &Quantization, graph: &HnswRamGraph, visited: &mut VisitedGuard, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, s: u32, k: usize, i: u8, @@ -395,7 +392,7 @@ pub fn open(path: &Path, options: IndexOptions) -> HnswMmap { pub fn basic( mmap: &HnswMmap, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, ef_search: usize, filter: impl Filter, ) -> BinaryHeap> { @@ -409,7 +406,7 @@ pub fn basic( pub fn vbase<'a, S: G>( mmap: &'a HnswMmap, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, range: usize, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -467,7 +464,7 @@ pub fn fast_search( mmap: &HnswMmap, levels: RangeInclusive, u: u32, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, mut filter: impl Filter, ) -> u32 { let mut u = u; @@ -497,7 +494,7 @@ pub fn local_search_basic( mmap: &HnswMmap, k: usize, s: u32, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, mut filter: impl Filter, ) -> ElementHeap { let mut visited = mmap.visited.fetch(); @@ -541,7 +538,7 @@ pub fn local_search_basic( pub fn local_search_vbase<'a, S: G>( mmap: &'a HnswMmap, s: u32, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, mut filter: impl Filter + 'a, ) -> impl Iterator + 'a { let mut visited = mmap.visited.fetch2(); diff --git a/crates/service/src/algorithms/ivf/ivf_naive.rs b/crates/service/src/algorithms/ivf/ivf_naive.rs index 278570236..acdc6eea6 100644 --- a/crates/service/src/algorithms/ivf/ivf_naive.rs +++ b/crates/service/src/algorithms/ivf/ivf_naive.rs @@ -1,12 +1,8 @@ use crate::algorithms::clustering::elkan_k_means::ElkanKMeans; use crate::algorithms::quantization::Quantization; use crate::algorithms::raw::Raw; -use crate::index::indexing::ivf::IvfIndexingOptions; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; -use crate::index::IndexOptions; -use crate::index::SearchOptions; -use crate::index::VectorOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::element_heap::ElementHeap; @@ -49,7 +45,7 @@ impl IvfNaive { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { self.mmap.raw.vector(i) } @@ -59,7 +55,7 @@ impl IvfNaive { pub fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -68,7 +64,7 @@ impl IvfNaive { pub fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -87,7 +83,7 @@ pub struct IvfRam { // ---------------------- nlist: u32, // ---------------------- - centroids: Vec2, + centroids: Vec2>, ptr: Vec, payloads: Vec, } @@ -103,7 +99,7 @@ pub struct IvfMmap { // ---------------------- nlist: u32, // ---------------------- - centroids: MmapArray, + centroids: MmapArray>, ptr: MmapArray, payloads: MmapArray, } @@ -112,7 +108,7 @@ unsafe impl Send for IvfMmap {} unsafe impl Sync for IvfMmap {} impl IvfMmap { - fn centroids(&self, i: u32) -> &[S::Scalar] { + fn centroids(&self, i: u32) -> &[Scalar] { let s = i as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize; &self.centroids[s..e] @@ -133,7 +129,7 @@ pub fn make( nsample, quantization: quantization_opts, } = options.indexing.clone().unwrap_ivf(); - let raw = Arc::new(Raw::create( + let raw = Arc::new(Raw::::create( &path.join("raw"), options.clone(), sealed, @@ -144,7 +140,7 @@ pub fn make( let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec(); let mut samples = Vec2::new(dims, m as usize); for i in 0..m { - samples[i as usize].copy_from_slice(S::to_dense(raw.vector(f[i as usize] as u32)).as_ref()); + samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32).to_vec().as_ref()); S::elkan_k_means_normalize(&mut samples[i as usize]); } let mut k_means = ElkanKMeans::::new(nlist as usize, samples); @@ -159,11 +155,11 @@ pub fn make( let centroids = k_means.finish(); let mut idx = vec![0usize; n as usize]; idx.par_iter_mut().enumerate().for_each(|(i, x)| { - let mut vector = S::ref_to_owned(raw.vector(i as u32)); + let mut vector = raw.vector(i as u32).for_own(); S::elkan_k_means_normalize2(&mut vector); let mut result = (F32::infinity(), 0); for i in 0..nlist as usize { - let dis = S::elkan_k_means_distance2(S::owned_to_ref(&vector), ¢roids[i]); + let dis = S::elkan_k_means_distance2(vector.for_borrow(), ¢roids[i]); result = std::cmp::min(result, (dis, i)); } *x = result.1; @@ -247,16 +243,16 @@ pub fn open(path: &Path, options: IndexOptions) -> IvfMmap { pub fn basic( mmap: &IvfMmap, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, nprobe: u32, mut filter: impl Filter, ) -> BinaryHeap> { - let mut target = S::ref_to_owned(vector); + let mut target = vector.for_own(); S::elkan_k_means_normalize2(&mut target); let mut lists = ElementHeap::new(nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = S::elkan_k_means_distance2(S::owned_to_ref(&target), centroid); + let distance = S::elkan_k_means_distance2(target.for_borrow(), centroid); if lists.check(distance) { lists.push(Element { distance, @@ -282,16 +278,16 @@ pub fn basic( pub fn vbase<'a, S: G>( mmap: &'a IvfMmap, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, nprobe: u32, mut filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { - let mut target = S::ref_to_owned(vector); + let mut target = vector.for_own(); S::elkan_k_means_normalize2(&mut target); let mut lists = ElementHeap::new(nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = S::elkan_k_means_distance2(S::owned_to_ref(&target), centroid); + let distance = S::elkan_k_means_distance2(target.for_borrow(), centroid); if lists.check(distance) { lists.push(Element { distance, diff --git a/crates/service/src/algorithms/ivf/ivf_pq.rs b/crates/service/src/algorithms/ivf/ivf_pq.rs index e2ca36b17..5c1eac304 100644 --- a/crates/service/src/algorithms/ivf/ivf_pq.rs +++ b/crates/service/src/algorithms/ivf/ivf_pq.rs @@ -2,12 +2,8 @@ use crate::algorithms::clustering::elkan_k_means::ElkanKMeans; use crate::algorithms::quantization::product::ProductQuantization; use crate::algorithms::quantization::Quan; use crate::algorithms::raw::Raw; -use crate::index::indexing::ivf::IvfIndexingOptions; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; -use crate::index::IndexOptions; -use crate::index::SearchOptions; -use crate::index::VectorOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::element_heap::ElementHeap; @@ -51,7 +47,7 @@ impl IvfPq { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { self.mmap.raw.vector(i) } @@ -61,7 +57,7 @@ impl IvfPq { pub fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -70,7 +66,7 @@ impl IvfPq { pub fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -89,7 +85,7 @@ pub struct IvfRam { // ---------------------- nlist: u32, // ---------------------- - centroids: Vec2, + centroids: Vec2>, ptr: Vec, payloads: Vec, } @@ -105,7 +101,7 @@ pub struct IvfMmap { // ---------------------- nlist: u32, // ---------------------- - centroids: MmapArray, + centroids: MmapArray>, ptr: MmapArray, payloads: MmapArray, } @@ -114,7 +110,7 @@ unsafe impl Send for IvfMmap {} unsafe impl Sync for IvfMmap {} impl IvfMmap { - fn centroids(&self, i: u32) -> &[S::Scalar] { + fn centroids(&self, i: u32) -> &[Scalar] { let s = i as usize * self.dims as usize; let e = (i + 1) as usize * self.dims as usize; &self.centroids[s..e] @@ -135,7 +131,7 @@ pub fn make( nsample, quantization: quantization_opts, } = options.indexing.clone().unwrap_ivf(); - let raw = Arc::new(Raw::create( + let raw = Arc::new(Raw::::create( &path.join("raw"), options.clone(), sealed, @@ -146,7 +142,7 @@ pub fn make( let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec(); let mut samples = Vec2::new(dims, m as usize); for i in 0..m { - samples[i as usize].copy_from_slice(S::to_dense(raw.vector(f[i as usize] as u32)).as_ref()); + samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32).to_vec().as_ref()); S::elkan_k_means_normalize(&mut samples[i as usize]); } let mut k_means = ElkanKMeans::::new(nlist as usize, samples); @@ -161,11 +157,11 @@ pub fn make( let centroids = k_means.finish(); let mut idx = vec![0usize; n as usize]; idx.par_iter_mut().enumerate().for_each(|(i, x)| { - let mut vector = S::ref_to_owned(raw.vector(i as u32)); + let mut vector = raw.vector(i as u32).for_own(); S::elkan_k_means_normalize2(&mut vector); let mut result = (F32::infinity(), 0); for i in 0..nlist as usize { - let dis = S::elkan_k_means_distance2(S::owned_to_ref(&vector), ¢roids[i]); + let dis = S::elkan_k_means_distance2(vector.for_borrow(), ¢roids[i]); result = std::cmp::min(result, (dis, i)); } *x = result.1; @@ -194,7 +190,7 @@ pub fn make( .enumerate() .for_each(|(i, v)| { for j in 0..dims { - v[j as usize] = S::to_dense(raw.vector(ids[i])).as_ref()[j as usize] + v[j as usize] = raw.vector(ids[i]).to_vec()[j as usize] - centroids[idx[ids[i] as usize]][j as usize]; } }); @@ -263,14 +259,15 @@ pub fn open(path: &Path, options: IndexOptions) -> IvfMmap { pub fn basic( mmap: &IvfMmap, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, nprobe: u32, mut filter: impl Filter, ) -> BinaryHeap> { + let dense = vector.to_vec(); let mut lists = ElementHeap::new(nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = S::distance2(vector, centroid); + let distance = S::product_quantization_dense_distance(&dense, centroid); if lists.check(distance) { lists.push(Element { distance, @@ -278,7 +275,7 @@ pub fn basic( }); } } - let runtime_table = mmap.quantization.init_query(S::to_dense(vector).as_ref()); + let runtime_table = mmap.quantization.init_query(vector.to_vec().as_ref()); let lists = lists.into_sorted_vec(); let mut result = BinaryHeap::new(); for i in lists.iter() { @@ -306,14 +303,15 @@ pub fn basic( pub fn vbase<'a, S: G>( mmap: &'a IvfMmap, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, nprobe: u32, mut filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { + let dense = vector.to_vec(); let mut lists = ElementHeap::new(nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = S::distance2(vector, centroid); + let distance = S::product_quantization_dense_distance(&dense, centroid); if lists.check(distance) { lists.push(Element { distance, @@ -321,7 +319,7 @@ pub fn vbase<'a, S: G>( }); } } - let runtime_table = mmap.quantization.init_query(S::to_dense(vector).as_ref()); + let runtime_table = mmap.quantization.init_query(vector.to_vec().as_ref()); let lists = lists.into_sorted_vec(); let mut result = Vec::new(); for i in lists.iter() { diff --git a/crates/service/src/algorithms/ivf/mod.rs b/crates/service/src/algorithms/ivf/mod.rs index 877e36c45..6b9f58295 100644 --- a/crates/service/src/algorithms/ivf/mod.rs +++ b/crates/service/src/algorithms/ivf/mod.rs @@ -5,8 +5,6 @@ use self::ivf_naive::IvfNaive; use self::ivf_pq::IvfPq; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; -use crate::index::IndexOptions; -use crate::index::SearchOptions; use crate::prelude::*; use std::cmp::Reverse; use std::collections::BinaryHeap; @@ -25,13 +23,10 @@ impl Ivf { sealed: Vec>>, growing: Vec>>, ) -> Self { - if options - .indexing - .clone() - .unwrap_ivf() - .quantization - .is_product_quantization() - { + if matches!( + options.indexing.clone().unwrap_ivf().quantization, + QuantizationOptions::Product(_) + ) { Self::Pq(IvfPq::create(path, options, sealed, growing)) } else { Self::Naive(IvfNaive::create(path, options, sealed, growing)) @@ -39,13 +34,10 @@ impl Ivf { } pub fn open(path: &Path, options: IndexOptions) -> Self { - if options - .indexing - .clone() - .unwrap_ivf() - .quantization - .is_product_quantization() - { + if matches!( + options.indexing.clone().unwrap_ivf().quantization, + QuantizationOptions::Product(_) + ) { Self::Pq(IvfPq::open(path, options)) } else { Self::Naive(IvfNaive::open(path, options)) @@ -59,7 +51,7 @@ impl Ivf { } } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { match self { Ivf::Naive(x) => x.vector(i), Ivf::Pq(x) => x.vector(i), @@ -75,7 +67,7 @@ impl Ivf { pub fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -87,7 +79,7 @@ impl Ivf { pub fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { diff --git a/crates/service/src/algorithms/quantization/mod.rs b/crates/service/src/algorithms/quantization/mod.rs index 778475016..39e6326f2 100644 --- a/crates/service/src/algorithms/quantization/mod.rs +++ b/crates/service/src/algorithms/quantization/mod.rs @@ -2,60 +2,13 @@ pub mod product; pub mod scalar; pub mod trivial; -use self::product::{ProductQuantization, ProductQuantizationOptions}; -use self::scalar::{ScalarQuantization, ScalarQuantizationOptions}; -use self::trivial::{TrivialQuantization, TrivialQuantizationOptions}; +use self::product::ProductQuantization; +use self::scalar::ScalarQuantization; +use self::trivial::TrivialQuantization; use super::raw::Raw; -use crate::index::IndexOptions; use crate::prelude::*; -use serde::{Deserialize, Serialize}; -use std::fmt::Debug; use std::path::Path; use std::sync::Arc; -use validator::Validate; - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -#[serde(rename_all = "snake_case")] -pub enum QuantizationOptions { - Trivial(TrivialQuantizationOptions), - Scalar(ScalarQuantizationOptions), - Product(ProductQuantizationOptions), -} - -impl Validate for QuantizationOptions { - fn validate(&self) -> Result<(), validator::ValidationErrors> { - match self { - Self::Trivial(x) => x.validate(), - Self::Scalar(x) => x.validate(), - Self::Product(x) => x.validate(), - } - } -} - -impl Default for QuantizationOptions { - fn default() -> Self { - Self::Trivial(Default::default()) - } -} - -impl QuantizationOptions { - fn _unwrap_scalar_quantization(self) -> ScalarQuantizationOptions { - match self { - Self::Scalar(x) => x, - _ => unreachable!(), - } - } - fn unwrap_product_quantization(self) -> ProductQuantizationOptions { - match self { - Self::Product(x) => x, - _ => unreachable!(), - } - } - pub fn is_product_quantization(&self) -> bool { - matches!(self, Self::Product(_)) - } -} pub trait Quan { fn create( @@ -71,7 +24,7 @@ pub trait Quan { quantization_options: QuantizationOptions, raw: &Arc>, ) -> Self; - fn distance(&self, lhs: S::VectorRef<'_>, rhs: u32) -> F32; + fn distance(&self, lhs: Borrowed<'_, S>, rhs: u32) -> F32; fn distance2(&self, lhs: u32, rhs: u32) -> F32; } @@ -142,7 +95,7 @@ impl Quantization { } } - pub fn distance(&self, lhs: S::VectorRef<'_>, rhs: u32) -> F32 { + pub fn distance(&self, lhs: Borrowed<'_, S>, rhs: u32) -> F32 { use Quantization::*; match self { Trivial(x) => x.distance(lhs, rhs), diff --git a/crates/service/src/algorithms/quantization/product.rs b/crates/service/src/algorithms/quantization/product.rs index 52f233c0b..d957ac3bc 100644 --- a/crates/service/src/algorithms/quantization/product.rs +++ b/crates/service/src/algorithms/quantization/product.rs @@ -2,7 +2,6 @@ use crate::algorithms::clustering::elkan_k_means::ElkanKMeans; use crate::algorithms::quantization::Quan; use crate::algorithms::quantization::QuantizationOptions; use crate::algorithms::raw::Raw; -use crate::index::IndexOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::mmap_array::MmapArray; @@ -12,59 +11,13 @@ use rand::thread_rng; use rayon::iter::IndexedParallelIterator; use rayon::iter::ParallelIterator; use rayon::slice::ParallelSliceMut; -use serde::{Deserialize, Serialize}; -use std::fmt::Debug; use std::path::Path; use std::sync::Arc; -use validator::Validate; - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct ProductQuantizationOptions { - #[serde(default = "ProductQuantizationOptions::default_sample")] - #[validate(range(min = 1, max = 1_000_000))] - pub sample: u32, - #[serde(default)] - pub ratio: ProductQuantizationOptionsRatio, -} - -impl ProductQuantizationOptions { - fn default_sample() -> u32 { - 65535 - } -} - -impl Default for ProductQuantizationOptions { - fn default() -> Self { - Self { - sample: Self::default_sample(), - ratio: Default::default(), - } - } -} - -#[repr(u16)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -#[serde(rename_all = "snake_case")] -pub enum ProductQuantizationOptionsRatio { - X4 = 1, - X8 = 2, - X16 = 4, - X32 = 8, - X64 = 16, -} - -impl Default for ProductQuantizationOptionsRatio { - fn default() -> Self { - Self::X4 - } -} pub struct ProductQuantization { dims: u16, ratio: u16, - centroids: Vec, + centroids: Vec>, codes: MmapArray, precomputed_table: Vec, } @@ -105,6 +58,9 @@ impl Quan for ProductQuantization { quantization_options: QuantizationOptions, _: &Arc>, ) -> Self { + let QuantizationOptions::Product(quantization_options) = quantization_options else { + unreachable!() + }; let centroids = serde_json::from_slice(&std::fs::read(path.join("centroids")).unwrap()).unwrap(); let codes = MmapArray::open(&path.join("codes")); @@ -112,14 +68,14 @@ impl Quan for ProductQuantization { serde_json::from_slice(&std::fs::read(path.join("table")).unwrap()).unwrap(); Self { dims: options.vector.dims, - ratio: quantization_options.unwrap_product_quantization().ratio as _, + ratio: quantization_options.ratio as _, centroids, codes, precomputed_table, } } - fn distance(&self, lhs: S::VectorRef<'_>, rhs: u32) -> F32 { + fn distance(&self, lhs: Borrowed<'_, S>, rhs: u32) -> F32 { let dims = self.dims; let ratio = self.ratio; let rhs = self.codes(rhs); @@ -145,38 +101,35 @@ impl ProductQuantization { permutation: Vec, ) -> Self where - F: Fn(u32, &mut [S::Scalar]), + F: Fn(u32, &mut [Scalar]), { - assert!( - S::KIND != Kind::SparseF32, - "Product quantization is not supported for sparse vectors." - ); - std::fs::create_dir(path).unwrap(); - let quantization_options = quantization_options.unwrap_product_quantization(); + let QuantizationOptions::Product(quantization_options) = quantization_options else { + unreachable!() + }; let dims = options.vector.dims; let ratio = quantization_options.ratio as u16; let n = raw.len(); let m = std::cmp::min(n, quantization_options.sample); let samples = { let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec(); - let mut samples = Vec2::::new(options.vector.dims, m as usize); + let mut samples = Vec2::>::new(options.vector.dims, m as usize); for i in 0..m { samples[i as usize] - .copy_from_slice(S::to_dense(raw.vector(f[i as usize] as u32)).as_ref()); + .copy_from_slice(raw.vector(f[i as usize] as u32).to_vec().as_ref()); } samples }; let width = dims.div_ceil(ratio); - let mut centroids = vec![S::Scalar::zero(); 256 * dims as usize]; + let mut centroids = vec![Scalar::::zero(); 256 * dims as usize]; for i in 0..width { let subdims = std::cmp::min(ratio, dims - ratio * i); - let mut subsamples = Vec2::::new(subdims, m as usize); + let mut subsamples = Vec2::>::new(subdims, m as usize); for j in 0..m { let src = &samples[j as usize][(i * ratio) as usize..][..subdims as usize]; subsamples[j as usize].copy_from_slice(src); } - let mut k_means = ElkanKMeans::::new(256, subsamples); + let mut k_means = ElkanKMeans::::new(256, subsamples); for _ in 0..25 { if k_means.iterate() { break; @@ -189,7 +142,7 @@ impl ProductQuantization { } } let codes_iter = (0..n).flat_map(|i| { - let mut vector = S::to_dense(raw.vector(permutation[i as usize])).to_vec(); + let mut vector = raw.vector(permutation[i as usize]).to_vec(); normalizer(permutation[i as usize], &mut vector); let width = dims.div_ceil(ratio); let mut result = Vec::with_capacity(width as usize); @@ -201,7 +154,7 @@ impl ProductQuantization { for j in 0u8..=255 { let right = ¢roids[j as usize * dims as usize..][(i * ratio) as usize..] [..subdims as usize]; - let dis = S::L2::distance(left, right); + let dis = S::product_quantization_l2_distance(left, right); if dis < minimal { minimal = dis; target = j; @@ -231,10 +184,12 @@ impl ProductQuantization { path: &Path, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Vec2, + raw: &Vec2>, ) -> Self { std::fs::create_dir(path).unwrap(); - let quantization_options = quantization_options.unwrap_product_quantization(); + let QuantizationOptions::Product(quantization_options) = quantization_options else { + unreachable!() + }; let dims = options.vector.dims; let ratio = quantization_options.ratio as u16; let n = raw.len(); @@ -249,7 +204,7 @@ impl ProductQuantization { }; let width = dims.div_ceil(ratio); // a temp layout (width * 256 * subdims) for par_chunks_mut - let mut tmp_centroids = vec![S::Scalar::zero(); 256 * dims as usize]; + let mut tmp_centroids = vec![Scalar::::zero(); 256 * dims as usize]; // this par_for parallelizes over sub quantizers tmp_centroids .par_chunks_mut(256 * ratio as usize) @@ -262,7 +217,7 @@ impl ProductQuantization { let src = &samples[j][i * ratio as usize..][..subdims]; subsamples[j].copy_from_slice(src); } - let mut k_means = ElkanKMeans::::new(256, subsamples); + let mut k_means = ElkanKMeans::::new(256, subsamples); for _ in 0..25 { if k_means.iterate() { break; @@ -274,7 +229,7 @@ impl ProductQuantization { } }); // transform back to normal layout (256 * width * subdims) - let mut centroids = vec![S::Scalar::zero(); 256 * dims as usize]; + let mut centroids = vec![Scalar::::zero(); 256 * dims as usize]; centroids .par_chunks_mut(dims as usize) .enumerate() @@ -301,7 +256,9 @@ impl ProductQuantization { for j in 0u8..=255 { let right = ¢roids[j as usize * dims as usize..] [(i * ratio) as usize..][..subdims as usize]; - let dis = S::L2::distance(left, right); + let dis = S::ProductQuantizationL2::product_quantization_dense_distance( + left, right, + ); if dis < minimal { minimal = dis; target = j; @@ -327,7 +284,7 @@ impl ProductQuantization { } // compute term3 at build time - pub fn precompute_table(&mut self, path: &Path, coarse_centroids: &Vec2) { + pub fn precompute_table(&mut self, path: &Path, coarse_centroids: &Vec2>) { let nlist = coarse_centroids.len(); let dims = self.dims; let ratio = self.ratio; @@ -358,8 +315,8 @@ impl ProductQuantization { } // compute term2 at query time - pub fn init_query(&self, query: &[S::Scalar]) -> Vec { - if S::DISTANCE == Distance::Cos { + pub fn init_query(&self, query: &[Scalar]) -> Vec { + if S::DISTANCE_KIND == DistanceKind::Cos { return Vec::new(); } let dims = self.dims; @@ -382,26 +339,26 @@ impl ProductQuantization { // add up all terms given codes pub fn distance_with_codes( &self, - lhs: S::VectorRef<'_>, + lhs: Borrowed<'_, S>, rhs: u32, - delta: &[S::Scalar], + delta: &[Scalar], key: usize, coarse_dis: F32, runtime_table: &[F32], ) -> F32 { - if S::DISTANCE == Distance::Cos { + if S::DISTANCE_KIND == DistanceKind::Cos { return self.distance_with_delta(lhs, rhs, delta); } let mut result = coarse_dis; let codes = self.codes(rhs); let width = self.dims.div_ceil(self.ratio); let precomputed_table = &self.precomputed_table[key * width as usize * 256..]; - if S::DISTANCE == Distance::L2 { + if S::DISTANCE_KIND == DistanceKind::L2 { for i in 0..width { result += precomputed_table[i as usize * 256 + codes[i as usize] as usize] + F32(2.0) * runtime_table[i as usize * 256 + codes[i as usize] as usize]; } - } else if S::DISTANCE == Distance::Dot { + } else if S::DISTANCE_KIND == DistanceKind::Dot { for i in 0..width { result += runtime_table[i as usize * 256 + codes[i as usize] as usize]; } @@ -409,10 +366,26 @@ impl ProductQuantization { result } - pub fn distance_with_delta(&self, lhs: S::VectorRef<'_>, rhs: u32, delta: &[S::Scalar]) -> F32 { + pub fn distance_with_delta(&self, lhs: Borrowed<'_, S>, rhs: u32, delta: &[Scalar]) -> F32 { let dims = self.dims; let ratio = self.ratio; let rhs = self.codes(rhs); S::product_quantization_distance_with_delta(dims, ratio, &self.centroids, lhs, rhs, delta) } } + +pub fn squared_norm(dims: u16, vec: &[Scalar]) -> F32 { + let mut result = F32::zero(); + for i in 0..dims as usize { + result += F32((vec[i] * vec[i]).to_f32()); + } + result +} + +pub fn inner_product(dims: u16, lhs: &[Scalar], rhs: &[Scalar]) -> F32 { + let mut result = F32::zero(); + for i in 0..dims as usize { + result += F32((lhs[i] * rhs[i]).to_f32()); + } + result +} diff --git a/crates/service/src/algorithms/quantization/scalar.rs b/crates/service/src/algorithms/quantization/scalar.rs index d0e4e8c73..2938ffd4e 100644 --- a/crates/service/src/algorithms/quantization/scalar.rs +++ b/crates/service/src/algorithms/quantization/scalar.rs @@ -1,30 +1,16 @@ use crate::algorithms::quantization::Quan; use crate::algorithms::quantization::QuantizationOptions; use crate::algorithms::raw::Raw; -use crate::index::IndexOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::mmap_array::MmapArray; -use base::scalar::FloatCast; -use serde::{Deserialize, Serialize}; use std::path::Path; use std::sync::Arc; -use validator::Validate; - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct ScalarQuantizationOptions {} - -impl Default for ScalarQuantizationOptions { - fn default() -> Self { - Self {} - } -} pub struct ScalarQuantization { dims: u16, - max: Vec, - min: Vec, + max: Vec>, + min: Vec>, codes: MmapArray, } @@ -47,18 +33,13 @@ impl Quan for ScalarQuantization { raw: &Arc>, permutation: Vec, // permutation is the mapping from placements to original ids ) -> Self { - assert!( - S::KIND != Kind::SparseF32, - "Scalar quantization is not supported for sparse vectors." - ); - std::fs::create_dir(path).unwrap(); let dims = options.vector.dims; - let mut max = vec![S::Scalar::neg_infinity(); dims as usize]; - let mut min = vec![S::Scalar::infinity(); dims as usize]; + let mut max = vec![Scalar::::neg_infinity(); dims as usize]; + let mut min = vec![Scalar::::infinity(); dims as usize]; let n = raw.len(); for i in 0..n { - let vector = S::to_dense(raw.vector(permutation[i as usize])); + let vector = raw.vector(permutation[i as usize]).to_vec(); for j in 0..dims as usize { max[j] = std::cmp::max(max[j], vector[j]); min[j] = std::cmp::min(min[j], vector[j]); @@ -67,7 +48,7 @@ impl Quan for ScalarQuantization { std::fs::write(path.join("max"), serde_json::to_string(&max).unwrap()).unwrap(); std::fs::write(path.join("min"), serde_json::to_string(&min).unwrap()).unwrap(); let codes_iter = (0..n).flat_map(|i| { - let vector = S::to_dense(raw.vector(permutation[i as usize])); + let vector = raw.vector(permutation[i as usize]).to_vec(); let mut result = vec![0u8; dims as usize]; for i in 0..dims as usize { let w = (((vector[i] - min[i]) / (max[i] - min[i])).to_f32() * 256.0) as u32; @@ -98,7 +79,7 @@ impl Quan for ScalarQuantization { } } - fn distance(&self, lhs: S::VectorRef<'_>, rhs: u32) -> F32 { + fn distance(&self, lhs: Borrowed<'_, S>, rhs: u32) -> F32 { let dims = self.dims; let rhs = self.codes(rhs); S::scalar_quantization_distance(dims, &self.max, &self.min, lhs, rhs) diff --git a/crates/service/src/algorithms/quantization/trivial.rs b/crates/service/src/algorithms/quantization/trivial.rs index d2fb75562..2350fa62e 100644 --- a/crates/service/src/algorithms/quantization/trivial.rs +++ b/crates/service/src/algorithms/quantization/trivial.rs @@ -1,23 +1,10 @@ use crate::algorithms::quantization::Quan; use crate::algorithms::quantization::QuantizationOptions; use crate::algorithms::raw::Raw; -use crate::index::IndexOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; -use serde::{Deserialize, Serialize}; use std::path::Path; use std::sync::Arc; -use validator::Validate; - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct TrivialQuantizationOptions {} - -impl Default for TrivialQuantizationOptions { - fn default() -> Self { - Self {} - } -} pub struct TrivialQuantization { raw: Arc>, @@ -25,7 +12,7 @@ pub struct TrivialQuantization { } impl TrivialQuantization { - pub fn codes(&self, i: u32) -> S::VectorRef<'_> { + pub fn codes(&self, i: u32) -> Borrowed<'_, S> { self.raw.vector(self.permutation[i as usize]) } } @@ -62,7 +49,7 @@ impl Quan for TrivialQuantization { } } - fn distance(&self, lhs: S::VectorRef<'_>, rhs: u32) -> F32 { + fn distance(&self, lhs: Borrowed<'_, S>, rhs: u32) -> F32 { S::distance(lhs, self.codes(rhs)) } diff --git a/crates/service/src/algorithms/raw.rs b/crates/service/src/algorithms/raw.rs index b501f7330..ee748732e 100644 --- a/crates/service/src/algorithms/raw.rs +++ b/crates/service/src/algorithms/raw.rs @@ -1,7 +1,7 @@ use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; -use crate::index::IndexOptions; use crate::prelude::*; +use crate::storage::Storage; use std::path::Path; use std::sync::Arc; @@ -10,11 +10,11 @@ pub struct Raw { } impl Raw { - pub fn create G = S::VectorRef<'a>>>( + pub fn create( path: &Path, options: IndexOptions, - sealed: Vec>>, - growing: Vec>>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { std::fs::create_dir(path).unwrap(); let ram = make(sealed, growing, options); @@ -29,7 +29,7 @@ impl Raw { self.mmap.len() } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { self.mmap.vector(i) } @@ -63,7 +63,7 @@ impl RawRam { + self.growing.iter().map(|x| x.len()).sum::() } - pub fn vector(&self, mut index: u32) -> S::VectorRef<'_> { + pub fn vector(&self, mut index: u32) -> Borrowed<'_, S> { for x in self.sealed.iter() { if index < x.len() { return x.vector(index); diff --git a/crates/service/src/index/indexing/flat.rs b/crates/service/src/index/indexing/flat.rs index 71cc83edc..f6660a21a 100644 --- a/crates/service/src/index/indexing/flat.rs +++ b/crates/service/src/index/indexing/flat.rs @@ -1,32 +1,13 @@ use super::AbstractIndexing; -use crate::algorithms::quantization::QuantizationOptions; use crate::index::segments::growing::GrowingSegment; use crate::index::IndexOptions; use crate::index::SearchOptions; use crate::prelude::*; use crate::{algorithms::flat::Flat, index::segments::sealed::SealedSegment}; -use serde::{Deserialize, Serialize}; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::path::Path; use std::sync::Arc; -use validator::Validate; - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct FlatIndexingOptions { - #[serde(default)] - #[validate] - pub quantization: QuantizationOptions, -} - -impl Default for FlatIndexingOptions { - fn default() -> Self { - Self { - quantization: QuantizationOptions::default(), - } - } -} pub struct FlatIndexing { raw: Flat, @@ -45,7 +26,7 @@ impl AbstractIndexing for FlatIndexing { fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -54,7 +35,7 @@ impl AbstractIndexing for FlatIndexing { fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -67,7 +48,7 @@ impl FlatIndexing { self.raw.len() } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { self.raw.vector(i) } diff --git a/crates/service/src/index/indexing/hnsw.rs b/crates/service/src/index/indexing/hnsw.rs index d283c1778..a99710bde 100644 --- a/crates/service/src/index/indexing/hnsw.rs +++ b/crates/service/src/index/indexing/hnsw.rs @@ -1,50 +1,14 @@ use super::AbstractIndexing; use crate::algorithms::hnsw::Hnsw; -use crate::algorithms::quantization::QuantizationOptions; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; use crate::index::IndexOptions; use crate::index::SearchOptions; use crate::prelude::*; -use serde::{Deserialize, Serialize}; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::path::Path; use std::sync::Arc; -use validator::Validate; - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct HnswIndexingOptions { - #[serde(default = "HnswIndexingOptions::default_m")] - #[validate(range(min = 4, max = 128))] - pub m: u32, - #[serde(default = "HnswIndexingOptions::default_ef_construction")] - #[validate(range(min = 10, max = 2000))] - pub ef_construction: usize, - #[serde(default)] - #[validate] - pub quantization: QuantizationOptions, -} - -impl HnswIndexingOptions { - fn default_m() -> u32 { - 12 - } - fn default_ef_construction() -> usize { - 300 - } -} - -impl Default for HnswIndexingOptions { - fn default() -> Self { - Self { - m: Self::default_m(), - ef_construction: Self::default_ef_construction(), - quantization: Default::default(), - } - } -} pub struct HnswIndexing { raw: Hnsw, @@ -63,7 +27,7 @@ impl AbstractIndexing for HnswIndexing { fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -72,7 +36,7 @@ impl AbstractIndexing for HnswIndexing { fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -85,7 +49,7 @@ impl HnswIndexing { self.raw.len() } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { self.raw.vector(i) } diff --git a/crates/service/src/index/indexing/ivf.rs b/crates/service/src/index/indexing/ivf.rs index 6959d7f51..85c06c7d5 100644 --- a/crates/service/src/index/indexing/ivf.rs +++ b/crates/service/src/index/indexing/ivf.rs @@ -1,64 +1,14 @@ use super::AbstractIndexing; use crate::algorithms::ivf::Ivf; -use crate::algorithms::quantization::QuantizationOptions; use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; use crate::index::IndexOptions; use crate::index::SearchOptions; use crate::prelude::*; -use serde::{Deserialize, Serialize}; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::path::Path; use std::sync::Arc; -use validator::Validate; - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct IvfIndexingOptions { - #[serde(default = "IvfIndexingOptions::default_least_iterations")] - #[validate(range(min = 1, max = 1_000_000))] - pub least_iterations: u32, - #[serde(default = "IvfIndexingOptions::default_iterations")] - #[validate(range(min = 1, max = 1_000_000))] - pub iterations: u32, - #[serde(default = "IvfIndexingOptions::default_nlist")] - #[validate(range(min = 1, max = 1_000_000))] - pub nlist: u32, - #[serde(default = "IvfIndexingOptions::default_nsample")] - #[validate(range(min = 1, max = 1_000_000))] - pub nsample: u32, - #[serde(default)] - #[validate] - pub quantization: QuantizationOptions, -} - -impl IvfIndexingOptions { - fn default_least_iterations() -> u32 { - 16 - } - fn default_iterations() -> u32 { - 500 - } - fn default_nlist() -> u32 { - 1000 - } - fn default_nsample() -> u32 { - 65536 - } -} - -impl Default for IvfIndexingOptions { - fn default() -> Self { - Self { - least_iterations: Self::default_least_iterations(), - iterations: Self::default_iterations(), - nlist: Self::default_nlist(), - nsample: Self::default_nsample(), - quantization: Default::default(), - } - } -} pub struct IvfIndexing { raw: Ivf, @@ -77,7 +27,7 @@ impl AbstractIndexing for IvfIndexing { fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -86,7 +36,7 @@ impl AbstractIndexing for IvfIndexing { fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -99,7 +49,7 @@ impl IvfIndexing { self.raw.len() } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { self.raw.vector(i) } diff --git a/crates/service/src/index/indexing/mod.rs b/crates/service/src/index/indexing/mod.rs index de4ddd7b4..111f5a957 100644 --- a/crates/service/src/index/indexing/mod.rs +++ b/crates/service/src/index/indexing/mod.rs @@ -2,75 +2,18 @@ pub mod flat; pub mod hnsw; pub mod ivf; -use self::flat::{FlatIndexing, FlatIndexingOptions}; -use self::hnsw::{HnswIndexing, HnswIndexingOptions}; -use self::ivf::{IvfIndexing, IvfIndexingOptions}; +use self::flat::FlatIndexing; +use self::hnsw::HnswIndexing; +use self::ivf::IvfIndexing; use super::segments::growing::GrowingSegment; use super::segments::sealed::SealedSegment; use super::IndexOptions; -use crate::algorithms::quantization::QuantizationOptions; use crate::index::SearchOptions; use crate::prelude::*; -use serde::{Deserialize, Serialize}; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::path::Path; use std::sync::Arc; -use validator::Validate; - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -#[serde(rename_all = "snake_case")] -pub enum IndexingOptions { - Flat(FlatIndexingOptions), - Ivf(IvfIndexingOptions), - Hnsw(HnswIndexingOptions), -} - -impl IndexingOptions { - pub fn unwrap_flat(self) -> FlatIndexingOptions { - let IndexingOptions::Flat(x) = self else { - unreachable!() - }; - x - } - pub fn unwrap_ivf(self) -> IvfIndexingOptions { - let IndexingOptions::Ivf(x) = self else { - unreachable!() - }; - x - } - pub fn unwrap_hnsw(self) -> HnswIndexingOptions { - let IndexingOptions::Hnsw(x) = self else { - unreachable!() - }; - x - } - pub fn has_quantization(&self) -> bool { - let option = match self { - Self::Flat(x) => &x.quantization, - Self::Ivf(x) => &x.quantization, - Self::Hnsw(x) => &x.quantization, - }; - !matches!(option, QuantizationOptions::Trivial(_)) - } -} - -impl Default for IndexingOptions { - fn default() -> Self { - Self::Hnsw(Default::default()) - } -} - -impl Validate for IndexingOptions { - fn validate(&self) -> Result<(), validator::ValidationErrors> { - match self { - Self::Flat(x) => x.validate(), - Self::Ivf(x) => x.validate(), - Self::Hnsw(x) => x.validate(), - } - } -} pub trait AbstractIndexing { fn create( @@ -81,13 +24,13 @@ pub trait AbstractIndexing { ) -> Self; fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap>; fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box + 'a>); @@ -121,7 +64,7 @@ impl DynamicIndexing { pub fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -134,7 +77,7 @@ impl DynamicIndexing { pub fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -153,7 +96,7 @@ impl DynamicIndexing { } } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { match self { DynamicIndexing::Flat(x) => x.vector(i), DynamicIndexing::Ivf(x) => x.vector(i), diff --git a/crates/service/src/index/mod.rs b/crates/service/src/index/mod.rs index 0093fad72..69384d119 100644 --- a/crates/service/src/index/mod.rs +++ b/crates/service/src/index/mod.rs @@ -4,11 +4,8 @@ pub mod optimizing; pub mod segments; use self::delete::Delete; -use self::indexing::IndexingOptions; -use self::optimizing::OptimizingOptions; use self::segments::growing::GrowingSegment; use self::segments::sealed::SealedSegment; -use self::segments::SegmentsOptions; use crate::index::optimizing::indexing::OptimizerIndexing; use crate::index::optimizing::sealing::OptimizerSealing; use crate::prelude::*; @@ -29,73 +26,11 @@ use std::time::Instant; use thiserror::Error; use uuid::Uuid; use validator::Validate; -use validator::ValidationError; #[derive(Debug, Error)] #[error("The index view is outdated.")] pub struct OutdatedError; -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct VectorOptions { - #[validate(range(min = 1, max = 65535))] - #[serde(rename = "dimensions")] - pub dims: u16, - #[serde(rename = "distance")] - pub d: Distance, - #[serde(rename = "kind")] - pub k: Kind, -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -#[validate(schema(function = "validate_index_options"))] -pub struct IndexOptions { - #[validate] - pub vector: VectorOptions, - #[validate] - pub segment: SegmentsOptions, - #[validate] - pub optimizing: OptimizingOptions, - #[validate] - pub indexing: IndexingOptions, -} - -fn validate_index_options(options: &IndexOptions) -> Result<(), ValidationError> { - if options.vector.k == Kind::SparseF32 && options.indexing.has_quantization() { - return Err(ValidationError::new( - "quantization is not supported for sparse vector", - )); - } - Ok(()) -} - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -pub struct SearchOptions { - pub prefilter_enable: bool, - #[validate(range(min = 1, max = 65535))] - pub hnsw_ef_search: usize, - #[validate(range(min = 1, max = 1_000_000))] - pub ivf_nprobe: u32, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct SegmentStat { - pub id: Uuid, - #[serde(rename = "type")] - pub typ: String, - pub length: usize, - pub size: u64, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct IndexStat { - pub indexing: bool, - pub segments: Vec, - pub options: IndexOptions, -} - pub struct Index { path: PathBuf, options: IndexOptions, @@ -318,7 +253,7 @@ pub struct IndexView { impl IndexView { pub fn basic<'a, F: Fn(Pointer) -> bool + Clone + 'a>( &'a self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &'a SearchOptions, filter: F, ) -> Result + 'a, BasicError> { @@ -396,7 +331,7 @@ impl IndexView { } pub fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: F, ) -> Result + 'a, VbaseError> { @@ -490,7 +425,7 @@ impl IndexView { } pub fn insert( &self, - vector: S::VectorOwned, + vector: Owned, pointer: Pointer, ) -> Result, InsertError> { if self.options.vector.dims != vector.dims() { diff --git a/crates/service/src/index/optimizing/indexing.rs b/crates/service/src/index/optimizing/indexing.rs index 35c697d95..6aa9c11cf 100644 --- a/crates/service/src/index/optimizing/indexing.rs +++ b/crates/service/src/index/optimizing/indexing.rs @@ -5,6 +5,7 @@ use crate::prelude::*; use std::cmp::Reverse; use std::sync::Arc; use std::time::Instant; +use thiserror::Error; use uuid::Uuid; pub struct OptimizerIndexing { @@ -76,7 +77,7 @@ impl Seg { } } -#[derive(Debug, thiserror::Error)] +#[derive(Debug, Error)] #[error("Interrupted, retry again.")] pub struct RetryError; diff --git a/crates/service/src/index/optimizing/mod.rs b/crates/service/src/index/optimizing/mod.rs index 2525b5057..ab4ba8d0a 100644 --- a/crates/service/src/index/optimizing/mod.rs +++ b/crates/service/src/index/optimizing/mod.rs @@ -1,52 +1,3 @@ pub mod indexing; pub mod sealing; pub mod vacuum; - -use serde::{Deserialize, Serialize}; -use validator::Validate; - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -pub struct OptimizingOptions { - #[serde(default = "OptimizingOptions::default_sealing_secs")] - #[validate(range(min = 1, max = 60))] - pub sealing_secs: u64, - #[serde(default = "OptimizingOptions::default_sealing_size")] - #[validate(range(min = 1, max = 4_000_000_000))] - pub sealing_size: u32, - #[serde(default = "OptimizingOptions::default_delete_threshold")] - #[validate(range(min = 0.01, max = 1.00))] - pub delete_threshold: f64, - #[serde(default = "OptimizingOptions::default_optimizing_threads")] - #[validate(range(min = 1, max = 65535))] - pub optimizing_threads: usize, -} - -impl OptimizingOptions { - fn default_sealing_secs() -> u64 { - 60 - } - fn default_sealing_size() -> u32 { - 1 - } - fn default_delete_threshold() -> f64 { - 0.2 - } - fn default_optimizing_threads() -> usize { - match std::thread::available_parallelism() { - Ok(threads) => (threads.get() as f64).sqrt() as _, - Err(_) => 1, - } - } -} - -impl Default for OptimizingOptions { - fn default() -> Self { - Self { - sealing_secs: Self::default_sealing_secs(), - sealing_size: Self::default_sealing_size(), - delete_threshold: Self::default_delete_threshold(), - optimizing_threads: Self::default_optimizing_threads(), - } - } -} diff --git a/crates/service/src/index/segments/growing.rs b/crates/service/src/index/segments/growing.rs index 0a353d36d..65cf5647c 100644 --- a/crates/service/src/index/segments/growing.rs +++ b/crates/service/src/index/segments/growing.rs @@ -173,13 +173,13 @@ impl GrowingSegment { } } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { let i = i as usize; if i >= self.len.load(Ordering::Acquire) { panic!("Out of bound."); } let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; - S::owned_to_ref(&log.vector) + log.vector.for_borrow() } pub fn payload(&self, i: u32) -> Payload { @@ -193,7 +193,7 @@ impl GrowingSegment { pub fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, _opts: &SearchOptions, mut filter: impl Filter, ) -> BinaryHeap> { @@ -202,7 +202,7 @@ impl GrowingSegment { for i in 0..n { let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; if filter.check(log.payload) { - let distance = S::distance(vector, S::owned_to_ref(&log.vector)); + let distance = S::distance(vector, log.vector.for_borrow()); result.push(Reverse(Element { distance, payload: log.payload, @@ -214,7 +214,7 @@ impl GrowingSegment { pub fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, _opts: &SearchOptions, mut filter: impl Filter + 'a, ) -> (Vec, Box + 'a>) { @@ -223,7 +223,7 @@ impl GrowingSegment { for i in 0..n { let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; if filter.check(log.payload) { - let distance = S::distance(vector, S::owned_to_ref(&log.vector)); + let distance = S::distance(vector, log.vector.for_borrow()); result.push(Element { distance, payload: log.payload, diff --git a/crates/service/src/index/segments/mod.rs b/crates/service/src/index/segments/mod.rs index 85822b983..dd3e1451b 100644 --- a/crates/service/src/index/segments/mod.rs +++ b/crates/service/src/index/segments/mod.rs @@ -2,50 +2,8 @@ pub mod growing; pub mod sealed; use super::IndexTracker; -use serde::{Deserialize, Serialize}; use std::path::PathBuf; use std::sync::Arc; -use validator::Validate; -use validator::ValidationError; - -#[derive(Debug, Clone, Serialize, Deserialize, Validate)] -#[serde(deny_unknown_fields)] -#[validate(schema(function = "Self::validate_0"))] -pub struct SegmentsOptions { - #[serde(default = "SegmentsOptions::default_max_growing_segment_size")] - #[validate(range(min = 1, max = 4_000_000_000))] - pub max_growing_segment_size: u32, - #[serde(default = "SegmentsOptions::default_max_sealed_segment_size")] - #[validate(range(min = 1, max = 4_000_000_000))] - pub max_sealed_segment_size: u32, -} - -impl SegmentsOptions { - fn default_max_growing_segment_size() -> u32 { - 20_000 - } - fn default_max_sealed_segment_size() -> u32 { - 1_000_000 - } - // max_growing_segment_size <= max_sealed_segment_size - fn validate_0(&self) -> Result<(), ValidationError> { - if self.max_growing_segment_size > self.max_sealed_segment_size { - return Err(ValidationError::new( - "`max_growing_segment_size` must be less than or equal to `max_sealed_segment_size`", - )); - } - Ok(()) - } -} - -impl Default for SegmentsOptions { - fn default() -> Self { - Self { - max_growing_segment_size: Self::default_max_growing_segment_size(), - max_sealed_segment_size: Self::default_max_sealed_segment_size(), - } - } -} #[derive(Debug, Clone)] pub struct SegmentTracker { diff --git a/crates/service/src/index/segments/sealed.rs b/crates/service/src/index/segments/sealed.rs index 07ffb8f52..137a33d71 100644 --- a/crates/service/src/index/segments/sealed.rs +++ b/crates/service/src/index/segments/sealed.rs @@ -65,7 +65,7 @@ impl SealedSegment { pub fn basic( &self, - vector: S::VectorRef<'_>, + vector: Borrowed<'_, S>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -74,7 +74,7 @@ impl SealedSegment { pub fn vbase<'a>( &'a self, - vector: S::VectorRef<'a>, + vector: Borrowed<'a, S>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box + 'a>) { @@ -85,7 +85,7 @@ impl SealedSegment { self.indexing.len() } - pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + pub fn vector(&self, i: u32) -> Borrowed<'_, S> { self.indexing.vector(i) } diff --git a/crates/service/src/instance/metadata.rs b/crates/service/src/instance/metadata.rs index eed22ec52..eb2124e17 100644 --- a/crates/service/src/instance/metadata.rs +++ b/crates/service/src/instance/metadata.rs @@ -18,7 +18,7 @@ pub struct Metadata { } impl Metadata { - const VERSION: u64 = 3; + const VERSION: u64 = 4; const SOFT_VERSION: u64 = 1; } diff --git a/crates/service/src/instance/mod.rs b/crates/service/src/instance/mod.rs index 3007458f4..da5b6f2c6 100644 --- a/crates/service/src/instance/mod.rs +++ b/crates/service/src/instance/mod.rs @@ -1,92 +1,74 @@ pub mod metadata; use crate::index::Index; -use crate::index::IndexOptions; -use crate::index::IndexStat; use crate::index::IndexView; use crate::index::OutdatedError; -use crate::index::SearchOptions; use crate::prelude::*; +use base::worker::*; use std::path::PathBuf; use std::sync::Arc; -pub trait InstanceViewOperations { - fn basic<'a, F: Fn(Pointer) -> bool + Clone + 'a>( - &'a self, - vector: &'a DynamicVector, - opts: &'a SearchOptions, - filter: F, - ) -> Result + 'a>, BasicError>; - fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>( - &'a self, - vector: &'a DynamicVector, - opts: &'a SearchOptions, - filter: F, - ) -> Result + 'a>, VbaseError>; - fn list(&self) -> Result + '_>, ListError>; -} - #[derive(Clone)] pub enum Instance { - F32Cos(Arc>), - F32Dot(Arc>), - F32L2(Arc>), - F16Cos(Arc>), - F16Dot(Arc>), - F16L2(Arc>), - SparseF32L2(Arc>), - SparseF32Cos(Arc>), - SparseF32Dot(Arc>), + Vecf32Cos(Arc>), + Vecf32Dot(Arc>), + Vecf32L2(Arc>), + Vecf16Cos(Arc>), + Vecf16Dot(Arc>), + Vecf16L2(Arc>), + SVecf32L2(Arc>), + SVecf32Cos(Arc>), + SVecf32Dot(Arc>), Upgrade, } impl Instance { pub fn create(path: PathBuf, options: IndexOptions) -> Result { - match (options.vector.d, options.vector.k) { - (Distance::Cos, Kind::F32) => { + match (options.vector.d, options.vector.v) { + (DistanceKind::Cos, VectorKind::Vecf32) => { let index = Index::create(path.clone(), options)?; self::metadata::Metadata::write(path.join("metadata")); - Ok(Self::F32Cos(index)) + Ok(Self::Vecf32Cos(index)) } - (Distance::Dot, Kind::F32) => { + (DistanceKind::Dot, VectorKind::Vecf32) => { let index = Index::create(path.clone(), options)?; self::metadata::Metadata::write(path.join("metadata")); - Ok(Self::F32Dot(index)) + Ok(Self::Vecf32Dot(index)) } - (Distance::L2, Kind::F32) => { + (DistanceKind::L2, VectorKind::Vecf32) => { let index = Index::create(path.clone(), options)?; self::metadata::Metadata::write(path.join("metadata")); - Ok(Self::F32L2(index)) + Ok(Self::Vecf32L2(index)) } - (Distance::Cos, Kind::F16) => { + (DistanceKind::Cos, VectorKind::Vecf16) => { let index = Index::create(path.clone(), options)?; self::metadata::Metadata::write(path.join("metadata")); - Ok(Self::F16Cos(index)) + Ok(Self::Vecf16Cos(index)) } - (Distance::Dot, Kind::F16) => { + (DistanceKind::Dot, VectorKind::Vecf16) => { let index = Index::create(path.clone(), options)?; self::metadata::Metadata::write(path.join("metadata")); - Ok(Self::F16Dot(index)) + Ok(Self::Vecf16Dot(index)) } - (Distance::L2, Kind::F16) => { + (DistanceKind::L2, VectorKind::Vecf16) => { let index = Index::create(path.clone(), options)?; self::metadata::Metadata::write(path.join("metadata")); - Ok(Self::F16L2(index)) + Ok(Self::Vecf16L2(index)) } - (Distance::L2, Kind::SparseF32) => { + (DistanceKind::L2, VectorKind::SVecf32) => { let index = Index::create(path.clone(), options)?; self::metadata::Metadata::write(path.join("metadata")); - Ok(Self::SparseF32L2(index)) + Ok(Self::SVecf32L2(index)) } - (Distance::Cos, Kind::SparseF32) => { + (DistanceKind::Cos, VectorKind::SVecf32) => { let index = Index::create(path.clone(), options)?; self::metadata::Metadata::write(path.join("metadata")); - Ok(Self::SparseF32Cos(index)) + Ok(Self::SVecf32Cos(index)) } - (Distance::Dot, Kind::SparseF32) => { + (DistanceKind::Dot, VectorKind::SVecf32) => { let index = Index::create(path.clone(), options)?; self::metadata::Metadata::write(path.join("metadata")); - Ok(Self::SparseF32Dot(index)) + Ok(Self::SVecf32Dot(index)) } } } @@ -97,160 +79,170 @@ impl Instance { let options = serde_json::from_slice::(&std::fs::read(path.join("options")).unwrap()) .unwrap(); - match (options.vector.d, options.vector.k) { - (Distance::Cos, Kind::F32) => Self::F32Cos(Index::open(path)), - (Distance::Dot, Kind::F32) => Self::F32Dot(Index::open(path)), - (Distance::L2, Kind::F32) => Self::F32L2(Index::open(path)), - (Distance::Cos, Kind::F16) => Self::F16Cos(Index::open(path)), - (Distance::Dot, Kind::F16) => Self::F16Dot(Index::open(path)), - (Distance::L2, Kind::F16) => Self::F16L2(Index::open(path)), - (Distance::L2, Kind::SparseF32) => Self::SparseF32L2(Index::open(path)), - (Distance::Cos, Kind::SparseF32) => Self::SparseF32Cos(Index::open(path)), - (Distance::Dot, Kind::SparseF32) => Self::SparseF32Dot(Index::open(path)), + match (options.vector.d, options.vector.v) { + (DistanceKind::Cos, VectorKind::Vecf32) => Self::Vecf32Cos(Index::open(path)), + (DistanceKind::Dot, VectorKind::Vecf32) => Self::Vecf32Dot(Index::open(path)), + (DistanceKind::L2, VectorKind::Vecf32) => Self::Vecf32L2(Index::open(path)), + (DistanceKind::Cos, VectorKind::Vecf16) => Self::Vecf16Cos(Index::open(path)), + (DistanceKind::Dot, VectorKind::Vecf16) => Self::Vecf16Dot(Index::open(path)), + (DistanceKind::L2, VectorKind::Vecf16) => Self::Vecf16L2(Index::open(path)), + (DistanceKind::L2, VectorKind::SVecf32) => Self::SVecf32L2(Index::open(path)), + (DistanceKind::Cos, VectorKind::SVecf32) => Self::SVecf32Cos(Index::open(path)), + (DistanceKind::Dot, VectorKind::SVecf32) => Self::SVecf32Dot(Index::open(path)), } } pub fn refresh(&self) { match self { - Instance::F32Cos(x) => x.refresh(), - Instance::F32Dot(x) => x.refresh(), - Instance::F32L2(x) => x.refresh(), - Instance::F16Cos(x) => x.refresh(), - Instance::F16Dot(x) => x.refresh(), - Instance::F16L2(x) => x.refresh(), - Instance::SparseF32L2(x) => x.refresh(), - Instance::SparseF32Cos(x) => x.refresh(), - Instance::SparseF32Dot(x) => x.refresh(), + Instance::Vecf32Cos(x) => x.refresh(), + Instance::Vecf32Dot(x) => x.refresh(), + Instance::Vecf32L2(x) => x.refresh(), + Instance::Vecf16Cos(x) => x.refresh(), + Instance::Vecf16Dot(x) => x.refresh(), + Instance::Vecf16L2(x) => x.refresh(), + Instance::SVecf32L2(x) => x.refresh(), + Instance::SVecf32Cos(x) => x.refresh(), + Instance::SVecf32Dot(x) => x.refresh(), Instance::Upgrade => (), } } pub fn view(&self) -> Option { match self { - Instance::F32Cos(x) => Some(InstanceView::F32Cos(x.view())), - Instance::F32Dot(x) => Some(InstanceView::F32Dot(x.view())), - Instance::F32L2(x) => Some(InstanceView::F32L2(x.view())), - Instance::F16Cos(x) => Some(InstanceView::F16Cos(x.view())), - Instance::F16Dot(x) => Some(InstanceView::F16Dot(x.view())), - Instance::F16L2(x) => Some(InstanceView::F16L2(x.view())), - Instance::SparseF32L2(x) => Some(InstanceView::SparseF32L2(x.view())), - Instance::SparseF32Cos(x) => Some(InstanceView::SparseF32Cos(x.view())), - Instance::SparseF32Dot(x) => Some(InstanceView::SparseF32Dot(x.view())), + Instance::Vecf32Cos(x) => Some(InstanceView::Vecf32Cos(x.view())), + Instance::Vecf32Dot(x) => Some(InstanceView::Vecf32Dot(x.view())), + Instance::Vecf32L2(x) => Some(InstanceView::Vecf32L2(x.view())), + Instance::Vecf16Cos(x) => Some(InstanceView::Vecf16Cos(x.view())), + Instance::Vecf16Dot(x) => Some(InstanceView::Vecf16Dot(x.view())), + Instance::Vecf16L2(x) => Some(InstanceView::Vecf16L2(x.view())), + Instance::SVecf32L2(x) => Some(InstanceView::SVecf32L2(x.view())), + Instance::SVecf32Cos(x) => Some(InstanceView::SVecf32Cos(x.view())), + Instance::SVecf32Dot(x) => Some(InstanceView::SVecf32Dot(x.view())), Instance::Upgrade => None, } } pub fn stat(&self) -> Option { match self { - Instance::F32Cos(x) => Some(x.stat()), - Instance::F32Dot(x) => Some(x.stat()), - Instance::F32L2(x) => Some(x.stat()), - Instance::F16Cos(x) => Some(x.stat()), - Instance::F16Dot(x) => Some(x.stat()), - Instance::F16L2(x) => Some(x.stat()), - Instance::SparseF32L2(x) => Some(x.stat()), - Instance::SparseF32Cos(x) => Some(x.stat()), - Instance::SparseF32Dot(x) => Some(x.stat()), + Instance::Vecf32Cos(x) => Some(x.stat()), + Instance::Vecf32Dot(x) => Some(x.stat()), + Instance::Vecf32L2(x) => Some(x.stat()), + Instance::Vecf16Cos(x) => Some(x.stat()), + Instance::Vecf16Dot(x) => Some(x.stat()), + Instance::Vecf16L2(x) => Some(x.stat()), + Instance::SVecf32L2(x) => Some(x.stat()), + Instance::SVecf32Cos(x) => Some(x.stat()), + Instance::SVecf32Dot(x) => Some(x.stat()), Instance::Upgrade => None, } } } pub enum InstanceView { - F32Cos(Arc>), - F32Dot(Arc>), - F32L2(Arc>), - F16Cos(Arc>), - F16Dot(Arc>), - F16L2(Arc>), - SparseF32Cos(Arc>), - SparseF32Dot(Arc>), - SparseF32L2(Arc>), + Vecf32Cos(Arc>), + Vecf32Dot(Arc>), + Vecf32L2(Arc>), + Vecf16Cos(Arc>), + Vecf16Dot(Arc>), + Vecf16L2(Arc>), + SVecf32Cos(Arc>), + SVecf32Dot(Arc>), + SVecf32L2(Arc>), } -impl InstanceViewOperations for InstanceView { +impl ViewBasicOperations for InstanceView { fn basic<'a, F: Fn(Pointer) -> bool + Clone + 'a>( &'a self, - vector: &'a DynamicVector, + vector: &'a OwnedVector, opts: &'a SearchOptions, filter: F, ) -> Result + 'a>, BasicError> { match (self, vector) { - (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => { - Ok(Box::new(x.basic(vector, opts, filter)?) as Box>) + (InstanceView::Vecf32Cos(x), OwnedVector::Vecf32(vector)) => { + Ok(Box::new(x.basic(vector.for_borrow(), opts, filter)?) + as Box>) } - (InstanceView::F32Dot(x), DynamicVector::F32(vector)) => { - Ok(Box::new(x.basic(vector, opts, filter)?)) + (InstanceView::Vecf32Dot(x), OwnedVector::Vecf32(vector)) => { + Ok(Box::new(x.basic(vector.for_borrow(), opts, filter)?)) } - (InstanceView::F32L2(x), DynamicVector::F32(vector)) => { - Ok(Box::new(x.basic(vector, opts, filter)?)) + (InstanceView::Vecf32L2(x), OwnedVector::Vecf32(vector)) => { + Ok(Box::new(x.basic(vector.for_borrow(), opts, filter)?)) } - (InstanceView::F16Cos(x), DynamicVector::F16(vector)) => { - Ok(Box::new(x.basic(vector, opts, filter)?)) + (InstanceView::Vecf16Cos(x), OwnedVector::Vecf16(vector)) => { + Ok(Box::new(x.basic(vector.for_borrow(), opts, filter)?)) } - (InstanceView::F16Dot(x), DynamicVector::F16(vector)) => { - Ok(Box::new(x.basic(vector, opts, filter)?)) + (InstanceView::Vecf16Dot(x), OwnedVector::Vecf16(vector)) => { + Ok(Box::new(x.basic(vector.for_borrow(), opts, filter)?)) } - (InstanceView::F16L2(x), DynamicVector::F16(vector)) => { - Ok(Box::new(x.basic(vector, opts, filter)?)) + (InstanceView::Vecf16L2(x), OwnedVector::Vecf16(vector)) => { + Ok(Box::new(x.basic(vector.for_borrow(), opts, filter)?)) } - (InstanceView::SparseF32Cos(x), DynamicVector::SparseF32(vector)) => { - Ok(Box::new(x.basic(vector.into(), opts, filter)?)) + (InstanceView::SVecf32Cos(x), OwnedVector::SVecF32(vector)) => { + Ok(Box::new(x.basic(vector.for_borrow(), opts, filter)?)) } - (InstanceView::SparseF32Dot(x), DynamicVector::SparseF32(vector)) => { - Ok(Box::new(x.basic(vector.into(), opts, filter)?)) + (InstanceView::SVecf32Dot(x), OwnedVector::SVecF32(vector)) => { + Ok(Box::new(x.basic(vector.for_borrow(), opts, filter)?)) } - (InstanceView::SparseF32L2(x), DynamicVector::SparseF32(vector)) => { - Ok(Box::new(x.basic(vector.into(), opts, filter)?)) + (InstanceView::SVecf32L2(x), OwnedVector::SVecF32(vector)) => { + Ok(Box::new(x.basic(vector.for_borrow(), opts, filter)?)) } _ => Err(BasicError::InvalidVector), } } +} + +impl ViewVbaseOperations for InstanceView { fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>( &'a self, - vector: &'a DynamicVector, + vector: &'a OwnedVector, opts: &'a SearchOptions, filter: F, ) -> Result + 'a>, VbaseError> { match (self, vector) { - (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => { - Ok(Box::new(x.vbase(vector, opts, filter)?) as Box>) + (InstanceView::Vecf32Cos(x), OwnedVector::Vecf32(vector)) => { + Ok(Box::new(x.vbase(vector.for_borrow(), opts, filter)?) + as Box>) } - (InstanceView::F32Dot(x), DynamicVector::F32(vector)) => { - Ok(Box::new(x.vbase(vector, opts, filter)?)) + (InstanceView::Vecf32Dot(x), OwnedVector::Vecf32(vector)) => { + Ok(Box::new(x.vbase(vector.for_borrow(), opts, filter)?)) } - (InstanceView::F32L2(x), DynamicVector::F32(vector)) => { - Ok(Box::new(x.vbase(vector, opts, filter)?)) + (InstanceView::Vecf32L2(x), OwnedVector::Vecf32(vector)) => { + Ok(Box::new(x.vbase(vector.for_borrow(), opts, filter)?)) } - (InstanceView::F16Cos(x), DynamicVector::F16(vector)) => { - Ok(Box::new(x.vbase(vector, opts, filter)?)) + (InstanceView::Vecf16Cos(x), OwnedVector::Vecf16(vector)) => { + Ok(Box::new(x.vbase(vector.for_borrow(), opts, filter)?)) } - (InstanceView::F16Dot(x), DynamicVector::F16(vector)) => { - Ok(Box::new(x.vbase(vector, opts, filter)?)) + (InstanceView::Vecf16Dot(x), OwnedVector::Vecf16(vector)) => { + Ok(Box::new(x.vbase(vector.for_borrow(), opts, filter)?)) } - (InstanceView::F16L2(x), DynamicVector::F16(vector)) => { - Ok(Box::new(x.vbase(vector, opts, filter)?)) + (InstanceView::Vecf16L2(x), OwnedVector::Vecf16(vector)) => { + Ok(Box::new(x.vbase(vector.for_borrow(), opts, filter)?)) } - (InstanceView::SparseF32Cos(x), DynamicVector::SparseF32(vector)) => { - Ok(Box::new(x.vbase(vector.into(), opts, filter)?)) + (InstanceView::SVecf32Cos(x), OwnedVector::SVecF32(vector)) => { + Ok(Box::new(x.vbase(vector.for_borrow(), opts, filter)?)) } - (InstanceView::SparseF32Dot(x), DynamicVector::SparseF32(vector)) => { - Ok(Box::new(x.vbase(vector.into(), opts, filter)?)) + (InstanceView::SVecf32Dot(x), OwnedVector::SVecF32(vector)) => { + Ok(Box::new(x.vbase(vector.for_borrow(), opts, filter)?)) } - (InstanceView::SparseF32L2(x), DynamicVector::SparseF32(vector)) => { - Ok(Box::new(x.vbase(vector.into(), opts, filter)?)) + (InstanceView::SVecf32L2(x), OwnedVector::SVecF32(vector)) => { + Ok(Box::new(x.vbase(vector.for_borrow(), opts, filter)?)) } _ => Err(VbaseError::InvalidVector), } } +} + +impl ViewListOperations for InstanceView { fn list(&self) -> Result + '_>, ListError> { match self { - InstanceView::F32Cos(x) => Ok(Box::new(x.list()?) as Box>), - InstanceView::F32Dot(x) => Ok(Box::new(x.list()?)), - InstanceView::F32L2(x) => Ok(Box::new(x.list()?)), - InstanceView::F16Cos(x) => Ok(Box::new(x.list()?)), - InstanceView::F16Dot(x) => Ok(Box::new(x.list()?)), - InstanceView::F16L2(x) => Ok(Box::new(x.list()?)), - InstanceView::SparseF32Cos(x) => Ok(Box::new(x.list()?)), - InstanceView::SparseF32Dot(x) => Ok(Box::new(x.list()?)), - InstanceView::SparseF32L2(x) => Ok(Box::new(x.list()?)), + InstanceView::Vecf32Cos(x) => { + Ok(Box::new(x.list()?) as Box>) + } + InstanceView::Vecf32Dot(x) => Ok(Box::new(x.list()?)), + InstanceView::Vecf32L2(x) => Ok(Box::new(x.list()?)), + InstanceView::Vecf16Cos(x) => Ok(Box::new(x.list()?)), + InstanceView::Vecf16Dot(x) => Ok(Box::new(x.list()?)), + InstanceView::Vecf16L2(x) => Ok(Box::new(x.list()?)), + InstanceView::SVecf32Cos(x) => Ok(Box::new(x.list()?)), + InstanceView::SVecf32Dot(x) => Ok(Box::new(x.list()?)), + InstanceView::SVecf32L2(x) => Ok(Box::new(x.list()?)), } } } @@ -258,52 +250,50 @@ impl InstanceViewOperations for InstanceView { impl InstanceView { pub fn insert( &self, - vector: DynamicVector, + vector: OwnedVector, pointer: Pointer, ) -> Result, InsertError> { match (self, vector) { - (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => x.insert(vector, pointer), - (InstanceView::F32Dot(x), DynamicVector::F32(vector)) => x.insert(vector, pointer), - (InstanceView::F32L2(x), DynamicVector::F32(vector)) => x.insert(vector, pointer), - (InstanceView::F16Cos(x), DynamicVector::F16(vector)) => x.insert(vector, pointer), - (InstanceView::F16Dot(x), DynamicVector::F16(vector)) => x.insert(vector, pointer), - (InstanceView::F16L2(x), DynamicVector::F16(vector)) => x.insert(vector, pointer), - (InstanceView::SparseF32Cos(x), DynamicVector::SparseF32(vector)) => { - x.insert(vector, pointer) - } - (InstanceView::SparseF32Dot(x), DynamicVector::SparseF32(vector)) => { + (InstanceView::Vecf32Cos(x), OwnedVector::Vecf32(vector)) => x.insert(vector, pointer), + (InstanceView::Vecf32Dot(x), OwnedVector::Vecf32(vector)) => x.insert(vector, pointer), + (InstanceView::Vecf32L2(x), OwnedVector::Vecf32(vector)) => x.insert(vector, pointer), + (InstanceView::Vecf16Cos(x), OwnedVector::Vecf16(vector)) => x.insert(vector, pointer), + (InstanceView::Vecf16Dot(x), OwnedVector::Vecf16(vector)) => x.insert(vector, pointer), + (InstanceView::Vecf16L2(x), OwnedVector::Vecf16(vector)) => x.insert(vector, pointer), + (InstanceView::SVecf32Cos(x), OwnedVector::SVecF32(vector)) => { x.insert(vector, pointer) } - (InstanceView::SparseF32L2(x), DynamicVector::SparseF32(vector)) => { + (InstanceView::SVecf32Dot(x), OwnedVector::SVecF32(vector)) => { x.insert(vector, pointer) } + (InstanceView::SVecf32L2(x), OwnedVector::SVecF32(vector)) => x.insert(vector, pointer), _ => Err(InsertError::InvalidVector), } } pub fn delete(&self, pointer: Pointer) -> Result<(), DeleteError> { match self { - InstanceView::F32Cos(x) => x.delete(pointer), - InstanceView::F32Dot(x) => x.delete(pointer), - InstanceView::F32L2(x) => x.delete(pointer), - InstanceView::F16Cos(x) => x.delete(pointer), - InstanceView::F16Dot(x) => x.delete(pointer), - InstanceView::F16L2(x) => x.delete(pointer), - InstanceView::SparseF32Cos(x) => x.delete(pointer), - InstanceView::SparseF32Dot(x) => x.delete(pointer), - InstanceView::SparseF32L2(x) => x.delete(pointer), + InstanceView::Vecf32Cos(x) => x.delete(pointer), + InstanceView::Vecf32Dot(x) => x.delete(pointer), + InstanceView::Vecf32L2(x) => x.delete(pointer), + InstanceView::Vecf16Cos(x) => x.delete(pointer), + InstanceView::Vecf16Dot(x) => x.delete(pointer), + InstanceView::Vecf16L2(x) => x.delete(pointer), + InstanceView::SVecf32Cos(x) => x.delete(pointer), + InstanceView::SVecf32Dot(x) => x.delete(pointer), + InstanceView::SVecf32L2(x) => x.delete(pointer), } } pub fn flush(&self) -> Result<(), FlushError> { match self { - InstanceView::F32Cos(x) => x.flush(), - InstanceView::F32Dot(x) => x.flush(), - InstanceView::F32L2(x) => x.flush(), - InstanceView::F16Cos(x) => x.flush(), - InstanceView::F16Dot(x) => x.flush(), - InstanceView::F16L2(x) => x.flush(), - InstanceView::SparseF32Cos(x) => x.flush(), - InstanceView::SparseF32Dot(x) => x.flush(), - InstanceView::SparseF32L2(x) => x.flush(), + InstanceView::Vecf32Cos(x) => x.flush(), + InstanceView::Vecf32Dot(x) => x.flush(), + InstanceView::Vecf32L2(x) => x.flush(), + InstanceView::Vecf16Cos(x) => x.flush(), + InstanceView::Vecf16Dot(x) => x.flush(), + InstanceView::Vecf16L2(x) => x.flush(), + InstanceView::SVecf32Cos(x) => x.flush(), + InstanceView::SVecf32Dot(x) => x.flush(), + InstanceView::SVecf32L2(x) => x.flush(), } } } diff --git a/crates/service/src/lib.rs b/crates/service/src/lib.rs index 1ae36c6a9..f74e734c3 100644 --- a/crates/service/src/lib.rs +++ b/crates/service/src/lib.rs @@ -1,10 +1,13 @@ -#![feature(core_intrinsics)] -#![feature(avx512_target_feature)] - -pub mod algorithms; -pub mod index; -pub mod instance; -pub mod prelude; -pub mod worker; +#![allow(clippy::derivable_impls)] +#![allow(clippy::len_without_is_empty)] +#![allow(clippy::needless_range_loop)] +mod algorithms; +mod index; +mod instance; +mod prelude; +mod storage; mod utils; +mod worker; + +pub use worker::Worker; diff --git a/crates/service/src/prelude.rs b/crates/service/src/prelude.rs new file mode 100644 index 000000000..ae3856f56 --- /dev/null +++ b/crates/service/src/prelude.rs @@ -0,0 +1,33 @@ +pub use base::distance::*; +pub use base::error::*; +pub use base::global::*; +pub use base::index::*; +pub use base::scalar::*; +pub use base::search::*; +pub use base::vector::*; +pub use num_traits::{Float, Zero}; + +use crate::storage::GlobalStorage; + +pub trait G: + Global + GlobalElkanKMeans + GlobalProductQuantization + GlobalScalarQuantization + GlobalStorage +{ +} + +impl G for SVecf32Cos {} + +impl G for SVecf32Dot {} + +impl G for SVecf32L2 {} + +impl G for Vecf16Cos {} + +impl G for Vecf16Dot {} + +impl G for Vecf16L2 {} + +impl G for Vecf32Cos {} + +impl G for Vecf32Dot {} + +impl G for Vecf32L2 {} diff --git a/crates/service/src/prelude/global/mod.rs b/crates/service/src/prelude/global/mod.rs deleted file mode 100644 index 857847433..000000000 --- a/crates/service/src/prelude/global/mod.rs +++ /dev/null @@ -1,163 +0,0 @@ -mod f16; -mod f16_cos; -mod f16_dot; -mod f16_l2; -mod f32; -mod f32_cos; -mod f32_dot; -mod f32_l2; -mod sparse_f32; -mod sparse_f32_cos; -mod sparse_f32_dot; -mod sparse_f32_l2; - -pub use f16_cos::F16Cos; -pub use f16_dot::F16Dot; -pub use f16_l2::F16L2; -pub use f32_cos::F32Cos; -pub use f32_dot::F32Dot; -pub use f32_l2::F32L2; -pub use sparse_f32_cos::SparseF32Cos; -pub use sparse_f32_dot::SparseF32Dot; -pub use sparse_f32_l2::SparseF32L2; - -use crate::prelude::*; -use base::scalar::FloatCast; -use serde::{Deserialize, Serialize}; -use std::{ - borrow::Cow, - fmt::{Debug, Display}, -}; - -pub trait G: Copy + Debug + 'static { - type Scalar: Copy - + Send - + Sync - + Debug - + Display - + Serialize - + for<'a> Deserialize<'a> - + Ord - + bytemuck::Zeroable - + bytemuck::Pod - + Float - + Zero - + num_traits::NumOps - + num_traits::NumAssignOps - + FloatCast; - type Storage: for<'a> Storage = Self::VectorRef<'a>>; - type L2: for<'a> G = &'a [Self::Scalar]>; - type VectorOwned: Vector + Clone + Serialize + for<'a> Deserialize<'a>; - type VectorRef<'a>: Vector + Copy + 'a - where - Self: 'a; - - const DISTANCE: Distance; - const KIND: Kind; - - fn owned_to_ref(vector: &Self::VectorOwned) -> Self::VectorRef<'_>; - fn ref_to_owned(vector: Self::VectorRef<'_>) -> Self::VectorOwned; - fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [Self::Scalar]>; - fn distance(lhs: Self::VectorRef<'_>, rhs: Self::VectorRef<'_>) -> F32; - fn distance2(lhs: Self::VectorRef<'_>, rhs: &[Self::Scalar]) -> F32; - - fn elkan_k_means_normalize(vector: &mut [Self::Scalar]); - fn elkan_k_means_normalize2(vector: &mut Self::VectorOwned); - fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; - fn elkan_k_means_distance2(lhs: Self::VectorRef<'_>, rhs: &[Self::Scalar]) -> F32; - - fn scalar_quantization_distance( - dims: u16, - max: &[Self::Scalar], - min: &[Self::Scalar], - lhs: Self::VectorRef<'_>, - rhs: &[u8], - ) -> F32; - fn scalar_quantization_distance2( - dims: u16, - max: &[Self::Scalar], - min: &[Self::Scalar], - lhs: &[u8], - rhs: &[u8], - ) -> F32; - - fn product_quantization_distance( - dims: u16, - ratio: u16, - centroids: &[Self::Scalar], - lhs: Self::VectorRef<'_>, - rhs: &[u8], - ) -> F32; - fn product_quantization_distance2( - dims: u16, - ratio: u16, - centroids: &[Self::Scalar], - lhs: &[u8], - rhs: &[u8], - ) -> F32; - fn product_quantization_distance_with_delta( - dims: u16, - ratio: u16, - centroids: &[Self::Scalar], - lhs: Self::VectorRef<'_>, - rhs: &[u8], - delta: &[Self::Scalar], - ) -> F32; -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum DynamicVector { - F32(Vec), - F16(Vec), - SparseF32(SparseF32), -} - -impl From> for DynamicVector { - fn from(value: Vec) -> Self { - Self::F32(value) - } -} - -impl From> for DynamicVector { - fn from(value: Vec) -> Self { - Self::F16(value) - } -} - -impl From for DynamicVector { - fn from(value: SparseF32) -> Self { - Self::SparseF32(value) - } -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub enum Distance { - L2, - Cos, - Dot, -} - -#[repr(u8)] -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] -pub enum Kind { - F32, - F16, - SparseF32, -} - -pub fn squared_norm(dims: u16, vec: &[S::Scalar]) -> F32 { - let mut result = F32::zero(); - for i in 0..dims as usize { - result += F32((vec[i] * vec[i]).to_f32()); - } - result -} - -pub fn inner_product(dims: u16, lhs: &[S::Scalar], rhs: &[S::Scalar]) -> F32 { - let mut result = F32::zero(); - for i in 0..dims as usize { - result += F32((lhs[i] * rhs[i]).to_f32()); - } - result -} diff --git a/crates/service/src/prelude/global/sparse_f32_cos.rs b/crates/service/src/prelude/global/sparse_f32_cos.rs deleted file mode 100644 index 372f8d744..000000000 --- a/crates/service/src/prelude/global/sparse_f32_cos.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::borrow::Cow; - -use crate::prelude::*; - -#[derive(Debug, Clone, Copy)] -pub enum SparseF32Cos {} - -impl G for SparseF32Cos { - type Scalar = F32; - type Storage = SparseMmap; - type L2 = F32L2; - type VectorOwned = SparseF32; - type VectorRef<'a> = SparseF32Ref<'a>; - - const DISTANCE: Distance = Distance::Cos; - const KIND: Kind = Kind::SparseF32; - - fn owned_to_ref(vector: &SparseF32) -> SparseF32Ref<'_> { - SparseF32Ref::from(vector) - } - - fn ref_to_owned(vector: SparseF32Ref<'_>) -> SparseF32 { - SparseF32::from(vector) - } - - fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { - Cow::Owned(vector.to_dense()) - } - - fn distance(lhs: Self::VectorRef<'_>, rhs: Self::VectorRef<'_>) -> F32 { - F32(1.0) - super::sparse_f32::cosine(lhs, rhs) - } - - fn distance2(_lhs: Self::VectorRef<'_>, _rhs: &[Self::Scalar]) -> F32 { - unimplemented!() - } - - fn elkan_k_means_normalize(vector: &mut [Self::Scalar]) { - super::f32::l2_normalize(vector) - } - - fn elkan_k_means_normalize2(vector: &mut SparseF32) { - super::sparse_f32::l2_normalize(vector) - } - - fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32 { - super::f32::dot(lhs, rhs).acos() - } - - fn elkan_k_means_distance2(lhs: Self::VectorRef<'_>, rhs: &[Self::Scalar]) -> F32 { - super::sparse_f32::dot_2(lhs, rhs).acos() - } - - fn scalar_quantization_distance( - _dims: u16, - _max: &[F32], - _min: &[F32], - _lhs: Self::VectorRef<'_>, - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn scalar_quantization_distance2( - _dims: u16, - _max: &[Self::Scalar], - _min: &[Self::Scalar], - _lhs: &[u8], - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn product_quantization_distance( - _dims: u16, - _ratio: u16, - _centroids: &[Self::Scalar], - _lhs: Self::VectorRef<'_>, - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn product_quantization_distance2( - _dims: u16, - _ratio: u16, - _centroids: &[Self::Scalar], - _lhs: &[u8], - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn product_quantization_distance_with_delta( - _dims: u16, - _ratio: u16, - _centroids: &[Self::Scalar], - _lhs: Self::VectorRef<'_>, - _rhs: &[u8], - _delta: &[Self::Scalar], - ) -> F32 { - unimplemented!() - } -} diff --git a/crates/service/src/prelude/global/sparse_f32_dot.rs b/crates/service/src/prelude/global/sparse_f32_dot.rs deleted file mode 100644 index 8c19bc987..000000000 --- a/crates/service/src/prelude/global/sparse_f32_dot.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::borrow::Cow; - -use crate::prelude::*; - -#[derive(Debug, Clone, Copy)] -pub enum SparseF32Dot {} - -impl G for SparseF32Dot { - type Scalar = F32; - type Storage = SparseMmap; - type L2 = F32L2; - type VectorOwned = SparseF32; - type VectorRef<'a> = SparseF32Ref<'a>; - - const DISTANCE: Distance = Distance::Dot; - const KIND: Kind = Kind::SparseF32; - - fn owned_to_ref(vector: &SparseF32) -> SparseF32Ref<'_> { - SparseF32Ref::from(vector) - } - - fn ref_to_owned(vector: SparseF32Ref<'_>) -> SparseF32 { - SparseF32::from(vector) - } - - fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { - Cow::Owned(vector.to_dense()) - } - - fn distance(lhs: Self::VectorRef<'_>, rhs: Self::VectorRef<'_>) -> F32 { - super::sparse_f32::dot(lhs, rhs) * (-1.0) - } - - fn distance2(_lhs: Self::VectorRef<'_>, _rhs: &[Self::Scalar]) -> F32 { - unimplemented!() - } - - fn elkan_k_means_normalize(vector: &mut [Self::Scalar]) { - super::f32::l2_normalize(vector) - } - - fn elkan_k_means_normalize2(vector: &mut SparseF32) { - super::sparse_f32::l2_normalize(vector) - } - - fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32 { - super::f32::dot(lhs, rhs).acos() - } - - fn elkan_k_means_distance2(lhs: Self::VectorRef<'_>, rhs: &[Self::Scalar]) -> F32 { - super::sparse_f32::dot_2(lhs, rhs).acos() - } - - fn scalar_quantization_distance( - _dims: u16, - _max: &[Self::Scalar], - _min: &[Self::Scalar], - _lhs: Self::VectorRef<'_>, - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn scalar_quantization_distance2( - _dims: u16, - _max: &[Self::Scalar], - _min: &[Self::Scalar], - _lhs: &[u8], - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn product_quantization_distance( - _dims: u16, - _ratio: u16, - _centroids: &[Self::Scalar], - _lhs: Self::VectorRef<'_>, - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn product_quantization_distance2( - _dims: u16, - _ratio: u16, - _centroids: &[Self::Scalar], - _lhs: &[u8], - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn product_quantization_distance_with_delta( - _dims: u16, - _ratio: u16, - _centroids: &[Self::Scalar], - _lhs: Self::VectorRef<'_>, - _rhs: &[u8], - _delta: &[Self::Scalar], - ) -> F32 { - unimplemented!() - } -} diff --git a/crates/service/src/prelude/global/sparse_f32_l2.rs b/crates/service/src/prelude/global/sparse_f32_l2.rs deleted file mode 100644 index b17559e7b..000000000 --- a/crates/service/src/prelude/global/sparse_f32_l2.rs +++ /dev/null @@ -1,100 +0,0 @@ -use std::borrow::Cow; - -use crate::prelude::*; - -#[derive(Debug, Clone, Copy)] -pub enum SparseF32L2 {} - -impl G for SparseF32L2 { - type Scalar = F32; - type Storage = SparseMmap; - type L2 = F32L2; - type VectorOwned = SparseF32; - type VectorRef<'a> = SparseF32Ref<'a>; - - const DISTANCE: Distance = Distance::L2; - const KIND: Kind = Kind::SparseF32; - - fn owned_to_ref(vector: &SparseF32) -> SparseF32Ref<'_> { - SparseF32Ref::from(vector) - } - - fn ref_to_owned(vector: SparseF32Ref<'_>) -> SparseF32 { - SparseF32::from(vector) - } - - fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { - Cow::Owned(vector.to_dense()) - } - - fn distance(lhs: SparseF32Ref<'_>, rhs: SparseF32Ref<'_>) -> F32 { - super::sparse_f32::sl2(lhs, rhs) - } - - fn distance2(_lhs: Self::VectorRef<'_>, _rhs: &[Self::Scalar]) -> F32 { - unimplemented!() - } - - fn elkan_k_means_normalize(_: &mut [Self::Scalar]) {} - - fn elkan_k_means_normalize2(_: &mut SparseF32) {} - - fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32 { - super::f32::sl2(lhs, rhs).sqrt() - } - - fn elkan_k_means_distance2(lhs: SparseF32Ref<'_>, rhs: &[Self::Scalar]) -> F32 { - super::sparse_f32::sl2_2(lhs, rhs).sqrt() - } - - fn scalar_quantization_distance( - _dims: u16, - _max: &[Self::Scalar], - _min: &[Self::Scalar], - _lhs: SparseF32Ref<'_>, - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn scalar_quantization_distance2( - _dims: u16, - _max: &[Self::Scalar], - _min: &[Self::Scalar], - _lhs: &[u8], - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn product_quantization_distance( - _dims: u16, - _ratio: u16, - _centroids: &[Self::Scalar], - _lhs: SparseF32Ref<'_>, - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn product_quantization_distance2( - _dims: u16, - _ratio: u16, - _centroids: &[Self::Scalar], - _lhs: &[u8], - _rhs: &[u8], - ) -> F32 { - unimplemented!() - } - - fn product_quantization_distance_with_delta( - _dims: u16, - _ratio: u16, - _centroids: &[Self::Scalar], - _lhs: SparseF32Ref<'_>, - _rhs: &[u8], - _delta: &[Self::Scalar], - ) -> F32 { - unimplemented!() - } -} diff --git a/crates/service/src/prelude/mod.rs b/crates/service/src/prelude/mod.rs deleted file mode 100644 index a559acb7e..000000000 --- a/crates/service/src/prelude/mod.rs +++ /dev/null @@ -1,13 +0,0 @@ -mod global; -mod storage; - -pub use self::global::*; -pub use self::storage::{DenseMmap, SparseMmap, Storage}; - -pub use base::error::*; -pub use base::scalar::{F16, F32}; -pub use base::search::{Element, Filter, Payload}; -pub use base::sys::{Handle, Pointer}; -pub use base::vector::{SparseF32, SparseF32Ref, Vector}; - -pub use num_traits::{Float, Zero}; diff --git a/crates/service/src/prelude/storage/dense.rs b/crates/service/src/prelude/storage/dense.rs deleted file mode 100644 index 3cb85b35c..000000000 --- a/crates/service/src/prelude/storage/dense.rs +++ /dev/null @@ -1,65 +0,0 @@ -use crate::algorithms::raw::RawRam; -use crate::index::IndexOptions; -use crate::prelude::*; -use crate::utils::mmap_array::MmapArray; -use std::path::Path; - -pub struct DenseMmap { - vectors: MmapArray, - payload: MmapArray, - dims: u16, -} - -impl Storage for DenseMmap -where - T: Copy + bytemuck::Pod, -{ - type VectorRef<'a> = &'a [T]; - - fn dims(&self) -> u16 { - self.dims - } - - fn len(&self) -> u32 { - self.payload.len() as u32 - } - - fn vector(&self, i: u32) -> &[T] { - let s = i as usize * self.dims as usize; - let e = (i + 1) as usize * self.dims as usize; - &self.vectors[s..e] - } - - fn payload(&self, i: u32) -> Payload { - self.payload[i as usize] - } - - fn open(path: &Path, options: IndexOptions) -> Self - where - Self: Sized, - { - let vectors = MmapArray::open(&path.join("vectors")); - let payload = MmapArray::open(&path.join("payload")); - Self { - vectors, - payload, - dims: options.vector.dims, - } - } - - fn save G = Self::VectorRef<'a>>>( - path: &Path, - ram: RawRam, - ) -> Self { - let n = ram.len(); - let vectors_iter = (0..n).flat_map(|i| ram.vector(i).iter()).copied(); - let payload_iter = (0..n).map(|i| ram.payload(i)); - let vectors = MmapArray::create(&path.join("vectors"), vectors_iter); - let payload = MmapArray::create(&path.join("payload"), payload_iter); - Self { - vectors, - payload, - dims: ram.dims(), - } - } -} diff --git a/crates/service/src/prelude/storage/mod.rs b/crates/service/src/prelude/storage/mod.rs deleted file mode 100644 index 96c977a41..000000000 --- a/crates/service/src/prelude/storage/mod.rs +++ /dev/null @@ -1,24 +0,0 @@ -mod dense; -mod sparse; - -pub use dense::DenseMmap; -pub use sparse::SparseMmap; - -use crate::algorithms::raw::RawRam; -use crate::index::IndexOptions; -use crate::prelude::*; -use std::path::Path; - -pub trait Storage { - type VectorRef<'a>: Copy + 'a - where - Self: 'a; - - fn dims(&self) -> u16; - fn len(&self) -> u32; - fn vector(&self, i: u32) -> Self::VectorRef<'_>; - fn payload(&self, i: u32) -> Payload; - fn open(path: &Path, options: IndexOptions) -> Self; - fn save G = Self::VectorRef<'a>>>(path: &Path, ram: RawRam) - -> Self; -} diff --git a/crates/service/src/storage/dense.rs b/crates/service/src/storage/dense.rs new file mode 100644 index 000000000..91114a2d6 --- /dev/null +++ b/crates/service/src/storage/dense.rs @@ -0,0 +1,107 @@ +use crate::algorithms::raw::RawRam; +use crate::prelude::*; +use crate::storage::Storage; +use crate::utils::mmap_array::MmapArray; +use std::path::Path; + +pub struct DenseMmap { + vectors: MmapArray, + payload: MmapArray, + dims: u16, +} + +impl Storage for DenseMmap { + type VectorOwned = Vecf32Owned; + + fn dims(&self) -> u16 { + self.dims + } + + fn len(&self) -> u32 { + self.payload.len() as u32 + } + + fn vector(&self, i: u32) -> Vecf32Borrowed<'_> { + let s = i as usize * self.dims as usize; + let e = (i + 1) as usize * self.dims as usize; + Vecf32Borrowed::new(&self.vectors[s..e]) + } + + fn payload(&self, i: u32) -> Payload { + self.payload[i as usize] + } + + fn open(path: &Path, options: IndexOptions) -> Self + where + Self: Sized, + { + let vectors = MmapArray::open(&path.join("vectors")); + let payload = MmapArray::open(&path.join("payload")); + Self { + vectors, + payload, + dims: options.vector.dims, + } + } + + fn save>(path: &Path, ram: RawRam) -> Self { + let n = ram.len(); + let vectors_iter = (0..n).flat_map(|i| ram.vector(i).to_vec()); + let payload_iter = (0..n).map(|i| ram.payload(i)); + let vectors = MmapArray::create(&path.join("vectors"), vectors_iter); + let payload = MmapArray::create(&path.join("payload"), payload_iter); + Self { + vectors, + payload, + dims: ram.dims(), + } + } +} + +impl Storage for DenseMmap { + type VectorOwned = Vecf16Owned; + + fn dims(&self) -> u16 { + self.dims + } + + fn len(&self) -> u32 { + self.payload.len() as u32 + } + + fn vector(&self, i: u32) -> Vecf16Borrowed { + let s = i as usize * self.dims as usize; + let e = (i + 1) as usize * self.dims as usize; + Vecf16Borrowed::new(&self.vectors[s..e]) + } + + fn payload(&self, i: u32) -> Payload { + self.payload[i as usize] + } + + fn open(path: &Path, options: IndexOptions) -> Self + where + Self: Sized, + { + let vectors = MmapArray::open(&path.join("vectors")); + let payload = MmapArray::open(&path.join("payload")); + Self { + vectors, + payload, + dims: options.vector.dims, + } + } + + fn save>(path: &Path, ram: RawRam) -> Self { + let n = ram.len(); + let vectors_iter = (0..n).flat_map(|i| ram.vector(i).to_vec()); + let payload_iter = (0..n).map(|i| ram.payload(i)); + let vectors = MmapArray::create(&path.join("vectors"), vectors_iter); + let payload = MmapArray::create(&path.join("payload"), payload_iter); + Self { + vectors, + payload, + dims: ram.dims(), + } + } +} diff --git a/crates/service/src/storage/mod.rs b/crates/service/src/storage/mod.rs new file mode 100644 index 000000000..d6175b791 --- /dev/null +++ b/crates/service/src/storage/mod.rs @@ -0,0 +1,60 @@ +mod dense; +mod sparse; + +pub use dense::DenseMmap; +pub use sparse::SparseMmap; + +use crate::algorithms::raw::RawRam; +use crate::prelude::*; +use std::path::Path; + +pub trait Storage { + type VectorOwned: VectorOwned; + + fn dims(&self) -> u16; + fn len(&self) -> u32; + fn vector(&self, i: u32) -> ::Borrowed<'_>; + fn payload(&self, i: u32) -> Payload; + fn open(path: &Path, options: IndexOptions) -> Self; + fn save>(path: &Path, ram: RawRam) -> Self; +} + +pub trait GlobalStorage: Global { + type Storage: Storage; +} + +impl GlobalStorage for SVecf32Cos { + type Storage = SparseMmap; +} + +impl GlobalStorage for SVecf32Dot { + type Storage = SparseMmap; +} + +impl GlobalStorage for SVecf32L2 { + type Storage = SparseMmap; +} + +impl GlobalStorage for Vecf16Cos { + type Storage = DenseMmap; +} + +impl GlobalStorage for Vecf16Dot { + type Storage = DenseMmap; +} + +impl GlobalStorage for Vecf16L2 { + type Storage = DenseMmap; +} + +impl GlobalStorage for Vecf32Cos { + type Storage = DenseMmap; +} + +impl GlobalStorage for Vecf32Dot { + type Storage = DenseMmap; +} + +impl GlobalStorage for Vecf32L2 { + type Storage = DenseMmap; +} diff --git a/crates/service/src/prelude/storage/sparse.rs b/crates/service/src/storage/sparse.rs similarity index 79% rename from crates/service/src/prelude/storage/sparse.rs rename to crates/service/src/storage/sparse.rs index d95e27276..f7c642f86 100644 --- a/crates/service/src/prelude/storage/sparse.rs +++ b/crates/service/src/storage/sparse.rs @@ -1,6 +1,6 @@ use crate::algorithms::raw::RawRam; -use crate::index::IndexOptions; use crate::prelude::*; +use crate::storage::Storage; use crate::utils::mmap_array::MmapArray; use std::path::Path; @@ -13,7 +13,7 @@ pub struct SparseMmap { } impl Storage for SparseMmap { - type VectorRef<'a> = SparseF32Ref<'a>; + type VectorOwned = SVecf32Owned; fn dims(&self) -> u16 { self.dims @@ -23,13 +23,11 @@ impl Storage for SparseMmap { self.payload.len() as u32 } - fn vector(&self, i: u32) -> SparseF32Ref<'_> { + fn vector(&self, i: u32) -> SVecf32Borrowed<'_> { let s = self.offsets[i as usize]; let e = self.offsets[i as usize + 1]; - SparseF32Ref { - dims: self.dims, - indexes: &self.indexes[s..e], - values: &self.values[s..e], + unsafe { + SVecf32Borrowed::new_unchecked(self.dims, &self.indexes[s..e], &self.values[s..e]) } } @@ -54,15 +52,12 @@ impl Storage for SparseMmap { } } - fn save G = Self::VectorRef<'a>>>( - path: &Path, - ram: RawRam, - ) -> Self { + fn save>(path: &Path, ram: RawRam) -> Self { let n = ram.len(); - let indexes_iter = (0..n).flat_map(|i| ram.vector(i).indexes.iter().copied()); - let values_iter = (0..n).flat_map(|i| ram.vector(i).values.iter().copied()); + let indexes_iter = (0..n).flat_map(|i| ram.vector(i).indexes().to_vec()); + let values_iter = (0..n).flat_map(|i| ram.vector(i).values().to_vec()); let offsets_iter = std::iter::once(0) - .chain((0..n).map(|i| ram.vector(i).length() as usize)) + .chain((0..n).map(|i| ram.vector(i).len() as usize)) .scan(0, |state, x| { *state += x; Some(*state) diff --git a/crates/service/src/worker/mod.rs b/crates/service/src/worker/mod.rs index 3ccc77862..104b66740 100644 --- a/crates/service/src/worker/mod.rs +++ b/crates/service/src/worker/mod.rs @@ -1,37 +1,18 @@ pub mod metadata; -use crate::index::{IndexOptions, IndexStat}; -use crate::instance::{Instance, InstanceView, InstanceViewOperations}; +use crate::instance::*; use crate::prelude::*; use crate::utils::clean::clean; use crate::utils::dir_ops::sync_dir; use crate::utils::file_atomic::FileAtomic; use arc_swap::ArcSwap; +use base::worker::*; use parking_lot::Mutex; use serde::{Deserialize, Serialize}; use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use std::sync::Arc; -pub trait WorkerOperations { - type InstanceView: InstanceViewOperations; - - fn create(&self, handle: Handle, options: IndexOptions) -> Result<(), CreateError>; - fn drop(&self, handle: Handle) -> Result<(), DropError>; - fn flush(&self, handle: Handle) -> Result<(), FlushError>; - fn insert( - &self, - handle: Handle, - vector: DynamicVector, - pointer: Pointer, - ) -> Result<(), InsertError>; - fn delete(&self, handle: Handle, pointer: Pointer) -> Result<(), DeleteError>; - fn basic_view(&self, handle: Handle) -> Result; - fn vbase_view(&self, handle: Handle) -> Result; - fn list_view(&self, handle: Handle) -> Result; - fn stat(&self, handle: Handle) -> Result; -} - pub struct Worker { path: PathBuf, protect: Mutex, @@ -81,14 +62,12 @@ impl Worker { view: ArcSwap::new(view), }) } - pub fn view(&self) -> Arc { + fn view(&self) -> Arc { self.view.load_full() } } impl WorkerOperations for Worker { - type InstanceView = InstanceView; - fn create(&self, handle: Handle, options: IndexOptions) -> Result<(), CreateError> { use std::collections::hash_map::Entry; let mut protect = self.protect.lock(); @@ -122,7 +101,7 @@ impl WorkerOperations for Worker { fn insert( &self, handle: Handle, - vector: DynamicVector, + vector: OwnedVector, pointer: Pointer, ) -> Result<(), InsertError> { let view = self.view(); @@ -143,17 +122,17 @@ impl WorkerOperations for Worker { view.delete(pointer)?; Ok(()) } - fn basic_view(&self, handle: Handle) -> Result { + fn view_basic(&self, handle: Handle) -> Result { let view = self.view(); let instance = view.get(handle).ok_or(BasicError::NotExist)?; instance.view().ok_or(BasicError::Upgrade) } - fn vbase_view(&self, handle: Handle) -> Result { + fn view_vbase(&self, handle: Handle) -> Result { let view = self.view(); let instance = view.get(handle).ok_or(VbaseError::NotExist)?; instance.view().ok_or(VbaseError::Upgrade) } - fn list_view(&self, handle: Handle) -> Result { + fn view_list(&self, handle: Handle) -> Result { let view = self.view(); let instance = view.get(handle).ok_or(ListError::NotExist)?; instance.view().ok_or(ListError::Upgrade) diff --git a/src/bgworker/mod.rs b/src/bgworker/mod.rs index 4e98a3908..bc903b3fc 100644 --- a/src/bgworker/mod.rs +++ b/src/bgworker/mod.rs @@ -5,7 +5,7 @@ use std::sync::atomic::{AtomicBool, Ordering}; static STARTED: AtomicBool = AtomicBool::new(false); pub unsafe fn init() { - use service::worker::Worker; + use service::Worker; let path = std::path::Path::new("pg_vectors"); if !path.try_exists().unwrap() || Worker::check(path.to_owned()) { use pgrx::bgworkers::BackgroundWorkerBuilder; @@ -65,7 +65,7 @@ extern "C" fn _vectors_main(_arg: pgrx::pg_sys::Datum) { std::alloc::set_alloc_error_hook(|layout| { std::panic::panic_any(AllocErrorPanicPayload { layout }); }); - use service::worker::Worker; + use service::Worker; use std::path::Path; let path = Path::new("pg_vectors"); if path.try_exists().unwrap() { diff --git a/src/bgworker/normal.rs b/src/bgworker/normal.rs index c96fbc71b..7979f9449 100644 --- a/src/bgworker/normal.rs +++ b/src/bgworker/normal.rs @@ -1,6 +1,7 @@ use crate::ipc::ConnectionError; use crate::ipc::ServerRpcHandler; -use service::worker::Worker; +use service::Worker; +use std::convert::Infallible; use std::sync::Arc; pub fn normal(worker: Arc) { @@ -57,10 +58,9 @@ pub fn normal(worker: Arc) { }); } -fn session(worker: Arc, handler: ServerRpcHandler) -> Result { +fn session(worker: Arc, handler: ServerRpcHandler) -> Result { use crate::ipc::ServerRpcHandle; - use service::instance::InstanceViewOperations; - use service::worker::WorkerOperations; + use base::worker::*; let mut handler = handler; loop { match handler.handle()? { @@ -95,7 +95,7 @@ fn session(worker: Arc, handler: ServerRpcHandler) -> Result { - let v = match worker.basic_view(handle) { + let v = match worker.view_basic(handle) { Ok(x) => x, Err(e) => { handler = x.error_err(e)?; @@ -127,7 +127,7 @@ fn session(worker: Arc, handler: ServerRpcHandler) -> Result { - let v = match worker.vbase_view(handle) { + let v = match worker.view_vbase(handle) { Ok(x) => x, Err(e) => { handler = x.error_err(e)?; @@ -154,7 +154,7 @@ fn session(worker: Arc, handler: ServerRpcHandler) -> Result { - let v = match worker.list_view(handle) { + let v = match worker.view_list(handle) { Ok(x) => x, Err(e) => { handler = x.error_err(e)?; diff --git a/src/datatype/binary_svecf32.rs b/src/datatype/binary_svecf32.rs new file mode 100644 index 000000000..a6bbedc4d --- /dev/null +++ b/src/datatype/binary_svecf32.rs @@ -0,0 +1,59 @@ +use super::memory_svecf32::SVecf32Input; +use super::memory_svecf32::SVecf32Output; +use base::scalar::F32; +use base::vector::SVecf32Borrowed; +use pgrx::datum::IntoDatum; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use std::ffi::c_char; + +#[pgrx::pg_extern(sql = "\ +CREATE FUNCTION _vectors_svecf32_send(svector) RETURNS bytea +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_svecf32_send(vector: SVecf32Input<'_>) -> Datum { + use pgrx::pg_sys::StringInfoData; + unsafe { + let mut buf = StringInfoData::default(); + let dims = vector.dims() as u16; + let len = vector.len() as u16; + let x = vector.for_borrow(); + let b_indexes = std::mem::size_of::() * len as usize; + let b_values = std::mem::size_of::() * len as usize; + pgrx::pg_sys::pq_begintypsend(&mut buf); + pgrx::pg_sys::pq_sendbytes(&mut buf, (&dims) as *const u16 as _, 2); + pgrx::pg_sys::pq_sendbytes(&mut buf, (&len) as *const u16 as _, 2); + pgrx::pg_sys::pq_sendbytes(&mut buf, x.indexes().as_ptr() as _, b_indexes as _); + pgrx::pg_sys::pq_sendbytes(&mut buf, x.values().as_ptr() as _, b_values as _); + Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) + } +} + +#[pgrx::pg_extern(sql = " +CREATE FUNCTION _vectors_svecf32_recv(internal, oid, integer) RETURNS svector +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_svecf32_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> SVecf32Output { + use pgrx::pg_sys::StringInfo; + unsafe { + let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); + let dims = (pgrx::pg_sys::pq_getmsgbytes(buf, 2) as *const u16).read_unaligned(); + let len = (pgrx::pg_sys::pq_getmsgbytes(buf, 2) as *const u16).read_unaligned(); + + let b_indexes = std::mem::size_of::() * len as usize; + let p_indexes = pgrx::pg_sys::pq_getmsgbytes(buf, b_indexes as _); + let mut indexes = Vec::::with_capacity(len as usize); + std::ptr::copy(p_indexes, indexes.as_mut_ptr().cast::(), b_indexes); + indexes.set_len(len as usize); + + let b_values = std::mem::size_of::() * len as usize; + let p_values = pgrx::pg_sys::pq_getmsgbytes(buf, b_values as _); + let mut values = Vec::::with_capacity(len as usize); + std::ptr::copy(p_values, values.as_mut_ptr().cast::(), b_values); + values.set_len(len as usize); + + if let Some(x) = SVecf32Borrowed::new_checked(dims, &indexes, &values) { + SVecf32Output::new(x) + } else { + pgrx::error!("detect data corruption"); + } + } +} diff --git a/src/datatype/binary_vecf16.rs b/src/datatype/binary_vecf16.rs new file mode 100644 index 000000000..3c7aa7efa --- /dev/null +++ b/src/datatype/binary_vecf16.rs @@ -0,0 +1,45 @@ +use super::memory_vecf16::{Vecf16Input, Vecf16Output}; +use base::scalar::F16; +use base::vector::Vecf16Borrowed; +use pgrx::datum::IntoDatum; +use pgrx::pg_sys::{Datum, Oid}; +use std::ffi::c_char; + +#[pgrx::pg_extern(sql = "\ +CREATE FUNCTION _vectors_vecf16_send(vecf16) RETURNS bytea +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_vecf16_send(vector: Vecf16Input<'_>) -> Datum { + use pgrx::pg_sys::StringInfoData; + unsafe { + let mut buf = StringInfoData::default(); + let dims = vector.dims() as u16; + let b_slice = std::mem::size_of::() * dims as usize; + pgrx::pg_sys::pq_begintypsend(&mut buf); + pgrx::pg_sys::pq_sendbytes(&mut buf, (&dims) as *const u16 as _, 2); + pgrx::pg_sys::pq_sendbytes(&mut buf, vector.slice().as_ptr() as _, b_slice as _); + Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) + } +} + +#[pgrx::pg_extern(sql = "\ +CREATE FUNCTION _vectors_vecf16_recv(internal, oid, integer) RETURNS vecf16 +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_vecf16_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> Vecf16Output { + use pgrx::pg_sys::StringInfo; + unsafe { + let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); + let dims = (pgrx::pg_sys::pq_getmsgbytes(buf, 2) as *const u16).read_unaligned(); + + let b_slice = std::mem::size_of::() * dims as usize; + let p_slice = pgrx::pg_sys::pq_getmsgbytes(buf, b_slice as _); + let mut slice = Vec::::with_capacity(dims as usize); + std::ptr::copy(p_slice, slice.as_mut_ptr().cast::(), b_slice); + slice.set_len(dims as usize); + + if let Some(x) = Vecf16Borrowed::new_checked(&slice) { + Vecf16Output::new(x) + } else { + pgrx::error!("detect data corruption"); + } + } +} diff --git a/src/datatype/binary_vecf32.rs b/src/datatype/binary_vecf32.rs new file mode 100644 index 000000000..4af913f3b --- /dev/null +++ b/src/datatype/binary_vecf32.rs @@ -0,0 +1,45 @@ +use super::memory_vecf32::{Vecf32Input, Vecf32Output}; +use base::scalar::F32; +use base::vector::Vecf32Borrowed; +use pgrx::datum::IntoDatum; +use pgrx::pg_sys::{Datum, Oid}; +use std::ffi::c_char; + +#[pgrx::pg_extern(sql = "\ +CREATE FUNCTION _vectors_vecf32_send(vector) RETURNS bytea +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_vecf32_send(vector: Vecf32Input<'_>) -> Datum { + use pgrx::pg_sys::StringInfoData; + unsafe { + let mut buf = StringInfoData::default(); + let dims = vector.dims() as u16; + let b_slice = std::mem::size_of::() * dims as usize; + pgrx::pg_sys::pq_begintypsend(&mut buf); + pgrx::pg_sys::pq_sendbytes(&mut buf, (&dims) as *const u16 as _, 2); + pgrx::pg_sys::pq_sendbytes(&mut buf, vector.slice().as_ptr() as _, b_slice as _); + Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) + } +} + +#[pgrx::pg_extern(sql = " +CREATE FUNCTION _vectors_vecf32_recv(internal, oid, integer) RETURNS vector +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_vecf32_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> Vecf32Output { + use pgrx::pg_sys::StringInfo; + unsafe { + let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); + let dims = (pgrx::pg_sys::pq_getmsgbytes(buf, 2) as *const u16).read_unaligned(); + + let b_slice = std::mem::size_of::() * dims as usize; + let p_slice = pgrx::pg_sys::pq_getmsgbytes(buf, b_slice as _); + let mut slice = Vec::::with_capacity(dims as usize); + std::ptr::copy(p_slice, slice.as_mut_ptr().cast::(), b_slice); + slice.set_len(dims as usize); + + if let Some(x) = Vecf32Borrowed::new_checked(&slice) { + Vecf32Output::new(x) + } else { + pgrx::error!("detect data corruption"); + } + } +} diff --git a/src/datatype/casts_f32.rs b/src/datatype/casts.rs similarity index 58% rename from src/datatype/casts_f32.rs rename to src/datatype/casts.rs index 123434ca7..21bd7d31a 100644 --- a/src/datatype/casts_f32.rs +++ b/src/datatype/casts.rs @@ -1,9 +1,7 @@ -use crate::datatype::svecf32::{SVecf32, SVecf32Input, SVecf32Output}; -use crate::datatype::vecf16::{Vecf16, Vecf16Input, Vecf16Output}; -use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output}; -use crate::prelude::check_value_dimensions; -use base::scalar::FloatCast; -use service::prelude::*; +use crate::datatype::memory_svecf32::{SVecf32Input, SVecf32Output}; +use crate::datatype::memory_vecf16::{Vecf16Input, Vecf16Output}; +use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output}; +use crate::prelude::*; #[pgrx::pg_extern(immutable, parallel_safe, strict)] fn _vectors_cast_array_to_vecf32( @@ -11,12 +9,12 @@ fn _vectors_cast_array_to_vecf32( _typmod: i32, _explicit: bool, ) -> Vecf32Output { - check_value_dimensions(array.len()); - let mut data = vec![F32::zero(); array.len()]; + check_value_dims(array.len()); + let mut slice = vec![F32::zero(); array.len()]; for (i, x) in array.iter().enumerate() { - data[i] = F32(x.unwrap_or(f32::NAN)); + slice[i] = F32(x.unwrap_or(f32::NAN)); } - Vecf32::new_in_postgres(&data) + Vecf32Output::new(Vecf32Borrowed::new(&slice)) } #[pgrx::pg_extern(immutable, parallel_safe, strict)] @@ -25,7 +23,7 @@ fn _vectors_cast_vecf32_to_array( _typmod: i32, _explicit: bool, ) -> Vec { - vector.data().iter().map(|x| x.to_f32()).collect() + vector.slice().iter().map(|x| x.to_f32()).collect() } #[pgrx::pg_extern(immutable, parallel_safe, strict)] @@ -34,9 +32,9 @@ fn _vectors_cast_vecf32_to_vecf16( _typmod: i32, _explicit: bool, ) -> Vecf16Output { - let data: Vec = vector.data().iter().map(|&x| F16::from_f(x)).collect(); + let slice: Vec = vector.slice().iter().map(|&x| F16::from_f(x)).collect(); - Vecf16::new_in_postgres(&data) + Vecf16Output::new(Vecf16Borrowed::new(&slice)) } #[pgrx::pg_extern(immutable, parallel_safe, strict)] @@ -45,9 +43,9 @@ fn _vectors_cast_vecf16_to_vecf32( _typmod: i32, _explicit: bool, ) -> Vecf32Output { - let data: Vec = vector.data().iter().map(|&x| x.to_f()).collect(); + let slice: Vec = vector.slice().iter().map(|&x| x.to_f()).collect(); - Vecf32::new_in_postgres(&data) + Vecf32Output::new(Vecf32Borrowed::new(&slice)) } #[pgrx::pg_extern(immutable, parallel_safe, strict)] @@ -59,7 +57,7 @@ fn _vectors_cast_vecf32_to_svecf32( let mut indexes = Vec::new(); let mut values = Vec::new(); vector - .data() + .slice() .iter() .enumerate() .filter(|(_, x)| !x.is_zero()) @@ -68,11 +66,11 @@ fn _vectors_cast_vecf32_to_svecf32( values.push(x); }); - SVecf32::new_in_postgres(SparseF32Ref { - dims: vector.len() as u16, - indexes: &indexes, - values: &values, - }) + SVecf32Output::new(SVecf32Borrowed::new( + vector.dims() as u16, + &indexes, + &values, + )) } #[pgrx::pg_extern(immutable, parallel_safe, strict)] @@ -81,6 +79,7 @@ fn _vectors_cast_svecf32_to_vecf32( _typmod: i32, _explicit: bool, ) -> Vecf32Output { - let data = vector.data().to_dense(); - Vecf32::new_in_postgres(&data) + let slice = vector.for_borrow().to_vec(); + + Vecf32Output::new(Vecf32Borrowed::new(&slice)) } diff --git a/src/datatype/functions.rs b/src/datatype/functions.rs new file mode 100644 index 000000000..d32a18f8a --- /dev/null +++ b/src/datatype/functions.rs @@ -0,0 +1,45 @@ +use super::memory_svecf32::SVecf32Output; +use crate::prelude::*; +use base::scalar::F32; +use base::vector::SVecf32Borrowed; + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_to_svector( + dims: i32, + index: pgrx::Array, + value: pgrx::Array, +) -> SVecf32Output { + let dims = check_value_dims(dims as usize); + if index.len() != value.len() { + bad_literal("Lengths of index and value are not matched."); + } + if index.contains_nulls() || value.contains_nulls() { + bad_literal("Index or value contains nulls."); + } + let mut vector: Vec<(u16, F32)> = index + .iter_deny_null() + .zip(value.iter_deny_null()) + .map(|(index, value)| { + if index < 0 || index >= dims.get() as i32 { + bad_literal("Index out of bound."); + } + (index as u16, F32(value)) + }) + .collect(); + vector.sort_unstable_by_key(|x| x.0); + if vector.len() > 1 { + for i in 0..vector.len() - 1 { + if vector[i].0 == vector[i + 1].0 { + bad_literal("Duplicated index."); + } + } + } + + let mut indexes = Vec::::with_capacity(vector.len()); + let mut values = Vec::::with_capacity(vector.len()); + for x in vector { + indexes.push(x.0); + values.push(x.1); + } + SVecf32Output::new(SVecf32Borrowed::new(dims.get(), &indexes, &values)) +} diff --git a/src/datatype/memory_svecf32.rs b/src/datatype/memory_svecf32.rs new file mode 100644 index 000000000..a3f79579c --- /dev/null +++ b/src/datatype/memory_svecf32.rs @@ -0,0 +1,181 @@ +use crate::prelude::*; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::FromDatum; +use pgrx::IntoDatum; +use std::alloc::Layout; +use std::ops::Deref; +use std::ptr::NonNull; + +#[repr(C, align(8))] +pub struct SVecf32Header { + varlena: u32, + dims: u16, + kind: u16, + len: u16, + reserved: [u8; 6], + phantom: [u8; 0], +} + +impl SVecf32Header { + fn varlena(size: usize) -> u32 { + (size << 2) as u32 + } + fn layout(len: usize) -> Layout { + u16::try_from(len).expect("Vector is too large."); + let layout = Layout::new::(); + let layout1 = Layout::array::(len).unwrap(); + let layout2 = Layout::array::(len).unwrap(); + let layout = layout.extend(layout1).unwrap().0.pad_to_align(); + layout.extend(layout2).unwrap().0.pad_to_align() + } + pub fn dims(&self) -> usize { + self.dims as usize + } + pub fn len(&self) -> usize { + self.len as usize + } + fn indexes(&self) -> &[u16] { + let ptr = self.phantom.as_ptr().cast(); + unsafe { std::slice::from_raw_parts(ptr, self.len as usize) } + } + fn values(&self) -> &[F32] { + let len = self.len as usize; + unsafe { + let ptr = self.phantom.as_ptr().cast::().add(len); + let offset = ptr.align_offset(8); + let ptr = ptr.add(offset).cast(); + std::slice::from_raw_parts(ptr, len) + } + } + pub fn for_borrow(&self) -> SVecf32Borrowed<'_> { + unsafe { SVecf32Borrowed::new_unchecked(self.dims, self.indexes(), self.values()) } + } +} + +pub enum SVecf32Input<'a> { + Owned(SVecf32Output), + Borrowed(&'a SVecf32Header), +} + +impl<'a> SVecf32Input<'a> { + unsafe fn new(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + SVecf32Input::Owned(SVecf32Output(q)) + } else { + unsafe { SVecf32Input::Borrowed(p.as_ref()) } + } + } +} + +impl Deref for SVecf32Input<'_> { + type Target = SVecf32Header; + + fn deref(&self) -> &Self::Target { + match self { + SVecf32Input::Owned(x) => x, + SVecf32Input::Borrowed(x) => x, + } + } +} + +pub struct SVecf32Output(NonNull); + +impl SVecf32Output { + pub fn new(vector: SVecf32Borrowed<'_>) -> SVecf32Output { + unsafe { + let layout = SVecf32Header::layout(vector.len() as usize); + let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut SVecf32Header; + ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); + std::ptr::addr_of_mut!((*ptr).varlena).write(SVecf32Header::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).dims).write(vector.dims()); + std::ptr::addr_of_mut!((*ptr).kind).write(2); + std::ptr::addr_of_mut!((*ptr).len).write(vector.len()); + std::ptr::addr_of_mut!((*ptr).reserved).write([0; 6]); + let mut data_ptr = (*ptr).phantom.as_mut_ptr().cast::(); + std::ptr::copy_nonoverlapping( + vector.indexes().as_ptr(), + data_ptr, + vector.len() as usize, + ); + data_ptr = data_ptr.add(vector.len() as usize); + let offset = data_ptr.align_offset(8); + std::ptr::write_bytes(data_ptr, 0, offset); + data_ptr = data_ptr.add(offset); + std::ptr::copy_nonoverlapping( + vector.values().as_ptr(), + data_ptr.cast(), + vector.len() as usize, + ); + SVecf32Output(NonNull::new(ptr).unwrap()) + } + } + pub fn into_raw(self) -> *mut SVecf32Header { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Deref for SVecf32Output { + type Target = SVecf32Header; + + fn deref(&self) -> &Self::Target { + unsafe { self.0.as_ref() } + } +} + +impl Drop for SVecf32Output { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr() as _); + } + } +} + +impl<'a> FromDatum for SVecf32Input<'a> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(SVecf32Input::new(ptr)) } + } + } +} + +impl IntoDatum for SVecf32Output { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw() as *mut ())) + } + + fn type_oid() -> Oid { + pgrx::wrappers::regtypein("vectors.svector") + } +} + +unsafe impl SqlTranslatable for SVecf32Input<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("svector"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("svector")))) + } +} + +unsafe impl SqlTranslatable for SVecf32Output { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("svector"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("svector")))) + } +} diff --git a/src/datatype/memory_vecf16.rs b/src/datatype/memory_vecf16.rs new file mode 100644 index 000000000..a8d8618c4 --- /dev/null +++ b/src/datatype/memory_vecf16.rs @@ -0,0 +1,162 @@ +use crate::prelude::*; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::FromDatum; +use pgrx::IntoDatum; +use std::alloc::Layout; +use std::ops::Deref; +use std::ptr::NonNull; + +#[repr(C, align(8))] +pub struct Vecf16Header { + varlena: u32, + dims: u16, + kind: u16, + phantom: [F16; 0], +} + +impl Vecf16Header { + fn varlena(size: usize) -> u32 { + (size << 2) as u32 + } + fn layout(len: usize) -> Layout { + u16::try_from(len).expect("Vector is too large."); + let layout_alpha = Layout::new::(); + let layout_beta = Layout::array::(len).unwrap(); + let layout = layout_alpha.extend(layout_beta).unwrap().0; + layout.pad_to_align() + } + pub fn dims(&self) -> usize { + self.dims as usize + } + pub fn slice(&self) -> &[F16] { + debug_assert_eq!(self.varlena & 3, 0); + // TODO: force checking it in the future + // debug_assert_eq!(self.kind, 1); + // debug_assert_eq!(self.reserved, 0); + unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize) } + } + pub fn for_borrow(&self) -> Vecf16Borrowed<'_> { + unsafe { Vecf16Borrowed::new_unchecked(self.slice()) } + } +} + +impl Deref for Vecf16Header { + type Target = [F16]; + + fn deref(&self) -> &Self::Target { + self.slice() + } +} + +pub enum Vecf16Input<'a> { + Owned(Vecf16Output), + Borrowed(&'a Vecf16Header), +} + +impl<'a> Vecf16Input<'a> { + unsafe fn new(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + Vecf16Input::Owned(Vecf16Output(q)) + } else { + unsafe { Vecf16Input::Borrowed(p.as_ref()) } + } + } +} + +impl Deref for Vecf16Input<'_> { + type Target = Vecf16Header; + + fn deref(&self) -> &Self::Target { + match self { + Vecf16Input::Owned(x) => x, + Vecf16Input::Borrowed(x) => x, + } + } +} + +pub struct Vecf16Output(NonNull); + +impl Vecf16Output { + pub fn new(vector: Vecf16Borrowed<'_>) -> Vecf16Output { + unsafe { + let slice = vector.slice(); + let layout = Vecf16Header::layout(slice.len()); + let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf16Header; + ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16Header::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(1); + std::ptr::addr_of_mut!((*ptr).dims).write(slice.len() as u16); + std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); + Vecf16Output(NonNull::new(ptr).unwrap()) + } + } + pub fn into_raw(self) -> *mut Vecf16Header { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Deref for Vecf16Output { + type Target = Vecf16Header; + + fn deref(&self) -> &Self::Target { + unsafe { self.0.as_ref() } + } +} + +impl Drop for Vecf16Output { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr() as _); + } + } +} + +impl<'a> FromDatum for Vecf16Input<'a> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(Vecf16Input::new(ptr)) } + } + } +} + +impl IntoDatum for Vecf16Output { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw() as *mut ())) + } + + fn type_oid() -> Oid { + pgrx::wrappers::regtypein("vectors.vecf16") + } +} + +unsafe impl SqlTranslatable for Vecf16Input<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vecf16"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vecf16")))) + } +} + +unsafe impl SqlTranslatable for Vecf16Output { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vecf16"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vecf16")))) + } +} diff --git a/src/datatype/memory_vecf32.rs b/src/datatype/memory_vecf32.rs new file mode 100644 index 000000000..70f1957dc --- /dev/null +++ b/src/datatype/memory_vecf32.rs @@ -0,0 +1,158 @@ +use crate::prelude::*; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::FromDatum; +use pgrx::IntoDatum; +use std::alloc::Layout; +use std::ops::Deref; +use std::ptr::NonNull; + +#[repr(C, align(8))] +pub struct Vecf32Header { + varlena: u32, + dims: u16, + kind: u16, + phantom: [F32; 0], +} + +impl Vecf32Header { + fn varlena(size: usize) -> u32 { + (size << 2) as u32 + } + fn layout(len: usize) -> Layout { + u16::try_from(len).expect("Vector is too large."); + let layout_alpha = Layout::new::(); + let layout_beta = Layout::array::(len).unwrap(); + let layout = layout_alpha.extend(layout_beta).unwrap().0; + layout.pad_to_align() + } + pub fn dims(&self) -> usize { + self.dims as usize + } + pub fn slice(&self) -> &[F32] { + unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.dims as usize) } + } + pub fn for_borrow(&self) -> Vecf32Borrowed<'_> { + unsafe { Vecf32Borrowed::new_unchecked(self.slice()) } + } +} + +impl Deref for Vecf32Header { + type Target = [F32]; + + fn deref(&self) -> &Self::Target { + self.slice() + } +} + +pub enum Vecf32Input<'a> { + Owned(Vecf32Output), + Borrowed(&'a Vecf32Header), +} + +impl<'a> Vecf32Input<'a> { + unsafe fn new(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + Vecf32Input::Owned(Vecf32Output(q)) + } else { + unsafe { Vecf32Input::Borrowed(p.as_ref()) } + } + } +} + +impl Deref for Vecf32Input<'_> { + type Target = Vecf32Header; + + fn deref(&self) -> &Self::Target { + match self { + Vecf32Input::Owned(x) => x, + Vecf32Input::Borrowed(x) => x, + } + } +} + +pub struct Vecf32Output(NonNull); + +impl Vecf32Output { + pub fn new(vector: Vecf32Borrowed<'_>) -> Vecf32Output { + unsafe { + let slice = vector.slice(); + let layout = Vecf32Header::layout(slice.len()); + let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf32Header; + ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); + std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32Header::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(0); + std::ptr::addr_of_mut!((*ptr).dims).write(slice.len() as u16); + std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); + Vecf32Output(NonNull::new(ptr).unwrap()) + } + } + pub fn into_raw(self) -> *mut Vecf32Header { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Deref for Vecf32Output { + type Target = Vecf32Header; + + fn deref(&self) -> &Self::Target { + unsafe { self.0.as_ref() } + } +} + +impl Drop for Vecf32Output { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr() as _); + } + } +} + +impl<'a> FromDatum for Vecf32Input<'a> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(Vecf32Input::new(ptr)) } + } + } +} + +impl IntoDatum for Vecf32Output { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw() as *mut ())) + } + + fn type_oid() -> Oid { + pgrx::wrappers::regtypein("vectors.vector") + } +} + +unsafe impl SqlTranslatable for Vecf32Input<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vector"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vector")))) + } +} + +unsafe impl SqlTranslatable for Vecf32Output { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("vector"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("vector")))) + } +} diff --git a/src/datatype/mod.rs b/src/datatype/mod.rs index 0def2db1f..0cbb6bada 100644 --- a/src/datatype/mod.rs +++ b/src/datatype/mod.rs @@ -1,8 +1,18 @@ -pub mod casts_f32; +pub mod binary_svecf32; +pub mod binary_vecf16; +pub mod binary_vecf32; +pub mod casts; +pub mod functions; +pub mod memory_svecf32; +pub mod memory_vecf16; +pub mod memory_vecf32; pub mod operators_svecf32; pub mod operators_vecf16; pub mod operators_vecf32; -pub mod svecf32; +pub mod subscript_svecf32; +pub mod subscript_vecf16; +pub mod subscript_vecf32; +pub mod text_svecf32; +pub mod text_vecf16; +pub mod text_vecf32; pub mod typmod; -pub mod vecf16; -pub mod vecf32; diff --git a/src/datatype/operators_svecf32.rs b/src/datatype/operators_svecf32.rs index 03f3616f7..f2c9a9cd4 100644 --- a/src/datatype/operators_svecf32.rs +++ b/src/datatype/operators_svecf32.rs @@ -1,12 +1,11 @@ -use crate::datatype::svecf32::{SVecf32, SVecf32Input, SVecf32Output}; +use crate::datatype::memory_svecf32::{SVecf32Input, SVecf32Output}; use crate::prelude::*; -use base::scalar::FloatCast; -use service::prelude::*; +use base::global::*; use std::ops::Deref; #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_add(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> SVecf32Output { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + check_matched_dims(lhs.dims() as _, rhs.dims() as _); let size1 = lhs.len(); let size2 = rhs.len(); @@ -15,13 +14,13 @@ fn _vectors_svecf32_operator_add(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) - let mut pos = 0; let mut indexes = vec![0u16; size1 + size2]; let mut values = vec![F32::zero(); size1 + size2]; - let lhs = lhs.data(); - let rhs = rhs.data(); + let lhs = lhs.for_borrow(); + let rhs = rhs.for_borrow(); while pos1 < size1 && pos2 < size2 { - let lhs_index = lhs.indexes[pos1]; - let rhs_index = rhs.indexes[pos2]; - let lhs_value = lhs.values[pos1]; - let rhs_value = rhs.values[pos2]; + let lhs_index = lhs.indexes()[pos1]; + let rhs_index = rhs.indexes()[pos2]; + let lhs_value = lhs.values()[pos1]; + let rhs_value = rhs.values()[pos2]; indexes[pos] = lhs_index.min(rhs_index); values[pos] = F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value + F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value; @@ -30,28 +29,24 @@ fn _vectors_svecf32_operator_add(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) - pos += (!values[pos].is_zero()) as usize; } for i in pos1..size1 { - indexes[pos] = lhs.indexes[i]; - values[pos] = lhs.values[i]; + indexes[pos] = lhs.indexes()[i]; + values[pos] = lhs.values()[i]; pos += 1; } for i in pos2..size2 { - indexes[pos] = rhs.indexes[i]; - values[pos] = rhs.values[i]; + indexes[pos] = rhs.indexes()[i]; + values[pos] = rhs.values()[i]; pos += 1; } indexes.truncate(pos); values.truncate(pos); - SVecf32::new_in_postgres(SparseF32Ref { - dims: lhs.dims(), - indexes: &indexes, - values: &values, - }) + SVecf32Output::new(SVecf32Borrowed::new(lhs.dims(), &indexes, &values)) } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_minus(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> SVecf32Output { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + check_matched_dims(lhs.dims() as _, rhs.dims() as _); let size1 = lhs.len(); let size2 = rhs.len(); @@ -60,13 +55,13 @@ fn _vectors_svecf32_operator_minus(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) let mut pos = 0; let mut indexes = vec![0u16; size1 + size2]; let mut values = vec![F32::zero(); size1 + size2]; - let lhs = lhs.data(); - let rhs = rhs.data(); + let lhs = lhs.for_borrow(); + let rhs = rhs.for_borrow(); while pos1 < size1 && pos2 < size2 { - let lhs_index = lhs.indexes[pos1]; - let rhs_index = rhs.indexes[pos2]; - let lhs_value = lhs.values[pos1]; - let rhs_value = rhs.values[pos2]; + let lhs_index = lhs.indexes()[pos1]; + let rhs_index = rhs.indexes()[pos2]; + let lhs_value = lhs.values()[pos1]; + let rhs_value = rhs.values()[pos2]; indexes[pos] = lhs_index.min(rhs_index); values[pos] = F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value - F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value; @@ -75,75 +70,101 @@ fn _vectors_svecf32_operator_minus(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) pos += (!values[pos].is_zero()) as usize; } for i in pos1..size1 { - indexes[pos] = lhs.indexes[i]; - values[pos] = lhs.values[i]; + indexes[pos] = lhs.indexes()[i]; + values[pos] = lhs.values()[i]; pos += 1; } for i in pos2..size2 { - indexes[pos] = rhs.indexes[i]; - values[pos] = -rhs.values[i]; + indexes[pos] = rhs.indexes()[i]; + values[pos] = -rhs.values()[i]; pos += 1; } indexes.truncate(pos); values.truncate(pos); - SVecf32::new_in_postgres(SparseF32Ref { - dims: lhs.dims(), - indexes: &indexes, - values: &values, - }) + SVecf32Output::new(SVecf32Borrowed::new(lhs.dims(), &indexes, &values)) } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_lt(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); - lhs.deref() < rhs.deref() + check_matched_dims(lhs.dims() as _, rhs.dims() as _); + compare(lhs, rhs).is_lt() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_lte(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); - lhs.deref() <= rhs.deref() + check_matched_dims(lhs.dims() as _, rhs.dims() as _); + compare(lhs, rhs).is_le() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_gt(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); - lhs.deref() > rhs.deref() + check_matched_dims(lhs.dims() as _, rhs.dims() as _); + compare(lhs, rhs).is_gt() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_gte(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); - lhs.deref() >= rhs.deref() + check_matched_dims(lhs.dims() as _, rhs.dims() as _); + compare(lhs, rhs).is_ge() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_eq(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); - lhs.deref() == rhs.deref() + check_matched_dims(lhs.dims() as _, rhs.dims() as _); + lhs.deref().for_borrow() == rhs.deref().for_borrow() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_neq(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); - lhs.deref() != rhs.deref() + check_matched_dims(lhs.dims() as _, rhs.dims() as _); + lhs.deref().for_borrow() != rhs.deref().for_borrow() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_cosine(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> f32 { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); - SparseF32Cos::distance(lhs.data(), rhs.data()).to_f32() + check_matched_dims(lhs.dims() as _, rhs.dims() as _); + SVecf32Cos::distance(lhs.for_borrow(), rhs.for_borrow()).to_f32() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_dot(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> f32 { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); - SparseF32Dot::distance(lhs.data(), rhs.data()).to_f32() + check_matched_dims(lhs.dims() as _, rhs.dims() as _); + SVecf32Dot::distance(lhs.for_borrow(), rhs.for_borrow()).to_f32() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_svecf32_operator_l2(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> f32 { - check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); - SparseF32L2::distance(lhs.data(), rhs.data()).to_f32() + check_matched_dims(lhs.dims() as _, rhs.dims() as _); + SVecf32L2::distance(lhs.for_borrow(), rhs.for_borrow()).to_f32() +} + +fn compare(a: SVecf32Input<'_>, b: SVecf32Input<'_>) -> std::cmp::Ordering { + use std::cmp::Ordering; + assert!(a.dims() == b.dims()); + let lhs = a.for_borrow(); + let rhs = b.for_borrow(); + let mut pos = 0; + let size1 = lhs.len() as usize; + let size2 = rhs.len() as usize; + while pos < size1 && pos < size2 { + let lhs_index = lhs.indexes()[pos]; + let rhs_index = rhs.indexes()[pos]; + let lhs_value = lhs.values()[pos]; + let rhs_value = rhs.values()[pos]; + match lhs_index.cmp(&rhs_index) { + Ordering::Less => return lhs_value.cmp(&F32::zero()), + Ordering::Greater => return F32::zero().cmp(&rhs_value), + Ordering::Equal => match lhs_value.cmp(&rhs_value) { + Ordering::Equal => {} + x => return x, + }, + } + pos += 1; + } + match size1.cmp(&size2) { + Ordering::Less => F32::zero().cmp(&rhs.values()[pos]), + Ordering::Greater => lhs.values()[pos].cmp(&F32::zero()), + Ordering::Equal => Ordering::Equal, + } } diff --git a/src/datatype/operators_vecf16.rs b/src/datatype/operators_vecf16.rs index 191c0e364..7320bc3b0 100644 --- a/src/datatype/operators_vecf16.rs +++ b/src/datatype/operators_vecf16.rs @@ -1,79 +1,78 @@ -use crate::datatype::vecf16::{Vecf16, Vecf16Input, Vecf16Output}; +use crate::datatype::memory_vecf16::{Vecf16Input, Vecf16Output}; use crate::prelude::*; -use base::scalar::FloatCast; -use service::prelude::*; +use base::global::*; use std::ops::Deref; #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_add(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output { - let n = check_matched_dimensions(lhs.len(), rhs.len()); + let n = check_matched_dims(lhs.dims(), rhs.dims()); let mut v = vec![F16::zero(); n]; for i in 0..n { v[i] = lhs[i] + rhs[i]; } - Vecf16::new_in_postgres(&v) + Vecf16Output::new(Vecf16Borrowed::new(&v)) } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_minus(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output { - let n = check_matched_dimensions(lhs.len(), rhs.len()); + let n = check_matched_dims(lhs.dims(), rhs.dims()); let mut v = vec![F16::zero(); n]; for i in 0..n { v[i] = lhs[i] - rhs[i]; } - Vecf16::new_in_postgres(&v) + Vecf16Output::new(Vecf16Borrowed::new(&v)) } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_lt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() < rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() < rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_lte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() <= rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() <= rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_gt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() > rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() > rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_gte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() >= rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() >= rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_eq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() == rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() == rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_neq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() != rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() != rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_cosine(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { - check_matched_dimensions(lhs.len(), rhs.len()); - F16Cos::distance(&lhs, &rhs).to_f32() + check_matched_dims(lhs.dims(), rhs.dims()); + Vecf16Cos::distance(lhs.for_borrow(), rhs.for_borrow()).to_f32() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_dot(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { - check_matched_dimensions(lhs.len(), rhs.len()); - F16Dot::distance(&lhs, &rhs).to_f32() + check_matched_dims(lhs.dims(), rhs.dims()); + Vecf16Dot::distance(lhs.for_borrow(), rhs.for_borrow()).to_f32() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_l2(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { - check_matched_dimensions(lhs.len(), rhs.len()); - F16L2::distance(&lhs, &rhs).to_f32() + check_matched_dims(lhs.dims(), rhs.dims()); + Vecf16L2::distance(lhs.for_borrow(), rhs.for_borrow()).to_f32() } diff --git a/src/datatype/operators_vecf32.rs b/src/datatype/operators_vecf32.rs index 50649f8d8..2f64bef70 100644 --- a/src/datatype/operators_vecf32.rs +++ b/src/datatype/operators_vecf32.rs @@ -1,79 +1,78 @@ -use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output}; +use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output}; use crate::prelude::*; -use base::scalar::FloatCast; -use service::prelude::*; +use base::global::*; use std::ops::Deref; #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_add(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output { - let n = check_matched_dimensions(lhs.len(), rhs.len()); + let n = check_matched_dims(lhs.dims(), rhs.dims()); let mut v = vec![F32::zero(); n]; for i in 0..n { v[i] = lhs[i] + rhs[i]; } - Vecf32::new_in_postgres(&v) + Vecf32Output::new(Vecf32Borrowed::new(&v)) } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_minus(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output { - let n = check_matched_dimensions(lhs.len(), rhs.len()); + let n = check_matched_dims(lhs.dims(), rhs.dims()); let mut v = vec![F32::zero(); n]; for i in 0..n { v[i] = lhs[i] - rhs[i]; } - Vecf32::new_in_postgres(&v) + Vecf32Output::new(Vecf32Borrowed::new(&v)) } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_lt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() < rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() < rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_lte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() <= rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() <= rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_gt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() > rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() > rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_gte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() >= rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() >= rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_eq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() == rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() == rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_neq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - check_matched_dimensions(lhs.len(), rhs.len()); - lhs.deref() != rhs.deref() + check_matched_dims(lhs.dims(), rhs.dims()); + lhs.deref().slice() != rhs.deref().slice() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_cosine(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { - check_matched_dimensions(lhs.len(), rhs.len()); - F32Cos::distance(&lhs, &rhs).to_f32() + check_matched_dims(lhs.dims(), rhs.dims()); + Vecf32Cos::distance(lhs.for_borrow(), rhs.for_borrow()).to_f32() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_dot(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { - check_matched_dimensions(lhs.len(), rhs.len()); - F32Dot::distance(&lhs, &rhs).to_f32() + check_matched_dims(lhs.dims(), rhs.dims()); + Vecf32Dot::distance(lhs.for_borrow(), rhs.for_borrow()).to_f32() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_l2(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { - check_matched_dimensions(lhs.len(), rhs.len()); - F32L2::distance(&lhs, &rhs).to_f32() + check_matched_dims(lhs.dims(), rhs.dims()); + Vecf32L2::distance(lhs.for_borrow(), rhs.for_borrow()).to_f32() } diff --git a/src/datatype/subscript_svecf32.rs b/src/datatype/subscript_svecf32.rs new file mode 100644 index 000000000..1058da9de --- /dev/null +++ b/src/datatype/subscript_svecf32.rs @@ -0,0 +1,202 @@ +use crate::datatype::memory_svecf32::{SVecf32Input, SVecf32Output}; +use base::vector::SVecf32Borrowed; +use pgrx::datum::FromDatum; +use pgrx::pg_sys::Datum; + +#[pgrx::pg_extern(sql = "\ +CREATE FUNCTION _vectors_svecf32_subscript(internal) RETURNS internal +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_svecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { + #[pgrx::pg_guard] + unsafe extern "C" fn transform( + subscript: *mut pgrx::pg_sys::SubscriptingRef, + indirection: *mut pgrx::pg_sys::List, + pstate: *mut pgrx::pg_sys::ParseState, + is_slice: bool, + is_assignment: bool, + ) { + unsafe { + if (*indirection).length != 1 { + pgrx::pg_sys::error!("type svector does only support one subscript"); + } + if !is_slice { + pgrx::pg_sys::error!("type svector does only support slice fetch"); + } + if is_assignment { + pgrx::pg_sys::error!("type svector does not support subscripted assignment"); + } + let subscript = &mut *subscript; + let ai = (*(*indirection).elements.add(0)).ptr_value as *mut pgrx::pg_sys::A_Indices; + subscript.refupperindexpr = pgrx::pg_sys::lappend( + std::ptr::null_mut(), + if !(*ai).uidx.is_null() { + let subexpr = + pgrx::pg_sys::transformExpr(pstate, (*ai).uidx, (*pstate).p_expr_kind); + let subexpr = pgrx::pg_sys::coerce_to_target_type( + pstate, + subexpr, + pgrx::pg_sys::exprType(subexpr), + pgrx::pg_sys::INT4OID, + -1, + pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, + pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, + -1, + ); + if subexpr.is_null() { + pgrx::error!("svector subscript must have type integer"); + } + subexpr.cast() + } else { + std::ptr::null_mut() + }, + ); + subscript.reflowerindexpr = pgrx::pg_sys::lappend( + std::ptr::null_mut(), + if !(*ai).lidx.is_null() { + let subexpr = + pgrx::pg_sys::transformExpr(pstate, (*ai).lidx, (*pstate).p_expr_kind); + let subexpr = pgrx::pg_sys::coerce_to_target_type( + pstate, + subexpr, + pgrx::pg_sys::exprType(subexpr), + pgrx::pg_sys::INT4OID, + -1, + pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, + pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, + -1, + ); + if subexpr.is_null() { + pgrx::error!("svector subscript must have type integer"); + } + subexpr.cast() + } else { + std::ptr::null_mut() + }, + ); + subscript.refrestype = subscript.refcontainertype; + } + } + #[pgrx::pg_guard] + unsafe extern "C" fn exec_setup( + _subscript: *const pgrx::pg_sys::SubscriptingRef, + state: *mut pgrx::pg_sys::SubscriptingRefState, + steps: *mut pgrx::pg_sys::SubscriptExecSteps, + ) { + #[derive(Default)] + struct Workspace { + range: Option<(Option, Option)>, + } + #[pgrx::pg_guard] + unsafe extern "C" fn sbs_check_subscripts( + _state: *mut pgrx::pg_sys::ExprState, + op: *mut pgrx::pg_sys::ExprEvalStep, + _econtext: *mut pgrx::pg_sys::ExprContext, + ) -> bool { + unsafe { + let state = &mut *(*op).d.sbsref.state; + let workspace = &mut *(state.workspace as *mut Workspace); + workspace.range = None; + let mut end = None; + let mut start = None; + if state.upperprovided.read() { + if !state.upperindexnull.read() { + let upper = state.upperindex.read().value() as i32; + if upper >= 0 { + end = Some(upper as usize); + } else { + (*op).resnull.write(true); + return false; + } + } else { + (*op).resnull.write(true); + return false; + } + } + if state.lowerprovided.read() { + if !state.lowerindexnull.read() { + let lower = state.lowerindex.read().value() as i32; + if lower >= 0 { + start = Some(lower as usize); + } else { + (*op).resnull.write(true); + return false; + } + } else { + (*op).resnull.write(true); + return false; + } + } + workspace.range = Some((start, end)); + true + } + } + #[pgrx::pg_guard] + unsafe extern "C" fn sbs_fetch( + _state: *mut pgrx::pg_sys::ExprState, + op: *mut pgrx::pg_sys::ExprEvalStep, + _econtext: *mut pgrx::pg_sys::ExprContext, + ) { + unsafe { + let state = &mut *(*op).d.sbsref.state; + let workspace = &mut *(state.workspace as *mut Workspace); + let input = + SVecf32Input::from_datum((*op).resvalue.read(), (*op).resnull.read()).unwrap(); + let dims = input.dims() as u16; + let Some((start, end)) = workspace.range else { + (*op).resnull.write(true); + return; + }; + let start: u16 = match start.unwrap_or(0).try_into() { + Ok(x) => x, + Err(_) => { + (*op).resnull.write(true); + return; + } + }; + let end: u16 = match end.unwrap_or(dims as usize).try_into() { + Ok(x) => x, + Err(_) => { + (*op).resnull.write(true); + return; + } + }; + if start >= end || end > dims { + (*op).resnull.write(true); + return; + } + let svecf32 = input.for_borrow(); + let start_index = svecf32.indexes().partition_point(|&x| x < start); + let end_index = svecf32.indexes().partition_point(|&x| x < end); + let mut indexes = svecf32.indexes()[start_index..end_index].to_vec(); + indexes.iter_mut().for_each(|x| *x -= start); + let output = SVecf32Output::new(SVecf32Borrowed::new( + end - start, + &indexes, + &svecf32.values()[start_index..end_index], + )); + (*op).resnull.write(false); + (*op).resvalue.write(Datum::from(output.into_raw())); + } + } + unsafe { + let state = &mut *state; + let steps = &mut *steps; + assert!(state.numlower == 1); + assert!(state.numupper == 1); + state.workspace = pgrx::pg_sys::palloc(std::mem::size_of::()); + std::ptr::write::(state.workspace.cast(), Workspace::default()); + steps.sbs_check_subscripts = Some(sbs_check_subscripts); + steps.sbs_fetch = Some(sbs_fetch); + steps.sbs_assign = None; + steps.sbs_fetch_old = None; + } + } + static SBSROUTINES: pgrx::pg_sys::SubscriptRoutines = pgrx::pg_sys::SubscriptRoutines { + transform: Some(transform), + exec_setup: Some(exec_setup), + fetch_strict: true, + fetch_leakproof: false, + store_leakproof: false, + }; + std::ptr::addr_of!(SBSROUTINES).into() +} diff --git a/src/datatype/subscript_vecf16.rs b/src/datatype/subscript_vecf16.rs new file mode 100644 index 000000000..eca3afd6d --- /dev/null +++ b/src/datatype/subscript_vecf16.rs @@ -0,0 +1,185 @@ +use crate::datatype::memory_vecf16::{Vecf16Input, Vecf16Output}; +use base::vector::Vecf16Borrowed; +use pgrx::datum::FromDatum; +use pgrx::pg_sys::Datum; + +#[pgrx::pg_extern(sql = "\ +CREATE FUNCTION _vectors_vecf16_subscript(internal) RETURNS internal +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_vecf16_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { + #[pgrx::pg_guard] + unsafe extern "C" fn transform( + subscript: *mut pgrx::pg_sys::SubscriptingRef, + indirection: *mut pgrx::pg_sys::List, + pstate: *mut pgrx::pg_sys::ParseState, + is_slice: bool, + is_assignment: bool, + ) { + unsafe { + if (*indirection).length != 1 { + pgrx::pg_sys::error!("type vecf16 does only support one subscript"); + } + if !is_slice { + pgrx::pg_sys::error!("type vecf16 does only support slice fetch"); + } + if is_assignment { + pgrx::pg_sys::error!("type vecf16 does not support subscripted assignment"); + } + let subscript = &mut *subscript; + let ai = (*(*indirection).elements.add(0)).ptr_value as *mut pgrx::pg_sys::A_Indices; + subscript.refupperindexpr = pgrx::pg_sys::lappend( + std::ptr::null_mut(), + if !(*ai).uidx.is_null() { + let subexpr = + pgrx::pg_sys::transformExpr(pstate, (*ai).uidx, (*pstate).p_expr_kind); + let subexpr = pgrx::pg_sys::coerce_to_target_type( + pstate, + subexpr, + pgrx::pg_sys::exprType(subexpr), + pgrx::pg_sys::INT4OID, + -1, + pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, + pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, + -1, + ); + if subexpr.is_null() { + pgrx::error!("vecf16 subscript must have type integer"); + } + subexpr.cast() + } else { + std::ptr::null_mut() + }, + ); + subscript.reflowerindexpr = pgrx::pg_sys::lappend( + std::ptr::null_mut(), + if !(*ai).lidx.is_null() { + let subexpr = + pgrx::pg_sys::transformExpr(pstate, (*ai).lidx, (*pstate).p_expr_kind); + let subexpr = pgrx::pg_sys::coerce_to_target_type( + pstate, + subexpr, + pgrx::pg_sys::exprType(subexpr), + pgrx::pg_sys::INT4OID, + -1, + pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, + pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, + -1, + ); + if subexpr.is_null() { + pgrx::error!("vecf16 subscript must have type integer"); + } + subexpr.cast() + } else { + std::ptr::null_mut() + }, + ); + subscript.refrestype = subscript.refcontainertype; + } + } + #[pgrx::pg_guard] + unsafe extern "C" fn exec_setup( + _subscript: *const pgrx::pg_sys::SubscriptingRef, + state: *mut pgrx::pg_sys::SubscriptingRefState, + steps: *mut pgrx::pg_sys::SubscriptExecSteps, + ) { + #[derive(Default)] + struct Workspace { + range: Option<(Option, Option)>, + } + #[pgrx::pg_guard] + unsafe extern "C" fn sbs_check_subscripts( + _state: *mut pgrx::pg_sys::ExprState, + op: *mut pgrx::pg_sys::ExprEvalStep, + _econtext: *mut pgrx::pg_sys::ExprContext, + ) -> bool { + unsafe { + let state = &mut *(*op).d.sbsref.state; + let workspace = &mut *(state.workspace as *mut Workspace); + workspace.range = None; + let mut end = None; + let mut start = None; + if state.upperprovided.read() { + if !state.upperindexnull.read() { + let upper = state.upperindex.read().value() as i32; + if upper >= 0 { + end = Some(upper as usize); + } else { + (*op).resnull.write(true); + return false; + } + } else { + (*op).resnull.write(true); + return false; + } + } + if state.lowerprovided.read() { + if !state.lowerindexnull.read() { + let lower = state.lowerindex.read().value() as i32; + if lower >= 0 { + start = Some(lower as usize); + } else { + (*op).resnull.write(true); + return false; + } + } else { + (*op).resnull.write(true); + return false; + } + } + workspace.range = Some((start, end)); + true + } + } + #[pgrx::pg_guard] + unsafe extern "C" fn sbs_fetch( + _state: *mut pgrx::pg_sys::ExprState, + op: *mut pgrx::pg_sys::ExprEvalStep, + _econtext: *mut pgrx::pg_sys::ExprContext, + ) { + unsafe { + let state = &mut *(*op).d.sbsref.state; + let workspace = &mut *(state.workspace as *mut Workspace); + let input = + Vecf16Input::from_datum((*op).resvalue.read(), (*op).resnull.read()).unwrap(); + let slice = match workspace.range { + Some((None, None)) => input.slice().get(..), + Some((None, Some(y))) => input.slice().get(..y), + Some((Some(x), None)) => input.slice().get(x..), + Some((Some(x), Some(y))) => input.slice().get(x..y), + None => None, + }; + if let Some(slice) = slice { + if !slice.is_empty() { + let output = Vecf16Output::new(Vecf16Borrowed::new(slice)); + (*op).resnull.write(false); + (*op).resvalue.write(Datum::from(output.into_raw())); + } else { + (*op).resnull.write(true); + } + } else { + (*op).resnull.write(true); + } + } + } + unsafe { + let state = &mut *state; + let steps = &mut *steps; + assert!(state.numlower == 1); + assert!(state.numupper == 1); + state.workspace = pgrx::pg_sys::palloc(std::mem::size_of::()); + std::ptr::write::(state.workspace.cast(), Workspace::default()); + steps.sbs_check_subscripts = Some(sbs_check_subscripts); + steps.sbs_fetch = Some(sbs_fetch); + steps.sbs_assign = None; + steps.sbs_fetch_old = None; + } + } + static SBSROUTINES: pgrx::pg_sys::SubscriptRoutines = pgrx::pg_sys::SubscriptRoutines { + transform: Some(transform), + exec_setup: Some(exec_setup), + fetch_strict: true, + fetch_leakproof: false, + store_leakproof: false, + }; + std::ptr::addr_of!(SBSROUTINES).into() +} diff --git a/src/datatype/subscript_vecf32.rs b/src/datatype/subscript_vecf32.rs new file mode 100644 index 000000000..58331ab33 --- /dev/null +++ b/src/datatype/subscript_vecf32.rs @@ -0,0 +1,185 @@ +use crate::datatype::memory_vecf32::{Vecf32Input, Vecf32Output}; +use base::vector::Vecf32Borrowed; +use pgrx::datum::FromDatum; +use pgrx::pg_sys::Datum; + +#[pgrx::pg_extern(sql = "\ +CREATE FUNCTION _vectors_vecf32_subscript(internal) RETURNS internal +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_vecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { + #[pgrx::pg_guard] + unsafe extern "C" fn transform( + subscript: *mut pgrx::pg_sys::SubscriptingRef, + indirection: *mut pgrx::pg_sys::List, + pstate: *mut pgrx::pg_sys::ParseState, + is_slice: bool, + is_assignment: bool, + ) { + unsafe { + if (*indirection).length != 1 { + pgrx::pg_sys::error!("type vector does only support one subscript"); + } + if !is_slice { + pgrx::pg_sys::error!("type vector does only support slice fetch"); + } + if is_assignment { + pgrx::pg_sys::error!("type vector does not support subscripted assignment"); + } + let subscript = &mut *subscript; + let ai = (*(*indirection).elements.add(0)).ptr_value as *mut pgrx::pg_sys::A_Indices; + subscript.refupperindexpr = pgrx::pg_sys::lappend( + std::ptr::null_mut(), + if !(*ai).uidx.is_null() { + let subexpr = + pgrx::pg_sys::transformExpr(pstate, (*ai).uidx, (*pstate).p_expr_kind); + let subexpr = pgrx::pg_sys::coerce_to_target_type( + pstate, + subexpr, + pgrx::pg_sys::exprType(subexpr), + pgrx::pg_sys::INT4OID, + -1, + pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, + pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, + -1, + ); + if subexpr.is_null() { + pgrx::error!("vector subscript must have type integer"); + } + subexpr.cast() + } else { + std::ptr::null_mut() + }, + ); + subscript.reflowerindexpr = pgrx::pg_sys::lappend( + std::ptr::null_mut(), + if !(*ai).lidx.is_null() { + let subexpr = + pgrx::pg_sys::transformExpr(pstate, (*ai).lidx, (*pstate).p_expr_kind); + let subexpr = pgrx::pg_sys::coerce_to_target_type( + pstate, + subexpr, + pgrx::pg_sys::exprType(subexpr), + pgrx::pg_sys::INT4OID, + -1, + pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, + pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, + -1, + ); + if subexpr.is_null() { + pgrx::error!("vector subscript must have type integer"); + } + subexpr.cast() + } else { + std::ptr::null_mut() + }, + ); + subscript.refrestype = subscript.refcontainertype; + } + } + #[pgrx::pg_guard] + unsafe extern "C" fn exec_setup( + _subscript: *const pgrx::pg_sys::SubscriptingRef, + state: *mut pgrx::pg_sys::SubscriptingRefState, + steps: *mut pgrx::pg_sys::SubscriptExecSteps, + ) { + #[derive(Default)] + struct Workspace { + range: Option<(Option, Option)>, + } + #[pgrx::pg_guard] + unsafe extern "C" fn sbs_check_subscripts( + _state: *mut pgrx::pg_sys::ExprState, + op: *mut pgrx::pg_sys::ExprEvalStep, + _econtext: *mut pgrx::pg_sys::ExprContext, + ) -> bool { + unsafe { + let state = &mut *(*op).d.sbsref.state; + let workspace = &mut *(state.workspace as *mut Workspace); + workspace.range = None; + let mut end = None; + let mut start = None; + if state.upperprovided.read() { + if !state.upperindexnull.read() { + let upper = state.upperindex.read().value() as i32; + if upper >= 0 { + end = Some(upper as usize); + } else { + (*op).resnull.write(true); + return false; + } + } else { + (*op).resnull.write(true); + return false; + } + } + if state.lowerprovided.read() { + if !state.lowerindexnull.read() { + let lower = state.lowerindex.read().value() as i32; + if lower >= 0 { + start = Some(lower as usize); + } else { + (*op).resnull.write(true); + return false; + } + } else { + (*op).resnull.write(true); + return false; + } + } + workspace.range = Some((start, end)); + true + } + } + #[pgrx::pg_guard] + unsafe extern "C" fn sbs_fetch( + _state: *mut pgrx::pg_sys::ExprState, + op: *mut pgrx::pg_sys::ExprEvalStep, + _econtext: *mut pgrx::pg_sys::ExprContext, + ) { + unsafe { + let state = &mut *(*op).d.sbsref.state; + let workspace = &mut *(state.workspace as *mut Workspace); + let input = + Vecf32Input::from_datum((*op).resvalue.read(), (*op).resnull.read()).unwrap(); + let slice = match workspace.range { + Some((None, None)) => input.slice().get(..), + Some((None, Some(y))) => input.slice().get(..y), + Some((Some(x), None)) => input.slice().get(x..), + Some((Some(x), Some(y))) => input.slice().get(x..y), + None => None, + }; + if let Some(slice) = slice { + if !slice.is_empty() { + let output = Vecf32Output::new(Vecf32Borrowed::new(slice)); + (*op).resnull.write(false); + (*op).resvalue.write(Datum::from(output.into_raw())); + } else { + (*op).resnull.write(true); + } + } else { + (*op).resnull.write(true); + } + } + } + unsafe { + let state = &mut *state; + let steps = &mut *steps; + assert!(state.numlower == 1); + assert!(state.numupper == 1); + state.workspace = pgrx::pg_sys::palloc(std::mem::size_of::()); + std::ptr::write::(state.workspace.cast(), Workspace::default()); + steps.sbs_check_subscripts = Some(sbs_check_subscripts); + steps.sbs_fetch = Some(sbs_fetch); + steps.sbs_assign = None; + steps.sbs_fetch_old = None; + } + } + static SBSROUTINES: pgrx::pg_sys::SubscriptRoutines = pgrx::pg_sys::SubscriptRoutines { + transform: Some(transform), + exec_setup: Some(exec_setup), + fetch_strict: true, + fetch_leakproof: false, + store_leakproof: false, + }; + std::ptr::addr_of!(SBSROUTINES).into() +} diff --git a/src/datatype/svecf32.rs b/src/datatype/svecf32.rs deleted file mode 100644 index 2f585d10a..000000000 --- a/src/datatype/svecf32.rs +++ /dev/null @@ -1,673 +0,0 @@ -use crate::prelude::*; -use pgrx::pg_sys::Datum; -use pgrx::pg_sys::Oid; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use pgrx::FromDatum; -use pgrx::IntoDatum; -use service::prelude::*; -use std::alloc::Layout; -use std::cmp::Ordering; -use std::ffi::CStr; -use std::ffi::CString; -use std::ops::Deref; -use std::ops::DerefMut; -use std::ptr::NonNull; - -#[repr(C, align(8))] -pub struct SVecf32 { - varlena: u32, - dims: u16, - kind: u8, - reserved: u8, - len: u16, - padding: [u8; 6], - phantom: [u8; 0], -} - -impl SVecf32 { - fn varlena(size: usize) -> u32 { - (size << 2) as u32 - } - fn layout(len: usize) -> Layout { - u16::try_from(len).expect("Vector is too large."); - let layout = Layout::new::(); - let layout1 = Layout::array::(len).unwrap(); - let layout2 = Layout::array::(len).unwrap(); - let layout = layout.extend(layout1).unwrap().0.pad_to_align(); - layout.extend(layout2).unwrap().0.pad_to_align() - } - pub fn new_in_postgres(vector: SparseF32Ref<'_>) -> SVecf32Output { - unsafe { - let layout = SVecf32::layout(vector.length() as usize); - let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut SVecf32; - ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena).write(SVecf32::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).dims).write(vector.dims); - std::ptr::addr_of_mut!((*ptr).kind).write(2); - std::ptr::addr_of_mut!((*ptr).reserved).write(0); - std::ptr::addr_of_mut!((*ptr).len).write(vector.length()); - std::ptr::addr_of_mut!((*ptr).padding).write(std::mem::zeroed()); - let mut data_ptr = (*ptr).phantom.as_mut_ptr().cast::(); - std::ptr::copy_nonoverlapping( - vector.indexes.as_ptr(), - data_ptr, - vector.length() as usize, - ); - data_ptr = data_ptr.add(vector.length() as usize); - let offset = data_ptr.align_offset(8); - std::ptr::write_bytes(data_ptr, 0, offset); - data_ptr = data_ptr.add(offset); - std::ptr::copy_nonoverlapping( - vector.values.as_ptr(), - data_ptr.cast(), - vector.length() as usize, - ); - SVecf32Output(NonNull::new(ptr).unwrap()) - } - } - pub fn new_zeroed_in_postgres(len: usize) -> SVecf32Output { - unsafe { - let layout = SVecf32::layout(len); - let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut SVecf32; - ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena).write(SVecf32::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(2); - std::ptr::addr_of_mut!((*ptr).reserved).write(0); - std::ptr::addr_of_mut!((*ptr).len).write(len as u16); - SVecf32Output(NonNull::new(ptr).unwrap()) - } - } - pub fn dims(&self) -> u16 { - self.dims - } - pub fn len(&self) -> usize { - self.len as usize - } - fn indexes(&self) -> &[u16] { - let ptr = self.phantom.as_ptr().cast(); - unsafe { std::slice::from_raw_parts(ptr, self.len as usize) } - } - fn values(&self) -> &[F32] { - let len = self.len as usize; - unsafe { - let ptr = self.phantom.as_ptr().cast::().add(len); - let offset = ptr.align_offset(8); - let ptr = ptr.add(offset).cast(); - std::slice::from_raw_parts(ptr, len) - } - } - fn indexes_mut(&mut self) -> &mut [u16] { - let ptr = self.phantom.as_mut_ptr().cast(); - unsafe { std::slice::from_raw_parts_mut(ptr, self.len as usize) } - } - fn values_mut(&mut self) -> &mut [F32] { - let len = self.len as usize; - unsafe { - let ptr = self.phantom.as_mut_ptr().cast::().add(len); - let offset = ptr.align_offset(8); - let ptr = ptr.add(offset).cast(); - std::slice::from_raw_parts_mut(ptr, len) - } - } - pub fn data(&self) -> SparseF32Ref<'_> { - debug_assert_eq!(self.varlena & 3, 0); - debug_assert_eq!(self.kind, 2); - SparseF32Ref { - dims: self.dims, - indexes: self.indexes(), - values: self.values(), - } - } -} - -impl PartialEq for SVecf32 { - fn eq(&self, other: &Self) -> bool { - self.data() == other.data() - } -} - -impl Eq for SVecf32 {} - -impl PartialOrd for SVecf32 { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for SVecf32 { - fn cmp(&self, other: &Self) -> Ordering { - assert!(self.dims() == other.dims()); - let lhs = self.data(); - let rhs = other.data(); - let mut pos = 0; - let size1 = lhs.length() as usize; - let size2 = rhs.length() as usize; - while pos < size1 && pos < size2 { - let lhs_index = lhs.indexes[pos]; - let rhs_index = rhs.indexes[pos]; - let lhs_value = lhs.values[pos]; - let rhs_value = rhs.values[pos]; - match lhs_index.cmp(&rhs_index) { - Ordering::Less => return lhs_value.cmp(&F32::zero()), - Ordering::Greater => return F32::zero().cmp(&rhs_value), - Ordering::Equal => match lhs_value.cmp(&rhs_value) { - Ordering::Equal => {} - x => return x, - }, - } - pos += 1; - } - match size1.cmp(&size2) { - Ordering::Less => F32::zero().cmp(&rhs.values[pos]), - Ordering::Greater => lhs.values[pos].cmp(&F32::zero()), - Ordering::Equal => Ordering::Equal, - } - } -} - -pub enum SVecf32Input<'a> { - Owned(SVecf32Output), - Borrowed(&'a SVecf32), -} - -impl<'a> SVecf32Input<'a> { - pub unsafe fn new(p: NonNull) -> Self { - let q = unsafe { - NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() - }; - if p != q { - SVecf32Input::Owned(SVecf32Output(q)) - } else { - unsafe { SVecf32Input::Borrowed(p.as_ref()) } - } - } -} - -impl Deref for SVecf32Input<'_> { - type Target = SVecf32; - - fn deref(&self) -> &Self::Target { - match self { - SVecf32Input::Owned(x) => x, - SVecf32Input::Borrowed(x) => x, - } - } -} - -pub struct SVecf32Output(NonNull); - -impl SVecf32Output { - pub fn into_raw(self) -> *mut SVecf32 { - let result = self.0.as_ptr(); - std::mem::forget(self); - result - } -} - -impl Deref for SVecf32Output { - type Target = SVecf32; - - fn deref(&self) -> &Self::Target { - unsafe { self.0.as_ref() } - } -} - -impl DerefMut for SVecf32Output { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { self.0.as_mut() } - } -} - -impl Drop for SVecf32Output { - fn drop(&mut self) { - unsafe { - pgrx::pg_sys::pfree(self.0.as_ptr() as _); - } - } -} - -impl<'a> FromDatum for SVecf32Input<'a> { - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { - if is_null { - None - } else { - let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); - unsafe { Some(SVecf32Input::new(ptr)) } - } - } -} - -impl IntoDatum for SVecf32Output { - fn into_datum(self) -> Option { - Some(Datum::from(self.into_raw() as *mut ())) - } - - fn type_oid() -> Oid { - pgrx::wrappers::regtypein("vectors.svector") - } -} - -unsafe impl SqlTranslatable for SVecf32Input<'_> { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("svector"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("svector")))) - } -} - -unsafe impl SqlTranslatable for SVecf32Output { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("svector"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("svector")))) - } -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output { - fn solve(option: Option, hint: &str) -> T { - if let Some(x) = option { - x - } else { - bad_literal(hint); - } - } - #[derive(Debug, Clone, Copy, PartialEq, Eq)] - enum State { - MatchingLeft, - Reading, - MatchedRight, - } - use State::*; - let input = input.to_bytes(); - let mut indexes = Vec::::new(); - let mut values = Vec::::new(); - let mut state = MatchingLeft; - let mut token: Option = None; - let mut index = 0; - for &c in input { - match (state, c) { - (MatchingLeft, b'[') => { - state = Reading; - } - (Reading, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => { - let token = token.get_or_insert(String::new()); - token.push(char::from_u32(c as u32).unwrap()); - } - (Reading, b',') => { - let token = solve(token.take(), "Expect a number."); - let value: F32 = solve(token.parse().ok(), "Bad number."); - if !value.is_zero() { - indexes.push(index); - values.push(value); - } - index = match index.checked_add(1) { - Some(x) => x, - None => check_value_dimensions(65536).get(), - }; - } - (Reading, b']') => { - if let Some(token) = token.take() { - let value: F32 = solve(token.parse().ok(), "Bad number."); - if !value.is_zero() { - indexes.push(index); - values.push(value); - } - index = match index.checked_add(1) { - Some(x) => x, - None => check_value_dimensions(65536).get(), - }; - } - state = MatchedRight; - } - (_, b' ') => {} - _ => { - bad_literal(&format!("Bad character with ascii {:#x}.", c)); - } - } - } - if state != MatchedRight { - bad_literal("Bad sequence"); - } - SVecf32::new_in_postgres(SparseF32Ref { - dims: check_value_dimensions(index as usize).get(), - indexes: &indexes, - values: &values, - }) -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString { - let mut buffer = String::new(); - buffer.push('['); - let vec = vector.data().to_dense(); - let mut iter = vec.iter(); - if let Some(x) = iter.next() { - buffer.push_str(format!("{}", x).as_str()); - } - for x in iter { - buffer.push_str(format!(", {}", x).as_str()); - } - buffer.push(']'); - CString::new(buffer).unwrap() -} - -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_svecf32_subscript(internal) RETURNS internal -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_svecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { - #[pgrx::pg_guard] - unsafe extern "C" fn transform( - subscript: *mut pgrx::pg_sys::SubscriptingRef, - indirection: *mut pgrx::pg_sys::List, - pstate: *mut pgrx::pg_sys::ParseState, - is_slice: bool, - is_assignment: bool, - ) { - unsafe { - if (*indirection).length != 1 { - pgrx::pg_sys::error!("type svector does only support one subscript"); - } - if !is_slice { - pgrx::pg_sys::error!("type svector does only support slice fetch"); - } - if is_assignment { - pgrx::pg_sys::error!("type svector does not support subscripted assignment"); - } - let subscript = &mut *subscript; - let ai = (*(*indirection).elements.add(0)).ptr_value as *mut pgrx::pg_sys::A_Indices; - subscript.refupperindexpr = pgrx::pg_sys::lappend( - std::ptr::null_mut(), - if !(*ai).uidx.is_null() { - let subexpr = - pgrx::pg_sys::transformExpr(pstate, (*ai).uidx, (*pstate).p_expr_kind); - let subexpr = pgrx::pg_sys::coerce_to_target_type( - pstate, - subexpr, - pgrx::pg_sys::exprType(subexpr), - pgrx::pg_sys::INT4OID, - -1, - pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, - pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, - -1, - ); - if subexpr.is_null() { - pgrx::error!("svector subscript must have type integer"); - } - subexpr.cast() - } else { - std::ptr::null_mut() - }, - ); - subscript.reflowerindexpr = pgrx::pg_sys::lappend( - std::ptr::null_mut(), - if !(*ai).lidx.is_null() { - let subexpr = - pgrx::pg_sys::transformExpr(pstate, (*ai).lidx, (*pstate).p_expr_kind); - let subexpr = pgrx::pg_sys::coerce_to_target_type( - pstate, - subexpr, - pgrx::pg_sys::exprType(subexpr), - pgrx::pg_sys::INT4OID, - -1, - pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, - pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, - -1, - ); - if subexpr.is_null() { - pgrx::error!("svector subscript must have type integer"); - } - subexpr.cast() - } else { - std::ptr::null_mut() - }, - ); - subscript.refrestype = subscript.refcontainertype; - } - } - #[pgrx::pg_guard] - unsafe extern "C" fn exec_setup( - _subscript: *const pgrx::pg_sys::SubscriptingRef, - state: *mut pgrx::pg_sys::SubscriptingRefState, - steps: *mut pgrx::pg_sys::SubscriptExecSteps, - ) { - #[derive(Default)] - struct Workspace { - range: Option<(Option, Option)>, - } - #[pgrx::pg_guard] - unsafe extern "C" fn sbs_check_subscripts( - _state: *mut pgrx::pg_sys::ExprState, - op: *mut pgrx::pg_sys::ExprEvalStep, - _econtext: *mut pgrx::pg_sys::ExprContext, - ) -> bool { - unsafe { - let state = &mut *(*op).d.sbsref.state; - let workspace = &mut *(state.workspace as *mut Workspace); - workspace.range = None; - let mut end = None; - let mut start = None; - if state.upperprovided.read() { - if !state.upperindexnull.read() { - let upper = state.upperindex.read().value() as i32; - if upper >= 0 { - end = Some(upper as usize); - } else { - (*op).resnull.write(true); - return false; - } - } else { - (*op).resnull.write(true); - return false; - } - } - if state.lowerprovided.read() { - if !state.lowerindexnull.read() { - let lower = state.lowerindex.read().value() as i32; - if lower >= 0 { - start = Some(lower as usize); - } else { - (*op).resnull.write(true); - return false; - } - } else { - (*op).resnull.write(true); - return false; - } - } - workspace.range = Some((start, end)); - true - } - } - #[pgrx::pg_guard] - unsafe extern "C" fn sbs_fetch( - _state: *mut pgrx::pg_sys::ExprState, - op: *mut pgrx::pg_sys::ExprEvalStep, - _econtext: *mut pgrx::pg_sys::ExprContext, - ) { - unsafe { - let state = &mut *(*op).d.sbsref.state; - let workspace = &mut *(state.workspace as *mut Workspace); - let input = - SVecf32Input::from_datum((*op).resvalue.read(), (*op).resnull.read()).unwrap(); - let Some((start, end)) = workspace.range else { - (*op).resnull.write(true); - return; - }; - let start: u16 = match start.unwrap_or(0).try_into() { - Ok(x) => x, - Err(_) => { - (*op).resnull.write(true); - return; - } - }; - let end: u16 = match end.unwrap_or(input.dims() as usize).try_into() { - Ok(x) => x, - Err(_) => { - (*op).resnull.write(true); - return; - } - }; - if start >= end || end > input.dims() { - (*op).resnull.write(true); - return; - } - let data = input.data(); - let start_index = data.indexes.partition_point(|&x| x < start); - let end_index = data.indexes.partition_point(|&x| x < end); - let mut indexes = data.indexes[start_index..end_index].to_vec(); - indexes.iter_mut().for_each(|x| *x -= start); - let output = SVecf32::new_in_postgres(SparseF32Ref { - dims: end - start, - indexes: &indexes, - values: &data.values[start_index..end_index], - }); - (*op).resnull.write(false); - (*op).resvalue.write(Datum::from(output.into_raw())); - } - } - unsafe { - let state = &mut *state; - let steps = &mut *steps; - assert!(state.numlower == 1); - assert!(state.numupper == 1); - state.workspace = pgrx::pg_sys::palloc(std::mem::size_of::()); - std::ptr::write::(state.workspace.cast(), Workspace::default()); - steps.sbs_check_subscripts = Some(sbs_check_subscripts); - steps.sbs_fetch = Some(sbs_fetch); - steps.sbs_assign = None; - steps.sbs_fetch_old = None; - } - } - static SBSROUTINES: pgrx::pg_sys::SubscriptRoutines = pgrx::pg_sys::SubscriptRoutines { - transform: Some(transform), - exec_setup: Some(exec_setup), - fetch_strict: true, - fetch_leakproof: false, - store_leakproof: false, - }; - std::ptr::addr_of!(SBSROUTINES).into() -} - -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_svecf32_send(svector) RETURNS bytea -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_svecf32_send(vector: SVecf32Input<'_>) -> Datum { - use pgrx::pg_sys::StringInfoData; - unsafe { - let mut buf = StringInfoData::default(); - let dims = vector.dims; - let len = vector.len; - let data = vector.data(); - pgrx::pg_sys::pq_begintypsend(&mut buf); - pgrx::pg_sys::pq_sendbytes(&mut buf, (&dims) as *const u16 as _, 2); - pgrx::pg_sys::pq_sendbytes(&mut buf, (&len) as *const u16 as _, 2); - pgrx::pg_sys::pq_sendbytes( - &mut buf, - data.indexes.as_ptr() as _, - (std::mem::size_of::() * len as usize) as _, - ); - pgrx::pg_sys::pq_sendbytes( - &mut buf, - data.values.as_ptr() as _, - (std::mem::size_of::() * len as usize) as _, - ); - Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) - } -} - -#[pgrx::pg_extern(sql = " -CREATE FUNCTION _vectors_svecf32_recv(internal, oid, integer) RETURNS svector -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_svecf32_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> SVecf32Output { - use pgrx::pg_sys::StringInfo; - unsafe { - let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); - let dims = (pgrx::pg_sys::pq_getmsgbytes(buf, 2) as *const u16).read_unaligned(); - let len = (pgrx::pg_sys::pq_getmsgbytes(buf, 2) as *const u16).read_unaligned(); - if dims == 0 || len == 0 { - pgrx::error!("data corruption is detected"); - } - let indexes_bytes = std::mem::size_of::() * len as usize; - let indexes_ptr = pgrx::pg_sys::pq_getmsgbytes(buf, indexes_bytes as _); - let values_bytes = std::mem::size_of::() * len as usize; - let values_ptr = pgrx::pg_sys::pq_getmsgbytes(buf, values_bytes as _); - let mut output = SVecf32::new_zeroed_in_postgres(len as usize); - - let indexes = std::slice::from_raw_parts(indexes_ptr as *const u16, len as usize); - if len > 1 { - for i in 0..len as usize - 1 { - if indexes[i] >= indexes[i + 1] { - pgrx::error!("data corruption is detected"); - } - } - } - if indexes[len as usize - 1] >= dims { - pgrx::error!("data corruption is detected"); - } - - output.dims = dims; - std::ptr::copy( - indexes_ptr, - output.indexes_mut().as_mut_ptr() as _, - indexes_bytes, - ); - std::ptr::copy( - values_ptr, - output.values_mut().as_mut_ptr() as _, - values_bytes, - ); - output - } -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn _vectors_to_svector( - dims: i32, - index: pgrx::Array, - value: pgrx::Array, -) -> SVecf32Output { - let dims = check_value_dimensions(dims as usize); - if index.len() != value.len() { - bad_literal("Lengths of index and value are not matched."); - } - if index.contains_nulls() || value.contains_nulls() { - bad_literal("Index or value contains nulls."); - } - let mut vector: Vec<(u16, F32)> = index - .iter_deny_null() - .zip(value.iter_deny_null()) - .map(|(index, value)| { - if index < 0 || index >= dims.get() as i32 { - bad_literal("Index out of bound."); - } - (index as u16, F32(value)) - }) - .collect(); - vector.sort_unstable_by_key(|x| x.0); - if vector.len() > 1 { - for i in 0..vector.len() - 1 { - if vector[i].0 == vector[i + 1].0 { - bad_literal("Duplicated index."); - } - } - } - - let mut indexes = Vec::::with_capacity(vector.len()); - let mut values = Vec::::with_capacity(vector.len()); - for x in vector { - indexes.push(x.0); - values.push(x.1); - } - SVecf32::new_in_postgres(SparseF32Ref { - dims: dims.get(), - indexes: &indexes, - values: &values, - }) -} diff --git a/src/datatype/text_svecf32.rs b/src/datatype/text_svecf32.rs new file mode 100644 index 000000000..0110716f2 --- /dev/null +++ b/src/datatype/text_svecf32.rs @@ -0,0 +1,96 @@ +use super::memory_svecf32::SVecf32Output; +use crate::datatype::memory_svecf32::SVecf32Input; +use crate::prelude::*; +use base::scalar::F32; +use base::vector::{SVecf32Borrowed, VectorBorrowed}; +use pgrx::pg_sys::Oid; +use std::ffi::{CStr, CString}; + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output { + fn solve(option: Option, hint: &str) -> T { + if let Some(x) = option { + x + } else { + bad_literal(hint); + } + } + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + enum State { + MatchingLeft, + Reading, + MatchedRight, + } + use State::*; + let input = input.to_bytes(); + let mut indexes = Vec::::new(); + let mut values = Vec::::new(); + let mut state = MatchingLeft; + let mut token: Option = None; + let mut index = 0; + for &c in input { + match (state, c) { + (MatchingLeft, b'[') => { + state = Reading; + } + (Reading, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => { + let token = token.get_or_insert(String::new()); + token.push(char::from_u32(c as u32).unwrap()); + } + (Reading, b',') => { + let token = solve(token.take(), "Expect a number."); + let value: F32 = solve(token.parse().ok(), "Bad number."); + if !value.is_zero() { + indexes.push(index); + values.push(value); + } + index = match index.checked_add(1) { + Some(x) => x, + None => check_value_dims(65536).get(), + }; + } + (Reading, b']') => { + if let Some(token) = token.take() { + let value: F32 = solve(token.parse().ok(), "Bad number."); + if !value.is_zero() { + indexes.push(index); + values.push(value); + } + index = match index.checked_add(1) { + Some(x) => x, + None => check_value_dims(65536).get(), + }; + } + state = MatchedRight; + } + (_, b' ') => {} + _ => { + bad_literal(&format!("Bad character with ascii {:#x}.", c)); + } + } + } + if state != MatchedRight { + bad_literal("Bad sequence"); + } + SVecf32Output::new(SVecf32Borrowed::new( + check_value_dims(index as usize).get(), + &indexes, + &values, + )) +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString { + let mut buffer = String::new(); + buffer.push('['); + let vec = vector.for_borrow().to_vec(); + let mut iter = vec.iter(); + if let Some(x) = iter.next() { + buffer.push_str(format!("{}", x).as_str()); + } + for x in iter { + buffer.push_str(format!(", {}", x).as_str()); + } + buffer.push(']'); + CString::new(buffer).unwrap() +} diff --git a/src/datatype/text_vecf16.rs b/src/datatype/text_vecf16.rs new file mode 100644 index 000000000..cbdee48d0 --- /dev/null +++ b/src/datatype/text_vecf16.rs @@ -0,0 +1,41 @@ +use super::memory_vecf16::Vecf16Output; +use crate::datatype::memory_vecf16::Vecf16Input; +use crate::datatype::typmod::Typmod; +use crate::prelude::*; +use base::vector::Vecf16Borrowed; +use pgrx::pg_sys::Oid; +use std::ffi::{CStr, CString}; + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_vecf16_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf16Output { + use crate::utils::parse::parse_vector; + let reserve = Typmod::parse_from_i32(typmod) + .unwrap() + .dims() + .map(|x| x.get()) + .unwrap_or(0); + let v = parse_vector(input.to_bytes(), reserve as usize, |s| s.parse().ok()); + match v { + Err(e) => { + bad_literal(&e.to_string()); + } + Ok(vector) => { + check_value_dims(vector.len()); + Vecf16Output::new(Vecf16Borrowed::new(&vector)) + } + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_vecf16_out(vector: Vecf16Input<'_>) -> CString { + let mut buffer = String::new(); + buffer.push('['); + if let Some(&x) = vector.slice().first() { + buffer.push_str(format!("{}", x).as_str()); + } + for &x in vector.slice().iter().skip(1) { + buffer.push_str(format!(", {}", x).as_str()); + } + buffer.push(']'); + CString::new(buffer).unwrap() +} diff --git a/src/datatype/text_vecf32.rs b/src/datatype/text_vecf32.rs new file mode 100644 index 000000000..964329bcd --- /dev/null +++ b/src/datatype/text_vecf32.rs @@ -0,0 +1,41 @@ +use super::memory_vecf32::Vecf32Output; +use crate::datatype::memory_vecf32::Vecf32Input; +use crate::datatype::typmod::Typmod; +use crate::prelude::*; +use base::vector::Vecf32Borrowed; +use pgrx::pg_sys::Oid; +use std::ffi::{CStr, CString}; + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_vecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf32Output { + use crate::utils::parse::parse_vector; + let reserve = Typmod::parse_from_i32(typmod) + .unwrap() + .dims() + .map(|x| x.get()) + .unwrap_or(0); + let v = parse_vector(input.to_bytes(), reserve as usize, |s| s.parse().ok()); + match v { + Err(e) => { + bad_literal(&e.to_string()); + } + Ok(vector) => { + check_value_dims(vector.len()); + Vecf32Output::new(Vecf32Borrowed::new(&vector)) + } + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_vecf32_out(vector: Vecf32Input<'_>) -> CString { + let mut buffer = String::new(); + buffer.push('['); + if let Some(&x) = vector.slice().first() { + buffer.push_str(format!("{}", x).as_str()); + } + for &x in vector.slice().iter().skip(1) { + buffer.push_str(format!(", {}", x).as_str()); + } + buffer.push(']'); + CString::new(buffer).unwrap() +} diff --git a/src/datatype/typmod.rs b/src/datatype/typmod.rs index ea8802416..84ebca79c 100644 --- a/src/datatype/typmod.rs +++ b/src/datatype/typmod.rs @@ -50,10 +50,10 @@ fn _vectors_typmod_in(list: Array<&CStr>) -> i32 { -1 } else if list.len() == 1 { let s = list.get(0).unwrap().unwrap().to_str().unwrap(); - let typmod = Typmod::Dims(check_type_dimensions(s.parse::().ok())); + let typmod = Typmod::Dims(check_type_dims(s.parse::().ok())); typmod.into_i32() } else { - check_type_dimensions(None); + check_type_dims(None); unreachable!() } } diff --git a/src/datatype/vecf16.rs b/src/datatype/vecf16.rs deleted file mode 100644 index 24e3bfd64..000000000 --- a/src/datatype/vecf16.rs +++ /dev/null @@ -1,516 +0,0 @@ -use crate::datatype::typmod::Typmod; -use crate::prelude::*; -use pgrx::pg_sys::Datum; -use pgrx::pg_sys::Oid; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use pgrx::FromDatum; -use pgrx::IntoDatum; -use service::prelude::*; -use std::alloc::Layout; -use std::cmp::Ordering; -use std::ffi::CStr; -use std::ffi::CString; -use std::ops::Deref; -use std::ops::DerefMut; -use std::ops::Index; -use std::ops::IndexMut; -use std::ptr::NonNull; - -#[repr(C, align(8))] -pub struct Vecf16 { - varlena: u32, - len: u16, - kind: u8, - reserved: u8, - phantom: [F16; 0], -} - -impl Vecf16 { - fn varlena(size: usize) -> u32 { - (size << 2) as u32 - } - fn layout(len: usize) -> Layout { - u16::try_from(len).expect("Vector is too large."); - let layout_alpha = Layout::new::(); - let layout_beta = Layout::array::(len).unwrap(); - let layout = layout_alpha.extend(layout_beta).unwrap().0; - layout.pad_to_align() - } - pub fn new_in_postgres(slice: &[F16]) -> Vecf16Output { - unsafe { - assert!(1 <= slice.len() && slice.len() <= 65535); - let layout = Vecf16::layout(slice.len()); - let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf16; - ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(1); - std::ptr::addr_of_mut!((*ptr).reserved).write(0); - std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); - std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); - Vecf16Output(NonNull::new(ptr).unwrap()) - } - } - pub fn new_zeroed_in_postgres(size: usize) -> Vecf16Output { - unsafe { - assert!(u16::try_from(size).is_ok()); - let layout = Vecf16::layout(size); - let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vecf16; - ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf16::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(1); - std::ptr::addr_of_mut!((*ptr).reserved).write(0); - std::ptr::addr_of_mut!((*ptr).len).write(size as u16); - Vecf16Output(NonNull::new(ptr).unwrap()) - } - } - pub fn len(&self) -> usize { - self.len as usize - } - pub fn data(&self) -> &[F16] { - debug_assert_eq!(self.varlena & 3, 0); - // TODO: force checking it in the future - // debug_assert_eq!(self.kind, 1); - // debug_assert_eq!(self.reserved, 0); - unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) } - } - pub fn data_mut(&mut self) -> &mut [F16] { - debug_assert_eq!(self.varlena & 3, 0); - // TODO: force checking it in the future - // debug_assert_eq!(self.kind, 1); - // debug_assert_eq!(self.reserved, 0); - unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } - } -} - -impl Deref for Vecf16 { - type Target = [F16]; - - fn deref(&self) -> &Self::Target { - self.data() - } -} - -impl DerefMut for Vecf16 { - fn deref_mut(&mut self) -> &mut Self::Target { - self.data_mut() - } -} - -impl AsRef<[F16]> for Vecf16 { - fn as_ref(&self) -> &[F16] { - self.data() - } -} - -impl AsMut<[F16]> for Vecf16 { - fn as_mut(&mut self) -> &mut [F16] { - self.data_mut() - } -} - -impl Index for Vecf16 { - type Output = F16; - - fn index(&self, index: usize) -> &Self::Output { - self.data().index(index) - } -} - -impl IndexMut for Vecf16 { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - self.data_mut().index_mut(index) - } -} - -impl PartialEq for Vecf16 { - fn eq(&self, other: &Self) -> bool { - if self.len() != other.len() { - return false; - } - let n = self.len(); - for i in 0..n { - if self[i] != other[i] { - return false; - } - } - true - } -} - -impl Eq for Vecf16 {} - -impl PartialOrd for Vecf16 { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Vecf16 { - fn cmp(&self, other: &Self) -> Ordering { - use Ordering::*; - if let x @ Less | x @ Greater = self.len().cmp(&other.len()) { - return x; - } - let n = self.len(); - for i in 0..n { - if let x @ Less | x @ Greater = self[i].cmp(&other[i]) { - return x; - } - } - Equal - } -} - -pub enum Vecf16Input<'a> { - Owned(Vecf16Output), - Borrowed(&'a Vecf16), -} - -impl<'a> Vecf16Input<'a> { - pub unsafe fn new(p: NonNull) -> Self { - let q = unsafe { - NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() - }; - if p != q { - Vecf16Input::Owned(Vecf16Output(q)) - } else { - unsafe { Vecf16Input::Borrowed(p.as_ref()) } - } - } -} - -impl Deref for Vecf16Input<'_> { - type Target = Vecf16; - - fn deref(&self) -> &Self::Target { - match self { - Vecf16Input::Owned(x) => x, - Vecf16Input::Borrowed(x) => x, - } - } -} - -pub struct Vecf16Output(NonNull); - -impl Vecf16Output { - pub fn into_raw(self) -> *mut Vecf16 { - let result = self.0.as_ptr(); - std::mem::forget(self); - result - } -} - -impl Deref for Vecf16Output { - type Target = Vecf16; - - fn deref(&self) -> &Self::Target { - unsafe { self.0.as_ref() } - } -} - -impl DerefMut for Vecf16Output { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { self.0.as_mut() } - } -} - -impl Drop for Vecf16Output { - fn drop(&mut self) { - unsafe { - pgrx::pg_sys::pfree(self.0.as_ptr() as _); - } - } -} - -impl<'a> FromDatum for Vecf16Input<'a> { - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { - if is_null { - None - } else { - let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); - unsafe { Some(Vecf16Input::new(ptr)) } - } - } -} - -impl IntoDatum for Vecf16Output { - fn into_datum(self) -> Option { - Some(Datum::from(self.into_raw() as *mut ())) - } - - fn type_oid() -> Oid { - pgrx::wrappers::regtypein("vectors.vecf16") - } -} - -unsafe impl SqlTranslatable for Vecf16Input<'_> { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("vecf16"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("vecf16")))) - } -} - -unsafe impl SqlTranslatable for Vecf16Output { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("vecf16"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("vecf16")))) - } -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn _vectors_vecf16_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf16Output { - use crate::utils::parse::parse_vector; - let reserve = Typmod::parse_from_i32(typmod) - .unwrap() - .dims() - .map(|x| x.get()) - .unwrap_or(0); - let v = parse_vector(input.to_bytes(), reserve as usize, |s| s.parse().ok()); - match v { - Err(e) => { - bad_literal(&e.to_string()); - } - Ok(vector) => { - check_value_dimensions(vector.len()); - Vecf16::new_in_postgres(&vector) - } - } -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn _vectors_vecf16_out(vector: Vecf16Input<'_>) -> CString { - let mut buffer = String::new(); - buffer.push('['); - if let Some(&x) = vector.data().first() { - buffer.push_str(format!("{}", x).as_str()); - } - for &x in vector.data().iter().skip(1) { - buffer.push_str(format!(", {}", x).as_str()); - } - buffer.push(']'); - CString::new(buffer).unwrap() -} - -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_vecf16_subscript(internal) RETURNS internal -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf16_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { - #[pgrx::pg_guard] - unsafe extern "C" fn transform( - subscript: *mut pgrx::pg_sys::SubscriptingRef, - indirection: *mut pgrx::pg_sys::List, - pstate: *mut pgrx::pg_sys::ParseState, - is_slice: bool, - is_assignment: bool, - ) { - unsafe { - if (*indirection).length != 1 { - pgrx::pg_sys::error!("type vecf16 does only support one subscript"); - } - if !is_slice { - pgrx::pg_sys::error!("type vecf16 does only support slice fetch"); - } - if is_assignment { - pgrx::pg_sys::error!("type vecf16 does not support subscripted assignment"); - } - let subscript = &mut *subscript; - let ai = (*(*indirection).elements.add(0)).ptr_value as *mut pgrx::pg_sys::A_Indices; - subscript.refupperindexpr = pgrx::pg_sys::lappend( - std::ptr::null_mut(), - if !(*ai).uidx.is_null() { - let subexpr = - pgrx::pg_sys::transformExpr(pstate, (*ai).uidx, (*pstate).p_expr_kind); - let subexpr = pgrx::pg_sys::coerce_to_target_type( - pstate, - subexpr, - pgrx::pg_sys::exprType(subexpr), - pgrx::pg_sys::INT4OID, - -1, - pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, - pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, - -1, - ); - if subexpr.is_null() { - pgrx::error!("vecf16 subscript must have type integer"); - } - subexpr.cast() - } else { - std::ptr::null_mut() - }, - ); - subscript.reflowerindexpr = pgrx::pg_sys::lappend( - std::ptr::null_mut(), - if !(*ai).lidx.is_null() { - let subexpr = - pgrx::pg_sys::transformExpr(pstate, (*ai).lidx, (*pstate).p_expr_kind); - let subexpr = pgrx::pg_sys::coerce_to_target_type( - pstate, - subexpr, - pgrx::pg_sys::exprType(subexpr), - pgrx::pg_sys::INT4OID, - -1, - pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, - pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, - -1, - ); - if subexpr.is_null() { - pgrx::error!("vecf16 subscript must have type integer"); - } - subexpr.cast() - } else { - std::ptr::null_mut() - }, - ); - subscript.refrestype = subscript.refcontainertype; - } - } - #[pgrx::pg_guard] - unsafe extern "C" fn exec_setup( - _subscript: *const pgrx::pg_sys::SubscriptingRef, - state: *mut pgrx::pg_sys::SubscriptingRefState, - steps: *mut pgrx::pg_sys::SubscriptExecSteps, - ) { - #[derive(Default)] - struct Workspace { - range: Option<(Option, Option)>, - } - #[pgrx::pg_guard] - unsafe extern "C" fn sbs_check_subscripts( - _state: *mut pgrx::pg_sys::ExprState, - op: *mut pgrx::pg_sys::ExprEvalStep, - _econtext: *mut pgrx::pg_sys::ExprContext, - ) -> bool { - unsafe { - let state = &mut *(*op).d.sbsref.state; - let workspace = &mut *(state.workspace as *mut Workspace); - workspace.range = None; - let mut end = None; - let mut start = None; - if state.upperprovided.read() { - if !state.upperindexnull.read() { - let upper = state.upperindex.read().value() as i32; - if upper >= 0 { - end = Some(upper as usize); - } else { - (*op).resnull.write(true); - return false; - } - } else { - (*op).resnull.write(true); - return false; - } - } - if state.lowerprovided.read() { - if !state.lowerindexnull.read() { - let lower = state.lowerindex.read().value() as i32; - if lower >= 0 { - start = Some(lower as usize); - } else { - (*op).resnull.write(true); - return false; - } - } else { - (*op).resnull.write(true); - return false; - } - } - workspace.range = Some((start, end)); - true - } - } - #[pgrx::pg_guard] - unsafe extern "C" fn sbs_fetch( - _state: *mut pgrx::pg_sys::ExprState, - op: *mut pgrx::pg_sys::ExprEvalStep, - _econtext: *mut pgrx::pg_sys::ExprContext, - ) { - unsafe { - let state = &mut *(*op).d.sbsref.state; - let workspace = &mut *(state.workspace as *mut Workspace); - let input = - Vecf16Input::from_datum((*op).resvalue.read(), (*op).resnull.read()).unwrap(); - let slice = match workspace.range { - Some((None, None)) => input.data().get(..), - Some((None, Some(y))) => input.data().get(..y), - Some((Some(x), None)) => input.data().get(x..), - Some((Some(x), Some(y))) => input.data().get(x..y), - None => None, - }; - if let Some(slice) = slice { - if !slice.is_empty() { - let output = Vecf16::new_in_postgres(slice); - (*op).resnull.write(false); - (*op).resvalue.write(Datum::from(output.into_raw())); - } else { - (*op).resnull.write(true); - } - } else { - (*op).resnull.write(true); - } - } - } - unsafe { - let state = &mut *state; - let steps = &mut *steps; - assert!(state.numlower == 1); - assert!(state.numupper == 1); - state.workspace = pgrx::pg_sys::palloc(std::mem::size_of::()); - std::ptr::write::(state.workspace.cast(), Workspace::default()); - steps.sbs_check_subscripts = Some(sbs_check_subscripts); - steps.sbs_fetch = Some(sbs_fetch); - steps.sbs_assign = None; - steps.sbs_fetch_old = None; - } - } - static SBSROUTINES: pgrx::pg_sys::SubscriptRoutines = pgrx::pg_sys::SubscriptRoutines { - transform: Some(transform), - exec_setup: Some(exec_setup), - fetch_strict: true, - fetch_leakproof: false, - store_leakproof: false, - }; - std::ptr::addr_of!(SBSROUTINES).into() -} - -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_vecf16_send(vecf16) RETURNS bytea -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf16_send(vector: Vecf16Input<'_>) -> Datum { - use pgrx::pg_sys::StringInfoData; - unsafe { - let mut buf = StringInfoData::default(); - let len = vector.len; - let bytes = std::mem::size_of::() * len as usize; - pgrx::pg_sys::pq_begintypsend(&mut buf); - pgrx::pg_sys::pq_sendbytes(&mut buf, (&len) as *const u16 as _, 2); - pgrx::pg_sys::pq_sendbytes(&mut buf, vector.data().as_ptr() as _, bytes as _); - Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) - } -} - -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_vecf16_recv(internal, oid, integer) RETURNS vecf16 -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf16_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> Vecf16Output { - use pgrx::pg_sys::StringInfo; - unsafe { - let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); - let len = (pgrx::pg_sys::pq_getmsgbytes(buf, 2) as *const u16).read_unaligned(); - if len == 0 { - pgrx::error!("data corruption is detected"); - } - let bytes = std::mem::size_of::() * len as usize; - let ptr = pgrx::pg_sys::pq_getmsgbytes(buf, bytes as _); - let mut output = Vecf16::new_zeroed_in_postgres(len as usize); - std::ptr::copy(ptr, output.data_mut().as_mut_ptr() as _, bytes); - output - } -} diff --git a/src/datatype/vecf32.rs b/src/datatype/vecf32.rs deleted file mode 100644 index dd59882e4..000000000 --- a/src/datatype/vecf32.rs +++ /dev/null @@ -1,516 +0,0 @@ -use crate::datatype::typmod::Typmod; -use crate::prelude::*; -use pgrx::pg_sys::Datum; -use pgrx::pg_sys::Oid; -use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; -use pgrx::pgrx_sql_entity_graph::metadata::Returns; -use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; -use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; -use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; -use pgrx::FromDatum; -use pgrx::IntoDatum; -use service::prelude::*; -use std::alloc::Layout; -use std::cmp::Ordering; -use std::ffi::CStr; -use std::ffi::CString; -use std::ops::Deref; -use std::ops::DerefMut; -use std::ops::Index; -use std::ops::IndexMut; -use std::ptr::NonNull; - -#[repr(C, align(8))] -pub struct Vecf32 { - varlena: u32, - len: u16, - kind: u8, - reserved: u8, - phantom: [F32; 0], -} - -impl Vecf32 { - fn varlena(size: usize) -> u32 { - (size << 2) as u32 - } - fn layout(len: usize) -> Layout { - u16::try_from(len).expect("Vector is too large."); - let layout_alpha = Layout::new::(); - let layout_beta = Layout::array::(len).unwrap(); - let layout = layout_alpha.extend(layout_beta).unwrap().0; - layout.pad_to_align() - } - pub fn new_in_postgres(slice: &[F32]) -> Vecf32Output { - unsafe { - assert!(1 <= slice.len() && slice.len() <= 65535); - let layout = Vecf32::layout(slice.len()); - let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut Vecf32; - ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(0); - std::ptr::addr_of_mut!((*ptr).reserved).write(0); - std::ptr::addr_of_mut!((*ptr).len).write(slice.len() as u16); - std::ptr::copy_nonoverlapping(slice.as_ptr(), (*ptr).phantom.as_mut_ptr(), slice.len()); - Vecf32Output(NonNull::new(ptr).unwrap()) - } - } - pub fn new_zeroed_in_postgres(size: usize) -> Vecf32Output { - unsafe { - assert!(u16::try_from(size).is_ok()); - let layout = Vecf32::layout(size); - let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut Vecf32; - ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); - std::ptr::addr_of_mut!((*ptr).varlena).write(Vecf32::varlena(layout.size())); - std::ptr::addr_of_mut!((*ptr).kind).write(0); - std::ptr::addr_of_mut!((*ptr).reserved).write(0); - std::ptr::addr_of_mut!((*ptr).len).write(size as u16); - Vecf32Output(NonNull::new(ptr).unwrap()) - } - } - pub fn len(&self) -> usize { - self.len as usize - } - pub fn data(&self) -> &[F32] { - debug_assert_eq!(self.varlena & 3, 0); - // TODO: force checking it in the future - // debug_assert_eq!(self.kind, 0); - // debug_assert_eq!(self.reserved, 1); - unsafe { std::slice::from_raw_parts(self.phantom.as_ptr(), self.len as usize) } - } - pub fn data_mut(&mut self) -> &mut [F32] { - debug_assert_eq!(self.varlena & 3, 0); - // TODO: force checking it in the future - // debug_assert_eq!(self.kind, 0); - // debug_assert_eq!(self.reserved, 1); - unsafe { std::slice::from_raw_parts_mut(self.phantom.as_mut_ptr(), self.len as usize) } - } -} - -impl Deref for Vecf32 { - type Target = [F32]; - - fn deref(&self) -> &Self::Target { - self.data() - } -} - -impl DerefMut for Vecf32 { - fn deref_mut(&mut self) -> &mut Self::Target { - self.data_mut() - } -} - -impl AsRef<[F32]> for Vecf32 { - fn as_ref(&self) -> &[F32] { - self.data() - } -} - -impl AsMut<[F32]> for Vecf32 { - fn as_mut(&mut self) -> &mut [F32] { - self.data_mut() - } -} - -impl Index for Vecf32 { - type Output = F32; - - fn index(&self, index: usize) -> &Self::Output { - self.data().index(index) - } -} - -impl IndexMut for Vecf32 { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - self.data_mut().index_mut(index) - } -} - -impl PartialEq for Vecf32 { - fn eq(&self, other: &Self) -> bool { - if self.len() != other.len() { - return false; - } - let n = self.len(); - for i in 0..n { - if self[i] != other[i] { - return false; - } - } - true - } -} - -impl Eq for Vecf32 {} - -impl PartialOrd for Vecf32 { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Vecf32 { - fn cmp(&self, other: &Self) -> Ordering { - use Ordering::*; - if let x @ Less | x @ Greater = self.len().cmp(&other.len()) { - return x; - } - let n = self.len(); - for i in 0..n { - if let x @ Less | x @ Greater = self[i].cmp(&other[i]) { - return x; - } - } - Equal - } -} - -pub enum Vecf32Input<'a> { - Owned(Vecf32Output), - Borrowed(&'a Vecf32), -} - -impl<'a> Vecf32Input<'a> { - pub unsafe fn new(p: NonNull) -> Self { - let q = unsafe { - NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() - }; - if p != q { - Vecf32Input::Owned(Vecf32Output(q)) - } else { - unsafe { Vecf32Input::Borrowed(p.as_ref()) } - } - } -} - -impl Deref for Vecf32Input<'_> { - type Target = Vecf32; - - fn deref(&self) -> &Self::Target { - match self { - Vecf32Input::Owned(x) => x, - Vecf32Input::Borrowed(x) => x, - } - } -} - -pub struct Vecf32Output(NonNull); - -impl Vecf32Output { - pub fn into_raw(self) -> *mut Vecf32 { - let result = self.0.as_ptr(); - std::mem::forget(self); - result - } -} - -impl Deref for Vecf32Output { - type Target = Vecf32; - - fn deref(&self) -> &Self::Target { - unsafe { self.0.as_ref() } - } -} - -impl DerefMut for Vecf32Output { - fn deref_mut(&mut self) -> &mut Self::Target { - unsafe { self.0.as_mut() } - } -} - -impl Drop for Vecf32Output { - fn drop(&mut self) { - unsafe { - pgrx::pg_sys::pfree(self.0.as_ptr() as _); - } - } -} - -impl<'a> FromDatum for Vecf32Input<'a> { - unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { - if is_null { - None - } else { - let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); - unsafe { Some(Vecf32Input::new(ptr)) } - } - } -} - -impl IntoDatum for Vecf32Output { - fn into_datum(self) -> Option { - Some(Datum::from(self.into_raw() as *mut ())) - } - - fn type_oid() -> Oid { - pgrx::wrappers::regtypein("vectors.vector") - } -} - -unsafe impl SqlTranslatable for Vecf32Input<'_> { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("vector"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("vector")))) - } -} - -unsafe impl SqlTranslatable for Vecf32Output { - fn argument_sql() -> Result { - Ok(SqlMapping::As(String::from("vector"))) - } - fn return_sql() -> Result { - Ok(Returns::One(SqlMapping::As(String::from("vector")))) - } -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn _vectors_vecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf32Output { - use crate::utils::parse::parse_vector; - let reserve = Typmod::parse_from_i32(typmod) - .unwrap() - .dims() - .map(|x| x.get()) - .unwrap_or(0); - let v = parse_vector(input.to_bytes(), reserve as usize, |s| s.parse().ok()); - match v { - Err(e) => { - bad_literal(&e.to_string()); - } - Ok(vector) => { - check_value_dimensions(vector.len()); - Vecf32::new_in_postgres(&vector) - } - } -} - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] -fn _vectors_vecf32_out(vector: Vecf32Input<'_>) -> CString { - let mut buffer = String::new(); - buffer.push('['); - if let Some(&x) = vector.data().first() { - buffer.push_str(format!("{}", x).as_str()); - } - for &x in vector.data().iter().skip(1) { - buffer.push_str(format!(", {}", x).as_str()); - } - buffer.push(']'); - CString::new(buffer).unwrap() -} - -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_vecf32_subscript(internal) RETURNS internal -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { - #[pgrx::pg_guard] - unsafe extern "C" fn transform( - subscript: *mut pgrx::pg_sys::SubscriptingRef, - indirection: *mut pgrx::pg_sys::List, - pstate: *mut pgrx::pg_sys::ParseState, - is_slice: bool, - is_assignment: bool, - ) { - unsafe { - if (*indirection).length != 1 { - pgrx::pg_sys::error!("type vector does only support one subscript"); - } - if !is_slice { - pgrx::pg_sys::error!("type vector does only support slice fetch"); - } - if is_assignment { - pgrx::pg_sys::error!("type vector does not support subscripted assignment"); - } - let subscript = &mut *subscript; - let ai = (*(*indirection).elements.add(0)).ptr_value as *mut pgrx::pg_sys::A_Indices; - subscript.refupperindexpr = pgrx::pg_sys::lappend( - std::ptr::null_mut(), - if !(*ai).uidx.is_null() { - let subexpr = - pgrx::pg_sys::transformExpr(pstate, (*ai).uidx, (*pstate).p_expr_kind); - let subexpr = pgrx::pg_sys::coerce_to_target_type( - pstate, - subexpr, - pgrx::pg_sys::exprType(subexpr), - pgrx::pg_sys::INT4OID, - -1, - pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, - pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, - -1, - ); - if subexpr.is_null() { - pgrx::error!("vector subscript must have type integer"); - } - subexpr.cast() - } else { - std::ptr::null_mut() - }, - ); - subscript.reflowerindexpr = pgrx::pg_sys::lappend( - std::ptr::null_mut(), - if !(*ai).lidx.is_null() { - let subexpr = - pgrx::pg_sys::transformExpr(pstate, (*ai).lidx, (*pstate).p_expr_kind); - let subexpr = pgrx::pg_sys::coerce_to_target_type( - pstate, - subexpr, - pgrx::pg_sys::exprType(subexpr), - pgrx::pg_sys::INT4OID, - -1, - pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, - pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, - -1, - ); - if subexpr.is_null() { - pgrx::error!("vector subscript must have type integer"); - } - subexpr.cast() - } else { - std::ptr::null_mut() - }, - ); - subscript.refrestype = subscript.refcontainertype; - } - } - #[pgrx::pg_guard] - unsafe extern "C" fn exec_setup( - _subscript: *const pgrx::pg_sys::SubscriptingRef, - state: *mut pgrx::pg_sys::SubscriptingRefState, - steps: *mut pgrx::pg_sys::SubscriptExecSteps, - ) { - #[derive(Default)] - struct Workspace { - range: Option<(Option, Option)>, - } - #[pgrx::pg_guard] - unsafe extern "C" fn sbs_check_subscripts( - _state: *mut pgrx::pg_sys::ExprState, - op: *mut pgrx::pg_sys::ExprEvalStep, - _econtext: *mut pgrx::pg_sys::ExprContext, - ) -> bool { - unsafe { - let state = &mut *(*op).d.sbsref.state; - let workspace = &mut *(state.workspace as *mut Workspace); - workspace.range = None; - let mut end = None; - let mut start = None; - if state.upperprovided.read() { - if !state.upperindexnull.read() { - let upper = state.upperindex.read().value() as i32; - if upper >= 0 { - end = Some(upper as usize); - } else { - (*op).resnull.write(true); - return false; - } - } else { - (*op).resnull.write(true); - return false; - } - } - if state.lowerprovided.read() { - if !state.lowerindexnull.read() { - let lower = state.lowerindex.read().value() as i32; - if lower >= 0 { - start = Some(lower as usize); - } else { - (*op).resnull.write(true); - return false; - } - } else { - (*op).resnull.write(true); - return false; - } - } - workspace.range = Some((start, end)); - true - } - } - #[pgrx::pg_guard] - unsafe extern "C" fn sbs_fetch( - _state: *mut pgrx::pg_sys::ExprState, - op: *mut pgrx::pg_sys::ExprEvalStep, - _econtext: *mut pgrx::pg_sys::ExprContext, - ) { - unsafe { - let state = &mut *(*op).d.sbsref.state; - let workspace = &mut *(state.workspace as *mut Workspace); - let input = - Vecf32Input::from_datum((*op).resvalue.read(), (*op).resnull.read()).unwrap(); - let slice = match workspace.range { - Some((None, None)) => input.data().get(..), - Some((None, Some(y))) => input.data().get(..y), - Some((Some(x), None)) => input.data().get(x..), - Some((Some(x), Some(y))) => input.data().get(x..y), - None => None, - }; - if let Some(slice) = slice { - if !slice.is_empty() { - let output = Vecf32::new_in_postgres(slice); - (*op).resnull.write(false); - (*op).resvalue.write(Datum::from(output.into_raw())); - } else { - (*op).resnull.write(true); - } - } else { - (*op).resnull.write(true); - } - } - } - unsafe { - let state = &mut *state; - let steps = &mut *steps; - assert!(state.numlower == 1); - assert!(state.numupper == 1); - state.workspace = pgrx::pg_sys::palloc(std::mem::size_of::()); - std::ptr::write::(state.workspace.cast(), Workspace::default()); - steps.sbs_check_subscripts = Some(sbs_check_subscripts); - steps.sbs_fetch = Some(sbs_fetch); - steps.sbs_assign = None; - steps.sbs_fetch_old = None; - } - } - static SBSROUTINES: pgrx::pg_sys::SubscriptRoutines = pgrx::pg_sys::SubscriptRoutines { - transform: Some(transform), - exec_setup: Some(exec_setup), - fetch_strict: true, - fetch_leakproof: false, - store_leakproof: false, - }; - std::ptr::addr_of!(SBSROUTINES).into() -} - -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_vecf32_send(vector) RETURNS bytea -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf32_send(vector: Vecf32Input<'_>) -> Datum { - use pgrx::pg_sys::StringInfoData; - unsafe { - let mut buf = StringInfoData::default(); - let len = vector.len; - let bytes = std::mem::size_of::() * len as usize; - pgrx::pg_sys::pq_begintypsend(&mut buf); - pgrx::pg_sys::pq_sendbytes(&mut buf, (&len) as *const u16 as _, 2); - pgrx::pg_sys::pq_sendbytes(&mut buf, vector.data().as_ptr() as _, bytes as _); - Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) - } -} - -#[pgrx::pg_extern(sql = " -CREATE FUNCTION _vectors_vecf32_recv(internal, oid, integer) RETURNS vector -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf32_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> Vecf32Output { - use pgrx::pg_sys::StringInfo; - unsafe { - let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); - let len = (pgrx::pg_sys::pq_getmsgbytes(buf, 2) as *const u16).read_unaligned(); - if len == 0 { - pgrx::error!("data corruption is detected"); - } - let bytes = std::mem::size_of::() * len as usize; - let ptr = pgrx::pg_sys::pq_getmsgbytes(buf, bytes as _); - let mut output = Vecf32::new_zeroed_in_postgres(len as usize); - std::ptr::copy(ptr, output.data_mut().as_mut_ptr() as _, bytes); - output - } -} diff --git a/src/embedding/mod.rs b/src/embedding/mod.rs index 1a3fe9f23..7ea35378b 100644 --- a/src/embedding/mod.rs +++ b/src/embedding/mod.rs @@ -1,8 +1,8 @@ -use crate::datatype::vecf32::{Vecf32, Vecf32Output}; +use crate::datatype::memory_vecf32::Vecf32Output; use crate::gucs::embedding::openai_options; +use crate::prelude::*; use embedding::openai_embedding; use pgrx::error; -use service::prelude::F32; #[pgrx::pg_extern(volatile, strict)] fn _vectors_text2vec_openai(input: String, model: String) -> Vecf32Output { @@ -16,5 +16,5 @@ fn _vectors_text2vec_openai(input: String, model: String) -> Vecf32Output { Err(e) => error!("{}", e.to_string()), }; - Vecf32::new_in_postgres(&embedding) + Vecf32Output::new(Vecf32Borrowed::new(&embedding)) } diff --git a/src/gucs/executing.rs b/src/gucs/executing.rs index 08d874846..95cff8faf 100644 --- a/src/gucs/executing.rs +++ b/src/gucs/executing.rs @@ -1,5 +1,5 @@ +use crate::prelude::*; use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting}; -use service::index::SearchOptions; static ENABLE_PREFILTER: GucSetting = GucSetting::::new(true); diff --git a/src/index/am.rs b/src/index/am.rs index 202e30b86..7e3cec4df 100644 --- a/src/index/am.rs +++ b/src/index/am.rs @@ -9,7 +9,6 @@ use crate::index::utils::from_datum; use crate::prelude::*; use crate::utils::cells::PgCell; use pgrx::pg_sys::Datum; -use service::prelude::*; static RELOPT_KIND: PgCell = unsafe { PgCell::new(0) }; diff --git a/src/index/am_build.rs b/src/index/am_build.rs index 38652c3b4..4b21c5afe 100644 --- a/src/index/am_build.rs +++ b/src/index/am_build.rs @@ -5,7 +5,6 @@ use crate::index::utils::from_datum; use crate::ipc::ClientRpc; use crate::prelude::*; use pgrx::pg_sys::{IndexBuildResult, IndexInfo, RelationData}; -use service::prelude::*; pub struct Builder { pub rpc: ClientRpc, diff --git a/src/index/am_scan.rs b/src/index/am_scan.rs index 6bbf09709..246c455f2 100644 --- a/src/index/am_scan.rs +++ b/src/index/am_scan.rs @@ -7,12 +7,11 @@ use crate::index::utils::from_datum; use crate::ipc::{ClientBasic, ClientVbase}; use crate::prelude::*; use pgrx::FromDatum; -use service::prelude::*; pub enum Scanner { Initial { node: Option<*mut pgrx::pg_sys::IndexScanState>, - vector: Option, + vector: Option, }, Basic { node: *mut pgrx::pg_sys::IndexScanState, diff --git a/src/index/am_setup.rs b/src/index/am_setup.rs index 6e1b1bfda..f71a03688 100644 --- a/src/index/am_setup.rs +++ b/src/index/am_setup.rs @@ -3,11 +3,6 @@ use crate::datatype::typmod::Typmod; use crate::prelude::*; use serde::Deserialize; -use service::index::indexing::IndexingOptions; -use service::index::optimizing::OptimizingOptions; -use service::index::segments::SegmentsOptions; -use service::index::{IndexOptions, VectorOptions}; -use service::prelude::*; use std::ffi::CStr; pub fn helper_offset() -> usize { @@ -18,7 +13,9 @@ pub fn helper_size() -> usize { std::mem::size_of::() } -pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> (Distance, Kind) { +pub unsafe fn convert_opclass_to_distance( + opclass: pgrx::pg_sys::Oid, +) -> (DistanceKind, VectorKind) { let opclass_cache_id = pgrx::pg_sys::SysCacheIdentifier_CLAOID as _; let tuple = pgrx::pg_sys::SearchSysCache1(opclass_cache_id, opclass.into()); assert!( @@ -32,7 +29,9 @@ pub unsafe fn convert_opclass_to_distance(opclass: pgrx::pg_sys::Oid) -> (Distan result } -pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> (Distance, Kind) { +pub unsafe fn convert_opfamily_to_distance( + opfamily: pgrx::pg_sys::Oid, +) -> (DistanceKind, VectorKind) { let opfamily_cache_id = pgrx::pg_sys::SysCacheIdentifier_OPFAMILYOID as _; let opstrategy_cache_id = pgrx::pg_sys::SysCacheIdentifier_AMOPSTRATEGY as _; let tuple = pgrx::pg_sys::SearchSysCache1(opfamily_cache_id, opfamily.into()); @@ -56,23 +55,23 @@ pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> (Dist let operator = (*amop).amopopr; let result; if operator == regoperatorin("vectors.<->(vectors.vector,vectors.vector)") { - result = (Distance::L2, Kind::F32); + result = (DistanceKind::L2, VectorKind::Vecf32); } else if operator == regoperatorin("vectors.<#>(vectors.vector,vectors.vector)") { - result = (Distance::Dot, Kind::F32); + result = (DistanceKind::Dot, VectorKind::Vecf32); } else if operator == regoperatorin("vectors.<=>(vectors.vector,vectors.vector)") { - result = (Distance::Cos, Kind::F32); + result = (DistanceKind::Cos, VectorKind::Vecf32); } else if operator == regoperatorin("vectors.<->(vectors.vecf16,vectors.vecf16)") { - result = (Distance::L2, Kind::F16); + result = (DistanceKind::L2, VectorKind::Vecf16); } else if operator == regoperatorin("vectors.<#>(vectors.vecf16,vectors.vecf16)") { - result = (Distance::Dot, Kind::F16); + result = (DistanceKind::Dot, VectorKind::Vecf16); } else if operator == regoperatorin("vectors.<=>(vectors.vecf16,vectors.vecf16)") { - result = (Distance::Cos, Kind::F16); + result = (DistanceKind::Cos, VectorKind::Vecf16); } else if operator == regoperatorin("vectors.<->(vectors.svector,vectors.svector)") { - result = (Distance::L2, Kind::SparseF32); + result = (DistanceKind::L2, VectorKind::SVecf32); } else if operator == regoperatorin("vectors.<#>(vectors.svector,vectors.svector)") { - result = (Distance::Dot, Kind::SparseF32); + result = (DistanceKind::Dot, VectorKind::SVecf32); } else if operator == regoperatorin("vectors.<=>(vectors.svector,vectors.svector)") { - result = (Distance::Cos, Kind::SparseF32); + result = (DistanceKind::Cos, VectorKind::SVecf32); } else { bad_opclass(); }; @@ -91,11 +90,11 @@ pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions { let attrs = (*(*index_relation).rd_att).attrs.as_slice(1); let attr = &attrs[0]; let typmod = Typmod::parse_from_i32(attr.type_mod()).unwrap(); - let dims = check_column_dimensions(typmod.dims()).get(); + let dims = check_column_dims(typmod.dims()).get(); // get other options let parsed = get_parsed_from_varlena((*index_relation).rd_options); IndexOptions { - vector: VectorOptions { dims, d, k }, + vector: VectorOptions { dims, d, v: k }, segment: parsed.segment, optimizing: parsed.optimizing, indexing: parsed.indexing, diff --git a/src/index/am_update.rs b/src/index/am_update.rs index 638db50ca..5d6fdbe99 100644 --- a/src/index/am_update.rs +++ b/src/index/am_update.rs @@ -1,8 +1,7 @@ use crate::index::hook_transaction::callback_dirty; use crate::prelude::*; -use service::prelude::*; -pub fn update_insert(handle: Handle, vector: DynamicVector, tid: pgrx::pg_sys::ItemPointerData) { +pub fn update_insert(handle: Handle, vector: OwnedVector, tid: pgrx::pg_sys::ItemPointerData) { callback_dirty(handle); let pointer = Pointer::from_sys(tid); diff --git a/src/index/hook_transaction.rs b/src/index/hook_transaction.rs index 530fd6a32..9af88ded6 100644 --- a/src/index/hook_transaction.rs +++ b/src/index/hook_transaction.rs @@ -1,6 +1,5 @@ use crate::prelude::*; use crate::utils::cells::PgRefCell; -use service::prelude::*; use std::collections::BTreeSet; use std::ops::DerefMut; diff --git a/src/index/utils.rs b/src/index/utils.rs index 132fbe9ca..ebd285454 100644 --- a/src/index/utils.rs +++ b/src/index/utils.rs @@ -1,27 +1,32 @@ #![allow(unsafe_op_in_unsafe_fn)] -use crate::datatype::svecf32::SVecf32; -use crate::datatype::vecf16::Vecf16; -use crate::datatype::vecf32::Vecf32; -use service::prelude::*; +use crate::datatype::memory_svecf32::SVecf32Header; +use crate::datatype::memory_vecf16::Vecf16Header; +use crate::datatype::memory_vecf32::Vecf32Header; +use crate::prelude::*; #[repr(C, align(8))] struct Header { varlena: u32, - len: u16, - kind: u8, - reserved: u8, + dims: u16, + kind: u16, } -pub unsafe fn from_datum(datum: pgrx::pg_sys::Datum) -> DynamicVector { +pub unsafe fn from_datum(datum: pgrx::pg_sys::Datum) -> OwnedVector { let p = datum.cast_mut_ptr::(); let q = pgrx::pg_sys::pg_detoast_datum(p); let vector = match (*q.cast::
()).kind { - 0 => DynamicVector::F32((*q.cast::()).data().to_vec()), - 1 => DynamicVector::F16((*q.cast::()).data().to_vec()), + 0 => { + let v = &*q.cast::(); + OwnedVector::Vecf32(v.for_borrow().for_own()) + } + 1 => { + let v = &*q.cast::(); + OwnedVector::Vecf16(v.for_borrow().for_own()) + } 2 => { - let svec = &*q.cast::(); - DynamicVector::SparseF32(SparseF32::from(svec.data())) + let v = &*q.cast::(); + OwnedVector::SVecF32(v.for_borrow().for_own()) } _ => unreachable!(), }; diff --git a/src/index/views.rs b/src/index/views.rs index 5b2d316e7..6163a538d 100644 --- a/src/index/views.rs +++ b/src/index/views.rs @@ -1,12 +1,10 @@ use crate::prelude::*; -use service::prelude::*; #[pgrx::pg_extern(volatile, strict)] fn _vectors_index_stat( oid: pgrx::pg_sys::Oid, ) -> pgrx::composite_type!('static, "vectors.vector_index_stat") { use pgrx::heap_tuple::PgHeapTuple; - use service::index::IndexStat; let id = Handle::from_sys(oid); let mut res = PgHeapTuple::new_composite_type("vectors.vector_index_stat").unwrap(); let mut rpc = check_client(crate::ipc::client()); diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs index 5c9d46cd7..3ce17a076 100644 --- a/src/ipc/mod.rs +++ b/src/ipc/mod.rs @@ -7,10 +7,6 @@ use crate::ipc::transport::Packet; use crate::prelude::*; use crate::utils::cells::PgRefCell; use serde::{Deserialize, Serialize}; -use service::index::IndexOptions; -use service::index::IndexStat; -use service::index::SearchOptions; -use service::prelude::*; #[derive(Debug, Clone)] pub enum ConnectionError { @@ -323,10 +319,10 @@ defines! { unary create(handle: Handle, options: IndexOptions) -> (); unary drop(handle: Handle) -> (); unary flush(handle: Handle) -> (); - unary insert(handle: Handle, vector: DynamicVector, pointer: Pointer) -> (); + unary insert(handle: Handle, vector: OwnedVector, pointer: Pointer) -> (); unary delete(handle: Handle, pointer: Pointer) -> (); - stream basic(handle: Handle, vector: DynamicVector, opts: SearchOptions) -> Pointer; - stream vbase(handle: Handle, vector: DynamicVector, opts: SearchOptions) -> Pointer; + stream basic(handle: Handle, vector: OwnedVector, opts: SearchOptions) -> Pointer; + stream vbase(handle: Handle, vector: OwnedVector, opts: SearchOptions) -> Pointer; stream list(handle: Handle) -> Pointer; unary stat(handle: Handle) -> IndexStat; } diff --git a/src/lib.rs b/src/lib.rs index 54d25db80..2e70cf7d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,9 @@ //! Postgres vector extension. //! //! Provides an easy-to-use extension for vector similarity search. -#![feature(never_type)] #![feature(alloc_error_hook)] +#![allow(clippy::needless_range_loop)] +#![allow(clippy::too_many_arguments)] mod bgworker; mod datatype; @@ -33,9 +34,6 @@ unsafe extern "C" fn _PG_init() { } } -#[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "freebsd")))] -compile_error!("Target is not supported."); - #[cfg(not(all(target_endian = "little", target_pointer_width = "64")))] compile_error!("Target is not supported."); diff --git a/src/prelude/error.rs b/src/prelude/error.rs index 571a03873..c56ed9ea8 100644 --- a/src/prelude/error.rs +++ b/src/prelude/error.rs @@ -19,8 +19,8 @@ INFORMATION: GUC = {key}, hint = {hint}" ); } -pub fn check_type_dimensions(dimensions: Option) -> NonZeroU16 { - match dimensions { +pub fn check_type_dims(dims: Option) -> NonZeroU16 { + match dims { None => { error!( "\ @@ -32,11 +32,8 @@ ADVICE: Check if modifier of the type is an integer among 1 and 65535." } } -pub fn check_value_dimensions(dimensions: usize) -> NonZeroU16 { - match u16::try_from(dimensions) - .and_then(NonZeroU16::try_from) - .ok() - { +pub fn check_value_dims(dims: usize) -> NonZeroU16 { + match u16::try_from(dims).and_then(NonZeroU16::try_from).ok() { None => { error!( "\ @@ -57,20 +54,20 @@ INFORMATION: hint = {hint}" } #[inline(always)] -pub fn check_matched_dimensions(left_dimensions: usize, right_dimensions: usize) -> usize { - if left_dimensions != right_dimensions { +pub fn check_matched_dims(left_dims: usize, right_dims: usize) -> usize { + if left_dims != right_dims { error!( "\ pgvecto.rs: Operands of the operator differs in dimensions or scalar type. -INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}", +INFORMATION: left_dimensions = {left_dims}, right_dimensions = {right_dims}", ) } - left_dimensions + left_dims } #[inline(always)] -pub fn check_column_dimensions(dimensions: Option) -> NonZeroU16 { - match dimensions { +pub fn check_column_dims(dims: Option) -> NonZeroU16 { + match dims { None => error!( "\ pgvecto.rs: Dimensions type modifier of a vector column is needed for building the index.", diff --git a/src/prelude/mod.rs b/src/prelude/mod.rs index af364e499..fa77bc414 100644 --- a/src/prelude/mod.rs +++ b/src/prelude/mod.rs @@ -3,3 +3,11 @@ mod sys; pub use error::*; pub use sys::{FromSys, IntoSys}; + +pub use base::distance::*; +pub use base::error::*; +pub use base::index::*; +pub use base::scalar::*; +pub use base::search::*; +pub use base::vector::*; +pub use num_traits::Zero; diff --git a/src/prelude/sys.rs b/src/prelude/sys.rs index 0c56d5d72..ea596fdb8 100644 --- a/src/prelude/sys.rs +++ b/src/prelude/sys.rs @@ -1,4 +1,4 @@ -use service::prelude::*; +use crate::prelude::*; pub trait FromSys { fn from_sys(sys: T) -> Self; diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index b490e6684..a29560753 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -5,9 +5,9 @@ CREATE TYPE vector ( INPUT = _vectors_vecf32_in, OUTPUT = _vectors_vecf32_out, - RECEIVE = _vectors_vecf32_recv, - SEND = _vectors_vecf32_send, - SUBSCRIPT = _vectors_vecf32_subscript, + RECEIVE = _vectors_vecf32_recv, + SEND = _vectors_vecf32_send, + SUBSCRIPT = _vectors_vecf32_subscript, TYPMOD_IN = _vectors_typmod_in, TYPMOD_OUT = _vectors_typmod_out, STORAGE = EXTERNAL, @@ -18,9 +18,9 @@ CREATE TYPE vector ( CREATE TYPE vecf16 ( INPUT = _vectors_vecf16_in, OUTPUT = _vectors_vecf16_out, - RECEIVE = _vectors_vecf16_recv, - SEND = _vectors_vecf16_send, - SUBSCRIPT = _vectors_vecf16_subscript, + RECEIVE = _vectors_vecf16_recv, + SEND = _vectors_vecf16_send, + SUBSCRIPT = _vectors_vecf16_subscript, TYPMOD_IN = _vectors_typmod_in, TYPMOD_OUT = _vectors_typmod_out, STORAGE = EXTERNAL, @@ -31,9 +31,9 @@ CREATE TYPE vecf16 ( CREATE TYPE svector ( INPUT = _vectors_svecf32_in, OUTPUT = _vectors_svecf32_out, - RECEIVE = _vectors_svecf32_recv, - SEND = _vectors_svecf32_send, - SUBSCRIPT = _vectors_svecf32_subscript, + RECEIVE = _vectors_svecf32_recv, + SEND = _vectors_svecf32_send, + SUBSCRIPT = _vectors_svecf32_subscript, TYPMOD_IN = _vectors_typmod_in, TYPMOD_OUT = _vectors_typmod_out, STORAGE = EXTERNAL, @@ -55,287 +55,287 @@ CREATE TYPE vector_index_stat AS ( -- List of operators CREATE OPERATOR + ( - PROCEDURE = _vectors_vecf32_operator_add, - LEFTARG = vector, - RIGHTARG = vector, - COMMUTATOR = + + PROCEDURE = _vectors_vecf32_operator_add, + LEFTARG = vector, + RIGHTARG = vector, + COMMUTATOR = + ); CREATE OPERATOR + ( - PROCEDURE = _vectors_vecf16_operator_add, - LEFTARG = vecf16, - RIGHTARG = vecf16, - COMMUTATOR = + + PROCEDURE = _vectors_vecf16_operator_add, + LEFTARG = vecf16, + RIGHTARG = vecf16, + COMMUTATOR = + ); CREATE OPERATOR + ( - PROCEDURE = _vectors_svecf32_operator_add, - LEFTARG = svector, - RIGHTARG = svector, - COMMUTATOR = + + PROCEDURE = _vectors_svecf32_operator_add, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = + ); CREATE OPERATOR - ( - PROCEDURE = _vectors_vecf32_operator_minus, - LEFTARG = vector, - RIGHTARG = vector + PROCEDURE = _vectors_vecf32_operator_minus, + LEFTARG = vector, + RIGHTARG = vector ); CREATE OPERATOR - ( - PROCEDURE = _vectors_vecf16_operator_minus, - LEFTARG = vecf16, - RIGHTARG = vecf16 + PROCEDURE = _vectors_vecf16_operator_minus, + LEFTARG = vecf16, + RIGHTARG = vecf16 ); CREATE OPERATOR - ( - PROCEDURE = _vectors_svecf32_operator_minus, - LEFTARG = svector, - RIGHTARG = svector + PROCEDURE = _vectors_svecf32_operator_minus, + LEFTARG = svector, + RIGHTARG = svector ); CREATE OPERATOR = ( - PROCEDURE = _vectors_vecf32_operator_eq, - LEFTARG = vector, - RIGHTARG = vector, - COMMUTATOR = =, - NEGATOR = <>, - RESTRICT = eqsel, - JOIN = eqjoinsel + PROCEDURE = _vectors_vecf32_operator_eq, + LEFTARG = vector, + RIGHTARG = vector, + COMMUTATOR = =, + NEGATOR = <>, + RESTRICT = eqsel, + JOIN = eqjoinsel ); CREATE OPERATOR = ( - PROCEDURE = _vectors_vecf16_operator_eq, - LEFTARG = vecf16, - RIGHTARG = vecf16, - COMMUTATOR = =, - NEGATOR = <>, - RESTRICT = eqsel, - JOIN = eqjoinsel + PROCEDURE = _vectors_vecf16_operator_eq, + LEFTARG = vecf16, + RIGHTARG = vecf16, + COMMUTATOR = =, + NEGATOR = <>, + RESTRICT = eqsel, + JOIN = eqjoinsel ); CREATE OPERATOR = ( - PROCEDURE = _vectors_svecf32_operator_eq, - LEFTARG = svector, - RIGHTARG = svector, - COMMUTATOR = =, - NEGATOR = <>, - RESTRICT = eqsel, - JOIN = eqjoinsel + PROCEDURE = _vectors_svecf32_operator_eq, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = =, + NEGATOR = <>, + RESTRICT = eqsel, + JOIN = eqjoinsel ); CREATE OPERATOR <> ( - PROCEDURE = _vectors_vecf32_operator_neq, - LEFTARG = vector, - RIGHTARG = vector, - COMMUTATOR = <>, - NEGATOR = =, - RESTRICT = eqsel, - JOIN = eqjoinsel + PROCEDURE = _vectors_vecf32_operator_neq, + LEFTARG = vector, + RIGHTARG = vector, + COMMUTATOR = <>, + NEGATOR = =, + RESTRICT = eqsel, + JOIN = eqjoinsel ); CREATE OPERATOR <> ( - PROCEDURE = _vectors_vecf16_operator_neq, - LEFTARG = vecf16, - RIGHTARG = vecf16, - COMMUTATOR = <>, - NEGATOR = =, - RESTRICT = eqsel, - JOIN = eqjoinsel + PROCEDURE = _vectors_vecf16_operator_neq, + LEFTARG = vecf16, + RIGHTARG = vecf16, + COMMUTATOR = <>, + NEGATOR = =, + RESTRICT = eqsel, + JOIN = eqjoinsel ); CREATE OPERATOR <> ( - PROCEDURE = _vectors_svecf32_operator_neq, - LEFTARG = svector, - RIGHTARG = svector, - COMMUTATOR = <>, - NEGATOR = =, - RESTRICT = eqsel, - JOIN = eqjoinsel + PROCEDURE = _vectors_svecf32_operator_neq, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <>, + NEGATOR = =, + RESTRICT = eqsel, + JOIN = eqjoinsel ); CREATE OPERATOR < ( - PROCEDURE = _vectors_vecf32_operator_lt, - LEFTARG = vector, - RIGHTARG = vector, - COMMUTATOR = >, - NEGATOR = >=, - RESTRICT = scalarltsel, - JOIN = scalarltjoinsel + PROCEDURE = _vectors_vecf32_operator_lt, + LEFTARG = vector, + RIGHTARG = vector, + COMMUTATOR = >, + NEGATOR = >=, + RESTRICT = scalarltsel, + JOIN = scalarltjoinsel ); CREATE OPERATOR < ( - PROCEDURE = _vectors_vecf16_operator_lt, - LEFTARG = vecf16, - RIGHTARG = vecf16, - COMMUTATOR = >, - NEGATOR = >=, - RESTRICT = scalarltsel, - JOIN = scalarltjoinsel + PROCEDURE = _vectors_vecf16_operator_lt, + LEFTARG = vecf16, + RIGHTARG = vecf16, + COMMUTATOR = >, + NEGATOR = >=, + RESTRICT = scalarltsel, + JOIN = scalarltjoinsel ); CREATE OPERATOR < ( - PROCEDURE = _vectors_svecf32_operator_lt, - LEFTARG = svector, - RIGHTARG = svector, - COMMUTATOR = >, - NEGATOR = >=, - RESTRICT = scalarltsel, - JOIN = scalarltjoinsel + PROCEDURE = _vectors_svecf32_operator_lt, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = >, + NEGATOR = >=, + RESTRICT = scalarltsel, + JOIN = scalarltjoinsel ); CREATE OPERATOR > ( - PROCEDURE = _vectors_vecf32_operator_gt, - LEFTARG = vector, - RIGHTARG = vector, - COMMUTATOR = <, - NEGATOR = <=, - RESTRICT = scalargtsel, - JOIN = scalargtjoinsel + PROCEDURE = _vectors_vecf32_operator_gt, + LEFTARG = vector, + RIGHTARG = vector, + COMMUTATOR = <, + NEGATOR = <=, + RESTRICT = scalargtsel, + JOIN = scalargtjoinsel ); CREATE OPERATOR > ( - PROCEDURE = _vectors_vecf16_operator_gt, - LEFTARG = vecf16, - RIGHTARG = vecf16, - COMMUTATOR = <, - NEGATOR = <=, - RESTRICT = scalargtsel, - JOIN = scalargtjoinsel + PROCEDURE = _vectors_vecf16_operator_gt, + LEFTARG = vecf16, + RIGHTARG = vecf16, + COMMUTATOR = <, + NEGATOR = <=, + RESTRICT = scalargtsel, + JOIN = scalargtjoinsel ); CREATE OPERATOR > ( - PROCEDURE = _vectors_svecf32_operator_gt, - LEFTARG = svector, - RIGHTARG = svector, - COMMUTATOR = <, - NEGATOR = <=, - RESTRICT = scalargtsel, - JOIN = scalargtjoinsel + PROCEDURE = _vectors_svecf32_operator_gt, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <, + NEGATOR = <=, + RESTRICT = scalargtsel, + JOIN = scalargtjoinsel ); CREATE OPERATOR <= ( - PROCEDURE = _vectors_vecf32_operator_lte, - LEFTARG = vector, - RIGHTARG = vector, - COMMUTATOR = >=, - NEGATOR = >, - RESTRICT = scalarltsel, - JOIN = scalarltjoinsel + PROCEDURE = _vectors_vecf32_operator_lte, + LEFTARG = vector, + RIGHTARG = vector, + COMMUTATOR = >=, + NEGATOR = >, + RESTRICT = scalarltsel, + JOIN = scalarltjoinsel ); CREATE OPERATOR <= ( - PROCEDURE = _vectors_vecf16_operator_lte, - LEFTARG = vecf16, - RIGHTARG = vecf16, - COMMUTATOR = >=, - NEGATOR = >, - RESTRICT = scalarltsel, - JOIN = scalarltjoinsel + PROCEDURE = _vectors_vecf16_operator_lte, + LEFTARG = vecf16, + RIGHTARG = vecf16, + COMMUTATOR = >=, + NEGATOR = >, + RESTRICT = scalarltsel, + JOIN = scalarltjoinsel ); CREATE OPERATOR <= ( - PROCEDURE = _vectors_svecf32_operator_lte, - LEFTARG = svector, - RIGHTARG = svector, - COMMUTATOR = >=, - NEGATOR = >, - RESTRICT = scalarltsel, - JOIN = scalarltjoinsel + PROCEDURE = _vectors_svecf32_operator_lte, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = >=, + NEGATOR = >, + RESTRICT = scalarltsel, + JOIN = scalarltjoinsel ); CREATE OPERATOR >= ( - PROCEDURE = _vectors_vecf32_operator_gte, - LEFTARG = vector, - RIGHTARG = vector, - COMMUTATOR = <=, - NEGATOR = <, - RESTRICT = scalargtsel, - JOIN = scalargtjoinsel + PROCEDURE = _vectors_vecf32_operator_gte, + LEFTARG = vector, + RIGHTARG = vector, + COMMUTATOR = <=, + NEGATOR = <, + RESTRICT = scalargtsel, + JOIN = scalargtjoinsel ); CREATE OPERATOR >= ( - PROCEDURE = _vectors_vecf16_operator_gte, - LEFTARG = vecf16, - RIGHTARG = vecf16, - COMMUTATOR = <=, - NEGATOR = <, - RESTRICT = scalargtsel, - JOIN = scalargtjoinsel + PROCEDURE = _vectors_vecf16_operator_gte, + LEFTARG = vecf16, + RIGHTARG = vecf16, + COMMUTATOR = <=, + NEGATOR = <, + RESTRICT = scalargtsel, + JOIN = scalargtjoinsel ); CREATE OPERATOR >= ( - PROCEDURE = _vectors_svecf32_operator_gte, - LEFTARG = svector, - RIGHTARG = svector, - COMMUTATOR = <=, - NEGATOR = <, - RESTRICT = scalargtsel, - JOIN = scalargtjoinsel + PROCEDURE = _vectors_svecf32_operator_gte, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <=, + NEGATOR = <, + RESTRICT = scalargtsel, + JOIN = scalargtjoinsel ); CREATE OPERATOR <-> ( - PROCEDURE = _vectors_vecf32_operator_l2, - LEFTARG = vector, - RIGHTARG = vector, - COMMUTATOR = <-> + PROCEDURE = _vectors_vecf32_operator_l2, + LEFTARG = vector, + RIGHTARG = vector, + COMMUTATOR = <-> ); CREATE OPERATOR <-> ( - PROCEDURE = _vectors_vecf16_operator_l2, - LEFTARG = vecf16, - RIGHTARG = vecf16, - COMMUTATOR = <-> + PROCEDURE = _vectors_vecf16_operator_l2, + LEFTARG = vecf16, + RIGHTARG = vecf16, + COMMUTATOR = <-> ); CREATE OPERATOR <-> ( - PROCEDURE = _vectors_svecf32_operator_l2, - LEFTARG = svector, - RIGHTARG = svector, - COMMUTATOR = <-> + PROCEDURE = _vectors_svecf32_operator_l2, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <-> ); CREATE OPERATOR <#> ( - PROCEDURE = _vectors_vecf32_operator_dot, - LEFTARG = vector, - RIGHTARG = vector, - COMMUTATOR = <#> + PROCEDURE = _vectors_vecf32_operator_dot, + LEFTARG = vector, + RIGHTARG = vector, + COMMUTATOR = <#> ); CREATE OPERATOR <#> ( - PROCEDURE = _vectors_vecf16_operator_dot, - LEFTARG = vecf16, - RIGHTARG = vecf16, - COMMUTATOR = <#> + PROCEDURE = _vectors_vecf16_operator_dot, + LEFTARG = vecf16, + RIGHTARG = vecf16, + COMMUTATOR = <#> ); CREATE OPERATOR <#> ( - PROCEDURE = _vectors_svecf32_operator_dot, - LEFTARG = svector, - RIGHTARG = svector, - COMMUTATOR = <#> + PROCEDURE = _vectors_svecf32_operator_dot, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <#> ); CREATE OPERATOR <=> ( - PROCEDURE = _vectors_vecf32_operator_cosine, - LEFTARG = vector, - RIGHTARG = vector, - COMMUTATOR = <=> + PROCEDURE = _vectors_vecf32_operator_cosine, + LEFTARG = vector, + RIGHTARG = vector, + COMMUTATOR = <=> ); CREATE OPERATOR <=> ( - PROCEDURE = _vectors_vecf16_operator_cosine, - LEFTARG = vecf16, - RIGHTARG = vecf16, - COMMUTATOR = <=> + PROCEDURE = _vectors_vecf16_operator_cosine, + LEFTARG = vecf16, + RIGHTARG = vecf16, + COMMUTATOR = <=> ); CREATE OPERATOR <=> ( - PROCEDURE = _vectors_svecf32_operator_cosine, - LEFTARG = svector, - RIGHTARG = svector, - COMMUTATOR = <=> + PROCEDURE = _vectors_svecf32_operator_cosine, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <=> ); -- List of functions @@ -343,7 +343,7 @@ CREATE OPERATOR <=> ( CREATE FUNCTION pgvectors_upgrade() RETURNS void STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_pgvectors_upgrade_wrapper'; -CREATE FUNCTION to_svector(dims INT, indices INT[], vals real[]) RETURNS svector +CREATE FUNCTION to_svector(dims INT, indexes INT[], "values" real[]) RETURNS svector IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_to_svector_wrapper'; CREATE FUNCTION text2vec_openai(input TEXT, model TEXT) RETURNS vector