diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index f756ddb28..218216796 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -46,8 +46,6 @@ jobs: strategy: matrix: include: - - { version: 12, os: "ubuntu-latest" } - - { version: 13, os: "ubuntu-latest" } - { version: 14, os: "ubuntu-latest" } - { version: 15, os: "ubuntu-latest" } - { version: 16, os: "ubuntu-latest" } @@ -56,19 +54,67 @@ jobs: VERSION: ${{ matrix.version }} OS: ${{ matrix.os }} steps: - - uses: actions/checkout@v3 - - uses: actions/cache/restore@v3 + - uses: actions/checkout@v4 + - uses: actions/cache/restore@v4 + id: cache with: path: | ~/.cargo/registry/index/ ~/.cargo/registry/cache/ ~/.cargo/git/db/ - key: cargo-${{ matrix.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.lock') }} - restore-keys: cargo-${{ matrix.os }}-pg${{ matrix.version }} - - uses: mozilla-actions/sccache-action@v0.0.3 + key: ${{ github.job }}-${{ matrix.version }}-${{ matrix.os }}-${{ hashFiles('./Cargo.lock') }} + - uses: mozilla-actions/sccache-action@v0.0.4 - name: Setup shell: bash - run: ./scripts/ci_setup.sh + run: | + ./scripts/ci_setup.sh + curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash + cargo binstall sqllogictest-bin -y --force + cargo install cargo-pgrx@$(grep 'pgrx = {' Cargo.toml | cut -d '"' -f 2 | head -n 1) --debug + cargo pgrx init --pg$VERSION=$(which pg_config) + - name: Install release + run: ./scripts/ci_install.sh + - uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Test + run: ./tests/tests.sh + - uses: actions/cache/save@v4 + if: ${{ !steps.cache.outputs.cache-hit }} + with: + path: | + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ github.job }}-${{ matrix.version }}-${{ matrix.os }}-${{ hashFiles('./Cargo.lock') }} + debug_check: + strategy: + matrix: + include: + - { version: 14, os: "ubuntu-latest" } + - { version: 15, os: "ubuntu-latest" } + - { version: 16, os: "ubuntu-latest" } + runs-on: ${{ matrix.os }} + env: + VERSION: ${{ matrix.version }} + OS: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + - uses: actions/cache/restore@v4 + id: cache + with: + path: | + ~/.cargo/registry/index/ + ~/.cargo/registry/cache/ + ~/.cargo/git/db/ + key: ${{ github.job }}-${{ matrix.version }}-${{ matrix.os }}-${{ hashFiles('./Cargo.lock') }} + - uses: mozilla-actions/sccache-action@v0.0.4 + - name: Setup + shell: bash + run: | + ./scripts/ci_setup.sh + cargo install cargo-pgrx@$(grep 'pgrx = {' Cargo.toml | cut -d '"' -f 2 | head -n 1) --debug + cargo pgrx init --pg$VERSION=$(which pg_config) - name: Format check run: cargo fmt --check - name: Semantic check @@ -82,17 +128,11 @@ jobs: - name: Test run: | cargo test --all --no-fail-fast --no-default-features --features "pg${{ matrix.version }} pg_test" --target x86_64-unknown-linux-gnu -- --nocapture - - name: Install release - run: ./scripts/ci_install.sh - - uses: actions/setup-python@v5 - with: - python-version: "3.10" - - name: Test 2 - run: ./tests/tests.sh - - uses: actions/cache/save@v3 + - uses: actions/cache/save@v4 + if: ${{ !steps.cache.outputs.cache-hit }} with: path: | ~/.cargo/registry/index/ ~/.cargo/registry/cache/ ~/.cargo/git/db/ - key: cargo-${{ matrix.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.lock') }} + key: ${{ github.job }}-${{ matrix.version }}-${{ matrix.os }}-${{ hashFiles('./Cargo.lock') }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c673cdc00..c9cae5330 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -55,14 +55,6 @@ jobs: run: | sed -i "s/@CARGO_VERSION@/${{ needs.semver.outputs.version }}/g" ./vectors.control cat ./vectors.control - - uses: actions/cache/restore@v3 - with: - path: | - ~/.cargo/registry/index/ - ~/.cargo/registry/cache/ - ~/.cargo/git/db/ - key: cargo-${{ runner.os }}-pg${{ matrix.version }}-${{ hashFiles('./Cargo.lock') }} - restore-keys: cargo-${{ runner.os }}-pg${{ matrix.version }} - uses: mozilla-actions/sccache-action@v0.0.3 - name: Prepare run: | diff --git a/Cargo.lock b/Cargo.lock index 0ef1ea605..609d782de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,17 +26,59 @@ dependencies = [ "memchr", ] +[[package]] +name = "anstream" +version = "0.6.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96b09b5178381e0874812a9b157f7fe84982617e48f71f4e3235482775e5b540" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "utf8parse", +] + [[package]] name = "anstyle" -version = "1.0.5" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" + +[[package]] +name = "anstyle-parse" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c75ac65da39e5fe5ab759307499ddad880d724eed2f6ce5b5e8a26f4f387928c" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2faccea4cc4ab4a667ce676a30e8ec13922a692c99bb8f5b11f1502c72e04220" +checksum = "e28923312444cdd728e4738b3f9c9cac739500909bb3d3c94b43551b16517648" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cd54b81ec8d6180e24654d0b371ad22fc3dd083b6ff8ba325b72e00c87660a7" +dependencies = [ + "anstyle", + "windows-sys 0.52.0", +] [[package]] name = "anyhow" -version = "1.0.79" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" [[package]] name = "arc-swap" @@ -49,9 +91,6 @@ name = "arrayvec" version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" -dependencies = [ - "serde", -] [[package]] name = "async-trait" @@ -61,7 +100,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] @@ -104,6 +143,36 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "base" +version = "0.0.0" +dependencies = [ + "arc-swap", + "bincode", + "bytemuck", + "byteorder", + "c", + "crc32fast", + "crossbeam", + "dashmap", + "detect", + "half 2.3.1", + "libc", + "log", + "memmap2", + "multiversion", + "num-traits", + "parking_lot", + "rand", + "rayon", + "rustix", + "serde", + "serde_json", + "thiserror", + "uuid", + "validator", +] + [[package]] name = "base64" version = "0.21.7" @@ -121,22 +190,22 @@ dependencies = [ [[package]] name = "bindgen" -version = "0.69.2" +version = "0.69.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4c69fae65a523209d34240b60abe0c42d33d1045d445c0839d8a4894a736e2d" +checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" dependencies = [ "bitflags 2.4.2", "cexpr", "clang-sys", + "itertools", "lazy_static", "lazycell", - "peeking_take_while", "proc-macro2", "quote", "regex", "rustc-hash", "shlex", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] @@ -189,15 +258,15 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.14.0" +version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" +checksum = "d32a994c2b3ca201d9b263612a374263f05e7adde37c4707f693dcd375076d1f" [[package]] name = "bytemuck" -version = "1.14.1" +version = "1.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed2490600f404f2b94c167e31d3ed1d5f3c225a0f3b80230053b3e0b7b962bd9" +checksum = "a2ef034f05691a48569bd920a96c81b9d91bbad1ab5ac7c4616c1f6ef36cb79f" dependencies = [ "bytemuck_derive", ] @@ -210,7 +279,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] @@ -282,9 +351,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.4.18" +version = "4.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e578d6ec4194633722ccf9544794b71b1385c3c027efe0c55db226fc880865c" +checksum = "c918d541ef2913577a0f9566e9ce27cb35b6df072075769e0b26cb5a554520da" dependencies = [ "clap_builder", "clap_derive", @@ -302,9 +371,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.4.18" +version = "4.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4df4df40ec50c46000231c914968278b1eb05098cf8f1b3a518a95030e71d1c7" +checksum = "9f3e7391dad68afb0c2ede1bf619f579a3dc9c2ec67f089baa397123a2f3d1eb" dependencies = [ "anstyle", "clap_lex", @@ -312,21 +381,27 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.4.7" +version = "4.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" +checksum = "307bc0538d5f0f83b8248db3087aa92fe504e4691294d0c96c0eabc33f47ba47" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] name = "clap_lex" -version = "0.6.0" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" + +[[package]] +name = "colorchoice" +version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" +checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "convert_case" @@ -354,9 +429,9 @@ dependencies = [ [[package]] name = "crc32fast" -version = "1.3.2" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" dependencies = [ "cfg-if", ] @@ -506,9 +581,9 @@ checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" [[package]] name = "either" -version = "1.9.0" +version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" [[package]] name = "enum-map" @@ -527,20 +602,30 @@ checksum = "f282cfdfe92516eb26c2af8589c274c7c17681f5ecc03c18255fe741c6aa64eb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", +] + +[[package]] +name = "env_filter" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" +dependencies = [ + "log", + "regex", ] [[package]] name = "env_logger" -version = "0.10.2" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" +checksum = "6c012a26a7f605efc424dd53697843a72be7dc86ad2d01f7814337794a12231d" dependencies = [ + "anstream", + "anstyle", + "env_filter", "humantime", - "is-terminal", "log", - "regex", - "termcolor", ] [[package]] @@ -638,7 +723,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] @@ -756,12 +841,6 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" -[[package]] -name = "hermit-abi" -version = "0.3.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d3d0e0f38255e7fa3cf31335b3a56f05febd18025f4db5ef7a0cfb4f8da651f" - [[package]] name = "hmac" version = "0.12.1" @@ -811,23 +890,29 @@ checksum = "ce23b50ad8242c51a442f3ff322d56b02f08852c77e4c0b4d3fd684abc89c683" [[package]] name = "indexmap" -version = "2.2.2" +version = "2.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "824b2ae422412366ba479e8111fd301f7b5faece8149317bb81925979a53f520" +checksum = "233cf39063f058ea2caae4091bf4a3ef70a653afbc026f5c4a4135d114e3c177" dependencies = [ "equivalent", "hashbrown", ] [[package]] -name = "is-terminal" -version = "0.4.10" +name = "interprocess_atomic_wait" +version = "0.0.0" +dependencies = [ + "libc", + "ulock-sys", +] + +[[package]] +name = "itertools" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" dependencies = [ - "hermit-abi", - "rustix", - "windows-sys 0.52.0", + "either", ] [[package]] @@ -838,9 +923,9 @@ checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] name = "js-sys" -version = "0.3.67" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a1d36f1235bc969acba30b7f5990b864423a6068a10f7c90ae8f0112e3a59d1" +checksum = "406cda4b368d531c842222cf9d2600a9a4acce8d29423695379c6868a143a9ee" dependencies = [ "wasm-bindgen", ] @@ -928,6 +1013,15 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" +[[package]] +name = "memfd" +version = "0.0.0" +dependencies = [ + "detect", + "rand", + "rustix", +] + [[package]] name = "memmap2" version = "0.9.4" @@ -954,9 +1048,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.1" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7810e0be55b428ada41041c41f32c9f1a42817901b4ccf45fa3d4b6561e74c7" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" dependencies = [ "adler", ] @@ -1015,9 +1109,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.17" +version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" dependencies = [ "autocfg", "libm", @@ -1073,6 +1167,12 @@ dependencies = [ "windows-targets 0.48.5", ] +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + [[package]] name = "pathsearch" version = "0.2.0" @@ -1083,12 +1183,6 @@ dependencies = [ "libc", ] -[[package]] -name = "peeking_take_while" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" - [[package]] name = "percent-encoding" version = "2.3.1" @@ -1097,9 +1191,9 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" -version = "2.7.6" +version = "2.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f200d8d83c44a45b21764d1916299752ca035d15ecd46faca3e9a2a2bf6ad06" +checksum = "219c0dcc30b6a27553f9cc242972b67f75b60eb0db71f0b5462f38b058c41546" dependencies = [ "memchr", "thiserror", @@ -1526,7 +1620,7 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" dependencies = [ - "semver 1.0.21", + "semver 1.0.22", ] [[package]] @@ -1556,9 +1650,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "same-file" @@ -1592,9 +1686,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.21" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" [[package]] name = "semver-parser" @@ -1605,6 +1699,15 @@ dependencies = [ "pest", ] +[[package]] +name = "send_fd" +version = "0.0.0" +dependencies = [ + "libc", + "log", + "rustix", +] + [[package]] name = "seq-macro" version = "0.3.5" @@ -1638,7 +1741,7 @@ checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] @@ -1666,7 +1769,7 @@ name = "service" version = "0.0.0" dependencies = [ "arc-swap", - "arrayvec", + "base", "bincode", "bytemuck", "byteorder", @@ -1679,7 +1782,6 @@ dependencies = [ "libc", "log", "memmap2", - "memoffset", "multiversion", "num-traits", "parking_lot", @@ -1689,7 +1791,6 @@ dependencies = [ "serde", "serde_json", "thiserror", - "ulock-sys", "uuid", "validator", ] @@ -1802,9 +1903,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.48" +version = "2.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +checksum = "915aea9e586f80826ee59f8453c1101f9d1c4b3964cd2460185ee8e299ada496" dependencies = [ "proc-macro2", "quote", @@ -1840,44 +1941,34 @@ checksum = "cfb5fa503293557c5158bd215fdc225695e567a77e453f5d4452a50a193969bd" [[package]] name = "tempfile" -version = "3.9.0" +version = "3.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01ce4141aa927a6d1bd34a041795abd0db1cccba5d5f24b009f694bdf3a1f3fa" +checksum = "a365e8cd18e44762ef95d87f284f4b5cd04107fec2ff3052bd6a3e6069669e67" dependencies = [ "cfg-if", "fastrand", - "redox_syscall", "rustix", "windows-sys 0.52.0", ] -[[package]] -name = "termcolor" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" -dependencies = [ - "winapi-util", -] - [[package]] name = "thiserror" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.56" +version = "1.0.57" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", ] [[package]] @@ -1897,9 +1988,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.35.1" +version = "1.36.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" +checksum = "61285f6515fa018fb2d1e46eb21223fff441ee8db5d0f1435e8ab4f5cdb80931" dependencies = [ "backtrace", "bytes", @@ -1952,9 +2043,9 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.9" +version = "0.8.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6a4b9e8023eb94392d3dca65d717c53abc5dad49c07cb65bb8fcd87115fa325" +checksum = "9a9aad4a3066010876e8dcf5a8a06e70a558751117a145c6ce2b82c2e2054290" dependencies = [ "serde", "serde_spanned", @@ -1973,9 +2064,9 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.21.1" +version = "0.22.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6a8534fd7f78b5405e860340ad6575217ce99f38d4d5c8f2442cb5ecb50090e1" +checksum = "2c1b5fd4128cc8d3e0cb74d4ed9a9cc7c7284becd4df68f5f940e1ad123606f6" dependencies = [ "indexmap", "serde", @@ -2059,9 +2150,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.10.1" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +checksum = "d4c87d22b6e3f4a18d4d40ef354e97c90fcb14dd91d7dc0aa9d8a1172ebf7202" [[package]] name = "url" @@ -2074,6 +2165,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf8parse" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" + [[package]] name = "uuid" version = "1.7.0" @@ -2131,19 +2228,25 @@ name = "vectors" version = "0.0.0" dependencies = [ "arrayvec", + "base", "bincode", "bytemuck", "byteorder", "detect", "env_logger", "half 2.3.1", + "interprocess_atomic_wait", "libc", "log", + "memfd", + "memmap2", "num-traits", + "paste", "pgrx", "pgrx-tests", "rand", "rustix", + "send_fd", "serde", "serde_json", "service", @@ -2185,9 +2288,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1223296a201415c7fad14792dbefaace9bd52b62d33453ade1c5b5f07555406" +checksum = "c1e124130aee3fb58c5bdd6b639a0509486b0338acaaae0c84a5124b0f588b7f" dependencies = [ "cfg-if", "wasm-bindgen-macro", @@ -2195,24 +2298,24 @@ dependencies = [ [[package]] name = "wasm-bindgen-backend" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcdc935b63408d58a32f8cc9738a0bffd8f05cc7c002086c6ef20b7312ad9dcd" +checksum = "c9e7e1900c352b609c8488ad12639a311045f40a35491fb69ba8c12f758af70b" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e4c238561b2d428924c49815533a8b9121c664599558a5d9ec51f8a1740a999" +checksum = "b30af9e2d358182b5c7449424f017eba305ed32a7010509ede96cdc4696c46ed" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -2220,28 +2323,28 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bae1abb6806dc1ad9e560ed242107c0f6c84335f1749dd4e8ddb012ebd5e25a7" +checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.48", + "syn 2.0.49", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" -version = "0.2.90" +version = "0.2.91" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d91413b1c31d7539ba5ef2451af3f0b833a005eb27a631cec32bc0635a8602b" +checksum = "4f186bd2dcf04330886ce82d6f33dd75a7bfcf69ecf5763b89fcde53b6ac9838" [[package]] name = "web-sys" -version = "0.3.67" +version = "0.3.68" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "58cd2333b6e0be7a39605f0e255892fd7418a682d8da8fe042fe25128794d2ed" +checksum = "96565907687f7aceb35bc5fc03770a8a0471d82e479f25832f54a0e3f4b28446" dependencies = [ "js-sys", "wasm-bindgen", @@ -2422,9 +2525,9 @@ checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" [[package]] name = "winnow" -version = "0.5.36" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "818ce546a11a9986bc24f93d0cdf38a8a1a400f1473ea8c82e59f6e0ffab9249" +checksum = "d90f4e0f530c4c69f62b80d839e9ef3855edc9cba471a160c4d692deed62b401" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index 4c6fe2852..5b7e5896b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,33 +8,37 @@ crate-type = ["cdylib", "lib"] [features] default = ["pg15"] -pg12 = ["pgrx/pg12", "pgrx-tests/pg12"] -pg13 = ["pgrx/pg13", "pgrx-tests/pg13"] pg14 = ["pgrx/pg14", "pgrx-tests/pg14"] pg15 = ["pgrx/pg15", "pgrx-tests/pg15"] pg16 = ["pgrx/pg16", "pgrx-tests/pg16"] pg_test = [] [dependencies] +arrayvec.workspace = true +bincode.workspace = true +bytemuck.workspace = true +byteorder.workspace = true +half.workspace = true libc.workspace = true log.workspace = true +memmap2.workspace = true +num-traits.workspace = true +paste.workspace = true +rand.workspace = true +rustix.workspace = true serde.workspace = true serde_json.workspace = true -validator.workspace = true -rustix.workspace = true thiserror.workspace = true -byteorder.workspace = true -bincode.workspace = true -half.workspace = true -num-traits.workspace = true -rand.workspace = true -bytemuck.workspace = true -service = { path = "crates/service" } +validator.workspace = true +base = { path = "crates/base" } detect = { path = "crates/detect" } +send_fd = { path = "crates/send_fd" } +service = { path = "crates/service" } +interprocess_atomic_wait = { path = "crates/interprocess-atomic-wait" } +memfd = { path = "crates/memfd" } pgrx = { version = "0.11.3", default-features = false, features = [] } -env_logger = "0.10.0" -toml = "0.8.8" -arrayvec = "0.7.4" +env_logger = "0.11.2" +toml = "0.8.10" [dev-dependencies] pgrx-tests = "0.11.3" @@ -60,24 +64,29 @@ version = "0.0.0" edition = "2021" [workspace.dependencies] -libc = "~0.2" -log = "~0.4" -serde = "~1.0" -serde_json = "1" -thiserror = "~1.0" +arrayvec = "~0.7" bincode = "~1.3" -byteorder = "~1.5" bytemuck = { version = "~1.14", features = ["extern_crate_alloc"] } +byteorder = "~1.5" half = { version = "~2.3", features = [ "bytemuck", "num-traits", "serde", "use-intrinsics", + "rand_distr", ] } +libc = "~0.2" +log = "~0.4" +memmap2 = "0.9.4" num-traits = "~0.2" -validator = { version = "~0.16", features = ["derive"] } +paste = "~1.0" +rand = "0.8.5" rustix = { version = "~0.38", features = ["fs", "net", "mm"] } -rand = "~0.8" +serde = "~1.0" +serde_json = "~1.0" +thiserror = "~1.0" +uuid = { version = "1.7.0", features = ["v4", "serde"] } +validator = { version = "~0.16", features = ["derive"] } [profile.dev] panic = "unwind" diff --git a/README.md b/README.md index 956b34daa..fca3e4a92 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ docker run \ --name pgvecto-rs-demo \ -e POSTGRES_PASSWORD=mysecretpassword \ -p 5432:5432 \ - -d tensorchord/pgvecto-rs:pg16-v0.1.14-beta + -d tensorchord/pgvecto-rs:pg16-v0.2.0 ``` Then you can connect to the database using the `psql` command line tool. The default username is `postgres`, and the default password is `mysecretpassword`. diff --git a/bindings/python/tests/__init__.py b/bindings/python/tests/__init__.py index 0a2e9ace2..b4b7391f7 100644 --- a/bindings/python/tests/__init__.py +++ b/bindings/python/tests/__init__.py @@ -55,7 +55,7 @@ OP_NEG_DOT_PROD_DIS = [1, 2, 4] EXPECTED_NEG_DOT_PROD_DIS = [-17.0, 80.64, -7.0] OP_NEG_COS_DIS = [3, 2, 1] -EXPECTED_NEG_COS_DIS = [-0.7142857, 0.5199225, -0.92582005] +EXPECTED_NEG_COS_DIS = [0.28571427, 1.5199225, 0.07417989] # ==== test_delete ==== LEN_AFT_DEL = 2 diff --git a/crates/base/Cargo.toml b/crates/base/Cargo.toml new file mode 100644 index 000000000..79c2f8d59 --- /dev/null +++ b/crates/base/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "base" +version.workspace = true +edition.workspace = true + +[dependencies] +bincode.workspace = true +bytemuck.workspace = true +byteorder.workspace = true +half.workspace = true +libc.workspace = true +log.workspace = true +memmap2.workspace = true +num-traits.workspace = true +rand.workspace = true +rustix.workspace = true +serde.workspace = true +serde_json.workspace = true +thiserror.workspace = true +uuid.workspace = true +validator.workspace = true +c = { path = "../c" } +detect = { path = "../detect" } +crc32fast = "1.4.0" +crossbeam = "0.8.4" +dashmap = "5.5.3" +parking_lot = "0.12.1" +rayon = "1.8.1" +arc-swap = "1.6.0" +multiversion = "0.7.3" + +[lints] +clippy.derivable_impls = "allow" +clippy.len_without_is_empty = "allow" +clippy.needless_range_loop = "allow" +clippy.too_many_arguments = "allow" +rust.internal_features = "allow" +rust.unsafe_op_in_unsafe_fn = "forbid" +rust.unused_lifetimes = "warn" +rust.unused_qualifications = "warn" diff --git a/crates/base/src/error.rs b/crates/base/src/error.rs new file mode 100644 index 000000000..91781b7f5 --- /dev/null +++ b/crates/base/src/error.rs @@ -0,0 +1,95 @@ +use serde::{Deserialize, Serialize}; +use thiserror::Error; + +// control plane + +#[must_use] +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum CreateError { + #[error("Index of given name already exists.")] + Exist, + #[error("Invalid index options.")] + InvalidIndexOptions { reason: String }, +} + +#[must_use] +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum DropError { + #[error("Index not found.")] + NotExist, +} + +// data plane + +#[must_use] +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum FlushError { + #[error("Index not found.")] + NotExist, + #[error("Maintenance should be done.")] + Upgrade, +} + +#[must_use] +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum InsertError { + #[error("Index not found.")] + NotExist, + #[error("Maintenance should be done.")] + Upgrade, + #[error("Invalid vector.")] + InvalidVector, +} + +#[must_use] +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum DeleteError { + #[error("Index not found.")] + NotExist, + #[error("Maintenance should be done.")] + Upgrade, +} + +#[must_use] +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum BasicError { + #[error("Index not found.")] + NotExist, + #[error("Maintenance should be done.")] + Upgrade, + #[error("Invalid vector.")] + InvalidVector, + #[error("Invalid search options.")] + InvalidSearchOptions { reason: String }, +} + +#[must_use] +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum VbaseError { + #[error("Index not found.")] + NotExist, + #[error("Maintenance should be done.")] + Upgrade, + #[error("Invalid vector.")] + InvalidVector, + #[error("Invalid search options.")] + InvalidSearchOptions { reason: String }, +} + +#[must_use] +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum ListError { + #[error("Index not found.")] + NotExist, + #[error("Maintenance should be done.")] + Upgrade, +} + +#[must_use] +#[derive(Debug, Clone, Error, Serialize, Deserialize)] +pub enum StatError { + #[error("Index not found.")] + NotExist, + #[error("Maintenance should be done.")] + Upgrade, +} diff --git a/crates/base/src/lib.rs b/crates/base/src/lib.rs new file mode 100644 index 000000000..653153e13 --- /dev/null +++ b/crates/base/src/lib.rs @@ -0,0 +1,7 @@ +#![feature(core_intrinsics)] + +pub mod error; +pub mod scalar; +pub mod search; +pub mod sys; +pub mod vector; diff --git a/crates/service/src/prelude/scalar/f16.rs b/crates/base/src/scalar/f16.rs similarity index 99% rename from crates/service/src/prelude/scalar/f16.rs rename to crates/base/src/scalar/f16.rs index 467542f06..da5735b37 100644 --- a/crates/service/src/prelude/scalar/f16.rs +++ b/crates/base/src/scalar/f16.rs @@ -1,4 +1,4 @@ -use crate::prelude::global::FloatCast; +use super::FloatCast; use half::f16; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; diff --git a/crates/service/src/prelude/scalar/f32.rs b/crates/base/src/scalar/f32.rs similarity index 99% rename from crates/service/src/prelude/scalar/f32.rs rename to crates/base/src/scalar/f32.rs index a4e70a10a..c6e431bcb 100644 --- a/crates/service/src/prelude/scalar/f32.rs +++ b/crates/base/src/scalar/f32.rs @@ -1,4 +1,4 @@ -use crate::prelude::global::FloatCast; +use super::FloatCast; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::fmt::{Debug, Display}; diff --git a/crates/base/src/scalar/mod.rs b/crates/base/src/scalar/mod.rs new file mode 100644 index 000000000..8e30d33c7 --- /dev/null +++ b/crates/base/src/scalar/mod.rs @@ -0,0 +1,16 @@ +mod f16; +mod f32; + +pub use f16::F16; +pub use f32::F32; + +pub trait FloatCast: Sized { + fn from_f32(x: f32) -> Self; + fn to_f32(self) -> f32; + fn from_f(x: F32) -> Self { + Self::from_f32(x.0) + } + fn to_f(self) -> F32 { + F32(Self::to_f32(self)) + } +} diff --git a/crates/service/src/prelude/search.rs b/crates/base/src/search.rs similarity index 90% rename from crates/service/src/prelude/search.rs rename to crates/base/src/search.rs index 2009730ee..c5e946bd5 100644 --- a/crates/service/src/prelude/search.rs +++ b/crates/base/src/search.rs @@ -1,4 +1,4 @@ -use crate::prelude::F32; +use crate::scalar::F32; pub type Payload = u64; diff --git a/crates/service/src/prelude/sys.rs b/crates/base/src/sys.rs similarity index 100% rename from crates/service/src/prelude/sys.rs rename to crates/base/src/sys.rs diff --git a/crates/base/src/vector/mod.rs b/crates/base/src/vector/mod.rs new file mode 100644 index 000000000..8f6117772 --- /dev/null +++ b/crates/base/src/vector/mod.rs @@ -0,0 +1,19 @@ +mod sparse_f32; + +pub use sparse_f32::{SparseF32, SparseF32Ref}; + +pub trait Vector { + fn dims(&self) -> u16; +} + +impl Vector for Vec { + fn dims(&self) -> u16 { + self.len().try_into().unwrap() + } +} + +impl<'a, T> Vector for &'a [T] { + fn dims(&self) -> u16 { + self.len().try_into().unwrap() + } +} diff --git a/crates/base/src/vector/sparse_f32.rs b/crates/base/src/vector/sparse_f32.rs new file mode 100644 index 000000000..d52903205 --- /dev/null +++ b/crates/base/src/vector/sparse_f32.rs @@ -0,0 +1,64 @@ +use super::Vector; +use crate::scalar::F32; +use num_traits::Zero; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SparseF32 { + pub dims: u16, + pub indexes: Vec, + pub values: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct SparseF32Ref<'a> { + pub dims: u16, + pub indexes: &'a [u16], + pub values: &'a [F32], +} + +impl<'a> From> for SparseF32 { + fn from(value: SparseF32Ref<'a>) -> Self { + Self { + dims: value.dims, + indexes: value.indexes.to_vec(), + values: value.values.to_vec(), + } + } +} + +impl<'a> From<&'a SparseF32> for SparseF32Ref<'a> { + fn from(value: &'a SparseF32) -> Self { + Self { + dims: value.dims, + indexes: &value.indexes, + values: &value.values, + } + } +} + +impl Vector for SparseF32 { + fn dims(&self) -> u16 { + self.dims + } +} + +impl<'a> Vector for SparseF32Ref<'a> { + fn dims(&self) -> u16 { + self.dims + } +} + +impl<'a> SparseF32Ref<'a> { + pub fn to_dense(&self) -> Vec { + let mut dense = vec![F32::zero(); self.dims as usize]; + for (&index, &value) in self.indexes.iter().zip(self.values.iter()) { + dense[index as usize] = value; + } + dense + } + + pub fn length(&self) -> u16 { + self.indexes.len().try_into().unwrap() + } +} diff --git a/crates/c/Cargo.toml b/crates/c/Cargo.toml index f0f0274bb..1c8631903 100644 --- a/crates/c/Cargo.toml +++ b/crates/c/Cargo.toml @@ -4,9 +4,9 @@ version.workspace = true edition.workspace = true [dev-dependencies] -half = { version = "~2.3", features = ["use-intrinsics", "rand_distr"] } +half.workspace = true +rand.workspace = true detect = { path = "../detect" } -rand = "0.8.5" [build-dependencies] cc = "1.0" diff --git a/crates/detect/Cargo.toml b/crates/detect/Cargo.toml index aaae19282..1bc7a99f0 100644 --- a/crates/detect/Cargo.toml +++ b/crates/detect/Cargo.toml @@ -4,5 +4,5 @@ version.workspace = true edition.workspace = true [dependencies] -std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" } rustix.workspace = true +std_detect = { git = "https://github.com/tensorchord/stdarch.git", branch = "avx512fp16" } diff --git a/crates/interprocess-atomic-wait/Cargo.toml b/crates/interprocess-atomic-wait/Cargo.toml new file mode 100644 index 000000000..2f36d9edb --- /dev/null +++ b/crates/interprocess-atomic-wait/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "interprocess_atomic_wait" +version.workspace = true +edition.workspace = true + +[dependencies] +libc.workspace = true + +[target.'cfg(target_os = "macos")'.dependencies] +ulock-sys = "0.1.0" + +[lints] +rust.internal_features = "allow" +rust.unsafe_op_in_unsafe_fn = "forbid" +rust.unused_lifetimes = "warn" +rust.unused_qualifications = "warn" diff --git a/crates/interprocess-atomic-wait/src/lib.rs b/crates/interprocess-atomic-wait/src/lib.rs new file mode 100644 index 000000000..324bd8a70 --- /dev/null +++ b/crates/interprocess-atomic-wait/src/lib.rs @@ -0,0 +1,91 @@ +use std::sync::atomic::AtomicU32; +use std::time::Duration; + +#[cfg(target_os = "linux")] +#[inline(always)] +pub fn wait(futex: &AtomicU32, value: u32, timeout: Duration) { + let timeout = libc::timespec { + tv_sec: i64::try_from(timeout.as_secs()).expect("Timeout is overflow."), + tv_nsec: timeout.subsec_nanos().into(), + }; + unsafe { + libc::syscall( + libc::SYS_futex, + futex.as_ptr(), + libc::FUTEX_WAIT, + value, + &timeout, + ); + } +} + +#[cfg(target_os = "linux")] +#[inline(always)] +pub fn wake(futex: &AtomicU32) { + unsafe { + libc::syscall(libc::SYS_futex, futex.as_ptr(), libc::FUTEX_WAKE, i32::MAX); + } +} + +#[cfg(target_os = "macos")] +#[inline(always)] +pub fn wait(futex: &AtomicU32, value: u32, timeout: Duration) { + let timeout = u32::try_from(timeout.as_millis()).expect("Timeout is overflow."); + unsafe { + // https://github.com/apple-oss-distributions/xnu/blob/main/bsd/kern/sys_ulock.c#L531 + ulock_sys::__ulock_wait( + ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, + futex.as_ptr().cast(), + value as _, + timeout, + ); + } +} + +#[cfg(target_os = "macos")] +#[inline(always)] +pub fn wake(futex: &AtomicU32) { + unsafe { + ulock_sys::__ulock_wake( + ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, + futex.as_ptr().cast(), + 0, + ); + } +} + +#[cfg(target_os = "freebsd")] +#[inline(always)] +pub fn wait(futex: &AtomicU32, value: u32, timeout: Duration) { + let ptr: *const AtomicU32 = futex; + let mut timeout = libc::timespec { + tv_sec: i64::try_from(timeout.as_secs()).expect("Timeout is overflow."), + tv_nsec: timeout.subsec_nanos().into(), + }; + unsafe { + // https://github.com/freebsd/freebsd-src/blob/main/sys/kern/kern_umtx.c#L3943 + // https://github.com/freebsd/freebsd-src/blob/main/sys/kern/kern_umtx.c#L3836 + libc::_umtx_op( + ptr as *mut libc::c_void, + libc::UMTX_OP_WAIT_UINT, + value as libc::c_ulong, + std::mem::size_of_val(&timeout) as *mut std::ffi::c_void, + std::ptr::addr_of_mut!(timeout).cast(), + ); + }; +} + +#[cfg(target_os = "freebsd")] +#[inline(always)] +pub fn wake(futex: &AtomicU32) { + let ptr: *const AtomicU32 = futex; + unsafe { + libc::_umtx_op( + ptr as *mut libc::c_void, + libc::UMTX_OP_WAKE, + i32::MAX as libc::c_ulong, + core::ptr::null_mut(), + core::ptr::null_mut(), + ); + }; +} diff --git a/crates/memfd/Cargo.toml b/crates/memfd/Cargo.toml new file mode 100644 index 000000000..2bf02ee22 --- /dev/null +++ b/crates/memfd/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "memfd" +version.workspace = true +edition.workspace = true + +[dependencies] +rand.workspace = true +rustix.workspace = true +detect = { path = "../detect" } + +[lints] +rust.internal_features = "allow" +rust.unsafe_op_in_unsafe_fn = "forbid" +rust.unused_lifetimes = "warn" +rust.unused_qualifications = "warn" diff --git a/crates/memfd/src/lib.rs b/crates/memfd/src/lib.rs new file mode 100644 index 000000000..54cdd42e4 --- /dev/null +++ b/crates/memfd/src/lib.rs @@ -0,0 +1,70 @@ +use std::os::fd::OwnedFd; + +#[cfg(target_os = "linux")] +pub fn memfd_create() -> std::io::Result { + if detect::linux::detect_memfd() { + use rustix::fs::MemfdFlags; + Ok(rustix::fs::memfd_create( + format!(".memfd.MEMFD.{:x}", std::process::id()), + MemfdFlags::empty(), + )?) + } else { + use rustix::fs::Mode; + use rustix::fs::OFlags; + // POSIX fcntl locking do not support shmem, so we use a regular file here. + // reference: https://man7.org/linux/man-pages/man3/fcntl.3p.html + // However, Linux shmem supports fcntl locking. + let name = format!( + ".shm.MEMFD.{:x}.{:x}", + std::process::id(), + rand::random::() + ); + let fd = rustix::fs::open( + &name, + OFlags::RDWR | OFlags::CREATE | OFlags::EXCL, + Mode::RUSR | Mode::WUSR, + )?; + rustix::fs::unlink(&name)?; + Ok(fd) + } +} + +#[cfg(target_os = "macos")] +pub fn memfd_create() -> std::io::Result { + use rustix::fs::Mode; + use rustix::fs::OFlags; + // POSIX fcntl locking do not support shmem, so we use a regular file here. + // reference: https://man7.org/linux/man-pages/man3/fcntl.3p.html + let name = format!( + ".shm.MEMFD.{:x}.{:x}", + std::process::id(), + rand::random::() + ); + let fd = rustix::fs::open( + &name, + OFlags::RDWR | OFlags::CREATE | OFlags::EXCL, + Mode::RUSR | Mode::WUSR, + )?; + rustix::fs::unlink(&name)?; + Ok(fd) +} + +#[cfg(target_os = "freebsd")] +pub fn memfd_create() -> std::io::Result { + use rustix::fs::Mode; + use rustix::fs::OFlags; + // POSIX fcntl locking do not support shmem, so we use a regular file here. + // reference: https://man7.org/linux/man-pages/man3/fcntl.3p.html + let name = format!( + ".shm.MEMFD.{:x}.{:x}", + std::process::id(), + rand::random::() + ); + let fd = rustix::fs::open( + &name, + OFlags::RDWR | OFlags::CREATE | OFlags::EXCL, + Mode::RUSR | Mode::WUSR, + )?; + rustix::fs::unlink(&name)?; + Ok(fd) +} diff --git a/crates/send_fd/Cargo.toml b/crates/send_fd/Cargo.toml new file mode 100644 index 000000000..fc50260e0 --- /dev/null +++ b/crates/send_fd/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "send_fd" +version.workspace = true +edition.workspace = true + +[dependencies] +libc.workspace = true +log.workspace = true +rustix.workspace = true + +[lints] +rust.internal_features = "allow" +rust.unsafe_op_in_unsafe_fn = "forbid" +rust.unused_lifetimes = "warn" +rust.unused_qualifications = "warn" diff --git a/src/utils/file_socket.rs b/crates/send_fd/src/lib.rs similarity index 97% rename from src/utils/file_socket.rs rename to crates/send_fd/src/lib.rs index 3d7a2a43b..a1adc42f8 100644 --- a/src/utils/file_socket.rs +++ b/crates/send_fd/src/lib.rs @@ -6,12 +6,12 @@ use std::io::{IoSlice, IoSliceMut}; use std::os::unix::net::UnixStream; #[repr(C)] -pub struct FileSocket { +pub struct SendFd { tx: OwnedFd, rx: OwnedFd, } -impl FileSocket { +impl SendFd { pub fn new() -> std::io::Result { let (tx, rx) = UnixStream::pair()?; Ok(Self { @@ -47,7 +47,7 @@ fn recv_fd(rx: BorrowedFd<'_>) -> std::io::Result { let mut control = RecvAncillaryBuffer::new(&mut buffer.0); let mut buffer_ios = [b'.']; let ios = IoSliceMut::new(&mut buffer_ios); - let returned = rustix::net::recvmsg(rx, &mut [ios], &mut control, RecvFlags::CMSG_CLOEXEC)?; + let returned = rustix::net::recvmsg(rx, &mut [ios], &mut control, RecvFlags::empty())?; if returned.flags.bits() & libc::MSG_CTRUNC as u32 != 0 { log::warn!("Ancillary is truncated."); } diff --git a/crates/service/Cargo.toml b/crates/service/Cargo.toml index a7d7b3ebe..8cd06be37 100644 --- a/crates/service/Cargo.toml +++ b/crates/service/Cargo.toml @@ -4,36 +4,32 @@ version.workspace = true edition.workspace = true [dependencies] +bincode.workspace = true +bytemuck.workspace = true +byteorder.workspace = true +half.workspace = true libc.workspace = true log.workspace = true +memmap2.workspace = true +num-traits.workspace = true +rand.workspace = true +rustix.workspace = true serde.workspace = true serde_json.workspace = true -validator.workspace = true -rustix.workspace = true thiserror.workspace = true -byteorder.workspace = true -bincode.workspace = true -half.workspace = true -num-traits.workspace = true -rand.workspace = true -bytemuck.workspace = true +uuid.workspace = true +validator.workspace = true +base = { path = "../base" } c = { path = "../c" } detect = { path = "../detect" } -crc32fast = "1.3.2" -crossbeam = "0.8.2" -dashmap = "5.4.0" +crc32fast = "1.4.0" +crossbeam = "0.8.4" +dashmap = "5.5.3" parking_lot = "0.12.1" -memoffset = "0.9.0" -arrayvec = { version = "0.7.3", features = ["serde"] } -memmap2 = "0.9.0" -rayon = "1.6.1" -uuid = { version = "1.6.1", features = ["v4", "serde"] } +rayon = "1.8.1" arc-swap = "1.6.0" multiversion = "0.7.3" -[target.'cfg(target_os = "macos")'.dependencies] -ulock-sys = "0.1.0" - [lints] clippy.derivable_impls = "allow" clippy.len_without_is_empty = "allow" diff --git a/crates/service/src/algorithms/clustering/elkan_k_means.rs b/crates/service/src/algorithms/clustering/elkan_k_means.rs index 376ed313b..9dc7fd97b 100644 --- a/crates/service/src/algorithms/clustering/elkan_k_means.rs +++ b/crates/service/src/algorithms/clustering/elkan_k_means.rs @@ -1,5 +1,6 @@ use crate::prelude::*; use crate::utils::vec2::Vec2; +use base::scalar::FloatCast; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator}; @@ -9,18 +10,18 @@ use std::ops::{Index, IndexMut}; pub struct ElkanKMeans { dims: u16, c: usize, - pub centroids: Vec2, + pub centroids: Vec2, lowerbound: Square, upperbound: Vec, assign: Vec, rand: StdRng, - samples: Vec2, + samples: Vec2, } const DELTA: f32 = 1.0 / 1024.0; impl ElkanKMeans { - pub fn new(c: usize, samples: Vec2) -> Self { + pub fn new(c: usize, samples: Vec2) -> Self { let n = samples.len(); let dims = samples.dims(); @@ -266,7 +267,7 @@ impl ElkanKMeans { change == 0 } - pub fn finish(self) -> Vec2 { + pub fn finish(self) -> Vec2 { self.centroids } } diff --git a/crates/service/src/algorithms/flat.rs b/crates/service/src/algorithms/flat.rs index 3771b0ed6..7ee1f1a5b 100644 --- a/crates/service/src/algorithms/flat.rs +++ b/crates/service/src/algorithms/flat.rs @@ -8,7 +8,7 @@ use crate::utils::dir_ops::sync_dir; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::fs::create_dir; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; pub struct Flat { @@ -17,37 +17,26 @@ pub struct Flat { impl Flat { pub fn create( - path: PathBuf, + path: &Path, options: IndexOptions, sealed: Vec>>, growing: Vec>>, ) -> Self { - create_dir(&path).unwrap(); - let ram = make(path.clone(), sealed, growing, options.clone()); - let mmap = save(ram, path.clone()); - sync_dir(&path); + create_dir(path).unwrap(); + let ram = make(path, sealed, growing, options); + let mmap = save(path, ram); + sync_dir(path); Self { mmap } } - pub fn open(path: PathBuf, options: IndexOptions) -> Self { - let mmap = load(path, options.clone()); - Self { mmap } - } - - pub fn len(&self) -> u32 { - self.mmap.raw.len() - } - pub fn vector(&self, i: u32) -> &[S::Scalar] { - self.mmap.raw.vector(i) - } - - pub fn payload(&self, i: u32) -> Payload { - self.mmap.raw.payload(i) + pub fn open(path: &Path, options: IndexOptions) -> Self { + let mmap = open(path, options); + Self { mmap } } pub fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, _opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -56,12 +45,24 @@ impl Flat { pub fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, _opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { vbase(&self.mmap, vector, filter) } + + pub fn len(&self) -> u32 { + self.mmap.raw.len() + } + + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + self.mmap.raw.vector(i) + } + + pub fn payload(&self, i: u32) -> Payload { + self.mmap.raw.payload(i) + } } unsafe impl Send for Flat {} @@ -81,20 +82,20 @@ unsafe impl Send for FlatMmap {} unsafe impl Sync for FlatMmap {} pub fn make( - path: PathBuf, + path: &Path, sealed: Vec>>, growing: Vec>>, options: IndexOptions, ) -> FlatRam { let idx_opts = options.indexing.clone().unwrap_flat(); let raw = Arc::new(Raw::create( - path.join("raw"), + &path.join("raw"), options.clone(), sealed, growing, )); let quantization = Quantization::create( - path.join("quantization"), + &path.join("quantization"), options.clone(), idx_opts.quantization, &raw, @@ -103,18 +104,18 @@ pub fn make( FlatRam { raw, quantization } } -pub fn save(ram: FlatRam, _: PathBuf) -> FlatMmap { +pub fn save(_: &Path, ram: FlatRam) -> FlatMmap { FlatMmap { raw: ram.raw, quantization: ram.quantization, } } -pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap { +pub fn open(path: &Path, options: IndexOptions) -> FlatMmap { let idx_opts = options.indexing.clone().unwrap_flat(); - let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); + let raw = Arc::new(Raw::open(&path.join("raw"), options.clone())); let quantization = Quantization::open( - path.join("quantization"), + &path.join("quantization"), options.clone(), idx_opts.quantization, &raw, @@ -124,7 +125,7 @@ pub fn load(path: PathBuf, options: IndexOptions) -> FlatMmap { pub fn basic( mmap: &FlatMmap, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, mut filter: impl Filter, ) -> BinaryHeap> { let mut result = BinaryHeap::new(); @@ -140,7 +141,7 @@ pub fn basic( pub fn vbase<'a, S: G>( mmap: &'a FlatMmap, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, mut filter: impl Filter + 'a, ) -> (Vec, Box + 'a>) { let mut result = Vec::new(); diff --git a/crates/service/src/algorithms/hnsw.rs b/crates/service/src/algorithms/hnsw.rs index efa830c91..cf2ced090 100644 --- a/crates/service/src/algorithms/hnsw.rs +++ b/crates/service/src/algorithms/hnsw.rs @@ -16,7 +16,7 @@ use std::cmp::Reverse; use std::collections::BinaryHeap; use std::fs::create_dir; use std::ops::RangeInclusive; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; pub struct Hnsw { @@ -25,37 +25,26 @@ pub struct Hnsw { impl Hnsw { pub fn create( - path: PathBuf, + path: &Path, options: IndexOptions, sealed: Vec>>, growing: Vec>>, ) -> Self { - create_dir(&path).unwrap(); - let ram = make(path.clone(), sealed, growing, options.clone()); - let mmap = save(ram, path.clone()); - sync_dir(&path); + create_dir(path).unwrap(); + let ram = make(path, sealed, growing, options); + let mmap = save(ram, path); + sync_dir(path); Self { mmap } } - pub fn open(path: PathBuf, options: IndexOptions) -> Self { - let mmap = load(path, options.clone()); - Self { mmap } - } - - pub fn len(&self) -> u32 { - self.mmap.raw.len() - } - pub fn vector(&self, i: u32) -> &[S::Scalar] { - self.mmap.raw.vector(i) - } - - pub fn payload(&self, i: u32) -> Payload { - self.mmap.raw.payload(i) + pub fn open(path: &Path, options: IndexOptions) -> Self { + let mmap = open(path, options); + Self { mmap } } pub fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -64,12 +53,24 @@ impl Hnsw { pub fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { vbase(&self.mmap, vector, opts.hnsw_ef_search, filter) } + + pub fn len(&self) -> u32 { + self.mmap.raw.len() + } + + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + self.mmap.raw.vector(i) + } + + pub fn payload(&self, i: u32) -> Payload { + self.mmap.raw.payload(i) + } } unsafe impl Send for Hnsw {} @@ -128,7 +129,7 @@ unsafe impl Pod for HnswMmapEdge {} unsafe impl Zeroable for HnswMmapEdge {} pub fn make( - path: PathBuf, + path: &Path, sealed: Vec>>, growing: Vec>>, options: IndexOptions, @@ -139,13 +140,13 @@ pub fn make( quantization: quantization_opts, } = options.indexing.clone().unwrap_hnsw(); let raw = Arc::new(Raw::create( - path.join("raw"), + &path.join("raw"), options.clone(), sealed, growing, )); let quantization = Quantization::create( - path.join("quantization"), + &path.join("quantization"), options.clone(), quantization_opts, &raw, @@ -170,7 +171,7 @@ pub fn make( graph: &HnswRamGraph, levels: RangeInclusive, u: u32, - target: &[S::Scalar], + target: S::VectorRef<'_>, ) -> u32 { let mut u = u; let mut u_dis = quantization.distance(target, u); @@ -195,7 +196,7 @@ pub fn make( quantization: &Quantization, graph: &HnswRamGraph, visited: &mut VisitedGuard, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, s: u32, k: usize, i: u8, @@ -335,9 +336,9 @@ pub fn make( } } -pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { +pub fn save(mut ram: HnswRam, path: &Path) -> HnswMmap { let edges = MmapArray::create( - path.join("edges"), + &path.join("edges"), ram.graph .vertexs .iter_mut() @@ -345,13 +346,13 @@ pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { .flat_map(|v| &v.get_mut().edges) .map(|&(_0, _1)| HnswMmapEdge(_0, _1)), ); - let by_layer_id = MmapArray::create(path.join("by_layer_id"), { + let by_layer_id = MmapArray::create(&path.join("by_layer_id"), { let iter = ram.graph.vertexs.iter_mut(); let iter = iter.flat_map(|v| v.layers.iter_mut()); let iter = iter.map(|v| v.get_mut().edges.len()); caluate_offsets(iter) }); - let by_vertex_id = MmapArray::create(path.join("by_vertex_id"), { + let by_vertex_id = MmapArray::create(&path.join("by_vertex_id"), { let iter = ram.graph.vertexs.iter_mut(); let iter = iter.map(|v| v.layers.len()); caluate_offsets(iter) @@ -367,18 +368,18 @@ pub fn save(mut ram: HnswRam, path: PathBuf) -> HnswMmap { } } -pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { +pub fn open(path: &Path, options: IndexOptions) -> HnswMmap { let idx_opts = options.indexing.clone().unwrap_hnsw(); - let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); + let raw = Arc::new(Raw::open(&path.join("raw"), options.clone())); let quantization = Quantization::open( - path.join("quantization"), + &path.join("quantization"), options.clone(), idx_opts.quantization, &raw, ); - let edges = MmapArray::open(path.join("edges")); - let by_layer_id = MmapArray::open(path.join("by_layer_id")); - let by_vertex_id = MmapArray::open(path.join("by_vertex_id")); + let edges = MmapArray::open(&path.join("edges")); + let by_layer_id = MmapArray::open(&path.join("by_layer_id")); + let by_vertex_id = MmapArray::open(&path.join("by_vertex_id")); let idx_opts = options.indexing.unwrap_hnsw(); let n = raw.len(); HnswMmap { @@ -394,7 +395,7 @@ pub fn load(path: PathBuf, options: IndexOptions) -> HnswMmap { pub fn basic( mmap: &HnswMmap, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, ef_search: usize, filter: impl Filter, ) -> BinaryHeap> { @@ -408,7 +409,7 @@ pub fn basic( pub fn vbase<'a, S: G>( mmap: &'a HnswMmap, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, range: usize, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -466,7 +467,7 @@ pub fn fast_search( mmap: &HnswMmap, levels: RangeInclusive, u: u32, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, mut filter: impl Filter, ) -> u32 { let mut u = u; @@ -496,7 +497,7 @@ pub fn local_search_basic( mmap: &HnswMmap, k: usize, s: u32, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, mut filter: impl Filter, ) -> ElementHeap { let mut visited = mmap.visited.fetch(); @@ -540,7 +541,7 @@ pub fn local_search_basic( pub fn local_search_vbase<'a, S: G>( mmap: &'a HnswMmap, s: u32, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, mut filter: impl Filter + 'a, ) -> impl Iterator + 'a { let mut visited = mmap.visited.fetch2(); @@ -618,7 +619,7 @@ impl VisitedPool { locked_buffers: Mutex::new(Vec::new()), } } - pub fn fetch(&self) -> VisitedGuard<'_> { + pub fn fetch(&self) -> VisitedGuard { let buffer = self .locked_buffers .lock() diff --git a/crates/service/src/algorithms/ivf/ivf_naive.rs b/crates/service/src/algorithms/ivf/ivf_naive.rs index c3da69690..278570236 100644 --- a/crates/service/src/algorithms/ivf/ivf_naive.rs +++ b/crates/service/src/algorithms/ivf/ivf_naive.rs @@ -19,7 +19,7 @@ use rayon::prelude::ParallelIterator; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::fs::create_dir; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; pub struct IvfNaive { @@ -28,20 +28,20 @@ pub struct IvfNaive { impl IvfNaive { pub fn create( - path: PathBuf, + path: &Path, options: IndexOptions, sealed: Vec>>, growing: Vec>>, ) -> Self { - create_dir(&path).unwrap(); - let ram = make(path.clone(), sealed, growing, options); - let mmap = save(ram, path.clone()); - sync_dir(&path); + create_dir(path).unwrap(); + let ram = make(path, sealed, growing, options); + let mmap = save(ram, path); + sync_dir(path); Self { mmap } } - pub fn open(path: PathBuf, options: IndexOptions) -> Self { - let mmap = load(path.clone(), options); + pub fn open(path: &Path, options: IndexOptions) -> Self { + let mmap = open(path, options); Self { mmap } } @@ -49,7 +49,7 @@ impl IvfNaive { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> &[S::Scalar] { + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { self.mmap.raw.vector(i) } @@ -59,7 +59,7 @@ impl IvfNaive { pub fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -68,7 +68,7 @@ impl IvfNaive { pub fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -87,7 +87,7 @@ pub struct IvfRam { // ---------------------- nlist: u32, // ---------------------- - centroids: Vec2, + centroids: Vec2, ptr: Vec, payloads: Vec, } @@ -120,7 +120,7 @@ impl IvfMmap { } pub fn make( - path: PathBuf, + path: &Path, sealed: Vec>>, growing: Vec>>, options: IndexOptions, @@ -134,7 +134,7 @@ pub fn make( quantization: quantization_opts, } = options.indexing.clone().unwrap_ivf(); let raw = Arc::new(Raw::create( - path.join("raw"), + &path.join("raw"), options.clone(), sealed, growing, @@ -144,10 +144,10 @@ pub fn make( let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec(); let mut samples = Vec2::new(dims, m as usize); for i in 0..m { - samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); + samples[i as usize].copy_from_slice(S::to_dense(raw.vector(f[i as usize] as u32)).as_ref()); S::elkan_k_means_normalize(&mut samples[i as usize]); } - let mut k_means = ElkanKMeans::new(nlist as usize, samples); + let mut k_means = ElkanKMeans::::new(nlist as usize, samples); for _ in 0..least_iterations { k_means.iterate(); } @@ -159,11 +159,11 @@ pub fn make( let centroids = k_means.finish(); let mut idx = vec![0usize; n as usize]; idx.par_iter_mut().enumerate().for_each(|(i, x)| { - let mut vector = raw.vector(i as u32).to_vec(); - S::elkan_k_means_normalize(&mut vector); + let mut vector = S::ref_to_owned(raw.vector(i as u32)); + S::elkan_k_means_normalize2(&mut vector); let mut result = (F32::infinity(), 0); for i in 0..nlist as usize { - let dis = S::elkan_k_means_distance(&vector, ¢roids[i]); + let dis = S::elkan_k_means_distance2(S::owned_to_ref(&vector), ¢roids[i]); result = std::cmp::min(result, (dis, i)); } *x = result.1; @@ -181,7 +181,7 @@ pub fn make( .copied(), ); let quantization = Quantization::create( - path.join("quantization"), + &path.join("quantization"), options.clone(), quantization_opts, &raw, @@ -202,15 +202,15 @@ pub fn make( } } -pub fn save(ram: IvfRam, path: PathBuf) -> IvfMmap { +pub fn save(ram: IvfRam, path: &Path) -> IvfMmap { let centroids = MmapArray::create( - path.join("centroids"), + &path.join("centroids"), (0..ram.nlist) .flat_map(|i| &ram.centroids[i as usize]) .copied(), ); - let ptr = MmapArray::create(path.join("ptr"), ram.ptr.iter().copied()); - let payloads = MmapArray::create(path.join("payload"), ram.payloads.iter().copied()); + let ptr = MmapArray::create(&path.join("ptr"), ram.ptr.iter().copied()); + let payloads = MmapArray::create(&path.join("payload"), ram.payloads.iter().copied()); IvfMmap { raw: ram.raw, quantization: ram.quantization, @@ -222,17 +222,17 @@ pub fn save(ram: IvfRam, path: PathBuf) -> IvfMmap { } } -pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { - let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); +pub fn open(path: &Path, options: IndexOptions) -> IvfMmap { + let raw = Arc::new(Raw::open(&path.join("raw"), options.clone())); let quantization = Quantization::open( - path.join("quantization"), + &path.join("quantization"), options.clone(), options.indexing.clone().unwrap_ivf().quantization, &raw, ); - let centroids = MmapArray::open(path.join("centroids")); - let ptr = MmapArray::open(path.join("ptr")); - let payloads = MmapArray::open(path.join("payload")); + let centroids = MmapArray::open(&path.join("centroids")); + let ptr = MmapArray::open(&path.join("ptr")); + let payloads = MmapArray::open(&path.join("payload")); let IvfIndexingOptions { nlist, .. } = options.indexing.unwrap_ivf(); IvfMmap { raw, @@ -247,16 +247,16 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { pub fn basic( mmap: &IvfMmap, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, nprobe: u32, mut filter: impl Filter, ) -> BinaryHeap> { - let mut target = vector.to_vec(); - S::elkan_k_means_normalize(&mut target); + let mut target = S::ref_to_owned(vector); + S::elkan_k_means_normalize2(&mut target); let mut lists = ElementHeap::new(nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = S::elkan_k_means_distance(&target, centroid); + let distance = S::elkan_k_means_distance2(S::owned_to_ref(&target), centroid); if lists.check(distance) { lists.push(Element { distance, @@ -282,16 +282,16 @@ pub fn basic( pub fn vbase<'a, S: G>( mmap: &'a IvfMmap, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, nprobe: u32, mut filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { - let mut target = vector.to_vec(); - S::elkan_k_means_normalize(&mut target); + let mut target = S::ref_to_owned(vector); + S::elkan_k_means_normalize2(&mut target); let mut lists = ElementHeap::new(nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = S::elkan_k_means_distance(&target, centroid); + let distance = S::elkan_k_means_distance2(S::owned_to_ref(&target), centroid); if lists.check(distance) { lists.push(Element { distance, diff --git a/crates/service/src/algorithms/ivf/ivf_pq.rs b/crates/service/src/algorithms/ivf/ivf_pq.rs index 4565def8d..e2ca36b17 100644 --- a/crates/service/src/algorithms/ivf/ivf_pq.rs +++ b/crates/service/src/algorithms/ivf/ivf_pq.rs @@ -17,11 +17,11 @@ use rand::seq::index::sample; use rand::thread_rng; use rayon::iter::IntoParallelRefMutIterator; use rayon::iter::{IndexedParallelIterator, ParallelIterator}; -use rayon::slice::ParallelSliceMut; +use rayon::prelude::ParallelSliceMut; use std::cmp::Reverse; use std::collections::BinaryHeap; use std::fs::create_dir; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; pub struct IvfPq { @@ -30,20 +30,20 @@ pub struct IvfPq { impl IvfPq { pub fn create( - path: PathBuf, + path: &Path, options: IndexOptions, sealed: Vec>>, growing: Vec>>, ) -> Self { - create_dir(&path).unwrap(); - let ram = make(path.clone(), sealed, growing, options); - let mmap = save(ram, path.clone()); - sync_dir(&path); + create_dir(path).unwrap(); + let ram = make(path, sealed, growing, options); + let mmap = save(ram, path); + sync_dir(path); Self { mmap } } - pub fn open(path: PathBuf, options: IndexOptions) -> Self { - let mmap = load(path.clone(), options); + pub fn open(path: &Path, options: IndexOptions) -> Self { + let mmap = open(path, options); Self { mmap } } @@ -51,7 +51,7 @@ impl IvfPq { self.mmap.raw.len() } - pub fn vector(&self, i: u32) -> &[S::Scalar] { + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { self.mmap.raw.vector(i) } @@ -61,7 +61,7 @@ impl IvfPq { pub fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -70,7 +70,7 @@ impl IvfPq { pub fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { @@ -89,7 +89,7 @@ pub struct IvfRam { // ---------------------- nlist: u32, // ---------------------- - centroids: Vec2, + centroids: Vec2, ptr: Vec, payloads: Vec, } @@ -122,7 +122,7 @@ impl IvfMmap { } pub fn make( - path: PathBuf, + path: &Path, sealed: Vec>>, growing: Vec>>, options: IndexOptions, @@ -136,7 +136,7 @@ pub fn make( quantization: quantization_opts, } = options.indexing.clone().unwrap_ivf(); let raw = Arc::new(Raw::create( - path.join("raw"), + &path.join("raw"), options.clone(), sealed, growing, @@ -146,10 +146,10 @@ pub fn make( let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec(); let mut samples = Vec2::new(dims, m as usize); for i in 0..m { - samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); + samples[i as usize].copy_from_slice(S::to_dense(raw.vector(f[i as usize] as u32)).as_ref()); S::elkan_k_means_normalize(&mut samples[i as usize]); } - let mut k_means = ElkanKMeans::new(nlist as usize, samples); + let mut k_means = ElkanKMeans::::new(nlist as usize, samples); for _ in 0..least_iterations { k_means.iterate(); } @@ -161,11 +161,11 @@ pub fn make( let centroids = k_means.finish(); let mut idx = vec![0usize; n as usize]; idx.par_iter_mut().enumerate().for_each(|(i, x)| { - let mut vector = raw.vector(i as u32).to_vec(); - S::elkan_k_means_normalize(&mut vector); + let mut vector = S::ref_to_owned(raw.vector(i as u32)); + S::elkan_k_means_normalize2(&mut vector); let mut result = (F32::infinity(), 0); for i in 0..nlist as usize { - let dis = S::elkan_k_means_distance(&vector, ¢roids[i]); + let dis = S::elkan_k_means_distance2(S::owned_to_ref(&vector), ¢roids[i]); result = std::cmp::min(result, (dis, i)); } *x = result.1; @@ -186,27 +186,27 @@ pub fn make( .flat_map(|i| &invlists_payloads[i as usize]) .copied(), ); - sync_dir(&path); + sync_dir(path); let residuals = { - let mut residuals = Vec2::::new(options.vector.dims, n as usize); + let mut residuals = Vec2::new(options.vector.dims, n as usize); residuals .par_chunks_mut(dims as usize) .enumerate() .for_each(|(i, v)| { for j in 0..dims { - v[j as usize] = raw.vector(ids[i])[j as usize] + v[j as usize] = S::to_dense(raw.vector(ids[i])).as_ref()[j as usize] - centroids[idx[ids[i] as usize]][j as usize]; } }); residuals }; let mut quantization = ProductQuantization::encode( - path.join("quantization"), + &path.join("quantization"), options.clone(), quantization_opts, &residuals, ); - quantization.precompute_table(path.join("quantization"), ¢roids); + quantization.precompute_table(&path.join("quantization"), ¢roids); IvfRam { raw, quantization, @@ -218,15 +218,15 @@ pub fn make( } } -pub fn save(ram: IvfRam, path: PathBuf) -> IvfMmap { +pub fn save(ram: IvfRam, path: &Path) -> IvfMmap { let centroids = MmapArray::create( - path.join("centroids"), + &path.join("centroids"), (0..ram.nlist) .flat_map(|i| &ram.centroids[i as usize]) .copied(), ); - let ptr = MmapArray::create(path.join("ptr"), ram.ptr.iter().copied()); - let payloads = MmapArray::create(path.join("payload"), ram.payloads.iter().copied()); + let ptr = MmapArray::create(&path.join("ptr"), ram.ptr.iter().copied()); + let payloads = MmapArray::create(&path.join("payload"), ram.payloads.iter().copied()); IvfMmap { raw: ram.raw, quantization: ram.quantization, @@ -238,17 +238,17 @@ pub fn save(ram: IvfRam, path: PathBuf) -> IvfMmap { } } -pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { - let raw = Arc::new(Raw::open(path.join("raw"), options.clone())); - let quantization = ProductQuantization::open( - path.join("quantization"), +pub fn open(path: &Path, options: IndexOptions) -> IvfMmap { + let raw = Arc::new(Raw::open(&path.join("raw"), options.clone())); + let quantization = ProductQuantization::open2( + &path.join("quantization"), options.clone(), options.indexing.clone().unwrap_ivf().quantization, &raw, ); - let centroids = MmapArray::open(path.join("centroids")); - let ptr = MmapArray::open(path.join("ptr")); - let payloads = MmapArray::open(path.join("payload")); + let centroids = MmapArray::open(&path.join("centroids")); + let ptr = MmapArray::open(&path.join("ptr")); + let payloads = MmapArray::open(&path.join("payload")); let IvfIndexingOptions { nlist, .. } = options.indexing.unwrap_ivf(); IvfMmap { raw, @@ -263,16 +263,14 @@ pub fn load(path: PathBuf, options: IndexOptions) -> IvfMmap { pub fn basic( mmap: &IvfMmap, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, nprobe: u32, mut filter: impl Filter, ) -> BinaryHeap> { - let target = vector.to_vec(); - // S::elkan_k_means_normalize(&mut target); let mut lists = ElementHeap::new(nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = S::distance(&target, centroid); + let distance = S::distance2(vector, centroid); if lists.check(distance) { lists.push(Element { distance, @@ -280,7 +278,7 @@ pub fn basic( }); } } - let runtime_table = mmap.quantization.init_query(&target); + let runtime_table = mmap.quantization.init_query(S::to_dense(vector).as_ref()); let lists = lists.into_sorted_vec(); let mut result = BinaryHeap::new(); for i in lists.iter() { @@ -308,16 +306,14 @@ pub fn basic( pub fn vbase<'a, S: G>( mmap: &'a IvfMmap, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, nprobe: u32, mut filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { - let target = vector.to_vec(); - // S::elkan_k_means_normalize(&mut target); let mut lists = ElementHeap::new(nprobe as usize); for i in 0..mmap.nlist { let centroid = mmap.centroids(i); - let distance = S::distance(&target, centroid); + let distance = S::distance2(vector, centroid); if lists.check(distance) { lists.push(Element { distance, @@ -325,7 +321,7 @@ pub fn vbase<'a, S: G>( }); } } - let runtime_table = mmap.quantization.init_query(&target); + let runtime_table = mmap.quantization.init_query(S::to_dense(vector).as_ref()); let lists = lists.into_sorted_vec(); let mut result = Vec::new(); for i in lists.iter() { diff --git a/crates/service/src/algorithms/ivf/mod.rs b/crates/service/src/algorithms/ivf/mod.rs index 8fea67ca7..877e36c45 100644 --- a/crates/service/src/algorithms/ivf/mod.rs +++ b/crates/service/src/algorithms/ivf/mod.rs @@ -10,7 +10,7 @@ use crate::index::SearchOptions; use crate::prelude::*; use std::cmp::Reverse; use std::collections::BinaryHeap; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; pub enum Ivf { @@ -20,7 +20,7 @@ pub enum Ivf { impl Ivf { pub fn create( - path: PathBuf, + path: &Path, options: IndexOptions, sealed: Vec>>, growing: Vec>>, @@ -38,7 +38,7 @@ impl Ivf { } } - pub fn open(path: PathBuf, options: IndexOptions) -> Self { + pub fn open(path: &Path, options: IndexOptions) -> Self { if options .indexing .clone() @@ -59,7 +59,7 @@ impl Ivf { } } - pub fn vector(&self, i: u32) -> &[S::Scalar] { + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { match self { Ivf::Naive(x) => x.vector(i), Ivf::Pq(x) => x.vector(i), @@ -75,7 +75,7 @@ impl Ivf { pub fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -87,7 +87,7 @@ impl Ivf { pub fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { diff --git a/crates/service/src/algorithms/quantization/mod.rs b/crates/service/src/algorithms/quantization/mod.rs index d33d2bf4d..778475016 100644 --- a/crates/service/src/algorithms/quantization/mod.rs +++ b/crates/service/src/algorithms/quantization/mod.rs @@ -10,7 +10,7 @@ use crate::index::IndexOptions; use crate::prelude::*; use serde::{Deserialize, Serialize}; use std::fmt::Debug; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use validator::Validate; @@ -59,19 +59,19 @@ impl QuantizationOptions { pub trait Quan { fn create( - path: PathBuf, + path: &Path, options: IndexOptions, quantization_options: QuantizationOptions, raw: &Arc>, permutation: Vec, ) -> Self; - fn open( - path: PathBuf, + fn open2( + path: &Path, options: IndexOptions, quantization_options: QuantizationOptions, raw: &Arc>, ) -> Self; - fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32; + fn distance(&self, lhs: S::VectorRef<'_>, rhs: u32) -> F32; fn distance2(&self, lhs: u32, rhs: u32) -> F32; } @@ -83,7 +83,7 @@ pub enum Quantization { impl Quantization { pub fn create( - path: PathBuf, + path: &Path, options: IndexOptions, quantization_options: QuantizationOptions, raw: &Arc>, @@ -115,25 +115,25 @@ impl Quantization { } pub fn open( - path: PathBuf, + path: &Path, options: IndexOptions, quantization_options: QuantizationOptions, raw: &Arc>, ) -> Self { match quantization_options { - QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::open( + QuantizationOptions::Trivial(_) => Self::Trivial(TrivialQuantization::open2( path, options, quantization_options, raw, )), - QuantizationOptions::Scalar(_) => Self::Scalar(ScalarQuantization::open( + QuantizationOptions::Scalar(_) => Self::Scalar(ScalarQuantization::open2( path, options, quantization_options, raw, )), - QuantizationOptions::Product(_) => Self::Product(ProductQuantization::open( + QuantizationOptions::Product(_) => Self::Product(ProductQuantization::open2( path, options, quantization_options, @@ -142,7 +142,7 @@ impl Quantization { } } - pub fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { + pub fn distance(&self, lhs: S::VectorRef<'_>, rhs: u32) -> F32 { use Quantization::*; match self { Trivial(x) => x.distance(lhs, rhs), diff --git a/crates/service/src/algorithms/quantization/product.rs b/crates/service/src/algorithms/quantization/product.rs index 5f70fd363..52f233c0b 100644 --- a/crates/service/src/algorithms/quantization/product.rs +++ b/crates/service/src/algorithms/quantization/product.rs @@ -14,7 +14,7 @@ use rayon::iter::ParallelIterator; use rayon::slice::ParallelSliceMut; use serde::{Deserialize, Serialize}; use std::fmt::Debug; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use validator::Validate; @@ -67,7 +67,6 @@ pub struct ProductQuantization { centroids: Vec, codes: MmapArray, precomputed_table: Vec, - metric: Distance, } unsafe impl Send for ProductQuantization {} @@ -84,7 +83,7 @@ impl ProductQuantization { impl Quan for ProductQuantization { fn create( - path: PathBuf, + path: &Path, options: IndexOptions, quantization_options: QuantizationOptions, raw: &Arc>, @@ -100,15 +99,15 @@ impl Quan for ProductQuantization { ) } - fn open( - path: PathBuf, + fn open2( + path: &Path, options: IndexOptions, quantization_options: QuantizationOptions, _: &Arc>, ) -> Self { let centroids = serde_json::from_slice(&std::fs::read(path.join("centroids")).unwrap()).unwrap(); - let codes = MmapArray::open(path.join("codes")); + let codes = MmapArray::open(&path.join("codes")); let precomputed_table = serde_json::from_slice(&std::fs::read(path.join("table")).unwrap()).unwrap(); Self { @@ -117,11 +116,10 @@ impl Quan for ProductQuantization { centroids, codes, precomputed_table, - metric: options.vector.d, } } - fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { + fn distance(&self, lhs: S::VectorRef<'_>, rhs: u32) -> F32 { let dims = self.dims; let ratio = self.ratio; let rhs = self.codes(rhs); @@ -139,7 +137,7 @@ impl Quan for ProductQuantization { impl ProductQuantization { pub fn with_normalizer( - path: PathBuf, + path: &Path, options: IndexOptions, quantization_options: QuantizationOptions, raw: &Raw, @@ -149,7 +147,12 @@ impl ProductQuantization { where F: Fn(u32, &mut [S::Scalar]), { - std::fs::create_dir(&path).unwrap(); + assert!( + S::KIND != Kind::SparseF32, + "Product quantization is not supported for sparse vectors." + ); + + std::fs::create_dir(path).unwrap(); let quantization_options = quantization_options.unwrap_product_quantization(); let dims = options.vector.dims; let ratio = quantization_options.ratio as u16; @@ -157,9 +160,10 @@ impl ProductQuantization { let m = std::cmp::min(n, quantization_options.sample); let samples = { let f = sample(&mut thread_rng(), n as usize, m as usize).into_vec(); - let mut samples = Vec2::::new(options.vector.dims, m as usize); + let mut samples = Vec2::::new(options.vector.dims, m as usize); for i in 0..m { - samples[i as usize].copy_from_slice(raw.vector(f[i as usize] as u32)); + samples[i as usize] + .copy_from_slice(S::to_dense(raw.vector(f[i as usize] as u32)).as_ref()); } samples }; @@ -167,7 +171,7 @@ impl ProductQuantization { let mut centroids = vec![S::Scalar::zero(); 256 * dims as usize]; for i in 0..width { let subdims = std::cmp::min(ratio, dims - ratio * i); - let mut subsamples = Vec2::::new(subdims, m as usize); + let mut subsamples = Vec2::::new(subdims, m as usize); for j in 0..m { let src = &samples[j as usize][(i * ratio) as usize..][..subdims as usize]; subsamples[j as usize].copy_from_slice(src); @@ -185,7 +189,7 @@ impl ProductQuantization { } } let codes_iter = (0..n).flat_map(|i| { - let mut vector = raw.vector(permutation[i as usize]).to_vec(); + let mut vector = S::to_dense(raw.vector(permutation[i as usize])).to_vec(); normalizer(permutation[i as usize], &mut vector); let width = dims.div_ceil(ratio); let mut result = Vec::with_capacity(width as usize); @@ -207,30 +211,29 @@ impl ProductQuantization { } result.into_iter() }); - sync_dir(&path); + sync_dir(path); std::fs::write( path.join("centroids"), serde_json::to_string(¢roids).unwrap(), ) .unwrap(); - let codes = MmapArray::create(path.join("codes"), codes_iter); + let codes = MmapArray::create(&path.join("codes"), codes_iter); Self { dims, ratio, centroids, codes, precomputed_table: Vec::new(), - metric: options.vector.d, } } pub fn encode( - path: PathBuf, + path: &Path, options: IndexOptions, quantization_options: QuantizationOptions, - raw: &Vec2, + raw: &Vec2, ) -> Self { - std::fs::create_dir(&path).unwrap(); + std::fs::create_dir(path).unwrap(); let quantization_options = quantization_options.unwrap_product_quantization(); let dims = options.vector.dims; let ratio = quantization_options.ratio as u16; @@ -238,7 +241,7 @@ impl ProductQuantization { let m = std::cmp::min(n, quantization_options.sample as usize); let samples = { let f = sample(&mut thread_rng(), n, m).into_vec(); - let mut samples = Vec2::::new(options.vector.dims, m); + let mut samples = Vec2::new(options.vector.dims, m); for i in 0..m { samples[i].copy_from_slice(&raw[f[i]]); } @@ -254,7 +257,7 @@ impl ProductQuantization { .for_each(|(i, v)| { // i is the index of subquantizer let subdims = std::cmp::min(ratio, dims - ratio * i as u16) as usize; - let mut subsamples = Vec2::::new(subdims as u16, m); + let mut subsamples = Vec2::new(subdims as u16, m); for j in 0..m { let src = &samples[j][i * ratio as usize..][..subdims]; subsamples[j].copy_from_slice(src); @@ -307,25 +310,24 @@ impl ProductQuantization { v[i as usize] = target; } }); - sync_dir(&path); + sync_dir(path); std::fs::write( path.join("centroids"), serde_json::to_string(¢roids).unwrap(), ) .unwrap(); - let codes = MmapArray::create(path.join("codes"), codes.into_iter()); + let codes = MmapArray::create(&path.join("codes"), codes.into_iter()); Self { dims, ratio, centroids, codes, precomputed_table: Vec::new(), - metric: options.vector.d, } } // compute term3 at build time - pub fn precompute_table(&mut self, path: PathBuf, coarse_centroids: &Vec2) { + pub fn precompute_table(&mut self, path: &Path, coarse_centroids: &Vec2) { let nlist = coarse_centroids.len(); let dims = self.dims; let ratio = self.ratio; @@ -357,7 +359,7 @@ impl ProductQuantization { // compute term2 at query time pub fn init_query(&self, query: &[S::Scalar]) -> Vec { - if matches!(self.metric, Distance::Cos) { + if S::DISTANCE == Distance::Cos { return Vec::new(); } let dims = self.dims; @@ -380,26 +382,26 @@ impl ProductQuantization { // add up all terms given codes pub fn distance_with_codes( &self, - lhs: &[S::Scalar], + lhs: S::VectorRef<'_>, rhs: u32, delta: &[S::Scalar], key: usize, coarse_dis: F32, runtime_table: &[F32], ) -> F32 { - if matches!(self.metric, Distance::Cos) { + if S::DISTANCE == Distance::Cos { return self.distance_with_delta(lhs, rhs, delta); } let mut result = coarse_dis; let codes = self.codes(rhs); let width = self.dims.div_ceil(self.ratio); let precomputed_table = &self.precomputed_table[key * width as usize * 256..]; - if matches!(self.metric, Distance::L2) { + if S::DISTANCE == Distance::L2 { for i in 0..width { result += precomputed_table[i as usize * 256 + codes[i as usize] as usize] + F32(2.0) * runtime_table[i as usize * 256 + codes[i as usize] as usize]; } - } else if matches!(self.metric, Distance::Dot) { + } else if S::DISTANCE == Distance::Dot { for i in 0..width { result += runtime_table[i as usize * 256 + codes[i as usize] as usize]; } @@ -407,7 +409,7 @@ impl ProductQuantization { result } - pub fn distance_with_delta(&self, lhs: &[S::Scalar], rhs: u32, delta: &[S::Scalar]) -> F32 { + pub fn distance_with_delta(&self, lhs: S::VectorRef<'_>, rhs: u32, delta: &[S::Scalar]) -> F32 { let dims = self.dims; let ratio = self.ratio; let rhs = self.codes(rhs); diff --git a/crates/service/src/algorithms/quantization/scalar.rs b/crates/service/src/algorithms/quantization/scalar.rs index a81aad53f..d0e4e8c73 100644 --- a/crates/service/src/algorithms/quantization/scalar.rs +++ b/crates/service/src/algorithms/quantization/scalar.rs @@ -5,8 +5,9 @@ use crate::index::IndexOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use crate::utils::mmap_array::MmapArray; +use base::scalar::FloatCast; use serde::{Deserialize, Serialize}; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use validator::Validate; @@ -40,19 +41,24 @@ impl ScalarQuantization { impl Quan for ScalarQuantization { fn create( - path: PathBuf, + path: &Path, options: IndexOptions, _: QuantizationOptions, raw: &Arc>, permutation: Vec, // permutation is the mapping from placements to original ids ) -> Self { - std::fs::create_dir(&path).unwrap(); + assert!( + S::KIND != Kind::SparseF32, + "Scalar quantization is not supported for sparse vectors." + ); + + std::fs::create_dir(path).unwrap(); let dims = options.vector.dims; let mut max = vec![S::Scalar::neg_infinity(); dims as usize]; let mut min = vec![S::Scalar::infinity(); dims as usize]; let n = raw.len(); for i in 0..n { - let vector = raw.vector(permutation[i as usize]); + let vector = S::to_dense(raw.vector(permutation[i as usize])); for j in 0..dims as usize { max[j] = std::cmp::max(max[j], vector[j]); min[j] = std::cmp::min(min[j], vector[j]); @@ -61,7 +67,7 @@ impl Quan for ScalarQuantization { std::fs::write(path.join("max"), serde_json::to_string(&max).unwrap()).unwrap(); std::fs::write(path.join("min"), serde_json::to_string(&min).unwrap()).unwrap(); let codes_iter = (0..n).flat_map(|i| { - let vector = raw.vector(permutation[i as usize]); + let vector = S::to_dense(raw.vector(permutation[i as usize])); let mut result = vec![0u8; dims as usize]; for i in 0..dims as usize { let w = (((vector[i] - min[i]) / (max[i] - min[i])).to_f32() * 256.0) as u32; @@ -69,8 +75,8 @@ impl Quan for ScalarQuantization { } result.into_iter() }); - let codes = MmapArray::create(path.join("codes"), codes_iter); - sync_dir(&path); + let codes = MmapArray::create(&path.join("codes"), codes_iter); + sync_dir(path); Self { dims, max, @@ -79,11 +85,11 @@ impl Quan for ScalarQuantization { } } - fn open(path: PathBuf, options: IndexOptions, _: QuantizationOptions, _: &Arc>) -> Self { + fn open2(path: &Path, options: IndexOptions, _: QuantizationOptions, _: &Arc>) -> Self { let dims = options.vector.dims; let max = serde_json::from_slice(&std::fs::read("max").unwrap()).unwrap(); let min = serde_json::from_slice(&std::fs::read("min").unwrap()).unwrap(); - let codes = MmapArray::open(path.join("codes")); + let codes = MmapArray::open(&path.join("codes")); Self { dims, max, @@ -92,7 +98,7 @@ impl Quan for ScalarQuantization { } } - fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { + fn distance(&self, lhs: S::VectorRef<'_>, rhs: u32) -> F32 { let dims = self.dims; let rhs = self.codes(rhs); S::scalar_quantization_distance(dims, &self.max, &self.min, lhs, rhs) diff --git a/crates/service/src/algorithms/quantization/trivial.rs b/crates/service/src/algorithms/quantization/trivial.rs index 7700b2a16..d2fb75562 100644 --- a/crates/service/src/algorithms/quantization/trivial.rs +++ b/crates/service/src/algorithms/quantization/trivial.rs @@ -5,7 +5,7 @@ use crate::index::IndexOptions; use crate::prelude::*; use crate::utils::dir_ops::sync_dir; use serde::{Deserialize, Serialize}; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use validator::Validate; @@ -25,7 +25,7 @@ pub struct TrivialQuantization { } impl TrivialQuantization { - pub fn codes(&self, i: u32) -> &[S::Scalar] { + pub fn codes(&self, i: u32) -> S::VectorRef<'_> { self.raw.vector(self.permutation[i as usize]) } } @@ -33,15 +33,15 @@ impl TrivialQuantization { impl Quan for TrivialQuantization { // permutation is the mapping from placements to original ids fn create( - path: PathBuf, + path: &Path, _: IndexOptions, _: QuantizationOptions, raw: &Arc>, permutation: Vec, ) -> Self { // here we cannot modify raw, so we record permutation for translation - std::fs::create_dir(&path).unwrap(); - sync_dir(&path); + std::fs::create_dir(path).unwrap(); + sync_dir(path); std::fs::write( path.join("permutation"), serde_json::to_string(&permutation).unwrap(), @@ -53,7 +53,7 @@ impl Quan for TrivialQuantization { } } - fn open(path: PathBuf, _: IndexOptions, _: QuantizationOptions, raw: &Arc>) -> Self { + fn open2(path: &Path, _: IndexOptions, _: QuantizationOptions, raw: &Arc>) -> Self { let permutation = serde_json::from_slice(&std::fs::read(path.join("permutation")).unwrap()).unwrap(); Self { @@ -62,7 +62,7 @@ impl Quan for TrivialQuantization { } } - fn distance(&self, lhs: &[S::Scalar], rhs: u32) -> F32 { + fn distance(&self, lhs: S::VectorRef<'_>, rhs: u32) -> F32 { S::distance(lhs, self.codes(rhs)) } diff --git a/crates/service/src/algorithms/raw.rs b/crates/service/src/algorithms/raw.rs index f94ba9d59..b501f7330 100644 --- a/crates/service/src/algorithms/raw.rs +++ b/crates/service/src/algorithms/raw.rs @@ -2,61 +2,68 @@ use crate::index::segments::growing::GrowingSegment; use crate::index::segments::sealed::SealedSegment; use crate::index::IndexOptions; use crate::prelude::*; -use crate::utils::mmap_array::MmapArray; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; pub struct Raw { - mmap: RawMmap, + mmap: S::Storage, } impl Raw { - pub fn create( - path: PathBuf, + pub fn create G = S::VectorRef<'a>>>( + path: &Path, options: IndexOptions, - sealed: Vec>>, - growing: Vec>>, + sealed: Vec>>, + growing: Vec>>, ) -> Self { - std::fs::create_dir(&path).unwrap(); + std::fs::create_dir(path).unwrap(); let ram = make(sealed, growing, options); - let mmap = save(ram, path.clone()); - crate::utils::dir_ops::sync_dir(&path); - Self { mmap } - } - - pub fn open(path: PathBuf, options: IndexOptions) -> Self { - let mmap = load(path.clone(), options); + let mmap = S::Storage::save(path, ram); + crate::utils::dir_ops::sync_dir(path); Self { mmap } } +} +impl Raw { pub fn len(&self) -> u32 { self.mmap.len() } - pub fn vector(&self, i: u32) -> &[S::Scalar] { + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { self.mmap.vector(i) } pub fn payload(&self, i: u32) -> Payload { self.mmap.payload(i) } + + pub fn open(path: &Path, options: IndexOptions) -> Self { + Self { + mmap: S::Storage::open(path, options), + } + } } unsafe impl Send for Raw {} unsafe impl Sync for Raw {} -struct RawRam { +pub struct RawRam { sealed: Vec>>, growing: Vec>>, dims: u16, } impl RawRam { - fn len(&self) -> u32 { + pub fn dims(&self) -> u16 { + self.dims + } + + pub fn len(&self) -> u32 { self.sealed.iter().map(|x| x.len()).sum::() + self.growing.iter().map(|x| x.len()).sum::() } - fn vector(&self, mut index: u32) -> &[S::Scalar] { + + pub fn vector(&self, mut index: u32) -> S::VectorRef<'_> { for x in self.sealed.iter() { if index < x.len() { return x.vector(index); @@ -71,7 +78,8 @@ impl RawRam { } panic!("Out of bound.") } - fn payload(&self, mut index: u32) -> Payload { + + pub fn payload(&self, mut index: u32) -> Payload { for x in self.sealed.iter() { if index < x.len() { return x.payload(index); @@ -88,31 +96,6 @@ impl RawRam { } } -struct RawMmap { - vectors: MmapArray, - payload: MmapArray, - dims: u16, -} - -impl RawMmap { - fn len(&self) -> u32 { - self.payload.len() as u32 - } - - fn vector(&self, i: u32) -> &[S::Scalar] { - let s = i as usize * self.dims as usize; - let e = (i + 1) as usize * self.dims as usize; - &self.vectors[s..e] - } - - fn payload(&self, i: u32) -> Payload { - self.payload[i as usize] - } -} - -unsafe impl Send for RawMmap {} -unsafe impl Sync for RawMmap {} - fn make( sealed: Vec>>, growing: Vec>>, @@ -124,26 +107,3 @@ fn make( dims: options.vector.dims, } } - -fn save(ram: RawRam, path: PathBuf) -> RawMmap { - let n = ram.len(); - let vectors_iter = (0..n).flat_map(|i| ram.vector(i)).copied(); - let payload_iter = (0..n).map(|i| ram.payload(i)); - let vectors = MmapArray::create(path.join("vectors"), vectors_iter); - let payload = MmapArray::create(path.join("payload"), payload_iter); - RawMmap { - vectors, - payload, - dims: ram.dims, - } -} - -fn load(path: PathBuf, options: IndexOptions) -> RawMmap { - let vectors = MmapArray::open(path.join("vectors")); - let payload = MmapArray::open(path.join("payload")); - RawMmap { - vectors, - payload, - dims: options.vector.dims, - } -} diff --git a/crates/service/src/algorithms/vamana.rs.txt b/crates/service/src/algorithms/vamana.rs.txt deleted file mode 100644 index 98f6b728e..000000000 --- a/crates/service/src/algorithms/vamana.rs.txt +++ /dev/null @@ -1,456 +0,0 @@ -#![allow(unused)] - -use crate::algorithms::raw::Raw; -use crate::prelude::*; -use crossbeam::atomic::AtomicCell; -use parking_lot::RwLock; -use parking_lot::RwLockReadGuard; -use parking_lot::RwLockWriteGuard; -use rand::distributions::Uniform; -use rand::prelude::SliceRandom; -use rand::Rng; -use rayon::prelude::*; -use std::cmp::Reverse; -use std::collections::{BTreeMap, BinaryHeap, HashSet}; -use std::sync::Arc; - -pub struct VertexWithDistance { - pub id: u32, - pub distance: Scalar, -} - -impl VertexWithDistance { - pub fn new(id: u32, distance: Scalar) -> Self { - Self { id, distance } - } -} - -impl PartialEq for VertexWithDistance { - fn eq(&self, other: &Self) -> bool { - self.distance.eq(&other.distance) - } -} - -impl Eq for VertexWithDistance {} - -impl PartialOrd for VertexWithDistance { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.distance.cmp(&other.distance)) - } -} - -impl Ord for VertexWithDistance { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.distance.cmp(&other.distance) - } -} - -/// DiskANN search state. -pub struct SearchState { - pub visited: HashSet, - candidates: BTreeMap, - heap: BinaryHeap>, - heap_visited: HashSet, - l: usize, - /// Number of results to return. - //TODO: used during search. - k: usize, -} - -impl SearchState { - /// Creates a new search state. - pub(crate) fn new(k: usize, l: usize) -> Self { - Self { - visited: HashSet::new(), - candidates: BTreeMap::new(), - heap: BinaryHeap::new(), - heap_visited: HashSet::new(), - k, - l, - } - } - - /// Return the next unvisited vertex. - fn pop(&mut self) -> Option { - while let Some(vertex) = self.heap.pop() { - if !self.candidates.contains_key(&vertex.0.distance) { - // The vertex has been removed from the candidate lists, - // from [`push()`]. - continue; - } - - self.visited.insert(vertex.0.id); - return Some(vertex.0.id); - } - - None - } - - /// Push a new (unvisited) vertex into the search state. - fn push(&mut self, vertex_id: u32, distance: Scalar) { - assert!(!self.visited.contains(&vertex_id)); - self.heap_visited.insert(vertex_id); - self.heap - .push(Reverse(VertexWithDistance::new(vertex_id, distance))); - self.candidates.insert(distance, vertex_id); - if self.candidates.len() > self.l { - self.candidates.pop_last(); - } - } - - /// Mark a vertex as visited. - fn visit(&mut self, vertex_id: u32) { - self.visited.insert(vertex_id); - } - - // Returns true if the vertex has been visited. - fn is_visited(&self, vertex_id: u32) -> bool { - self.visited.contains(&vertex_id) || self.heap_visited.contains(&vertex_id) - } -} - -pub struct VamanaImpl { - raw: Arc, - - /// neighbors[vertex_id*r..(vertex_id+1)*r] records r neighbors for each vertex - neighbors: Vec>, - - /// neighbor_size[vertex_id] records the actual number of neighbors for each vertex - /// the RwLock is for protecting both the data for size and original data - neighbor_size: Vec>, - - /// the entry for the entire graph, the closet vector to centroid - medoid: u32, - - dims: u16, - r: u32, - alpha: f32, - l: usize, - - d: Distance, -} - -unsafe impl Send for VamanaImpl {} -unsafe impl Sync for VamanaImpl {} - -impl VamanaImpl { - pub fn new( - raw: Arc, - n: u32, - dims: u16, - r: u32, - alpha: f32, - l: usize, - d: Distance, - ) -> Self { - let neighbors = { - let mut result = Vec::new(); - result.resize_with(r as usize * n as usize, || AtomicCell::new(0)); - result - }; - let neighbor_size = unsafe { - let mut result = Vec::new(); - result.resize_with(n as usize, || RwLock::new(0)); - result - }; - let medoid = 0; - - let mut new_vamana = Self { - raw, - neighbors, - neighbor_size, - medoid, - dims, - r, - alpha, - l, - d, - }; - - // 1. init graph with r random neighbors for each node - let rng = rand::thread_rng(); - new_vamana._init_graph(n, rng.clone()); - - // 2. find medoid - new_vamana.medoid = new_vamana._find_medoid(n); - - // 3. iterate pass - new_vamana._one_pass(n, 1.0, r, l, rng.clone()); - - new_vamana._one_pass(n, alpha, r, l, rng.clone()); - - new_vamana - } - - pub fn search(&self, target: Box<[Scalar]>, k: usize, f: F) -> Vec<(Scalar, Payload)> - where - F: FnMut(Payload) -> bool, - { - // TODO: filter - let state = self._greedy_search_with_filter(0, &target, k, k * 2, f); - - let mut results = BinaryHeap::<(Scalar, u32)>::new(); - for (distance, row) in state.candidates { - if results.len() == k { - break; - } - - results.push((distance, row)); - } - let mut res_vec: Vec<(Scalar, Payload)> = results - .iter() - .map(|x| (x.0, self.raw.payload(x.1))) - .collect(); - res_vec.sort(); - res_vec - } - - fn _greedy_search_with_filter( - &self, - start: u32, - query: &[Scalar], - k: usize, - search_size: usize, - mut f: F, - ) -> SearchState - where - F: FnMut(Payload) -> bool, - { - let mut state = SearchState::new(k, search_size); - - let dist = self.d.distance(query, self.raw.vector(start)); - state.push(start, dist); - while let Some(id) = state.pop() { - // only pop id in the search list but not visited - state.visit(id); - { - let guard = self.neighbor_size[id as usize].read(); - let neighbor_ids = self._get_neighbors(id, &guard); - for neighbor_id in neighbor_ids { - let neighbor_id = neighbor_id.load(); - if state.is_visited(neighbor_id) { - continue; - } - - if f(self.raw.payload(neighbor_id)) { - let dist = self.d.distance(query, self.raw.vector(neighbor_id)); - state.push(neighbor_id, dist); // push and retain closet l nodes - } - } - } - } - - state - } - - fn _init_graph(&self, n: u32, mut rng: impl Rng) { - let distribution = Uniform::new(0, n); - for i in 0..n { - let mut neighbor_ids: HashSet = HashSet::new(); - if self.r < n { - while neighbor_ids.len() < self.r as usize { - let neighbor_id = rng.sample(distribution); - if neighbor_id != i { - neighbor_ids.insert(neighbor_id); - } - } - } else { - neighbor_ids = (0..n).collect(); - } - - { - let mut guard = self.neighbor_size[i as usize].write(); - self._set_neighbors(i, &neighbor_ids, &mut guard); - } - } - } - - fn _set_neighbors( - &self, - vertex_index: u32, - neighbor_ids: &HashSet, - guard: &mut RwLockWriteGuard, - ) { - assert!(neighbor_ids.len() <= self.r as usize); - for (i, item) in neighbor_ids.iter().enumerate() { - self.neighbors[vertex_index as usize * self.r as usize + i].store(*item); - } - **guard = neighbor_ids.len() as u32; - } - - fn _get_neighbors( - &self, - vertex_index: u32, - guard: &RwLockReadGuard, - ) -> &[AtomicCell] { - //TODO: store neighbor length - let size = **guard; - &self.neighbors[(vertex_index as usize * self.r as usize) - ..(vertex_index as usize * self.r as usize + size as usize)] - } - - fn _get_neighbors_with_write_guard( - &self, - vertex_index: u32, - guard: &RwLockWriteGuard, - ) -> &[AtomicCell] { - let size = **guard; - &self.neighbors[(vertex_index as usize * self.r as usize) - ..(vertex_index as usize * self.r as usize + size as usize)] - } - - fn _find_medoid(&self, n: u32) -> u32 { - let centroid = self._compute_centroid(n); - let centroid_arr: &[Scalar] = ¢roid; - - let mut medoid_index = 0; - let mut min_dis = Scalar::INFINITY; - for i in 0..n { - let dis = self.d.distance(centroid_arr, self.raw.vector(i)); - if dis < min_dis { - min_dis = dis; - medoid_index = i; - } - } - medoid_index - } - - fn _compute_centroid(&self, n: u32) -> Vec { - let dim = self.dims as usize; - let mut sum = vec![0_f64; dim]; // change to f32 to avoid overflow - for i in 0..n { - let vec = self.raw.vector(i); - for j in 0..dim { - sum[j] += f32::from(vec[j]) as f64; - } - } - - let collection: Vec = sum - .iter() - .map(|v| Scalar::from((*v / n as f64) as f32)) - .collect(); - collection - } - - // r and l leave here for multiple pass extension - fn _one_pass(&self, n: u32, alpha: f32, r: u32, l: usize, mut rng: impl Rng) { - let mut ids = (0..n).collect::>(); - ids.shuffle(&mut rng); - - ids.into_par_iter() - .for_each(|id| self.search_and_prune_for_one_vertex(id, alpha, r, l)); - } - - fn search_and_prune_for_one_vertex(&self, id: u32, alpha: f32, r: u32, l: usize) { - let query = self.raw.vector(id); - let mut state = self._greedy_search(self.medoid, query, 1, l); - state.visited.remove(&id); // in case visited has id itself - let mut new_neighbor_ids: HashSet = HashSet::new(); - { - let mut guard = self.neighbor_size[id as usize].write(); - let neighbor_ids = self._get_neighbors_with_write_guard(id, &guard); - state.visited.extend(neighbor_ids.iter().map(|x| x.load())); - let neighbor_ids = self._robust_prune(id, state.visited, alpha, r); - let neighbor_ids: HashSet = neighbor_ids.into_iter().collect(); - self._set_neighbors(id, &neighbor_ids, &mut guard); - new_neighbor_ids = neighbor_ids; - } - - for &neighbor_id in new_neighbor_ids.iter() { - { - let mut guard = self.neighbor_size[neighbor_id as usize].write(); - let old_neighbors = self._get_neighbors_with_write_guard(neighbor_id, &guard); - let mut old_neighbors: HashSet = - old_neighbors.iter().map(|x| x.load()).collect(); - old_neighbors.insert(id); - if old_neighbors.len() > r as usize { - // need robust prune - let new_neighbors = self._robust_prune(neighbor_id, old_neighbors, alpha, r); - let new_neighbors: HashSet = new_neighbors.into_iter().collect(); - self._set_neighbors(neighbor_id, &new_neighbors, &mut guard); - } else { - self._set_neighbors(neighbor_id, &old_neighbors, &mut guard); - } - } - } - } - - fn _greedy_search( - &self, - start: u32, - query: &[Scalar], - k: usize, - search_size: usize, - ) -> SearchState { - let mut state = SearchState::new(k, search_size); - - let dist = self.d.distance(query, self.raw.vector(start)); - state.push(start, dist); - while let Some(id) = state.pop() { - // only pop id in the search list but not visited - state.visit(id); - { - let guard = self.neighbor_size[id as usize].read(); - let neighbor_ids = self._get_neighbors(id, &guard); - for neighbor_id in neighbor_ids { - let neighbor_id = neighbor_id.load(); - if state.is_visited(neighbor_id) { - continue; - } - - let dist = self.d.distance(query, self.raw.vector(neighbor_id)); - state.push(neighbor_id, dist); // push and retain closet l nodes - } - } - } - - state - } - - fn _robust_prune(&self, id: u32, mut visited: HashSet, alpha: f32, r: u32) -> Vec { - let mut heap: BinaryHeap = visited - .iter() - .map(|v| { - let dist = self.d.distance(self.raw.vector(id), self.raw.vector(*v)); - VertexWithDistance { - id: *v, - distance: dist, - } - }) - .collect(); - - let mut new_neighbor_ids: Vec = vec![]; - while !visited.is_empty() { - if let Some(mut p) = heap.pop() { - while !visited.contains(&p.id) { - match heap.pop() { - Some(value) => { - p = value; - } - None => { - return new_neighbor_ids; - } - } - } - new_neighbor_ids.push(p.id); - if new_neighbor_ids.len() >= r as usize { - break; - } - let mut to_remove: HashSet = HashSet::new(); - for pv in visited.iter() { - let dist_prime = self.d.distance(self.raw.vector(p.id), self.raw.vector(*pv)); - let dist_query = self.d.distance(self.raw.vector(id), self.raw.vector(*pv)); - if Scalar::from(alpha) * dist_prime <= dist_query { - to_remove.insert(*pv); - } - } - for pv in to_remove.iter() { - visited.remove(pv); - } - } else { - return new_neighbor_ids; - } - } - new_neighbor_ids - } -} diff --git a/crates/service/src/index/indexing/flat.rs b/crates/service/src/index/indexing/flat.rs index 13b8ff0f2..71cc83edc 100644 --- a/crates/service/src/index/indexing/flat.rs +++ b/crates/service/src/index/indexing/flat.rs @@ -8,7 +8,7 @@ use crate::{algorithms::flat::Flat, index::segments::sealed::SealedSegment}; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; use std::collections::BinaryHeap; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use validator::Validate; @@ -34,7 +34,7 @@ pub struct FlatIndexing { impl AbstractIndexing for FlatIndexing { fn create( - path: PathBuf, + path: &Path, options: IndexOptions, sealed: Vec>>, growing: Vec>>, @@ -43,26 +43,9 @@ impl AbstractIndexing for FlatIndexing { Self { raw } } - fn open(path: PathBuf, options: IndexOptions) -> Self { - let raw = Flat::open(path, options); - Self { raw } - } - - fn len(&self) -> u32 { - self.raw.len() - } - - fn vector(&self, i: u32) -> &[S::Scalar] { - self.raw.vector(i) - } - - fn payload(&self, i: u32) -> Payload { - self.raw.payload(i) - } - fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -71,10 +54,30 @@ impl AbstractIndexing for FlatIndexing { fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { self.raw.vbase(vector, opts, filter) } } + +impl FlatIndexing { + pub fn len(&self) -> u32 { + self.raw.len() + } + + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + self.raw.vector(i) + } + + pub fn payload(&self, i: u32) -> Payload { + self.raw.payload(i) + } + + pub fn open(path: &Path, options: IndexOptions) -> Self { + Self { + raw: Flat::open(path, options), + } + } +} diff --git a/crates/service/src/index/indexing/hnsw.rs b/crates/service/src/index/indexing/hnsw.rs index 362f1b014..d283c1778 100644 --- a/crates/service/src/index/indexing/hnsw.rs +++ b/crates/service/src/index/indexing/hnsw.rs @@ -9,7 +9,8 @@ use crate::prelude::*; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; use std::collections::BinaryHeap; -use std::{path::PathBuf, sync::Arc}; +use std::path::Path; +use std::sync::Arc; use validator::Validate; #[derive(Debug, Clone, Serialize, Deserialize, Validate)] @@ -51,7 +52,7 @@ pub struct HnswIndexing { impl AbstractIndexing for HnswIndexing { fn create( - path: PathBuf, + path: &Path, options: IndexOptions, sealed: Vec>>, growing: Vec>>, @@ -60,26 +61,9 @@ impl AbstractIndexing for HnswIndexing { Self { raw } } - fn open(path: PathBuf, options: IndexOptions) -> Self { - let raw = Hnsw::open(path, options); - Self { raw } - } - - fn len(&self) -> u32 { - self.raw.len() - } - - fn vector(&self, i: u32) -> &[S::Scalar] { - self.raw.vector(i) - } - - fn payload(&self, i: u32) -> Payload { - self.raw.payload(i) - } - fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -88,10 +72,30 @@ impl AbstractIndexing for HnswIndexing { fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { self.raw.vbase(vector, opts, filter) } } + +impl HnswIndexing { + pub fn len(&self) -> u32 { + self.raw.len() + } + + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + self.raw.vector(i) + } + + pub fn payload(&self, i: u32) -> Payload { + self.raw.payload(i) + } + + pub fn open(path: &Path, options: IndexOptions) -> Self { + Self { + raw: Hnsw::open(path, options), + } + } +} diff --git a/crates/service/src/index/indexing/ivf.rs b/crates/service/src/index/indexing/ivf.rs index 1425c0c7a..6959d7f51 100644 --- a/crates/service/src/index/indexing/ivf.rs +++ b/crates/service/src/index/indexing/ivf.rs @@ -9,7 +9,7 @@ use crate::prelude::*; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; use std::collections::BinaryHeap; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use validator::Validate; @@ -66,7 +66,7 @@ pub struct IvfIndexing { impl AbstractIndexing for IvfIndexing { fn create( - path: PathBuf, + path: &Path, options: IndexOptions, sealed: Vec>>, growing: Vec>>, @@ -75,26 +75,9 @@ impl AbstractIndexing for IvfIndexing { Self { raw } } - fn open(path: PathBuf, options: IndexOptions) -> Self { - let raw = Ivf::open(path, options); - Self { raw } - } - - fn len(&self) -> u32 { - self.raw.len() - } - - fn vector(&self, i: u32) -> &[S::Scalar] { - self.raw.vector(i) - } - - fn payload(&self, i: u32) -> Payload { - self.raw.payload(i) - } - fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { @@ -103,10 +86,30 @@ impl AbstractIndexing for IvfIndexing { fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box<(dyn Iterator + 'a)>) { self.raw.vbase(vector, opts, filter) } } + +impl IvfIndexing { + pub fn len(&self) -> u32 { + self.raw.len() + } + + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + self.raw.vector(i) + } + + pub fn payload(&self, i: u32) -> Payload { + self.raw.payload(i) + } + + pub fn open(path: &Path, options: IndexOptions) -> Self { + Self { + raw: Ivf::open(path, options), + } + } +} diff --git a/crates/service/src/index/indexing/mod.rs b/crates/service/src/index/indexing/mod.rs index 9b2468d87..de4ddd7b4 100644 --- a/crates/service/src/index/indexing/mod.rs +++ b/crates/service/src/index/indexing/mod.rs @@ -8,12 +8,13 @@ use self::ivf::{IvfIndexing, IvfIndexingOptions}; use super::segments::growing::GrowingSegment; use super::segments::sealed::SealedSegment; use super::IndexOptions; +use crate::algorithms::quantization::QuantizationOptions; use crate::index::SearchOptions; use crate::prelude::*; use serde::{Deserialize, Serialize}; use std::cmp::Reverse; use std::collections::BinaryHeap; -use std::path::PathBuf; +use std::path::Path; use std::sync::Arc; use validator::Validate; @@ -45,6 +46,14 @@ impl IndexingOptions { }; x } + pub fn has_quantization(&self) -> bool { + let option = match self { + Self::Flat(x) => &x.quantization, + Self::Ivf(x) => &x.quantization, + Self::Hnsw(x) => &x.quantization, + }; + !matches!(option, QuantizationOptions::Trivial(_)) + } } impl Default for IndexingOptions { @@ -63,26 +72,22 @@ impl Validate for IndexingOptions { } } -pub trait AbstractIndexing: Sized { +pub trait AbstractIndexing { fn create( - path: PathBuf, + path: &Path, options: IndexOptions, sealed: Vec>>, growing: Vec>>, ) -> Self; - fn open(path: PathBuf, options: IndexOptions) -> Self; - fn len(&self) -> u32; - fn vector(&self, i: u32) -> &[S::Scalar]; - fn payload(&self, i: u32) -> Payload; fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap>; fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box + 'a>); @@ -96,7 +101,7 @@ pub enum DynamicIndexing { impl DynamicIndexing { pub fn create( - path: PathBuf, + path: &Path, options: IndexOptions, sealed: Vec>>, growing: Vec>>, @@ -114,11 +119,29 @@ impl DynamicIndexing { } } - pub fn open(path: PathBuf, options: IndexOptions) -> Self { - match options.indexing { - IndexingOptions::Flat(_) => Self::Flat(FlatIndexing::open(path, options)), - IndexingOptions::Ivf(_) => Self::Ivf(IvfIndexing::open(path, options)), - IndexingOptions::Hnsw(_) => Self::Hnsw(HnswIndexing::open(path, options)), + pub fn basic( + &self, + vector: S::VectorRef<'_>, + opts: &SearchOptions, + filter: impl Filter, + ) -> BinaryHeap> { + match self { + DynamicIndexing::Flat(x) => x.basic(vector, opts, filter), + DynamicIndexing::Ivf(x) => x.basic(vector, opts, filter), + DynamicIndexing::Hnsw(x) => x.basic(vector, opts, filter), + } + } + + pub fn vbase<'a>( + &'a self, + vector: S::VectorRef<'a>, + opts: &'a SearchOptions, + filter: impl Filter + 'a, + ) -> (Vec, Box<(dyn Iterator + 'a)>) { + match self { + DynamicIndexing::Flat(x) => x.vbase(vector, opts, filter), + DynamicIndexing::Ivf(x) => x.vbase(vector, opts, filter), + DynamicIndexing::Hnsw(x) => x.vbase(vector, opts, filter), } } @@ -130,7 +153,7 @@ impl DynamicIndexing { } } - pub fn vector(&self, i: u32) -> &[S::Scalar] { + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { match self { DynamicIndexing::Flat(x) => x.vector(i), DynamicIndexing::Ivf(x) => x.vector(i), @@ -146,29 +169,11 @@ impl DynamicIndexing { } } - pub fn basic( - &self, - vector: &[S::Scalar], - opts: &SearchOptions, - filter: impl Filter, - ) -> BinaryHeap> { - match self { - DynamicIndexing::Flat(x) => x.basic(vector, opts, filter), - DynamicIndexing::Ivf(x) => x.basic(vector, opts, filter), - DynamicIndexing::Hnsw(x) => x.basic(vector, opts, filter), - } - } - - pub fn vbase<'a>( - &'a self, - vector: &'a [S::Scalar], - opts: &'a SearchOptions, - filter: impl Filter + 'a, - ) -> (Vec, Box<(dyn Iterator + 'a)>) { - match self { - DynamicIndexing::Flat(x) => x.vbase(vector, opts, filter), - DynamicIndexing::Ivf(x) => x.vbase(vector, opts, filter), - DynamicIndexing::Hnsw(x) => x.vbase(vector, opts, filter), + pub fn open(path: &Path, options: IndexOptions) -> Self { + match options.indexing { + IndexingOptions::Flat(_) => Self::Flat(FlatIndexing::open(path, options)), + IndexingOptions::Ivf(_) => Self::Ivf(IvfIndexing::open(path, options)), + IndexingOptions::Hnsw(_) => Self::Hnsw(HnswIndexing::open(path, options)), } } } diff --git a/crates/service/src/index/mod.rs b/crates/service/src/index/mod.rs index a1cd65a17..0093fad72 100644 --- a/crates/service/src/index/mod.rs +++ b/crates/service/src/index/mod.rs @@ -29,6 +29,7 @@ use std::time::Instant; use thiserror::Error; use uuid::Uuid; use validator::Validate; +use validator::ValidationError; #[derive(Debug, Error)] #[error("The index view is outdated.")] @@ -48,6 +49,7 @@ pub struct VectorOptions { #[derive(Debug, Clone, Serialize, Deserialize, Validate)] #[serde(deny_unknown_fields)] +#[validate(schema(function = "validate_index_options"))] pub struct IndexOptions { #[validate] pub vector: VectorOptions, @@ -59,6 +61,15 @@ pub struct IndexOptions { pub indexing: IndexingOptions, } +fn validate_index_options(options: &IndexOptions) -> Result<(), ValidationError> { + if options.vector.k == Kind::SparseF32 && options.indexing.has_quantization() { + return Err(ValidationError::new( + "quantization is not supported for sparse vector", + )); + } + Ok(()) +} + #[derive(Debug, Clone, Serialize, Deserialize, Validate)] pub struct SearchOptions { pub prefilter_enable: bool, @@ -79,13 +90,10 @@ pub struct SegmentStat { } #[derive(Debug, Serialize, Deserialize)] -pub enum IndexStat { - Normal { - indexing: bool, - segments: Vec, - options: IndexOptions, - }, - Upgrade, +pub struct IndexStat { + pub indexing: bool, + pub segments: Vec, + pub options: IndexOptions, } pub struct Index { @@ -100,10 +108,10 @@ pub struct Index { } impl Index { - pub fn create(path: PathBuf, options: IndexOptions) -> Result, ServiceError> { + pub fn create(path: PathBuf, options: IndexOptions) -> Result, CreateError> { if let Err(err) = options.validate() { - return Err(ServiceError::BadOption { - validation: err.to_string(), + return Err(CreateError::InvalidIndexOptions { + reason: err.to_string(), }); } std::fs::create_dir(&path).unwrap(); @@ -147,6 +155,7 @@ impl Index { OptimizerSealing::new(index.clone()).spawn(); Ok(index) } + pub fn open(path: PathBuf) -> Arc { let options = serde_json::from_slice::(&std::fs::read(path.join("options")).unwrap()) @@ -189,6 +198,7 @@ impl Index { tracker.clone(), path.join("segments").join(uuid.to_string()), uuid, + options.clone(), ), ) }) @@ -262,7 +272,7 @@ impl Index { } pub fn stat(&self) -> IndexStat { let view = self.view(); - IndexStat::Normal { + IndexStat { indexing: self.instant_index.load() < self.instant_write.load(), options: self.options().clone(), segments: { @@ -308,12 +318,17 @@ pub struct IndexView { impl IndexView { pub fn basic<'a, F: Fn(Pointer) -> bool + Clone + 'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'_>, opts: &'a SearchOptions, filter: F, - ) -> Result + 'a, ServiceError> { - if self.options.vector.dims as usize != vector.len() { - return Err(ServiceError::Unmatched); + ) -> Result + 'a, BasicError> { + if self.options.vector.dims != vector.dims() { + return Err(BasicError::InvalidVector); + } + if let Err(err) = opts.validate() { + return Err(BasicError::InvalidSearchOptions { + reason: err.to_string(), + }); } struct Comparer(std::collections::BinaryHeap>); @@ -381,12 +396,17 @@ impl IndexView { } pub fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, opts: &'a SearchOptions, filter: F, - ) -> Result + 'a, ServiceError> { - if self.options.vector.dims as usize != vector.len() { - return Err(ServiceError::Unmatched); + ) -> Result + 'a, VbaseError> { + if self.options.vector.dims != vector.dims() { + return Err(VbaseError::InvalidVector); + } + if let Err(err) = opts.validate() { + return Err(VbaseError::InvalidSearchOptions { + reason: err.to_string(), + }); } struct Filtering<'a, F: 'a> { @@ -448,7 +468,7 @@ impl IndexView { } })) } - pub fn list(&self) -> impl Iterator + '_ { + pub fn list(&self) -> Result + '_, ListError> { let sealed = self .sealed .values() @@ -462,19 +482,21 @@ impl IndexView { .iter() .map(|(_, x)| x) .flat_map(|x| (0..x.len()).map(|i| x.payload(i))); - sealed + let iter = sealed .chain(growing) .chain(write) - .filter_map(|p| self.delete.check(p)) + .filter_map(|p| self.delete.check(p)); + Ok(iter) } pub fn insert( &self, - vector: Vec, + vector: S::VectorOwned, pointer: Pointer, - ) -> Result, ServiceError> { - if self.options.vector.dims as usize != vector.len() { - return Err(ServiceError::Unmatched); + ) -> Result, InsertError> { + if self.options.vector.dims != vector.dims() { + return Err(InsertError::InvalidVector); } + let payload = (pointer.as_u48() << 16) | self.delete.version(pointer) as Payload; if let Some((_, growing)) = self.write.as_ref() { use crate::index::segments::growing::GrowingSegmentInsertError; @@ -486,14 +508,16 @@ impl IndexView { Ok(Err(OutdatedError)) } } - pub fn delete(&self, p: Pointer) { + pub fn delete(&self, p: Pointer) -> Result<(), DeleteError> { self.delete.delete(p); + Ok(()) } - pub fn flush(&self) { + pub fn flush(&self) -> Result<(), FlushError> { self.delete.flush(); if let Some((_, write)) = &self.write { write.flush(); } + Ok(()) } } diff --git a/crates/service/src/index/optimizing/mod.rs b/crates/service/src/index/optimizing/mod.rs index 67972a93e..2525b5057 100644 --- a/crates/service/src/index/optimizing/mod.rs +++ b/crates/service/src/index/optimizing/mod.rs @@ -14,9 +14,9 @@ pub struct OptimizingOptions { #[serde(default = "OptimizingOptions::default_sealing_size")] #[validate(range(min = 1, max = 4_000_000_000))] pub sealing_size: u32, - #[serde(default = "OptimizingOptions::default_deleted_threshold", skip)] + #[serde(default = "OptimizingOptions::default_delete_threshold")] #[validate(range(min = 0.01, max = 1.00))] - pub deleted_threshold: f64, + pub delete_threshold: f64, #[serde(default = "OptimizingOptions::default_optimizing_threads")] #[validate(range(min = 1, max = 65535))] pub optimizing_threads: usize, @@ -29,7 +29,7 @@ impl OptimizingOptions { fn default_sealing_size() -> u32 { 1 } - fn default_deleted_threshold() -> f64 { + fn default_delete_threshold() -> f64 { 0.2 } fn default_optimizing_threads() -> usize { @@ -45,7 +45,7 @@ impl Default for OptimizingOptions { Self { sealing_secs: Self::default_sealing_secs(), sealing_size: Self::default_sealing_size(), - deleted_threshold: Self::default_deleted_threshold(), + delete_threshold: Self::default_delete_threshold(), optimizing_threads: Self::default_optimizing_threads(), } } diff --git a/crates/service/src/index/segments/growing.rs b/crates/service/src/index/segments/growing.rs index fa635f47e..0a353d36d 100644 --- a/crates/service/src/index/segments/growing.rs +++ b/crates/service/src/index/segments/growing.rs @@ -59,7 +59,13 @@ impl GrowingSegment { _tracker: Arc::new(SegmentTracker { path, _tracker }), }) } - pub fn open(_tracker: Arc, path: PathBuf, uuid: Uuid) -> Arc { + + pub fn open( + _tracker: Arc, + path: PathBuf, + uuid: Uuid, + _: IndexOptions, + ) -> Arc { let mut wal = FileWal::open(path.join("wal")); let mut vec = Vec::new(); while let Some(log) = wal.read() { @@ -80,9 +86,11 @@ impl GrowingSegment { _tracker: Arc::new(SegmentTracker { path, _tracker }), }) } + pub fn uuid(&self) -> Uuid { self.uuid } + pub fn is_full(&self) -> bool { let n; { @@ -97,6 +105,7 @@ impl GrowingSegment { } true } + pub fn seal(&self) { let n; { @@ -109,12 +118,14 @@ impl GrowingSegment { } self.wal.lock().sync_all(); } + pub fn flush(&self) { self.wal.lock().sync_all(); } + pub fn insert( &self, - vector: Vec, + vector: S::VectorOwned, payload: Payload, ) -> Result<(), GrowingSegmentInsertError> { let log = Log { vector, payload }; @@ -139,9 +150,11 @@ impl GrowingSegment { .write(&bincode::serialize::>(&log).unwrap()); Ok(()) } + pub fn len(&self) -> u32 { self.len.load(Ordering::Acquire) as u32 } + pub fn stat_growing(&self) -> SegmentStat { SegmentStat { id: self.uuid, @@ -150,6 +163,7 @@ impl GrowingSegment { size: (self.len() as u64) * (std::mem::size_of::>() as u64), } } + pub fn stat_write(&self) -> SegmentStat { SegmentStat { id: self.uuid, @@ -158,14 +172,16 @@ impl GrowingSegment { size: (self.len() as u64) * (std::mem::size_of::>() as u64), } } - pub fn vector(&self, i: u32) -> &[S::Scalar] { + + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { let i = i as usize; if i >= self.len.load(Ordering::Acquire) { panic!("Out of bound."); } let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; - log.vector.as_ref() + S::owned_to_ref(&log.vector) } + pub fn payload(&self, i: u32) -> Payload { let i = i as usize; if i >= self.len.load(Ordering::Acquire) { @@ -174,9 +190,10 @@ impl GrowingSegment { let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; log.payload } + pub fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, _opts: &SearchOptions, mut filter: impl Filter, ) -> BinaryHeap> { @@ -185,7 +202,7 @@ impl GrowingSegment { for i in 0..n { let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; if filter.check(log.payload) { - let distance = S::distance(vector, &log.vector); + let distance = S::distance(vector, S::owned_to_ref(&log.vector)); result.push(Reverse(Element { distance, payload: log.payload, @@ -194,9 +211,10 @@ impl GrowingSegment { } result } + pub fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, _opts: &SearchOptions, mut filter: impl Filter + 'a, ) -> (Vec, Box + 'a>) { @@ -205,7 +223,7 @@ impl GrowingSegment { for i in 0..n { let log = unsafe { (*self.vec[i].get()).assume_init_ref() }; if filter.check(log.payload) { - let distance = S::distance(vector, &log.vector); + let distance = S::distance(vector, S::owned_to_ref(&log.vector)); result.push(Element { distance, payload: log.payload, @@ -232,7 +250,7 @@ impl Drop for GrowingSegment { #[derive(Debug, Clone, Serialize, Deserialize)] struct Log { - vector: Vec, + vector: S::VectorOwned, payload: Payload, } diff --git a/crates/service/src/index/segments/sealed.rs b/crates/service/src/index/segments/sealed.rs index fc9675723..07ffb8f52 100644 --- a/crates/service/src/index/segments/sealed.rs +++ b/crates/service/src/index/segments/sealed.rs @@ -26,7 +26,7 @@ impl SealedSegment { growing: Vec>>, ) -> Arc { std::fs::create_dir(&path).unwrap(); - let indexing = DynamicIndexing::create(path.join("indexing"), options, sealed, growing); + let indexing = DynamicIndexing::create(&path.join("indexing"), options, sealed, growing); sync_dir(&path); Arc::new(Self { uuid, @@ -34,25 +34,25 @@ impl SealedSegment { _tracker: Arc::new(SegmentTracker { path, _tracker }), }) } + pub fn open( _tracker: Arc, path: PathBuf, uuid: Uuid, options: IndexOptions, ) -> Arc { - let indexing = DynamicIndexing::open(path.join("indexing"), options); + let indexing = DynamicIndexing::open(&path.join("indexing"), options); Arc::new(Self { uuid, indexing, _tracker: Arc::new(SegmentTracker { path, _tracker }), }) } + pub fn uuid(&self) -> Uuid { self.uuid } - pub fn len(&self) -> u32 { - self.indexing.len() - } + pub fn stat_sealed(&self) -> SegmentStat { let path = self._tracker.path.join("indexing"); SegmentStat { @@ -62,26 +62,34 @@ impl SealedSegment { size: dir_size(&path).unwrap(), } } - pub fn vector(&self, i: u32) -> &[S::Scalar] { - self.indexing.vector(i) - } - pub fn payload(&self, i: u32) -> Payload { - self.indexing.payload(i) - } + pub fn basic( &self, - vector: &[S::Scalar], + vector: S::VectorRef<'_>, opts: &SearchOptions, filter: impl Filter, ) -> BinaryHeap> { self.indexing.basic(vector, opts, filter) } + pub fn vbase<'a>( &'a self, - vector: &'a [S::Scalar], + vector: S::VectorRef<'a>, opts: &'a SearchOptions, filter: impl Filter + 'a, ) -> (Vec, Box + 'a>) { self.indexing.vbase(vector, opts, filter) } + + pub fn len(&self) -> u32 { + self.indexing.len() + } + + pub fn vector(&self, i: u32) -> S::VectorRef<'_> { + self.indexing.vector(i) + } + + pub fn payload(&self, i: u32) -> Payload { + self.indexing.payload(i) + } } diff --git a/crates/service/src/instance/mod.rs b/crates/service/src/instance/mod.rs index 20b061fc5..3007458f4 100644 --- a/crates/service/src/instance/mod.rs +++ b/crates/service/src/instance/mod.rs @@ -10,6 +10,22 @@ use crate::prelude::*; use std::path::PathBuf; use std::sync::Arc; +pub trait InstanceViewOperations { + fn basic<'a, F: Fn(Pointer) -> bool + Clone + 'a>( + &'a self, + vector: &'a DynamicVector, + opts: &'a SearchOptions, + filter: F, + ) -> Result + 'a>, BasicError>; + fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>( + &'a self, + vector: &'a DynamicVector, + opts: &'a SearchOptions, + filter: F, + ) -> Result + 'a>, VbaseError>; + fn list(&self) -> Result + '_>, ListError>; +} + #[derive(Clone)] pub enum Instance { F32Cos(Arc>), @@ -18,11 +34,14 @@ pub enum Instance { F16Cos(Arc>), F16Dot(Arc>), F16L2(Arc>), + SparseF32L2(Arc>), + SparseF32Cos(Arc>), + SparseF32Dot(Arc>), Upgrade, } impl Instance { - pub fn create(path: PathBuf, options: IndexOptions) -> Result { + pub fn create(path: PathBuf, options: IndexOptions) -> Result { match (options.vector.d, options.vector.k) { (Distance::Cos, Kind::F32) => { let index = Index::create(path.clone(), options)?; @@ -54,6 +73,21 @@ impl Instance { self::metadata::Metadata::write(path.join("metadata")); Ok(Self::F16L2(index)) } + (Distance::L2, Kind::SparseF32) => { + let index = Index::create(path.clone(), options)?; + self::metadata::Metadata::write(path.join("metadata")); + Ok(Self::SparseF32L2(index)) + } + (Distance::Cos, Kind::SparseF32) => { + let index = Index::create(path.clone(), options)?; + self::metadata::Metadata::write(path.join("metadata")); + Ok(Self::SparseF32Cos(index)) + } + (Distance::Dot, Kind::SparseF32) => { + let index = Index::create(path.clone(), options)?; + self::metadata::Metadata::write(path.join("metadata")); + Ok(Self::SparseF32Dot(index)) + } } } pub fn open(path: PathBuf) -> Self { @@ -70,6 +104,9 @@ impl Instance { (Distance::Cos, Kind::F16) => Self::F16Cos(Index::open(path)), (Distance::Dot, Kind::F16) => Self::F16Dot(Index::open(path)), (Distance::L2, Kind::F16) => Self::F16L2(Index::open(path)), + (Distance::L2, Kind::SparseF32) => Self::SparseF32L2(Index::open(path)), + (Distance::Cos, Kind::SparseF32) => Self::SparseF32Cos(Index::open(path)), + (Distance::Dot, Kind::SparseF32) => Self::SparseF32Dot(Index::open(path)), } } pub fn refresh(&self) { @@ -80,6 +117,9 @@ impl Instance { Instance::F16Cos(x) => x.refresh(), Instance::F16Dot(x) => x.refresh(), Instance::F16L2(x) => x.refresh(), + Instance::SparseF32L2(x) => x.refresh(), + Instance::SparseF32Cos(x) => x.refresh(), + Instance::SparseF32Dot(x) => x.refresh(), Instance::Upgrade => (), } } @@ -91,18 +131,24 @@ impl Instance { Instance::F16Cos(x) => Some(InstanceView::F16Cos(x.view())), Instance::F16Dot(x) => Some(InstanceView::F16Dot(x.view())), Instance::F16L2(x) => Some(InstanceView::F16L2(x.view())), + Instance::SparseF32L2(x) => Some(InstanceView::SparseF32L2(x.view())), + Instance::SparseF32Cos(x) => Some(InstanceView::SparseF32Cos(x.view())), + Instance::SparseF32Dot(x) => Some(InstanceView::SparseF32Dot(x.view())), Instance::Upgrade => None, } } - pub fn stat(&self) -> IndexStat { + pub fn stat(&self) -> Option { match self { - Instance::F32Cos(x) => x.stat(), - Instance::F32Dot(x) => x.stat(), - Instance::F32L2(x) => x.stat(), - Instance::F16Cos(x) => x.stat(), - Instance::F16Dot(x) => x.stat(), - Instance::F16L2(x) => x.stat(), - Instance::Upgrade => IndexStat::Upgrade, + Instance::F32Cos(x) => Some(x.stat()), + Instance::F32Dot(x) => Some(x.stat()), + Instance::F32L2(x) => Some(x.stat()), + Instance::F16Cos(x) => Some(x.stat()), + Instance::F16Dot(x) => Some(x.stat()), + Instance::F16L2(x) => Some(x.stat()), + Instance::SparseF32L2(x) => Some(x.stat()), + Instance::SparseF32Cos(x) => Some(x.stat()), + Instance::SparseF32Dot(x) => Some(x.stat()), + Instance::Upgrade => None, } } } @@ -114,15 +160,18 @@ pub enum InstanceView { F16Cos(Arc>), F16Dot(Arc>), F16L2(Arc>), + SparseF32Cos(Arc>), + SparseF32Dot(Arc>), + SparseF32L2(Arc>), } -impl InstanceView { - pub fn basic<'a, F: Fn(Pointer) -> bool + Clone + 'a>( +impl InstanceViewOperations for InstanceView { + fn basic<'a, F: Fn(Pointer) -> bool + Clone + 'a>( &'a self, vector: &'a DynamicVector, opts: &'a SearchOptions, filter: F, - ) -> Result + 'a, ServiceError> { + ) -> Result + 'a>, BasicError> { match (self, vector) { (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => { Ok(Box::new(x.basic(vector, opts, filter)?) as Box>) @@ -142,15 +191,24 @@ impl InstanceView { (InstanceView::F16L2(x), DynamicVector::F16(vector)) => { Ok(Box::new(x.basic(vector, opts, filter)?)) } - _ => Err(ServiceError::Unmatched), + (InstanceView::SparseF32Cos(x), DynamicVector::SparseF32(vector)) => { + Ok(Box::new(x.basic(vector.into(), opts, filter)?)) + } + (InstanceView::SparseF32Dot(x), DynamicVector::SparseF32(vector)) => { + Ok(Box::new(x.basic(vector.into(), opts, filter)?)) + } + (InstanceView::SparseF32L2(x), DynamicVector::SparseF32(vector)) => { + Ok(Box::new(x.basic(vector.into(), opts, filter)?)) + } + _ => Err(BasicError::InvalidVector), } } - pub fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>( + fn vbase<'a, F: FnMut(Pointer) -> bool + Clone + 'a>( &'a self, vector: &'a DynamicVector, opts: &'a SearchOptions, filter: F, - ) -> Result + '_, ServiceError> { + ) -> Result + 'a>, VbaseError> { match (self, vector) { (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => { Ok(Box::new(x.vbase(vector, opts, filter)?) as Box>) @@ -170,24 +228,39 @@ impl InstanceView { (InstanceView::F16L2(x), DynamicVector::F16(vector)) => { Ok(Box::new(x.vbase(vector, opts, filter)?)) } - _ => Err(ServiceError::Unmatched), + (InstanceView::SparseF32Cos(x), DynamicVector::SparseF32(vector)) => { + Ok(Box::new(x.vbase(vector.into(), opts, filter)?)) + } + (InstanceView::SparseF32Dot(x), DynamicVector::SparseF32(vector)) => { + Ok(Box::new(x.vbase(vector.into(), opts, filter)?)) + } + (InstanceView::SparseF32L2(x), DynamicVector::SparseF32(vector)) => { + Ok(Box::new(x.vbase(vector.into(), opts, filter)?)) + } + _ => Err(VbaseError::InvalidVector), } } - pub fn list(&self) -> impl Iterator + '_ { + fn list(&self) -> Result + '_>, ListError> { match self { - InstanceView::F32Cos(x) => Box::new(x.list()) as Box>, - InstanceView::F32Dot(x) => Box::new(x.list()), - InstanceView::F32L2(x) => Box::new(x.list()), - InstanceView::F16Cos(x) => Box::new(x.list()), - InstanceView::F16Dot(x) => Box::new(x.list()), - InstanceView::F16L2(x) => Box::new(x.list()), + InstanceView::F32Cos(x) => Ok(Box::new(x.list()?) as Box>), + InstanceView::F32Dot(x) => Ok(Box::new(x.list()?)), + InstanceView::F32L2(x) => Ok(Box::new(x.list()?)), + InstanceView::F16Cos(x) => Ok(Box::new(x.list()?)), + InstanceView::F16Dot(x) => Ok(Box::new(x.list()?)), + InstanceView::F16L2(x) => Ok(Box::new(x.list()?)), + InstanceView::SparseF32Cos(x) => Ok(Box::new(x.list()?)), + InstanceView::SparseF32Dot(x) => Ok(Box::new(x.list()?)), + InstanceView::SparseF32L2(x) => Ok(Box::new(x.list()?)), } } +} + +impl InstanceView { pub fn insert( &self, vector: DynamicVector, pointer: Pointer, - ) -> Result, ServiceError> { + ) -> Result, InsertError> { match (self, vector) { (InstanceView::F32Cos(x), DynamicVector::F32(vector)) => x.insert(vector, pointer), (InstanceView::F32Dot(x), DynamicVector::F32(vector)) => x.insert(vector, pointer), @@ -195,10 +268,19 @@ impl InstanceView { (InstanceView::F16Cos(x), DynamicVector::F16(vector)) => x.insert(vector, pointer), (InstanceView::F16Dot(x), DynamicVector::F16(vector)) => x.insert(vector, pointer), (InstanceView::F16L2(x), DynamicVector::F16(vector)) => x.insert(vector, pointer), - _ => Err(ServiceError::Unmatched), + (InstanceView::SparseF32Cos(x), DynamicVector::SparseF32(vector)) => { + x.insert(vector, pointer) + } + (InstanceView::SparseF32Dot(x), DynamicVector::SparseF32(vector)) => { + x.insert(vector, pointer) + } + (InstanceView::SparseF32L2(x), DynamicVector::SparseF32(vector)) => { + x.insert(vector, pointer) + } + _ => Err(InsertError::InvalidVector), } } - pub fn delete(&self, pointer: Pointer) { + pub fn delete(&self, pointer: Pointer) -> Result<(), DeleteError> { match self { InstanceView::F32Cos(x) => x.delete(pointer), InstanceView::F32Dot(x) => x.delete(pointer), @@ -206,9 +288,12 @@ impl InstanceView { InstanceView::F16Cos(x) => x.delete(pointer), InstanceView::F16Dot(x) => x.delete(pointer), InstanceView::F16L2(x) => x.delete(pointer), + InstanceView::SparseF32Cos(x) => x.delete(pointer), + InstanceView::SparseF32Dot(x) => x.delete(pointer), + InstanceView::SparseF32L2(x) => x.delete(pointer), } } - pub fn flush(&self) { + pub fn flush(&self) -> Result<(), FlushError> { match self { InstanceView::F32Cos(x) => x.flush(), InstanceView::F32Dot(x) => x.flush(), @@ -216,6 +301,9 @@ impl InstanceView { InstanceView::F16Cos(x) => x.flush(), InstanceView::F16Dot(x) => x.flush(), InstanceView::F16L2(x) => x.flush(), + InstanceView::SparseF32Cos(x) => x.flush(), + InstanceView::SparseF32Dot(x) => x.flush(), + InstanceView::SparseF32L2(x) => x.flush(), } } } diff --git a/crates/service/src/prelude/error.rs b/crates/service/src/prelude/error.rs deleted file mode 100644 index f3fa22df9..000000000 --- a/crates/service/src/prelude/error.rs +++ /dev/null @@ -1,37 +0,0 @@ -use serde::{Deserialize, Serialize}; -use thiserror::Error; - -#[must_use] -#[derive(Debug, Clone, Error, Serialize, Deserialize)] -#[rustfmt::skip] -pub enum ServiceError { - #[error("\ -The given index option is invalid. -INFORMATION: reason = {validation:?}\ -")] - BadOption { validation: String }, - #[error("\ -The index is not existing in the background worker. -ADVICE: Drop or rebuild the index.\ -")] - UnknownIndex, -#[error("\ -The index is already existing in the background worker.\ -")] - KnownIndex, - #[error("\ -The given vector is invalid for input. -ADVICE: Check if dimensions and scalar type of the vector is matched with the index.\ -")] - Unmatched, - #[error("\ -The extension is upgraded so all index files are outdated. -ADVICE: Delete all index files. Please read `https://docs.pgvecto.rs/admin/upgrading.html`.\ -")] - Upgrade, - #[error("\ -The extension is upgraded so this index is outdated. -ADVICE: Rebuild the index. Please read `https://docs.pgvecto.rs/admin/upgrading.html`.\ -")] - Upgrade2, -} diff --git a/crates/service/src/prelude/global/f16.rs b/crates/service/src/prelude/global/f16.rs index 8e6a52ddc..be5c560d2 100644 --- a/crates/service/src/prelude/global/f16.rs +++ b/crates/service/src/prelude/global/f16.rs @@ -1,4 +1,5 @@ use crate::prelude::*; +use base::scalar::FloatCast; pub fn cosine(lhs: &[F16], rhs: &[F16]) -> F32 { #[inline(always)] @@ -136,3 +137,34 @@ pub fn sl2(lhs: &[F16], rhs: &[F16]) -> F32 { } sl2(lhs, rhs) } + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +fn length(vector: &[F16]) -> F16 { + let n = vector.len(); + let mut dot = F16::zero(); + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn l2_normalize(vector: &mut [F16]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} diff --git a/crates/service/src/prelude/global/f16_cos.rs b/crates/service/src/prelude/global/f16_cos.rs index 4ac156d60..64c648d03 100644 --- a/crates/service/src/prelude/global/f16_cos.rs +++ b/crates/service/src/prelude/global/f16_cos.rs @@ -1,29 +1,56 @@ -use super::G; -use crate::prelude::scalar::F32; use crate::prelude::*; +use base::scalar::FloatCast; +use std::borrow::Cow; #[derive(Debug, Clone, Copy)] pub enum F16Cos {} impl G for F16Cos { type Scalar = F16; + type Storage = DenseMmap; + type L2 = F16L2; + type VectorOwned = Vec; + type VectorRef<'a> = &'a [F16]; const DISTANCE: Distance = Distance::Cos; + const KIND: Kind = Kind::F16; - type L2 = F16L2; + fn owned_to_ref(vector: &Vec) -> &[F16] { + vector + } + + fn ref_to_owned(vector: &[F16]) -> Vec { + vector.to_vec() + } + + fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F16]> { + Cow::Borrowed(vector) + } fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { F32(1.0) - super::f16::cosine(lhs, rhs) } + fn distance2(lhs: &[F16], rhs: &[F16]) -> F32 { + F32(1.0) - super::f16::cosine(lhs, rhs) + } + fn elkan_k_means_normalize(vector: &mut [F16]) { - l2_normalize(vector) + super::f16::l2_normalize(vector) + } + + fn elkan_k_means_normalize2(vector: &mut Vec) { + super::f16::l2_normalize(vector) } fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { super::f16::dot(lhs, rhs).acos() } + fn elkan_k_means_distance2(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::dot(lhs, rhs).acos() + } + #[multiversion::multiversion(targets( "x86_64/x86-64-v4", "x86_64/x86-64-v3", @@ -170,37 +197,6 @@ impl G for F16Cos { } } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -fn length(vector: &[F16]) -> F16 { - let n = vector.len(); - let mut dot = F16::zero(); - for i in 0..n { - dot += vector[i] * vector[i]; - } - dot.sqrt() -} - -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -fn l2_normalize(vector: &mut [F16]) { - let n = vector.len(); - let l = length(vector); - for i in 0..n { - vector[i] /= l; - } -} - #[inline(always)] #[multiversion::multiversion(targets( "x86_64/x86-64-v4", diff --git a/crates/service/src/prelude/global/f16_dot.rs b/crates/service/src/prelude/global/f16_dot.rs index 085c2b827..0d210a421 100644 --- a/crates/service/src/prelude/global/f16_dot.rs +++ b/crates/service/src/prelude/global/f16_dot.rs @@ -1,29 +1,56 @@ -use super::G; -use crate::prelude::scalar::F32; use crate::prelude::*; +use base::scalar::FloatCast; +use std::borrow::Cow; #[derive(Debug, Clone, Copy)] pub enum F16Dot {} impl G for F16Dot { type Scalar = F16; + type Storage = DenseMmap; + type L2 = F16L2; + type VectorOwned = Vec; + type VectorRef<'a> = &'a [F16]; const DISTANCE: Distance = Distance::Dot; + const KIND: Kind = Kind::F16; - type L2 = F16L2; + fn owned_to_ref(vector: &Vec) -> &[F16] { + vector + } + + fn ref_to_owned(vector: &[F16]) -> Vec { + vector.to_vec() + } + + fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F16]> { + Cow::Borrowed(vector) + } fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { super::f16::dot(lhs, rhs) * (-1.0) } + fn distance2(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::dot(lhs, rhs) * (-1.0) + } + fn elkan_k_means_normalize(vector: &mut [F16]) { - l2_normalize(vector) + super::f16::l2_normalize(vector) + } + + fn elkan_k_means_normalize2(vector: &mut Vec) { + super::f16::l2_normalize(vector) } fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { super::f16::dot(lhs, rhs).acos() } + fn elkan_k_means_distance2(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::dot(lhs, rhs).acos() + } + #[multiversion::multiversion(targets( "x86_64/x86-64-v4", "x86_64/x86-64-v3", @@ -150,37 +177,6 @@ impl G for F16Dot { } } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -fn length(vector: &[F16]) -> F16 { - let n = vector.len(); - let mut dot = F16::zero(); - for i in 0..n { - dot += vector[i] * vector[i]; - } - dot.sqrt() -} - -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -fn l2_normalize(vector: &mut [F16]) { - let n = vector.len(); - let l = length(vector); - for i in 0..n { - vector[i] /= l; - } -} - #[inline(always)] #[multiversion::multiversion(targets( "x86_64/x86-64-v4", diff --git a/crates/service/src/prelude/global/f16_l2.rs b/crates/service/src/prelude/global/f16_l2.rs index 647c6f900..14e7bddbb 100644 --- a/crates/service/src/prelude/global/f16_l2.rs +++ b/crates/service/src/prelude/global/f16_l2.rs @@ -1,28 +1,52 @@ -use super::G; -use crate::prelude::scalar::F16; -use crate::prelude::scalar::F32; use crate::prelude::*; +use base::scalar::FloatCast; +use std::borrow::Cow; #[derive(Debug, Clone, Copy)] pub enum F16L2 {} impl G for F16L2 { type Scalar = F16; + type Storage = DenseMmap; + type L2 = F16L2; + type VectorOwned = Vec; + type VectorRef<'a> = &'a [F16]; const DISTANCE: Distance = Distance::L2; + const KIND: Kind = Kind::F16; - type L2 = F16L2; + fn owned_to_ref(vector: &Vec) -> &[F16] { + vector + } + + fn ref_to_owned(vector: &[F16]) -> Vec { + vector.to_vec() + } + + fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F16]> { + Cow::Borrowed(vector) + } fn distance(lhs: &[F16], rhs: &[F16]) -> F32 { super::f16::sl2(lhs, rhs) } + fn distance2(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::sl2(lhs, rhs) + } + fn elkan_k_means_normalize(_: &mut [F16]) {} + fn elkan_k_means_normalize2(_: &mut Vec) {} + fn elkan_k_means_distance(lhs: &[F16], rhs: &[F16]) -> F32 { super::f16::sl2(lhs, rhs).sqrt() } + fn elkan_k_means_distance2(lhs: &[F16], rhs: &[F16]) -> F32 { + super::f16::sl2(lhs, rhs).sqrt() + } + #[multiversion::multiversion(targets( "x86_64/x86-64-v4", "x86_64/x86-64-v3", diff --git a/crates/service/src/prelude/global/f32.rs b/crates/service/src/prelude/global/f32.rs new file mode 100644 index 000000000..962c50f39 --- /dev/null +++ b/crates/service/src/prelude/global/f32.rs @@ -0,0 +1,88 @@ +use crate::prelude::*; + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + for i in 0..n { + xy += lhs[i] * rhs[i]; + x2 += lhs[i] * lhs[i]; + y2 += rhs[i] * rhs[i]; + } + xy / (x2 * y2).sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut xy = F32::zero(); + for i in 0..n { + xy += lhs[i] * rhs[i]; + } + xy +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn sl2(lhs: &[F32], rhs: &[F32]) -> F32 { + assert!(lhs.len() == rhs.len()); + let n = lhs.len(); + let mut d2 = F32::zero(); + for i in 0..n { + let d = lhs[i] - rhs[i]; + d2 += d * d; + } + d2 +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn length(vector: &[F32]) -> F32 { + let n = vector.len(); + let mut dot = F32::zero(); + for i in 0..n { + dot += vector[i] * vector[i]; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn l2_normalize(vector: &mut [F32]) { + let n = vector.len(); + let l = length(vector); + for i in 0..n { + vector[i] /= l; + } +} diff --git a/crates/service/src/prelude/global/f32_cos.rs b/crates/service/src/prelude/global/f32_cos.rs index e989ceb59..c9e75f92a 100644 --- a/crates/service/src/prelude/global/f32_cos.rs +++ b/crates/service/src/prelude/global/f32_cos.rs @@ -1,27 +1,53 @@ -use super::G; -use crate::prelude::scalar::F32; use crate::prelude::*; +use std::borrow::Cow; #[derive(Debug, Clone, Copy)] pub enum F32Cos {} impl G for F32Cos { type Scalar = F32; + type Storage = DenseMmap; + type L2 = F32L2; + type VectorOwned = Vec; + type VectorRef<'a> = &'a [F32]; const DISTANCE: Distance = Distance::Cos; + const KIND: Kind = Kind::F32; - type L2 = F32L2; + fn owned_to_ref(vector: &Vec) -> &[F32] { + vector + } + + fn ref_to_owned(vector: &[F32]) -> Vec { + vector.to_vec() + } + + fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { + Cow::Borrowed(vector) + } fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { - F32(1.0) - cosine(lhs, rhs) + F32(1.0) - super::f32::cosine(lhs, rhs) + } + + fn distance2(lhs: &[F32], rhs: &[F32]) -> F32 { + F32(1.0) - super::f32::cosine(lhs, rhs) } fn elkan_k_means_normalize(vector: &mut [F32]) { - l2_normalize(vector) + super::f32::l2_normalize(vector) + } + + fn elkan_k_means_normalize2(vector: &mut Vec) { + super::f32::l2_normalize(vector) } fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { - super::f32_dot::dot(lhs, rhs).acos() + super::f32::dot(lhs, rhs).acos() + } + + fn elkan_k_means_distance2(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32::dot(lhs, rhs).acos() } #[multiversion::multiversion(targets( @@ -170,58 +196,6 @@ impl G for F32Cos { } } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -fn length(vector: &[F32]) -> F32 { - let n = vector.len(); - let mut dot = F32::zero(); - for i in 0..n { - dot += vector[i] * vector[i]; - } - dot.sqrt() -} - -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -fn l2_normalize(vector: &mut [F32]) { - let n = vector.len(); - let l = length(vector); - for i in 0..n { - vector[i] /= l; - } -} - -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut xy = F32::zero(); - let mut x2 = F32::zero(); - let mut y2 = F32::zero(); - for i in 0..n { - xy += lhs[i] * rhs[i]; - x2 += lhs[i] * lhs[i]; - y2 += rhs[i] * rhs[i]; - } - xy / (x2 * y2).sqrt() -} - #[inline(always)] #[multiversion::multiversion(targets( "x86_64/x86-64-v4", diff --git a/crates/service/src/prelude/global/f32_dot.rs b/crates/service/src/prelude/global/f32_dot.rs index 08b7d5dd2..d4f58632b 100644 --- a/crates/service/src/prelude/global/f32_dot.rs +++ b/crates/service/src/prelude/global/f32_dot.rs @@ -1,27 +1,53 @@ -use super::G; -use crate::prelude::scalar::F32; use crate::prelude::*; +use std::borrow::Cow; #[derive(Debug, Clone, Copy)] pub enum F32Dot {} impl G for F32Dot { type Scalar = F32; + type Storage = DenseMmap; + type L2 = F32L2; + type VectorOwned = Vec; + type VectorRef<'a> = &'a [F32]; const DISTANCE: Distance = Distance::Dot; + const KIND: Kind = Kind::F32; - type L2 = F32L2; + fn owned_to_ref(vector: &Vec) -> &[F32] { + vector + } + + fn ref_to_owned(vector: &[F32]) -> Vec { + vector.to_vec() + } + + fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { + Cow::Borrowed(vector) + } fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { - dot(lhs, rhs) * (-1.0) + super::f32::dot(lhs, rhs) * (-1.0) + } + + fn distance2(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32::dot(lhs, rhs) * (-1.0) } fn elkan_k_means_normalize(vector: &mut [F32]) { - l2_normalize(vector) + super::f32::l2_normalize(vector) + } + + fn elkan_k_means_normalize2(vector: &mut Vec) { + super::f32::l2_normalize(vector) } fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { - dot(lhs, rhs).acos() + super::f32::dot(lhs, rhs).acos() + } + + fn elkan_k_means_distance2(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32::dot(lhs, rhs).acos() } #[multiversion::multiversion(targets( @@ -88,7 +114,7 @@ impl G for F32Dot { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = dot(lhs, rhs); + let _xy = super::f32::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -115,7 +141,7 @@ impl G for F32Dot { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - let _xy = dot(lhs, rhs); + let _xy = super::f32::dot(lhs, rhs); xy += _xy; } xy * (-1.0) @@ -150,75 +176,6 @@ impl G for F32Dot { } } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -fn length(vector: &[F32]) -> F32 { - let n = vector.len(); - let mut dot = F32::zero(); - for i in 0..n { - dot += vector[i] * vector[i]; - } - dot.sqrt() -} - -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -fn l2_normalize(vector: &mut [F32]) { - let n = vector.len(); - let l = length(vector); - for i in 0..n { - vector[i] /= l; - } -} - -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -fn cosine(lhs: &[F32], rhs: &[F32]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut xy = F32::zero(); - let mut x2 = F32::zero(); - let mut y2 = F32::zero(); - for i in 0..n { - xy += lhs[i] * rhs[i]; - x2 += lhs[i] * lhs[i]; - y2 += rhs[i] * rhs[i]; - } - xy / (x2 * y2).sqrt() -} - -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -pub fn dot(lhs: &[F32], rhs: &[F32]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut xy = F32::zero(); - for i in 0..n { - xy += lhs[i] * rhs[i]; - } - xy -} - #[inline(always)] #[multiversion::multiversion(targets( "x86_64/x86-64-v4", diff --git a/crates/service/src/prelude/global/f32_l2.rs b/crates/service/src/prelude/global/f32_l2.rs index 2672b6714..815d67303 100644 --- a/crates/service/src/prelude/global/f32_l2.rs +++ b/crates/service/src/prelude/global/f32_l2.rs @@ -1,25 +1,49 @@ -use super::G; -use crate::prelude::scalar::F32; use crate::prelude::*; +use std::borrow::Cow; #[derive(Debug, Clone, Copy)] pub enum F32L2 {} impl G for F32L2 { type Scalar = F32; + type Storage = DenseMmap; + type L2 = F32L2; + type VectorOwned = Vec; + type VectorRef<'a> = &'a [F32]; const DISTANCE: Distance = Distance::L2; + const KIND: Kind = Kind::F32; - type L2 = F32L2; + fn owned_to_ref(vector: &Vec) -> &[F32] { + vector + } + + fn ref_to_owned(vector: &[F32]) -> Vec { + vector.to_vec() + } + + fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { + Cow::Borrowed(vector) + } fn distance(lhs: &[F32], rhs: &[F32]) -> F32 { - distance_squared_l2(lhs, rhs) + super::f32::sl2(lhs, rhs) + } + + fn distance2(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32::sl2(lhs, rhs) } fn elkan_k_means_normalize(_: &mut [F32]) {} + fn elkan_k_means_normalize2(_: &mut Vec) {} + fn elkan_k_means_distance(lhs: &[F32], rhs: &[F32]) -> F32 { - distance_squared_l2(lhs, rhs).sqrt() + super::f32::sl2(lhs, rhs).sqrt() + } + + fn elkan_k_means_distance2(lhs: &[F32], rhs: &[F32]) -> F32 { + super::f32::sl2(lhs, rhs).sqrt() } #[multiversion::multiversion(targets( @@ -86,7 +110,7 @@ impl G for F32L2 { let lhs = &lhs[(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += distance_squared_l2(lhs, rhs); + result += super::f32::sl2(lhs, rhs); } result } @@ -112,7 +136,7 @@ impl G for F32L2 { let lhs = ¢roids[lhsp..][(i * ratio) as usize..][..k as usize]; let rhsp = rhs[i as usize] as usize * dims as usize; let rhs = ¢roids[rhsp..][(i * ratio) as usize..][..k as usize]; - result += distance_squared_l2(lhs, rhs); + result += super::f32::sl2(lhs, rhs); } result } @@ -145,24 +169,6 @@ impl G for F32L2 { } } -#[inline(always)] -#[multiversion::multiversion(targets( - "x86_64/x86-64-v4", - "x86_64/x86-64-v3", - "x86_64/x86-64-v2", - "aarch64+neon" -))] -pub fn distance_squared_l2(lhs: &[F32], rhs: &[F32]) -> F32 { - assert!(lhs.len() == rhs.len()); - let n = lhs.len(); - let mut d2 = F32::zero(); - for i in 0..n { - let d = lhs[i] - rhs[i]; - d2 += d * d; - } - d2 -} - #[inline(always)] #[multiversion::multiversion(targets( "x86_64/x86-64-v4", diff --git a/crates/service/src/prelude/global/mod.rs b/crates/service/src/prelude/global/mod.rs index 804f20e85..857847433 100644 --- a/crates/service/src/prelude/global/mod.rs +++ b/crates/service/src/prelude/global/mod.rs @@ -2,9 +2,14 @@ mod f16; mod f16_cos; mod f16_dot; mod f16_l2; +mod f32; mod f32_cos; mod f32_dot; mod f32_l2; +mod sparse_f32; +mod sparse_f32_cos; +mod sparse_f32_dot; +mod sparse_f32_l2; pub use f16_cos::F16Cos; pub use f16_dot::F16Dot; @@ -12,10 +17,17 @@ pub use f16_l2::F16L2; pub use f32_cos::F32Cos; pub use f32_dot::F32Dot; pub use f32_l2::F32L2; +pub use sparse_f32_cos::SparseF32Cos; +pub use sparse_f32_dot::SparseF32Dot; +pub use sparse_f32_l2::SparseF32L2; use crate::prelude::*; +use base::scalar::FloatCast; use serde::{Deserialize, Serialize}; -use std::fmt::{Debug, Display}; +use std::{ + borrow::Cow, + fmt::{Debug, Display}, +}; pub trait G: Copy + Debug + 'static { type Scalar: Copy @@ -33,17 +45,32 @@ pub trait G: Copy + Debug + 'static { + num_traits::NumOps + num_traits::NumAssignOps + FloatCast; + type Storage: for<'a> Storage = Self::VectorRef<'a>>; + type L2: for<'a> G = &'a [Self::Scalar]>; + type VectorOwned: Vector + Clone + Serialize + for<'a> Deserialize<'a>; + type VectorRef<'a>: Vector + Copy + 'a + where + Self: 'a; + const DISTANCE: Distance; - type L2: G; + const KIND: Kind; + + fn owned_to_ref(vector: &Self::VectorOwned) -> Self::VectorRef<'_>; + fn ref_to_owned(vector: Self::VectorRef<'_>) -> Self::VectorOwned; + fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [Self::Scalar]>; + fn distance(lhs: Self::VectorRef<'_>, rhs: Self::VectorRef<'_>) -> F32; + fn distance2(lhs: Self::VectorRef<'_>, rhs: &[Self::Scalar]) -> F32; - fn distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; fn elkan_k_means_normalize(vector: &mut [Self::Scalar]); + fn elkan_k_means_normalize2(vector: &mut Self::VectorOwned); fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32; + fn elkan_k_means_distance2(lhs: Self::VectorRef<'_>, rhs: &[Self::Scalar]) -> F32; + fn scalar_quantization_distance( dims: u16, max: &[Self::Scalar], min: &[Self::Scalar], - lhs: &[Self::Scalar], + lhs: Self::VectorRef<'_>, rhs: &[u8], ) -> F32; fn scalar_quantization_distance2( @@ -53,11 +80,12 @@ pub trait G: Copy + Debug + 'static { lhs: &[u8], rhs: &[u8], ) -> F32; + fn product_quantization_distance( dims: u16, ratio: u16, centroids: &[Self::Scalar], - lhs: &[Self::Scalar], + lhs: Self::VectorRef<'_>, rhs: &[u8], ) -> F32; fn product_quantization_distance2( @@ -71,27 +99,17 @@ pub trait G: Copy + Debug + 'static { dims: u16, ratio: u16, centroids: &[Self::Scalar], - lhs: &[Self::Scalar], + lhs: Self::VectorRef<'_>, rhs: &[u8], delta: &[Self::Scalar], ) -> F32; } -pub trait FloatCast: Sized { - fn from_f32(x: f32) -> Self; - fn to_f32(self) -> f32; - fn from_f(x: F32) -> Self { - Self::from_f32(x.0) - } - fn to_f(self) -> F32 { - F32(Self::to_f32(self)) - } -} - #[derive(Debug, Clone, Serialize, Deserialize)] pub enum DynamicVector { F32(Vec), F16(Vec), + SparseF32(SparseF32), } impl From> for DynamicVector { @@ -106,6 +124,12 @@ impl From> for DynamicVector { } } +impl From for DynamicVector { + fn from(value: SparseF32) -> Self { + Self::SparseF32(value) + } +} + #[repr(u8)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub enum Distance { @@ -119,6 +143,7 @@ pub enum Distance { pub enum Kind { F32, F16, + SparseF32, } pub fn squared_norm(dims: u16, vec: &[S::Scalar]) -> F32 { diff --git a/crates/service/src/prelude/global/sparse_f32.rs b/crates/service/src/prelude/global/sparse_f32.rs new file mode 100644 index 000000000..0490b8b29 --- /dev/null +++ b/crates/service/src/prelude/global/sparse_f32.rs @@ -0,0 +1,162 @@ +use crate::prelude::*; + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn cosine<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { + let mut lhs_pos = 0; + let mut rhs_pos = 0; + let size1 = lhs.length() as usize; + let size2 = rhs.length() as usize; + let mut xy = F32::zero(); + let mut x2 = F32::zero(); + let mut y2 = F32::zero(); + while lhs_pos < size1 && rhs_pos < size2 { + let lhs_index = lhs.indexes[lhs_pos]; + let rhs_index = rhs.indexes[rhs_pos]; + let lhs_value = lhs.values[lhs_pos]; + let rhs_value = rhs.values[rhs_pos]; + xy += F32((lhs_index == rhs_index) as u32 as f32) * lhs_value * rhs_value; + x2 += F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value * lhs_value; + y2 += F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value * rhs_value; + lhs_pos += (lhs_index <= rhs_index) as usize; + rhs_pos += (lhs_index >= rhs_index) as usize; + } + for i in lhs_pos..size1 { + x2 += lhs.values[i] * lhs.values[i]; + } + for i in rhs_pos..size2 { + y2 += rhs.values[i] * rhs.values[i]; + } + xy / (x2 * y2).sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn dot<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { + let mut lhs_pos = 0; + let mut rhs_pos = 0; + let size1 = lhs.length() as usize; + let size2 = rhs.length() as usize; + let mut xy = F32::zero(); + while lhs_pos < size1 && rhs_pos < size2 { + let lhs_index = lhs.indexes[lhs_pos]; + let rhs_index = rhs.indexes[rhs_pos]; + let lhs_value = lhs.values[lhs_pos]; + let rhs_value = rhs.values[rhs_pos]; + xy += F32((lhs_index == rhs_index) as u32 as f32) * lhs_value * rhs_value; + lhs_pos += (lhs_index <= rhs_index) as usize; + rhs_pos += (lhs_index >= rhs_index) as usize; + } + xy +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn dot_2<'a>(lhs: SparseF32Ref<'a>, rhs: &[F32]) -> F32 { + let mut xy = F32::zero(); + for i in 0..lhs.indexes.len() { + xy += lhs.values[i] * rhs[lhs.indexes[i] as usize]; + } + xy +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn sl2<'a>(lhs: SparseF32Ref<'a>, rhs: SparseF32Ref<'a>) -> F32 { + let mut lhs_pos = 0; + let mut rhs_pos = 0; + let size1 = lhs.length() as usize; + let size2 = rhs.length() as usize; + let mut d2 = F32::zero(); + while lhs_pos < size1 && rhs_pos < size2 { + let lhs_index = lhs.indexes[lhs_pos]; + let rhs_index = rhs.indexes[rhs_pos]; + let lhs_value = lhs.values[lhs_pos]; + let rhs_value = rhs.values[rhs_pos]; + let d = F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value + - F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value; + d2 += d * d; + lhs_pos += (lhs_index <= rhs_index) as usize; + rhs_pos += (lhs_index >= rhs_index) as usize; + } + for i in lhs_pos..size1 { + d2 += lhs.values[i] * lhs.values[i]; + } + for i in rhs_pos..size2 { + d2 += rhs.values[i] * rhs.values[i]; + } + d2 +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn sl2_2<'a>(lhs: SparseF32Ref<'a>, rhs: &[F32]) -> F32 { + let mut d2 = F32::zero(); + let mut lhs_pos: u16 = 0; + let mut rhs_pos: u16 = 0; + while lhs_pos < lhs.length() { + let index_eq = lhs.indexes[lhs_pos as usize] == rhs_pos; + let d = F32(index_eq as u32 as f32) * lhs.values[lhs_pos as usize] - rhs[rhs_pos as usize]; + d2 += d * d; + lhs_pos += index_eq as u16; + rhs_pos += 1; + } + for i in rhs_pos..rhs.len() as u16 { + d2 += rhs[i as usize] * rhs[i as usize]; + } + d2 +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn length<'a>(vector: SparseF32Ref<'a>) -> F32 { + let mut dot = F32::zero(); + for &i in vector.values { + dot += i * i; + } + dot.sqrt() +} + +#[inline(always)] +#[multiversion::multiversion(targets( + "x86_64/x86-64-v4", + "x86_64/x86-64-v3", + "x86_64/x86-64-v2", + "aarch64+neon" +))] +pub fn l2_normalize(vector: &mut SparseF32) { + let l = length(SparseF32Ref::from(vector as &SparseF32)); + for i in vector.values.iter_mut() { + *i /= l; + } +} diff --git a/crates/service/src/prelude/global/sparse_f32_cos.rs b/crates/service/src/prelude/global/sparse_f32_cos.rs new file mode 100644 index 000000000..372f8d744 --- /dev/null +++ b/crates/service/src/prelude/global/sparse_f32_cos.rs @@ -0,0 +1,104 @@ +use std::borrow::Cow; + +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum SparseF32Cos {} + +impl G for SparseF32Cos { + type Scalar = F32; + type Storage = SparseMmap; + type L2 = F32L2; + type VectorOwned = SparseF32; + type VectorRef<'a> = SparseF32Ref<'a>; + + const DISTANCE: Distance = Distance::Cos; + const KIND: Kind = Kind::SparseF32; + + fn owned_to_ref(vector: &SparseF32) -> SparseF32Ref<'_> { + SparseF32Ref::from(vector) + } + + fn ref_to_owned(vector: SparseF32Ref<'_>) -> SparseF32 { + SparseF32::from(vector) + } + + fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { + Cow::Owned(vector.to_dense()) + } + + fn distance(lhs: Self::VectorRef<'_>, rhs: Self::VectorRef<'_>) -> F32 { + F32(1.0) - super::sparse_f32::cosine(lhs, rhs) + } + + fn distance2(_lhs: Self::VectorRef<'_>, _rhs: &[Self::Scalar]) -> F32 { + unimplemented!() + } + + fn elkan_k_means_normalize(vector: &mut [Self::Scalar]) { + super::f32::l2_normalize(vector) + } + + fn elkan_k_means_normalize2(vector: &mut SparseF32) { + super::sparse_f32::l2_normalize(vector) + } + + fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32 { + super::f32::dot(lhs, rhs).acos() + } + + fn elkan_k_means_distance2(lhs: Self::VectorRef<'_>, rhs: &[Self::Scalar]) -> F32 { + super::sparse_f32::dot_2(lhs, rhs).acos() + } + + fn scalar_quantization_distance( + _dims: u16, + _max: &[F32], + _min: &[F32], + _lhs: Self::VectorRef<'_>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn scalar_quantization_distance2( + _dims: u16, + _max: &[Self::Scalar], + _min: &[Self::Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance( + _dims: u16, + _ratio: u16, + _centroids: &[Self::Scalar], + _lhs: Self::VectorRef<'_>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance2( + _dims: u16, + _ratio: u16, + _centroids: &[Self::Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance_with_delta( + _dims: u16, + _ratio: u16, + _centroids: &[Self::Scalar], + _lhs: Self::VectorRef<'_>, + _rhs: &[u8], + _delta: &[Self::Scalar], + ) -> F32 { + unimplemented!() + } +} diff --git a/crates/service/src/prelude/global/sparse_f32_dot.rs b/crates/service/src/prelude/global/sparse_f32_dot.rs new file mode 100644 index 000000000..8c19bc987 --- /dev/null +++ b/crates/service/src/prelude/global/sparse_f32_dot.rs @@ -0,0 +1,104 @@ +use std::borrow::Cow; + +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum SparseF32Dot {} + +impl G for SparseF32Dot { + type Scalar = F32; + type Storage = SparseMmap; + type L2 = F32L2; + type VectorOwned = SparseF32; + type VectorRef<'a> = SparseF32Ref<'a>; + + const DISTANCE: Distance = Distance::Dot; + const KIND: Kind = Kind::SparseF32; + + fn owned_to_ref(vector: &SparseF32) -> SparseF32Ref<'_> { + SparseF32Ref::from(vector) + } + + fn ref_to_owned(vector: SparseF32Ref<'_>) -> SparseF32 { + SparseF32::from(vector) + } + + fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { + Cow::Owned(vector.to_dense()) + } + + fn distance(lhs: Self::VectorRef<'_>, rhs: Self::VectorRef<'_>) -> F32 { + super::sparse_f32::dot(lhs, rhs) * (-1.0) + } + + fn distance2(_lhs: Self::VectorRef<'_>, _rhs: &[Self::Scalar]) -> F32 { + unimplemented!() + } + + fn elkan_k_means_normalize(vector: &mut [Self::Scalar]) { + super::f32::l2_normalize(vector) + } + + fn elkan_k_means_normalize2(vector: &mut SparseF32) { + super::sparse_f32::l2_normalize(vector) + } + + fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32 { + super::f32::dot(lhs, rhs).acos() + } + + fn elkan_k_means_distance2(lhs: Self::VectorRef<'_>, rhs: &[Self::Scalar]) -> F32 { + super::sparse_f32::dot_2(lhs, rhs).acos() + } + + fn scalar_quantization_distance( + _dims: u16, + _max: &[Self::Scalar], + _min: &[Self::Scalar], + _lhs: Self::VectorRef<'_>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn scalar_quantization_distance2( + _dims: u16, + _max: &[Self::Scalar], + _min: &[Self::Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance( + _dims: u16, + _ratio: u16, + _centroids: &[Self::Scalar], + _lhs: Self::VectorRef<'_>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance2( + _dims: u16, + _ratio: u16, + _centroids: &[Self::Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance_with_delta( + _dims: u16, + _ratio: u16, + _centroids: &[Self::Scalar], + _lhs: Self::VectorRef<'_>, + _rhs: &[u8], + _delta: &[Self::Scalar], + ) -> F32 { + unimplemented!() + } +} diff --git a/crates/service/src/prelude/global/sparse_f32_l2.rs b/crates/service/src/prelude/global/sparse_f32_l2.rs new file mode 100644 index 000000000..b17559e7b --- /dev/null +++ b/crates/service/src/prelude/global/sparse_f32_l2.rs @@ -0,0 +1,100 @@ +use std::borrow::Cow; + +use crate::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum SparseF32L2 {} + +impl G for SparseF32L2 { + type Scalar = F32; + type Storage = SparseMmap; + type L2 = F32L2; + type VectorOwned = SparseF32; + type VectorRef<'a> = SparseF32Ref<'a>; + + const DISTANCE: Distance = Distance::L2; + const KIND: Kind = Kind::SparseF32; + + fn owned_to_ref(vector: &SparseF32) -> SparseF32Ref<'_> { + SparseF32Ref::from(vector) + } + + fn ref_to_owned(vector: SparseF32Ref<'_>) -> SparseF32 { + SparseF32::from(vector) + } + + fn to_dense(vector: Self::VectorRef<'_>) -> Cow<'_, [F32]> { + Cow::Owned(vector.to_dense()) + } + + fn distance(lhs: SparseF32Ref<'_>, rhs: SparseF32Ref<'_>) -> F32 { + super::sparse_f32::sl2(lhs, rhs) + } + + fn distance2(_lhs: Self::VectorRef<'_>, _rhs: &[Self::Scalar]) -> F32 { + unimplemented!() + } + + fn elkan_k_means_normalize(_: &mut [Self::Scalar]) {} + + fn elkan_k_means_normalize2(_: &mut SparseF32) {} + + fn elkan_k_means_distance(lhs: &[Self::Scalar], rhs: &[Self::Scalar]) -> F32 { + super::f32::sl2(lhs, rhs).sqrt() + } + + fn elkan_k_means_distance2(lhs: SparseF32Ref<'_>, rhs: &[Self::Scalar]) -> F32 { + super::sparse_f32::sl2_2(lhs, rhs).sqrt() + } + + fn scalar_quantization_distance( + _dims: u16, + _max: &[Self::Scalar], + _min: &[Self::Scalar], + _lhs: SparseF32Ref<'_>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn scalar_quantization_distance2( + _dims: u16, + _max: &[Self::Scalar], + _min: &[Self::Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance( + _dims: u16, + _ratio: u16, + _centroids: &[Self::Scalar], + _lhs: SparseF32Ref<'_>, + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance2( + _dims: u16, + _ratio: u16, + _centroids: &[Self::Scalar], + _lhs: &[u8], + _rhs: &[u8], + ) -> F32 { + unimplemented!() + } + + fn product_quantization_distance_with_delta( + _dims: u16, + _ratio: u16, + _centroids: &[Self::Scalar], + _lhs: SparseF32Ref<'_>, + _rhs: &[u8], + _delta: &[Self::Scalar], + ) -> F32 { + unimplemented!() + } +} diff --git a/crates/service/src/prelude/mod.rs b/crates/service/src/prelude/mod.rs index 0dd1ab456..a559acb7e 100644 --- a/crates/service/src/prelude/mod.rs +++ b/crates/service/src/prelude/mod.rs @@ -1,13 +1,13 @@ -mod error; mod global; -mod scalar; -mod search; -mod sys; +mod storage; -pub use self::error::ServiceError; pub use self::global::*; -pub use self::scalar::{F16, F32}; -pub use self::search::{Element, Filter, Payload}; -pub use self::sys::{Handle, Pointer}; +pub use self::storage::{DenseMmap, SparseMmap, Storage}; + +pub use base::error::*; +pub use base::scalar::{F16, F32}; +pub use base::search::{Element, Filter, Payload}; +pub use base::sys::{Handle, Pointer}; +pub use base::vector::{SparseF32, SparseF32Ref, Vector}; pub use num_traits::{Float, Zero}; diff --git a/crates/service/src/prelude/scalar/mod.rs b/crates/service/src/prelude/scalar/mod.rs deleted file mode 100644 index 1894a906f..000000000 --- a/crates/service/src/prelude/scalar/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod f16; -mod f32; - -pub use f16::F16; -pub use f32::F32; diff --git a/crates/service/src/prelude/storage/dense.rs b/crates/service/src/prelude/storage/dense.rs new file mode 100644 index 000000000..3cb85b35c --- /dev/null +++ b/crates/service/src/prelude/storage/dense.rs @@ -0,0 +1,65 @@ +use crate::algorithms::raw::RawRam; +use crate::index::IndexOptions; +use crate::prelude::*; +use crate::utils::mmap_array::MmapArray; +use std::path::Path; + +pub struct DenseMmap { + vectors: MmapArray, + payload: MmapArray, + dims: u16, +} + +impl Storage for DenseMmap +where + T: Copy + bytemuck::Pod, +{ + type VectorRef<'a> = &'a [T]; + + fn dims(&self) -> u16 { + self.dims + } + + fn len(&self) -> u32 { + self.payload.len() as u32 + } + + fn vector(&self, i: u32) -> &[T] { + let s = i as usize * self.dims as usize; + let e = (i + 1) as usize * self.dims as usize; + &self.vectors[s..e] + } + + fn payload(&self, i: u32) -> Payload { + self.payload[i as usize] + } + + fn open(path: &Path, options: IndexOptions) -> Self + where + Self: Sized, + { + let vectors = MmapArray::open(&path.join("vectors")); + let payload = MmapArray::open(&path.join("payload")); + Self { + vectors, + payload, + dims: options.vector.dims, + } + } + + fn save G = Self::VectorRef<'a>>>( + path: &Path, + ram: RawRam, + ) -> Self { + let n = ram.len(); + let vectors_iter = (0..n).flat_map(|i| ram.vector(i).iter()).copied(); + let payload_iter = (0..n).map(|i| ram.payload(i)); + let vectors = MmapArray::create(&path.join("vectors"), vectors_iter); + let payload = MmapArray::create(&path.join("payload"), payload_iter); + Self { + vectors, + payload, + dims: ram.dims(), + } + } +} diff --git a/crates/service/src/prelude/storage/mod.rs b/crates/service/src/prelude/storage/mod.rs new file mode 100644 index 000000000..96c977a41 --- /dev/null +++ b/crates/service/src/prelude/storage/mod.rs @@ -0,0 +1,24 @@ +mod dense; +mod sparse; + +pub use dense::DenseMmap; +pub use sparse::SparseMmap; + +use crate::algorithms::raw::RawRam; +use crate::index::IndexOptions; +use crate::prelude::*; +use std::path::Path; + +pub trait Storage { + type VectorRef<'a>: Copy + 'a + where + Self: 'a; + + fn dims(&self) -> u16; + fn len(&self) -> u32; + fn vector(&self, i: u32) -> Self::VectorRef<'_>; + fn payload(&self, i: u32) -> Payload; + fn open(path: &Path, options: IndexOptions) -> Self; + fn save G = Self::VectorRef<'a>>>(path: &Path, ram: RawRam) + -> Self; +} diff --git a/crates/service/src/prelude/storage/sparse.rs b/crates/service/src/prelude/storage/sparse.rs new file mode 100644 index 000000000..d95e27276 --- /dev/null +++ b/crates/service/src/prelude/storage/sparse.rs @@ -0,0 +1,83 @@ +use crate::algorithms::raw::RawRam; +use crate::index::IndexOptions; +use crate::prelude::*; +use crate::utils::mmap_array::MmapArray; +use std::path::Path; + +pub struct SparseMmap { + indexes: MmapArray, + values: MmapArray, + offsets: MmapArray, + payload: MmapArray, + dims: u16, +} + +impl Storage for SparseMmap { + type VectorRef<'a> = SparseF32Ref<'a>; + + fn dims(&self) -> u16 { + self.dims + } + + fn len(&self) -> u32 { + self.payload.len() as u32 + } + + fn vector(&self, i: u32) -> SparseF32Ref<'_> { + let s = self.offsets[i as usize]; + let e = self.offsets[i as usize + 1]; + SparseF32Ref { + dims: self.dims, + indexes: &self.indexes[s..e], + values: &self.values[s..e], + } + } + + fn payload(&self, i: u32) -> Payload { + self.payload[i as usize] + } + + fn open(path: &Path, options: IndexOptions) -> Self + where + Self: Sized, + { + let indexes = MmapArray::open(&path.join("indexes")); + let values = MmapArray::open(&path.join("values")); + let offsets = MmapArray::open(&path.join("offsets")); + let payload = MmapArray::open(&path.join("payload")); + Self { + indexes, + values, + offsets, + payload, + dims: options.vector.dims, + } + } + + fn save G = Self::VectorRef<'a>>>( + path: &Path, + ram: RawRam, + ) -> Self { + let n = ram.len(); + let indexes_iter = (0..n).flat_map(|i| ram.vector(i).indexes.iter().copied()); + let values_iter = (0..n).flat_map(|i| ram.vector(i).values.iter().copied()); + let offsets_iter = std::iter::once(0) + .chain((0..n).map(|i| ram.vector(i).length() as usize)) + .scan(0, |state, x| { + *state += x; + Some(*state) + }); + let payload_iter = (0..n).map(|i| ram.payload(i)); + let indexes = MmapArray::create(&path.join("indexes"), indexes_iter); + let values = MmapArray::create(&path.join("values"), values_iter); + let offsets = MmapArray::create(&path.join("offsets"), offsets_iter); + let payload = MmapArray::create(&path.join("payload"), payload_iter); + Self { + indexes, + values, + offsets, + payload, + dims: ram.dims(), + } + } +} diff --git a/crates/service/src/utils/mmap_array.rs b/crates/service/src/utils/mmap_array.rs index 23e7bfd7b..fc2e861be 100644 --- a/crates/service/src/utils/mmap_array.rs +++ b/crates/service/src/utils/mmap_array.rs @@ -3,7 +3,7 @@ use std::fs::File; use std::io::{BufWriter, Read, Seek, Write}; use std::ops::Index; use std::ops::{Deref, Range, RangeInclusive}; -use std::path::PathBuf; +use std::path::Path; pub struct MmapArray { info: Information, @@ -15,7 +15,7 @@ impl MmapArray where T: Pod, { - pub fn create(path: PathBuf, iter: I) -> Self + pub fn create(path: &Path, iter: I) -> Self where I: Iterator, { @@ -43,7 +43,7 @@ where _mmap: mmap, } } - pub fn open(path: PathBuf) -> Self { + pub fn open(path: &Path) -> Self { let file = std::fs::OpenOptions::new().read(true).open(path).unwrap(); let info = read_information(&file); let mmap = unsafe { read_mmap(&file, info.len * std::mem::size_of::()) }; diff --git a/crates/service/src/utils/vec2.rs b/crates/service/src/utils/vec2.rs index b671b681e..a07619875 100644 --- a/crates/service/src/utils/vec2.rs +++ b/crates/service/src/utils/vec2.rs @@ -1,13 +1,13 @@ -use crate::prelude::*; +use bytemuck::Zeroable; use std::ops::{Deref, DerefMut, Index, IndexMut}; #[derive(Debug, Clone)] -pub struct Vec2 { +pub struct Vec2 { dims: u16, - v: Vec, + v: Vec, } -impl Vec2 { +impl Vec2 { pub fn new(dims: u16, n: usize) -> Self { Self { dims, @@ -37,29 +37,29 @@ impl Vec2 { } } -impl Index for Vec2 { - type Output = [S::Scalar]; +impl Index for Vec2 { + type Output = [T]; fn index(&self, index: usize) -> &Self::Output { &self.v[self.dims as usize * index..][..self.dims as usize] } } -impl IndexMut for Vec2 { +impl IndexMut for Vec2 { fn index_mut(&mut self, index: usize) -> &mut Self::Output { &mut self.v[self.dims as usize * index..][..self.dims as usize] } } -impl Deref for Vec2 { - type Target = [S::Scalar]; +impl Deref for Vec2 { + type Target = [T]; fn deref(&self) -> &Self::Target { self.v.deref() } } -impl DerefMut for Vec2 { +impl DerefMut for Vec2 { fn deref_mut(&mut self) -> &mut Self::Target { self.v.deref_mut() } diff --git a/crates/service/src/worker/mod.rs b/crates/service/src/worker/mod.rs index de9d5f0a0..3ccc77862 100644 --- a/crates/service/src/worker/mod.rs +++ b/crates/service/src/worker/mod.rs @@ -1,7 +1,7 @@ pub mod metadata; -use crate::index::IndexOptions; -use crate::instance::Instance; +use crate::index::{IndexOptions, IndexStat}; +use crate::instance::{Instance, InstanceView, InstanceViewOperations}; use crate::prelude::*; use crate::utils::clean::clean; use crate::utils::dir_ops::sync_dir; @@ -13,6 +13,25 @@ use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use std::sync::Arc; +pub trait WorkerOperations { + type InstanceView: InstanceViewOperations; + + fn create(&self, handle: Handle, options: IndexOptions) -> Result<(), CreateError>; + fn drop(&self, handle: Handle) -> Result<(), DropError>; + fn flush(&self, handle: Handle) -> Result<(), FlushError>; + fn insert( + &self, + handle: Handle, + vector: DynamicVector, + pointer: Pointer, + ) -> Result<(), InsertError>; + fn delete(&self, handle: Handle, pointer: Pointer) -> Result<(), DeleteError>; + fn basic_view(&self, handle: Handle) -> Result; + fn vbase_view(&self, handle: Handle) -> Result; + fn list_view(&self, handle: Handle) -> Result; + fn stat(&self, handle: Handle) -> Result; +} + pub struct Worker { path: PathBuf, protect: Mutex, @@ -65,11 +84,12 @@ impl Worker { pub fn view(&self) -> Arc { self.view.load_full() } - pub fn instance_create( - &self, - handle: Handle, - options: IndexOptions, - ) -> Result<(), ServiceError> { +} + +impl WorkerOperations for Worker { + type InstanceView = InstanceView; + + fn create(&self, handle: Handle, options: IndexOptions) -> Result<(), CreateError> { use std::collections::hash_map::Entry; let mut protect = self.protect.lock(); match protect.indexes.entry(handle) { @@ -80,15 +100,70 @@ impl Worker { protect.maintain(&self.view); Ok(()) } - Entry::Occupied(_) => Err(ServiceError::KnownIndex), + Entry::Occupied(_) => Err(CreateError::Exist), } } - pub fn instance_destroy(&self, handle: Handle) { + fn drop(&self, handle: Handle) -> Result<(), DropError> { let mut protect = self.protect.lock(); if protect.indexes.remove(&handle).is_some() { protect.maintain(&self.view); + Ok(()) + } else { + Err(DropError::NotExist) } } + fn flush(&self, handle: Handle) -> Result<(), FlushError> { + let view = self.view(); + let instance = view.get(handle).ok_or(FlushError::NotExist)?; + let view = instance.view().ok_or(FlushError::Upgrade)?; + view.flush()?; + Ok(()) + } + fn insert( + &self, + handle: Handle, + vector: DynamicVector, + pointer: Pointer, + ) -> Result<(), InsertError> { + let view = self.view(); + let instance = view.get(handle).ok_or(InsertError::NotExist)?; + loop { + let view = instance.view().ok_or(InsertError::Upgrade)?; + match view.insert(vector.clone(), pointer)? { + Ok(()) => break, + Err(_) => instance.refresh(), + } + } + Ok(()) + } + fn delete(&self, handle: Handle, pointer: Pointer) -> Result<(), DeleteError> { + let view = self.view(); + let instance = view.get(handle).ok_or(DeleteError::NotExist)?; + let view = instance.view().ok_or(DeleteError::Upgrade)?; + view.delete(pointer)?; + Ok(()) + } + fn basic_view(&self, handle: Handle) -> Result { + let view = self.view(); + let instance = view.get(handle).ok_or(BasicError::NotExist)?; + instance.view().ok_or(BasicError::Upgrade) + } + fn vbase_view(&self, handle: Handle) -> Result { + let view = self.view(); + let instance = view.get(handle).ok_or(VbaseError::NotExist)?; + instance.view().ok_or(VbaseError::Upgrade) + } + fn list_view(&self, handle: Handle) -> Result { + let view = self.view(); + let instance = view.get(handle).ok_or(ListError::NotExist)?; + instance.view().ok_or(ListError::Upgrade) + } + fn stat(&self, handle: Handle) -> Result { + let view = self.view(); + let instance = view.get(handle).ok_or(StatError::NotExist)?; + let stat = instance.stat().ok_or(StatError::Upgrade)?; + Ok(stat) + } } pub struct WorkerView { diff --git a/scripts/ci_install.sh b/scripts/ci_install.sh index 7eacbd56b..ed89363a2 100755 --- a/scripts/ci_install.sh +++ b/scripts/ci_install.sh @@ -10,6 +10,3 @@ if [ "$OS" == "ubuntu-latest" ]; then sudo systemctl restart postgresql pg_lsclusters fi -if [ "$OS" == "macos-latest" ]; then - brew services restart postgresql@$VERSION -fi diff --git a/scripts/ci_setup.sh b/scripts/ci_setup.sh index e6593be35..03d2d3196 100755 --- a/scripts/ci_setup.sh +++ b/scripts/ci_setup.sh @@ -25,21 +25,6 @@ if [ "$OS" == "ubuntu-latest" ]; then sudo -iu postgres createuser -s -r runner createdb fi -if [ "$OS" == "macos-latest" ]; then - brew uninstall postgresql - brew install postgresql@$VERSION - export PATH="$PATH:$(brew --prefix postgresql@$VERSION)/bin" - echo "$(brew --prefix postgresql@$VERSION)/bin" >> $GITHUB_PATH - brew services start postgresql@$VERSION - sleep 30 - createdb -fi sudo chmod -R 777 `pg_config --pkglibdir` sudo chmod -R 777 `pg_config --sharedir`/extension - -curl -L --proto '=https' --tlsv1.2 -sSf https://raw.githubusercontent.com/cargo-bins/cargo-binstall/main/install-from-binstall-release.sh | bash -cargo binstall sqllogictest-bin -y --force - -cargo install cargo-pgrx@$(grep 'pgrx = {' Cargo.toml | cut -d '"' -f 2 | head -n 1) --debug -cargo pgrx init --pg$VERSION=$(which pg_config) diff --git a/src/bgworker/mod.rs b/src/bgworker/mod.rs index 04c37678f..4e98a3908 100644 --- a/src/bgworker/mod.rs +++ b/src/bgworker/mod.rs @@ -1,26 +1,35 @@ pub mod normal; -pub mod upgrade; + +use std::sync::atomic::{AtomicBool, Ordering}; + +static STARTED: AtomicBool = AtomicBool::new(false); pub unsafe fn init() { - use pgrx::bgworkers::BackgroundWorkerBuilder; - use pgrx::bgworkers::BgWorkerStartTime; - use std::time::Duration; - BackgroundWorkerBuilder::new("vectors") - .set_library("vectors") - .set_function("_vectors_main") - .set_argument(None) - .enable_shmem_access(None) - .set_start_time(BgWorkerStartTime::PostmasterStart) - .set_restart_time(Some(Duration::from_secs(1))) - .load(); + use service::worker::Worker; + let path = std::path::Path::new("pg_vectors"); + if !path.try_exists().unwrap() || Worker::check(path.to_owned()) { + use pgrx::bgworkers::BackgroundWorkerBuilder; + use pgrx::bgworkers::BgWorkerStartTime; + use std::time::Duration; + BackgroundWorkerBuilder::new("vectors") + .set_library("vectors") + .set_function("_vectors_main") + .set_argument(None) + .enable_shmem_access(None) + .set_start_time(BgWorkerStartTime::PostmasterStart) + .set_restart_time(Some(Duration::from_secs(15))) + .load(); + STARTED.store(true, Ordering::Relaxed); + } } -#[no_mangle] -extern "C" fn _vectors_main(_arg: pgrx::pg_sys::Datum) { - let _ = std::panic::catch_unwind(main); +pub fn is_started() -> bool { + STARTED.load(Ordering::Relaxed) } -fn main() { +#[pgrx::pg_guard] +#[no_mangle] +extern "C" fn _vectors_main(_arg: pgrx::pg_sys::Datum) { pub struct AllocErrorPanicPayload { pub layout: std::alloc::Layout, } @@ -60,12 +69,8 @@ fn main() { use std::path::Path; let path = Path::new("pg_vectors"); if path.try_exists().unwrap() { - if Worker::check(path.to_owned()) { - let worker = Worker::open(path.to_owned()); - self::normal::normal(worker); - } else { - self::upgrade::upgrade(); - } + let worker = Worker::open(path.to_owned()); + self::normal::normal(worker); } else { let worker = Worker::create(path.to_owned()); self::normal::normal(worker); diff --git a/src/bgworker/normal.rs b/src/bgworker/normal.rs index cf87f8ba6..c96fbc71b 100644 --- a/src/bgworker/normal.rs +++ b/src/bgworker/normal.rs @@ -1,7 +1,5 @@ -use crate::ipc::server::RpcHandler; use crate::ipc::ConnectionError; -use service::index::OutdatedError; -use service::prelude::ServiceError; +use crate::ipc::ServerRpcHandler; use service::worker::Worker; use std::sync::Arc; @@ -59,167 +57,128 @@ pub fn normal(worker: Arc) { }); } -fn session(worker: Arc, handler: RpcHandler) -> Result { - use crate::ipc::server::RpcHandle; +fn session(worker: Arc, handler: ServerRpcHandler) -> Result { + use crate::ipc::ServerRpcHandle; + use service::instance::InstanceViewOperations; + use service::worker::WorkerOperations; let mut handler = handler; loop { match handler.handle()? { - // transaction - RpcHandle::Flush { handle, x } => { - let view = worker.view(); - if let Some(instance) = view.get(handle) { - if let Some(view) = instance.view() { - view.flush(); - } - } - handler = x.leave()?; + // control plane + ServerRpcHandle::Create { handle, options, x } => { + handler = x.leave(WorkerOperations::create(worker.as_ref(), handle, options))?; } - RpcHandle::Drop { handle, x } => { - worker.instance_destroy(handle); - handler = x.leave()?; + ServerRpcHandle::Drop { handle, x } => { + handler = x.leave(WorkerOperations::drop(worker.as_ref(), handle))?; } - RpcHandle::Create { handle, options, x } => { - match worker.instance_create(handle, options) { - Ok(()) => (), - Err(e) => x.reset(e)?, - }; - handler = x.leave()?; + // data plane + ServerRpcHandle::Flush { handle, x } => { + handler = x.leave(worker.flush(handle))?; } - // instance - RpcHandle::Insert { + ServerRpcHandle::Insert { handle, vector, pointer, x, } => { - let view = worker.view(); - let Some(instance) = view.get(handle) else { - x.reset(ServiceError::UnknownIndex)?; - }; - loop { - let instance_view = match instance.view() { - Some(x) => x, - None => x.reset(ServiceError::Upgrade2)?, - }; - match instance_view.insert(vector.clone(), pointer) { - Ok(Ok(())) => break, - Ok(Err(OutdatedError)) => instance.refresh(), - Err(e) => x.reset(e)?, - } - } - handler = x.leave()?; + handler = x.leave(worker.insert(handle, vector, pointer))?; } - RpcHandle::Delete { handle, pointer, x } => { - let view = worker.view(); - let Some(instance) = view.get(handle) else { - x.reset(ServiceError::UnknownIndex)?; - }; - let instance_view = match instance.view() { - Some(x) => x, - None => x.reset(ServiceError::Upgrade2)?, - }; - instance_view.delete(pointer); - handler = x.leave()?; + ServerRpcHandle::Delete { handle, pointer, x } => { + handler = x.leave(worker.delete(handle, pointer))?; } - RpcHandle::Stat { handle, x } => { - let view = worker.view(); - let Some(instance) = view.get(handle) else { - x.reset(ServiceError::UnknownIndex)?; - }; - let r = instance.stat(); - handler = x.leave(r)? + ServerRpcHandle::Stat { handle, x } => { + handler = x.leave(worker.stat(handle))?; } - RpcHandle::Basic { + ServerRpcHandle::Basic { handle, vector, opts, x, } => { - use crate::ipc::server::BasicHandle::*; - let view = worker.view(); - let Some(instance) = view.get(handle) else { - x.reset(ServiceError::UnknownIndex)?; - }; - let view = match instance.view() { - Some(x) => x, - None => x.reset(ServiceError::Upgrade2)?, - }; - let mut it = match view.basic(&vector, &opts, |_| true) { + let v = match worker.basic_view(handle) { Ok(x) => x, - Err(e) => x.reset(e)?, + Err(e) => { + handler = x.error_err(e)?; + continue; + } }; - let mut x = x.error()?; - loop { - match x.handle()? { - Next { x: y } => { - x = y.leave(it.next())?; - } - Leave { x } => { - handler = x; - break; + match v.basic(&vector, &opts, |_| true) { + Ok(mut iter) => { + use crate::ipc::ServerBasicHandle; + let mut x = x.error_ok()?; + loop { + match x.handle()? { + ServerBasicHandle::Next { x: y } => { + x = y.leave(iter.next())?; + } + ServerBasicHandle::Leave { x } => { + handler = x; + break; + } + } } } - } + Err(e) => handler = x.error_err(e)?, + }; } - RpcHandle::Vbase { + ServerRpcHandle::Vbase { handle, vector, opts, x, } => { - use crate::ipc::server::VbaseHandle::*; - let view = worker.view(); - let Some(instance) = view.get(handle) else { - x.reset(ServiceError::UnknownIndex)?; - }; - let view = match instance.view() { - Some(x) => x, - None => x.reset(ServiceError::Upgrade2)?, - }; - let mut it = match view.vbase(&vector, &opts, |_| true) { + let v = match worker.vbase_view(handle) { Ok(x) => x, - Err(e) => x.reset(e)?, + Err(e) => { + handler = x.error_err(e)?; + continue; + } }; - let mut x = x.error()?; - loop { - match x.handle()? { - Next { x: y } => { - x = y.leave(it.next())?; - } - Leave { x } => { - handler = x; - break; + match v.vbase(&vector, &opts, |_| true) { + Ok(mut iter) => { + use crate::ipc::ServerVbaseHandle; + let mut x = x.error_ok()?; + loop { + match x.handle()? { + ServerVbaseHandle::Next { x: y } => { + x = y.leave(iter.next())?; + } + ServerVbaseHandle::Leave { x } => { + handler = x; + break; + } + } } } - } - } - RpcHandle::List { handle, x } => { - use crate::ipc::server::ListHandle::*; - let view = worker.view(); - let Some(instance) = view.get(handle) else { - x.reset(ServiceError::UnknownIndex)?; + Err(e) => handler = x.error_err(e)?, }; - let view = match instance.view() { - Some(x) => x, - None => x.reset(ServiceError::Upgrade2)?, + } + ServerRpcHandle::List { handle, x } => { + let v = match worker.list_view(handle) { + Ok(x) => x, + Err(e) => { + handler = x.error_err(e)?; + continue; + } }; - let mut it = view.list(); - let mut x = x.error()?; - loop { - match x.handle()? { - Next { x: y } => { - x = y.leave(it.next())?; - } - Leave { x } => { - handler = x; - break; + match v.list() { + Ok(mut iter) => { + use crate::ipc::ServerListHandle; + let mut x = x.error_ok()?; + loop { + match x.handle()? { + ServerListHandle::Next { x: y } => { + x = y.leave(iter.next())?; + } + ServerListHandle::Leave { x } => { + handler = x; + break; + } + } } } - } - } - // admin - RpcHandle::Upgrade { x } => { - handler = x.leave()?; + Err(e) => handler = x.error_err(e)?, + }; } } } diff --git a/src/bgworker/upgrade.rs b/src/bgworker/upgrade.rs deleted file mode 100644 index 7accc5fd9..000000000 --- a/src/bgworker/upgrade.rs +++ /dev/null @@ -1,78 +0,0 @@ -use crate::ipc::server::RpcHandler; -use crate::ipc::ConnectionError; -use service::prelude::*; - -pub fn upgrade() { - std::thread::scope(|scope| { - scope.spawn({ - move || { - for rpc_handler in crate::ipc::listen_unix() { - std::thread::spawn({ - move || { - log::trace!("Session established."); - let _ = session(rpc_handler); - log::trace!("Session closed."); - } - }); - } - } - }); - scope.spawn({ - move || { - for rpc_handler in crate::ipc::listen_mmap() { - std::thread::spawn({ - move || { - log::trace!("Session established."); - let _ = session(rpc_handler); - log::trace!("Session closed."); - } - }); - } - } - }); - loop { - let mut sig: i32 = 0; - unsafe { - let mut set: libc::sigset_t = std::mem::zeroed(); - libc::sigemptyset(&mut set); - libc::sigaddset(&mut set, libc::SIGHUP); - libc::sigaddset(&mut set, libc::SIGTERM); - libc::sigwait(&set, &mut sig); - } - match sig { - libc::SIGHUP => { - std::process::exit(0); - } - libc::SIGTERM => { - std::process::exit(0); - } - _ => (), - } - } - }); -} - -fn session(handler: RpcHandler) -> Result<(), ConnectionError> { - use crate::ipc::server::RpcHandle; - let mut handler = handler; - loop { - match handler.handle()? { - RpcHandle::Drop { x, .. } => { - // false drop - handler = x.leave()?; - } - RpcHandle::Flush { x, .. } => x.reset(ServiceError::Upgrade)?, - RpcHandle::Create { x, .. } => x.reset(ServiceError::Upgrade)?, - RpcHandle::Insert { x, .. } => x.reset(ServiceError::Upgrade)?, - RpcHandle::Delete { x, .. } => x.reset(ServiceError::Upgrade)?, - RpcHandle::Stat { x, .. } => x.reset(ServiceError::Upgrade)?, - RpcHandle::Basic { x, .. } => x.reset(ServiceError::Upgrade)?, - RpcHandle::Vbase { x, .. } => x.reset(ServiceError::Upgrade)?, - RpcHandle::List { x, .. } => x.reset(ServiceError::Upgrade)?, - RpcHandle::Upgrade { x } => { - let _ = std::fs::remove_dir_all("./pg_vectors"); - handler = x.leave()?; - } - } - } -} diff --git a/src/datatype/casts_f32.rs b/src/datatype/casts_f32.rs index 18189e013..123434ca7 100644 --- a/src/datatype/casts_f32.rs +++ b/src/datatype/casts_f32.rs @@ -1,7 +1,8 @@ -use crate::datatype::vecf16::{Vecf16, Vecf16Output}; +use crate::datatype::svecf32::{SVecf32, SVecf32Input, SVecf32Output}; +use crate::datatype::vecf16::{Vecf16, Vecf16Input, Vecf16Output}; use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output}; -use crate::prelude::{FriendlyError, SessionError}; -use half::f16; +use crate::prelude::check_value_dimensions; +use base::scalar::FloatCast; use service::prelude::*; #[pgrx::pg_extern(immutable, parallel_safe, strict)] @@ -10,9 +11,7 @@ fn _vectors_cast_array_to_vecf32( _typmod: i32, _explicit: bool, ) -> Vecf32Output { - if array.is_empty() || array.len() > 65535 { - SessionError::BadValueDimensions.friendly(); - } + check_value_dimensions(array.len()); let mut data = vec![F32::zero(); array.len()]; for (i, x) in array.iter().enumerate() { data[i] = F32(x.unwrap_or(f32::NAN)); @@ -35,13 +34,53 @@ fn _vectors_cast_vecf32_to_vecf16( _typmod: i32, _explicit: bool, ) -> Vecf16Output { - let data: Vec = vector + let data: Vec = vector.data().iter().map(|&x| F16::from_f(x)).collect(); + + Vecf16::new_in_postgres(&data) +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_cast_vecf16_to_vecf32( + vector: Vecf16Input<'_>, + _typmod: i32, + _explicit: bool, +) -> Vecf32Output { + let data: Vec = vector.data().iter().map(|&x| x.to_f()).collect(); + + Vecf32::new_in_postgres(&data) +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_cast_vecf32_to_svecf32( + vector: Vecf32Input<'_>, + _typmod: i32, + _explicit: bool, +) -> SVecf32Output { + let mut indexes = Vec::new(); + let mut values = Vec::new(); + vector .data() .iter() - .map(|x| x.to_f32()) - .map(f16::from_f32) - .map(F16::from) - .collect(); + .enumerate() + .filter(|(_, x)| !x.is_zero()) + .for_each(|(i, &x)| { + indexes.push(i as u16); + values.push(x); + }); - Vecf16::new_in_postgres(&data) + SVecf32::new_in_postgres(SparseF32Ref { + dims: vector.len() as u16, + indexes: &indexes, + values: &values, + }) +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_cast_svecf32_to_vecf32( + vector: SVecf32Input<'_>, + _typmod: i32, + _explicit: bool, +) -> Vecf32Output { + let data = vector.data().to_dense(); + Vecf32::new_in_postgres(&data) } diff --git a/src/datatype/mod.rs b/src/datatype/mod.rs index 1b0ef7a78..0def2db1f 100644 --- a/src/datatype/mod.rs +++ b/src/datatype/mod.rs @@ -1,6 +1,8 @@ pub mod casts_f32; -pub mod operators_f16; -pub mod operators_f32; +pub mod operators_svecf32; +pub mod operators_vecf16; +pub mod operators_vecf32; +pub mod svecf32; pub mod typmod; pub mod vecf16; pub mod vecf32; diff --git a/src/datatype/operators_svecf32.rs b/src/datatype/operators_svecf32.rs new file mode 100644 index 000000000..03f3616f7 --- /dev/null +++ b/src/datatype/operators_svecf32.rs @@ -0,0 +1,149 @@ +use crate::datatype::svecf32::{SVecf32, SVecf32Input, SVecf32Output}; +use crate::prelude::*; +use base::scalar::FloatCast; +use service::prelude::*; +use std::ops::Deref; + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_add(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> SVecf32Output { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + + let size1 = lhs.len(); + let size2 = rhs.len(); + let mut pos1 = 0; + let mut pos2 = 0; + let mut pos = 0; + let mut indexes = vec![0u16; size1 + size2]; + let mut values = vec![F32::zero(); size1 + size2]; + let lhs = lhs.data(); + let rhs = rhs.data(); + while pos1 < size1 && pos2 < size2 { + let lhs_index = lhs.indexes[pos1]; + let rhs_index = rhs.indexes[pos2]; + let lhs_value = lhs.values[pos1]; + let rhs_value = rhs.values[pos2]; + indexes[pos] = lhs_index.min(rhs_index); + values[pos] = F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value + + F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value; + pos1 += (lhs_index <= rhs_index) as usize; + pos2 += (lhs_index >= rhs_index) as usize; + pos += (!values[pos].is_zero()) as usize; + } + for i in pos1..size1 { + indexes[pos] = lhs.indexes[i]; + values[pos] = lhs.values[i]; + pos += 1; + } + for i in pos2..size2 { + indexes[pos] = rhs.indexes[i]; + values[pos] = rhs.values[i]; + pos += 1; + } + indexes.truncate(pos); + values.truncate(pos); + + SVecf32::new_in_postgres(SparseF32Ref { + dims: lhs.dims(), + indexes: &indexes, + values: &values, + }) +} + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_minus(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> SVecf32Output { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + + let size1 = lhs.len(); + let size2 = rhs.len(); + let mut pos1 = 0; + let mut pos2 = 0; + let mut pos = 0; + let mut indexes = vec![0u16; size1 + size2]; + let mut values = vec![F32::zero(); size1 + size2]; + let lhs = lhs.data(); + let rhs = rhs.data(); + while pos1 < size1 && pos2 < size2 { + let lhs_index = lhs.indexes[pos1]; + let rhs_index = rhs.indexes[pos2]; + let lhs_value = lhs.values[pos1]; + let rhs_value = rhs.values[pos2]; + indexes[pos] = lhs_index.min(rhs_index); + values[pos] = F32((lhs_index <= rhs_index) as u32 as f32) * lhs_value + - F32((lhs_index >= rhs_index) as u32 as f32) * rhs_value; + pos1 += (lhs_index <= rhs_index) as usize; + pos2 += (lhs_index >= rhs_index) as usize; + pos += (!values[pos].is_zero()) as usize; + } + for i in pos1..size1 { + indexes[pos] = lhs.indexes[i]; + values[pos] = lhs.values[i]; + pos += 1; + } + for i in pos2..size2 { + indexes[pos] = rhs.indexes[i]; + values[pos] = -rhs.values[i]; + pos += 1; + } + indexes.truncate(pos); + values.truncate(pos); + + SVecf32::new_in_postgres(SparseF32Ref { + dims: lhs.dims(), + indexes: &indexes, + values: &values, + }) +} + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_lt(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + lhs.deref() < rhs.deref() +} + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_lte(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + lhs.deref() <= rhs.deref() +} + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_gt(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + lhs.deref() > rhs.deref() +} + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_gte(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + lhs.deref() >= rhs.deref() +} + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_eq(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + lhs.deref() == rhs.deref() +} + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_neq(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> bool { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + lhs.deref() != rhs.deref() +} + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_cosine(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> f32 { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + SparseF32Cos::distance(lhs.data(), rhs.data()).to_f32() +} + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_dot(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> f32 { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + SparseF32Dot::distance(lhs.data(), rhs.data()).to_f32() +} + +#[pgrx::pg_extern(immutable, parallel_safe)] +fn _vectors_svecf32_operator_l2(lhs: SVecf32Input<'_>, rhs: SVecf32Input<'_>) -> f32 { + check_matched_dimensions(lhs.dims() as _, rhs.dims() as _); + SparseF32L2::distance(lhs.data(), rhs.data()).to_f32() +} diff --git a/src/datatype/operators_f16.rs b/src/datatype/operators_vecf16.rs similarity index 50% rename from src/datatype/operators_f16.rs rename to src/datatype/operators_vecf16.rs index 1f3e65a24..191c0e364 100644 --- a/src/datatype/operators_f16.rs +++ b/src/datatype/operators_vecf16.rs @@ -1,18 +1,12 @@ use crate::datatype::vecf16::{Vecf16, Vecf16Input, Vecf16Output}; use crate::prelude::*; +use base::scalar::FloatCast; use service::prelude::*; use std::ops::Deref; #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_add(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); + let n = check_matched_dimensions(lhs.len(), rhs.len()); let mut v = vec![F16::zero(); n]; for i in 0..n { v[i] = lhs[i] + rhs[i]; @@ -22,14 +16,7 @@ fn _vectors_vecf16_operator_add(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> V #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_minus(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> Vecf16Output { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); + let n = check_matched_dimensions(lhs.len(), rhs.len()); let mut v = vec![F16::zero(); n]; for i in 0..n { v[i] = lhs[i] - rhs[i]; @@ -39,108 +26,54 @@ fn _vectors_vecf16_operator_minus(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_lt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() < rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_lte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() <= rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_gt(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() > rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_gte(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() >= rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_eq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() == rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_neq(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() != rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_cosine(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); F16Cos::distance(&lhs, &rhs).to_f32() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_dot(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); F16Dot::distance(&lhs, &rhs).to_f32() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf16_operator_l2(lhs: Vecf16Input<'_>, rhs: Vecf16Input<'_>) -> f32 { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); F16L2::distance(&lhs, &rhs).to_f32() } diff --git a/src/datatype/operators_f32.rs b/src/datatype/operators_vecf32.rs similarity index 50% rename from src/datatype/operators_f32.rs rename to src/datatype/operators_vecf32.rs index 098e92f9e..50649f8d8 100644 --- a/src/datatype/operators_f32.rs +++ b/src/datatype/operators_vecf32.rs @@ -1,18 +1,12 @@ use crate::datatype::vecf32::{Vecf32, Vecf32Input, Vecf32Output}; use crate::prelude::*; +use base::scalar::FloatCast; use service::prelude::*; use std::ops::Deref; #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_add(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); + let n = check_matched_dimensions(lhs.len(), rhs.len()); let mut v = vec![F32::zero(); n]; for i in 0..n { v[i] = lhs[i] + rhs[i]; @@ -22,14 +16,7 @@ fn _vectors_vecf32_operator_add(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> V #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_minus(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> Vecf32Output { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } - let n = lhs.len(); + let n = check_matched_dimensions(lhs.len(), rhs.len()); let mut v = vec![F32::zero(); n]; for i in 0..n { v[i] = lhs[i] - rhs[i]; @@ -39,108 +26,54 @@ fn _vectors_vecf32_operator_minus(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_lt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() < rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_lte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() <= rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_gt(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() > rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_gte(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() >= rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_eq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() == rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_neq(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> bool { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); lhs.deref() != rhs.deref() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_cosine(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); F32Cos::distance(&lhs, &rhs).to_f32() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_dot(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); F32Dot::distance(&lhs, &rhs).to_f32() } #[pgrx::pg_extern(immutable, parallel_safe)] fn _vectors_vecf32_operator_l2(lhs: Vecf32Input<'_>, rhs: Vecf32Input<'_>) -> f32 { - if lhs.len() != rhs.len() { - SessionError::Unmatched { - left_dimensions: lhs.len() as _, - right_dimensions: rhs.len() as _, - } - .friendly(); - } + check_matched_dimensions(lhs.len(), rhs.len()); F32L2::distance(&lhs, &rhs).to_f32() } diff --git a/src/datatype/svecf32.rs b/src/datatype/svecf32.rs new file mode 100644 index 000000000..2f585d10a --- /dev/null +++ b/src/datatype/svecf32.rs @@ -0,0 +1,673 @@ +use crate::prelude::*; +use pgrx::pg_sys::Datum; +use pgrx::pg_sys::Oid; +use pgrx::pgrx_sql_entity_graph::metadata::ArgumentError; +use pgrx::pgrx_sql_entity_graph::metadata::Returns; +use pgrx::pgrx_sql_entity_graph::metadata::ReturnsError; +use pgrx::pgrx_sql_entity_graph::metadata::SqlMapping; +use pgrx::pgrx_sql_entity_graph::metadata::SqlTranslatable; +use pgrx::FromDatum; +use pgrx::IntoDatum; +use service::prelude::*; +use std::alloc::Layout; +use std::cmp::Ordering; +use std::ffi::CStr; +use std::ffi::CString; +use std::ops::Deref; +use std::ops::DerefMut; +use std::ptr::NonNull; + +#[repr(C, align(8))] +pub struct SVecf32 { + varlena: u32, + dims: u16, + kind: u8, + reserved: u8, + len: u16, + padding: [u8; 6], + phantom: [u8; 0], +} + +impl SVecf32 { + fn varlena(size: usize) -> u32 { + (size << 2) as u32 + } + fn layout(len: usize) -> Layout { + u16::try_from(len).expect("Vector is too large."); + let layout = Layout::new::(); + let layout1 = Layout::array::(len).unwrap(); + let layout2 = Layout::array::(len).unwrap(); + let layout = layout.extend(layout1).unwrap().0.pad_to_align(); + layout.extend(layout2).unwrap().0.pad_to_align() + } + pub fn new_in_postgres(vector: SparseF32Ref<'_>) -> SVecf32Output { + unsafe { + let layout = SVecf32::layout(vector.length() as usize); + let ptr = pgrx::pg_sys::palloc(layout.size()) as *mut SVecf32; + ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); + std::ptr::addr_of_mut!((*ptr).varlena).write(SVecf32::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).dims).write(vector.dims); + std::ptr::addr_of_mut!((*ptr).kind).write(2); + std::ptr::addr_of_mut!((*ptr).reserved).write(0); + std::ptr::addr_of_mut!((*ptr).len).write(vector.length()); + std::ptr::addr_of_mut!((*ptr).padding).write(std::mem::zeroed()); + let mut data_ptr = (*ptr).phantom.as_mut_ptr().cast::(); + std::ptr::copy_nonoverlapping( + vector.indexes.as_ptr(), + data_ptr, + vector.length() as usize, + ); + data_ptr = data_ptr.add(vector.length() as usize); + let offset = data_ptr.align_offset(8); + std::ptr::write_bytes(data_ptr, 0, offset); + data_ptr = data_ptr.add(offset); + std::ptr::copy_nonoverlapping( + vector.values.as_ptr(), + data_ptr.cast(), + vector.length() as usize, + ); + SVecf32Output(NonNull::new(ptr).unwrap()) + } + } + pub fn new_zeroed_in_postgres(len: usize) -> SVecf32Output { + unsafe { + let layout = SVecf32::layout(len); + let ptr = pgrx::pg_sys::palloc0(layout.size()) as *mut SVecf32; + ptr.cast::().add(layout.size() - 8).write_bytes(0, 8); + std::ptr::addr_of_mut!((*ptr).varlena).write(SVecf32::varlena(layout.size())); + std::ptr::addr_of_mut!((*ptr).kind).write(2); + std::ptr::addr_of_mut!((*ptr).reserved).write(0); + std::ptr::addr_of_mut!((*ptr).len).write(len as u16); + SVecf32Output(NonNull::new(ptr).unwrap()) + } + } + pub fn dims(&self) -> u16 { + self.dims + } + pub fn len(&self) -> usize { + self.len as usize + } + fn indexes(&self) -> &[u16] { + let ptr = self.phantom.as_ptr().cast(); + unsafe { std::slice::from_raw_parts(ptr, self.len as usize) } + } + fn values(&self) -> &[F32] { + let len = self.len as usize; + unsafe { + let ptr = self.phantom.as_ptr().cast::().add(len); + let offset = ptr.align_offset(8); + let ptr = ptr.add(offset).cast(); + std::slice::from_raw_parts(ptr, len) + } + } + fn indexes_mut(&mut self) -> &mut [u16] { + let ptr = self.phantom.as_mut_ptr().cast(); + unsafe { std::slice::from_raw_parts_mut(ptr, self.len as usize) } + } + fn values_mut(&mut self) -> &mut [F32] { + let len = self.len as usize; + unsafe { + let ptr = self.phantom.as_mut_ptr().cast::().add(len); + let offset = ptr.align_offset(8); + let ptr = ptr.add(offset).cast(); + std::slice::from_raw_parts_mut(ptr, len) + } + } + pub fn data(&self) -> SparseF32Ref<'_> { + debug_assert_eq!(self.varlena & 3, 0); + debug_assert_eq!(self.kind, 2); + SparseF32Ref { + dims: self.dims, + indexes: self.indexes(), + values: self.values(), + } + } +} + +impl PartialEq for SVecf32 { + fn eq(&self, other: &Self) -> bool { + self.data() == other.data() + } +} + +impl Eq for SVecf32 {} + +impl PartialOrd for SVecf32 { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SVecf32 { + fn cmp(&self, other: &Self) -> Ordering { + assert!(self.dims() == other.dims()); + let lhs = self.data(); + let rhs = other.data(); + let mut pos = 0; + let size1 = lhs.length() as usize; + let size2 = rhs.length() as usize; + while pos < size1 && pos < size2 { + let lhs_index = lhs.indexes[pos]; + let rhs_index = rhs.indexes[pos]; + let lhs_value = lhs.values[pos]; + let rhs_value = rhs.values[pos]; + match lhs_index.cmp(&rhs_index) { + Ordering::Less => return lhs_value.cmp(&F32::zero()), + Ordering::Greater => return F32::zero().cmp(&rhs_value), + Ordering::Equal => match lhs_value.cmp(&rhs_value) { + Ordering::Equal => {} + x => return x, + }, + } + pos += 1; + } + match size1.cmp(&size2) { + Ordering::Less => F32::zero().cmp(&rhs.values[pos]), + Ordering::Greater => lhs.values[pos].cmp(&F32::zero()), + Ordering::Equal => Ordering::Equal, + } + } +} + +pub enum SVecf32Input<'a> { + Owned(SVecf32Output), + Borrowed(&'a SVecf32), +} + +impl<'a> SVecf32Input<'a> { + pub unsafe fn new(p: NonNull) -> Self { + let q = unsafe { + NonNull::new(pgrx::pg_sys::pg_detoast_datum(p.cast().as_ptr()).cast()).unwrap() + }; + if p != q { + SVecf32Input::Owned(SVecf32Output(q)) + } else { + unsafe { SVecf32Input::Borrowed(p.as_ref()) } + } + } +} + +impl Deref for SVecf32Input<'_> { + type Target = SVecf32; + + fn deref(&self) -> &Self::Target { + match self { + SVecf32Input::Owned(x) => x, + SVecf32Input::Borrowed(x) => x, + } + } +} + +pub struct SVecf32Output(NonNull); + +impl SVecf32Output { + pub fn into_raw(self) -> *mut SVecf32 { + let result = self.0.as_ptr(); + std::mem::forget(self); + result + } +} + +impl Deref for SVecf32Output { + type Target = SVecf32; + + fn deref(&self) -> &Self::Target { + unsafe { self.0.as_ref() } + } +} + +impl DerefMut for SVecf32Output { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { self.0.as_mut() } + } +} + +impl Drop for SVecf32Output { + fn drop(&mut self) { + unsafe { + pgrx::pg_sys::pfree(self.0.as_ptr() as _); + } + } +} + +impl<'a> FromDatum for SVecf32Input<'a> { + unsafe fn from_polymorphic_datum(datum: Datum, is_null: bool, _typoid: Oid) -> Option { + if is_null { + None + } else { + let ptr = NonNull::new(datum.cast_mut_ptr::()).unwrap(); + unsafe { Some(SVecf32Input::new(ptr)) } + } + } +} + +impl IntoDatum for SVecf32Output { + fn into_datum(self) -> Option { + Some(Datum::from(self.into_raw() as *mut ())) + } + + fn type_oid() -> Oid { + pgrx::wrappers::regtypein("vectors.svector") + } +} + +unsafe impl SqlTranslatable for SVecf32Input<'_> { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("svector"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("svector")))) + } +} + +unsafe impl SqlTranslatable for SVecf32Output { + fn argument_sql() -> Result { + Ok(SqlMapping::As(String::from("svector"))) + } + fn return_sql() -> Result { + Ok(Returns::One(SqlMapping::As(String::from("svector")))) + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_svecf32_in(input: &CStr, _oid: Oid, _typmod: i32) -> SVecf32Output { + fn solve(option: Option, hint: &str) -> T { + if let Some(x) = option { + x + } else { + bad_literal(hint); + } + } + #[derive(Debug, Clone, Copy, PartialEq, Eq)] + enum State { + MatchingLeft, + Reading, + MatchedRight, + } + use State::*; + let input = input.to_bytes(); + let mut indexes = Vec::::new(); + let mut values = Vec::::new(); + let mut state = MatchingLeft; + let mut token: Option = None; + let mut index = 0; + for &c in input { + match (state, c) { + (MatchingLeft, b'[') => { + state = Reading; + } + (Reading, b'0'..=b'9' | b'a'..=b'z' | b'A'..=b'Z' | b'.' | b'+' | b'-') => { + let token = token.get_or_insert(String::new()); + token.push(char::from_u32(c as u32).unwrap()); + } + (Reading, b',') => { + let token = solve(token.take(), "Expect a number."); + let value: F32 = solve(token.parse().ok(), "Bad number."); + if !value.is_zero() { + indexes.push(index); + values.push(value); + } + index = match index.checked_add(1) { + Some(x) => x, + None => check_value_dimensions(65536).get(), + }; + } + (Reading, b']') => { + if let Some(token) = token.take() { + let value: F32 = solve(token.parse().ok(), "Bad number."); + if !value.is_zero() { + indexes.push(index); + values.push(value); + } + index = match index.checked_add(1) { + Some(x) => x, + None => check_value_dimensions(65536).get(), + }; + } + state = MatchedRight; + } + (_, b' ') => {} + _ => { + bad_literal(&format!("Bad character with ascii {:#x}.", c)); + } + } + } + if state != MatchedRight { + bad_literal("Bad sequence"); + } + SVecf32::new_in_postgres(SparseF32Ref { + dims: check_value_dimensions(index as usize).get(), + indexes: &indexes, + values: &values, + }) +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_svecf32_out(vector: SVecf32Input<'_>) -> CString { + let mut buffer = String::new(); + buffer.push('['); + let vec = vector.data().to_dense(); + let mut iter = vec.iter(); + if let Some(x) = iter.next() { + buffer.push_str(format!("{}", x).as_str()); + } + for x in iter { + buffer.push_str(format!(", {}", x).as_str()); + } + buffer.push(']'); + CString::new(buffer).unwrap() +} + +#[pgrx::pg_extern(sql = "\ +CREATE FUNCTION _vectors_svecf32_subscript(internal) RETURNS internal +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_svecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { + #[pgrx::pg_guard] + unsafe extern "C" fn transform( + subscript: *mut pgrx::pg_sys::SubscriptingRef, + indirection: *mut pgrx::pg_sys::List, + pstate: *mut pgrx::pg_sys::ParseState, + is_slice: bool, + is_assignment: bool, + ) { + unsafe { + if (*indirection).length != 1 { + pgrx::pg_sys::error!("type svector does only support one subscript"); + } + if !is_slice { + pgrx::pg_sys::error!("type svector does only support slice fetch"); + } + if is_assignment { + pgrx::pg_sys::error!("type svector does not support subscripted assignment"); + } + let subscript = &mut *subscript; + let ai = (*(*indirection).elements.add(0)).ptr_value as *mut pgrx::pg_sys::A_Indices; + subscript.refupperindexpr = pgrx::pg_sys::lappend( + std::ptr::null_mut(), + if !(*ai).uidx.is_null() { + let subexpr = + pgrx::pg_sys::transformExpr(pstate, (*ai).uidx, (*pstate).p_expr_kind); + let subexpr = pgrx::pg_sys::coerce_to_target_type( + pstate, + subexpr, + pgrx::pg_sys::exprType(subexpr), + pgrx::pg_sys::INT4OID, + -1, + pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, + pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, + -1, + ); + if subexpr.is_null() { + pgrx::error!("svector subscript must have type integer"); + } + subexpr.cast() + } else { + std::ptr::null_mut() + }, + ); + subscript.reflowerindexpr = pgrx::pg_sys::lappend( + std::ptr::null_mut(), + if !(*ai).lidx.is_null() { + let subexpr = + pgrx::pg_sys::transformExpr(pstate, (*ai).lidx, (*pstate).p_expr_kind); + let subexpr = pgrx::pg_sys::coerce_to_target_type( + pstate, + subexpr, + pgrx::pg_sys::exprType(subexpr), + pgrx::pg_sys::INT4OID, + -1, + pgrx::pg_sys::CoercionContext_COERCION_ASSIGNMENT, + pgrx::pg_sys::CoercionForm_COERCE_IMPLICIT_CAST, + -1, + ); + if subexpr.is_null() { + pgrx::error!("svector subscript must have type integer"); + } + subexpr.cast() + } else { + std::ptr::null_mut() + }, + ); + subscript.refrestype = subscript.refcontainertype; + } + } + #[pgrx::pg_guard] + unsafe extern "C" fn exec_setup( + _subscript: *const pgrx::pg_sys::SubscriptingRef, + state: *mut pgrx::pg_sys::SubscriptingRefState, + steps: *mut pgrx::pg_sys::SubscriptExecSteps, + ) { + #[derive(Default)] + struct Workspace { + range: Option<(Option, Option)>, + } + #[pgrx::pg_guard] + unsafe extern "C" fn sbs_check_subscripts( + _state: *mut pgrx::pg_sys::ExprState, + op: *mut pgrx::pg_sys::ExprEvalStep, + _econtext: *mut pgrx::pg_sys::ExprContext, + ) -> bool { + unsafe { + let state = &mut *(*op).d.sbsref.state; + let workspace = &mut *(state.workspace as *mut Workspace); + workspace.range = None; + let mut end = None; + let mut start = None; + if state.upperprovided.read() { + if !state.upperindexnull.read() { + let upper = state.upperindex.read().value() as i32; + if upper >= 0 { + end = Some(upper as usize); + } else { + (*op).resnull.write(true); + return false; + } + } else { + (*op).resnull.write(true); + return false; + } + } + if state.lowerprovided.read() { + if !state.lowerindexnull.read() { + let lower = state.lowerindex.read().value() as i32; + if lower >= 0 { + start = Some(lower as usize); + } else { + (*op).resnull.write(true); + return false; + } + } else { + (*op).resnull.write(true); + return false; + } + } + workspace.range = Some((start, end)); + true + } + } + #[pgrx::pg_guard] + unsafe extern "C" fn sbs_fetch( + _state: *mut pgrx::pg_sys::ExprState, + op: *mut pgrx::pg_sys::ExprEvalStep, + _econtext: *mut pgrx::pg_sys::ExprContext, + ) { + unsafe { + let state = &mut *(*op).d.sbsref.state; + let workspace = &mut *(state.workspace as *mut Workspace); + let input = + SVecf32Input::from_datum((*op).resvalue.read(), (*op).resnull.read()).unwrap(); + let Some((start, end)) = workspace.range else { + (*op).resnull.write(true); + return; + }; + let start: u16 = match start.unwrap_or(0).try_into() { + Ok(x) => x, + Err(_) => { + (*op).resnull.write(true); + return; + } + }; + let end: u16 = match end.unwrap_or(input.dims() as usize).try_into() { + Ok(x) => x, + Err(_) => { + (*op).resnull.write(true); + return; + } + }; + if start >= end || end > input.dims() { + (*op).resnull.write(true); + return; + } + let data = input.data(); + let start_index = data.indexes.partition_point(|&x| x < start); + let end_index = data.indexes.partition_point(|&x| x < end); + let mut indexes = data.indexes[start_index..end_index].to_vec(); + indexes.iter_mut().for_each(|x| *x -= start); + let output = SVecf32::new_in_postgres(SparseF32Ref { + dims: end - start, + indexes: &indexes, + values: &data.values[start_index..end_index], + }); + (*op).resnull.write(false); + (*op).resvalue.write(Datum::from(output.into_raw())); + } + } + unsafe { + let state = &mut *state; + let steps = &mut *steps; + assert!(state.numlower == 1); + assert!(state.numupper == 1); + state.workspace = pgrx::pg_sys::palloc(std::mem::size_of::()); + std::ptr::write::(state.workspace.cast(), Workspace::default()); + steps.sbs_check_subscripts = Some(sbs_check_subscripts); + steps.sbs_fetch = Some(sbs_fetch); + steps.sbs_assign = None; + steps.sbs_fetch_old = None; + } + } + static SBSROUTINES: pgrx::pg_sys::SubscriptRoutines = pgrx::pg_sys::SubscriptRoutines { + transform: Some(transform), + exec_setup: Some(exec_setup), + fetch_strict: true, + fetch_leakproof: false, + store_leakproof: false, + }; + std::ptr::addr_of!(SBSROUTINES).into() +} + +#[pgrx::pg_extern(sql = "\ +CREATE FUNCTION _vectors_svecf32_send(svector) RETURNS bytea +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_svecf32_send(vector: SVecf32Input<'_>) -> Datum { + use pgrx::pg_sys::StringInfoData; + unsafe { + let mut buf = StringInfoData::default(); + let dims = vector.dims; + let len = vector.len; + let data = vector.data(); + pgrx::pg_sys::pq_begintypsend(&mut buf); + pgrx::pg_sys::pq_sendbytes(&mut buf, (&dims) as *const u16 as _, 2); + pgrx::pg_sys::pq_sendbytes(&mut buf, (&len) as *const u16 as _, 2); + pgrx::pg_sys::pq_sendbytes( + &mut buf, + data.indexes.as_ptr() as _, + (std::mem::size_of::() * len as usize) as _, + ); + pgrx::pg_sys::pq_sendbytes( + &mut buf, + data.values.as_ptr() as _, + (std::mem::size_of::() * len as usize) as _, + ); + Datum::from(pgrx::pg_sys::pq_endtypsend(&mut buf)) + } +} + +#[pgrx::pg_extern(sql = " +CREATE FUNCTION _vectors_svecf32_recv(internal, oid, integer) RETURNS svector +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] +fn _vectors_svecf32_recv(internal: pgrx::Internal, _oid: Oid, _typmod: i32) -> SVecf32Output { + use pgrx::pg_sys::StringInfo; + unsafe { + let buf: StringInfo = internal.into_datum().unwrap().cast_mut_ptr(); + let dims = (pgrx::pg_sys::pq_getmsgbytes(buf, 2) as *const u16).read_unaligned(); + let len = (pgrx::pg_sys::pq_getmsgbytes(buf, 2) as *const u16).read_unaligned(); + if dims == 0 || len == 0 { + pgrx::error!("data corruption is detected"); + } + let indexes_bytes = std::mem::size_of::() * len as usize; + let indexes_ptr = pgrx::pg_sys::pq_getmsgbytes(buf, indexes_bytes as _); + let values_bytes = std::mem::size_of::() * len as usize; + let values_ptr = pgrx::pg_sys::pq_getmsgbytes(buf, values_bytes as _); + let mut output = SVecf32::new_zeroed_in_postgres(len as usize); + + let indexes = std::slice::from_raw_parts(indexes_ptr as *const u16, len as usize); + if len > 1 { + for i in 0..len as usize - 1 { + if indexes[i] >= indexes[i + 1] { + pgrx::error!("data corruption is detected"); + } + } + } + if indexes[len as usize - 1] >= dims { + pgrx::error!("data corruption is detected"); + } + + output.dims = dims; + std::ptr::copy( + indexes_ptr, + output.indexes_mut().as_mut_ptr() as _, + indexes_bytes, + ); + std::ptr::copy( + values_ptr, + output.values_mut().as_mut_ptr() as _, + values_bytes, + ); + output + } +} + +#[pgrx::pg_extern(immutable, parallel_safe, strict)] +fn _vectors_to_svector( + dims: i32, + index: pgrx::Array, + value: pgrx::Array, +) -> SVecf32Output { + let dims = check_value_dimensions(dims as usize); + if index.len() != value.len() { + bad_literal("Lengths of index and value are not matched."); + } + if index.contains_nulls() || value.contains_nulls() { + bad_literal("Index or value contains nulls."); + } + let mut vector: Vec<(u16, F32)> = index + .iter_deny_null() + .zip(value.iter_deny_null()) + .map(|(index, value)| { + if index < 0 || index >= dims.get() as i32 { + bad_literal("Index out of bound."); + } + (index as u16, F32(value)) + }) + .collect(); + vector.sort_unstable_by_key(|x| x.0); + if vector.len() > 1 { + for i in 0..vector.len() - 1 { + if vector[i].0 == vector[i + 1].0 { + bad_literal("Duplicated index."); + } + } + } + + let mut indexes = Vec::::with_capacity(vector.len()); + let mut values = Vec::::with_capacity(vector.len()); + for x in vector { + indexes.push(x.0); + values.push(x.1); + } + SVecf32::new_in_postgres(SparseF32Ref { + dims: dims.get(), + indexes: &indexes, + values: &values, + }) +} diff --git a/src/datatype/typmod.rs b/src/datatype/typmod.rs index de48d2179..ea8802416 100644 --- a/src/datatype/typmod.rs +++ b/src/datatype/typmod.rs @@ -11,14 +11,6 @@ pub enum Typmod { } impl Typmod { - pub fn parse_from_str(s: &str) -> Option { - use Typmod::*; - if let Ok(x) = s.parse::() { - Some(Dims(x)) - } else { - None - } - } pub fn parse_from_i32(x: i32) -> Option { use Typmod::*; if x == -1 { @@ -43,11 +35,11 @@ impl Typmod { Dims(x) => i32::from(x.get()), } } - pub fn dims(self) -> Option { + pub fn dims(self) -> Option { use Typmod::*; match self { Any => None, - Dims(dims) => Some(dims.get()), + Dims(dims) => Some(dims), } } } @@ -58,12 +50,11 @@ fn _vectors_typmod_in(list: Array<&CStr>) -> i32 { -1 } else if list.len() == 1 { let s = list.get(0).unwrap().unwrap().to_str().unwrap(); - let typmod = Typmod::parse_from_str(s) - .ok_or(SessionError::BadTypeDimensions) - .friendly(); + let typmod = Typmod::Dims(check_type_dimensions(s.parse::().ok())); typmod.into_i32() } else { - SessionError::BadTypeDimensions.friendly(); + check_type_dimensions(None); + unreachable!() } } diff --git a/src/datatype/vecf16.rs b/src/datatype/vecf16.rs index d8599331c..24e3bfd64 100644 --- a/src/datatype/vecf16.rs +++ b/src/datatype/vecf16.rs @@ -268,19 +268,18 @@ unsafe impl SqlTranslatable for Vecf16Output { #[pgrx::pg_extern(immutable, parallel_safe, strict)] fn _vectors_vecf16_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf16Output { use crate::utils::parse::parse_vector; - let reserve = Typmod::parse_from_i32(typmod).unwrap().dims().unwrap_or(0); + let reserve = Typmod::parse_from_i32(typmod) + .unwrap() + .dims() + .map(|x| x.get()) + .unwrap_or(0); let v = parse_vector(input.to_bytes(), reserve as usize, |s| s.parse().ok()); match v { Err(e) => { - SessionError::BadLiteral { - hint: e.to_string(), - } - .friendly(); + bad_literal(&e.to_string()); } Ok(vector) => { - if vector.is_empty() || vector.len() > 65535 { - SessionError::BadValueDimensions.friendly(); - } + check_value_dimensions(vector.len()); Vecf16::new_in_postgres(&vector) } } @@ -300,7 +299,6 @@ fn _vectors_vecf16_out(vector: Vecf16Input<'_>) -> CString { CString::new(buffer).unwrap() } -#[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16"))] #[pgrx::pg_extern(sql = "\ CREATE FUNCTION _vectors_vecf16_subscript(internal) RETURNS internal IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] @@ -482,14 +480,6 @@ fn _vectors_vecf16_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { std::ptr::addr_of!(SBSROUTINES).into() } -#[cfg(not(any(feature = "pg14", feature = "pg15", feature = "pg16")))] -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_vecf16_subscript(internal) RETURNS internal -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf16_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { - unreachable!() -} - #[pgrx::pg_extern(sql = "\ CREATE FUNCTION _vectors_vecf16_send(vecf16) RETURNS bytea IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] diff --git a/src/datatype/vecf32.rs b/src/datatype/vecf32.rs index facdac411..dd59882e4 100644 --- a/src/datatype/vecf32.rs +++ b/src/datatype/vecf32.rs @@ -268,19 +268,18 @@ unsafe impl SqlTranslatable for Vecf32Output { #[pgrx::pg_extern(immutable, parallel_safe, strict)] fn _vectors_vecf32_in(input: &CStr, _oid: Oid, typmod: i32) -> Vecf32Output { use crate::utils::parse::parse_vector; - let reserve = Typmod::parse_from_i32(typmod).unwrap().dims().unwrap_or(0); + let reserve = Typmod::parse_from_i32(typmod) + .unwrap() + .dims() + .map(|x| x.get()) + .unwrap_or(0); let v = parse_vector(input.to_bytes(), reserve as usize, |s| s.parse().ok()); match v { Err(e) => { - SessionError::BadLiteral { - hint: e.to_string(), - } - .friendly(); + bad_literal(&e.to_string()); } Ok(vector) => { - if vector.is_empty() || vector.len() > 65535 { - SessionError::BadValueDimensions.friendly(); - } + check_value_dimensions(vector.len()); Vecf32::new_in_postgres(&vector) } } @@ -300,7 +299,6 @@ fn _vectors_vecf32_out(vector: Vecf32Input<'_>) -> CString { CString::new(buffer).unwrap() } -#[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16"))] #[pgrx::pg_extern(sql = "\ CREATE FUNCTION _vectors_vecf32_subscript(internal) RETURNS internal IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] @@ -482,14 +480,6 @@ fn _vectors_vecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { std::ptr::addr_of!(SBSROUTINES).into() } -#[cfg(not(any(feature = "pg14", feature = "pg15", feature = "pg16")))] -#[pgrx::pg_extern(sql = "\ -CREATE FUNCTION _vectors_vecf32_subscript(internal) RETURNS internal -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] -fn _vectors_vecf32_subscript(_fcinfo: pgrx::pg_sys::FunctionCallInfo) -> Datum { - unreachable!() -} - #[pgrx::pg_extern(sql = "\ CREATE FUNCTION _vectors_vecf32_send(vector) RETURNS bytea IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '@FUNCTION_NAME@';")] diff --git a/src/gucs/executing.rs b/src/gucs/executing.rs index 0708d266c..08d874846 100644 --- a/src/gucs/executing.rs +++ b/src/gucs/executing.rs @@ -1,6 +1,5 @@ use pgrx::{GucContext, GucFlags, GucRegistry, GucSetting}; use service::index::SearchOptions; -use validator::Validate; static ENABLE_PREFILTER: GucSetting = GucSetting::::new(true); @@ -40,11 +39,9 @@ pub unsafe fn init() { } pub fn search_options() -> SearchOptions { - let options = SearchOptions { + SearchOptions { prefilter_enable: ENABLE_PREFILTER.get(), hnsw_ef_search: HNSW_EF_SEARCH.get() as usize, ivf_nprobe: IVF_NPROBE.get() as u32, - }; - assert!(options.validate().is_ok()); - options + } } diff --git a/src/index/am.rs b/src/index/am.rs index 5d9a48968..202e30b86 100644 --- a/src/index/am.rs +++ b/src/index/am.rs @@ -22,10 +22,7 @@ pub unsafe fn init() { "".as_pg_cstr(), "".as_pg_cstr(), None, - #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))] - { - pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE - }, + pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE, ); } @@ -90,40 +87,6 @@ pub unsafe extern "C" fn amvalidate(opclass_oid: pgrx::pg_sys::Oid) -> bool { true } -#[cfg(feature = "pg12")] -#[pgrx::pg_guard] -pub unsafe extern "C" fn amoptions(reloptions: Datum, validate: bool) -> *mut pgrx::pg_sys::bytea { - use pgrx::pg_sys::AsPgCStr; - let tab: &[pgrx::pg_sys::relopt_parse_elt] = &[pgrx::pg_sys::relopt_parse_elt { - optname: "options".as_pg_cstr(), - opttype: pgrx::pg_sys::relopt_type_RELOPT_TYPE_STRING, - offset: am_setup::helper_offset() as i32, - }]; - let mut noptions = 0; - let options = - pgrx::pg_sys::parseRelOptions(reloptions, validate, RELOPT_KIND.get(), &mut noptions); - if noptions == 0 { - return std::ptr::null_mut(); - } - for relopt in std::slice::from_raw_parts_mut(options, noptions as usize) { - relopt.gen.as_mut().unwrap().lockmode = - pgrx::pg_sys::AccessExclusiveLock as pgrx::pg_sys::LOCKMODE; - } - let rdopts = pgrx::pg_sys::allocateReloptStruct(am_setup::helper_size(), options, noptions); - pgrx::pg_sys::fillRelOptions( - rdopts, - am_setup::helper_size(), - options, - noptions, - validate, - tab.as_ptr(), - tab.len() as i32, - ); - pgrx::pg_sys::pfree(options as pgrx::void_mut_ptr); - rdopts as *mut pgrx::pg_sys::bytea -} - -#[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))] #[pgrx::pg_guard] pub unsafe extern "C" fn amoptions(reloptions: Datum, validate: bool) -> *mut pgrx::pg_sys::bytea { use pgrx::pg_sys::AsPgCStr; @@ -189,25 +152,6 @@ pub unsafe extern "C" fn ambuildempty(index_relation: pgrx::pg_sys::Relation) { am_build::build(index_relation, None); } -#[cfg(any(feature = "pg12", feature = "pg13"))] -#[pgrx::pg_guard] -pub unsafe extern "C" fn aminsert( - index_relation: pgrx::pg_sys::Relation, - values: *mut Datum, - _is_null: *mut bool, - heap_tid: pgrx::pg_sys::ItemPointer, - _heap_relation: pgrx::pg_sys::Relation, - _check_unique: pgrx::pg_sys::IndexUniqueCheck, - _index_info: *mut pgrx::pg_sys::IndexInfo, -) -> bool { - let oid = (*index_relation).rd_node.relNode; - let id = Handle::from_sys(oid); - let vector = from_datum(*values.add(0)); - am_update::update_insert(id, vector, *heap_tid); - true -} - -#[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16"))] #[pgrx::pg_guard] pub unsafe extern "C" fn aminsert( index_relation: pgrx::pg_sys::Relation, @@ -276,7 +220,7 @@ pub unsafe extern "C" fn ambulkdelete( callback: pgrx::pg_sys::IndexBulkDeleteCallback, callback_state: *mut std::os::raw::c_void, ) -> *mut pgrx::pg_sys::IndexBulkDeleteResult { - #[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))] + #[cfg(any(feature = "pg14", feature = "pg15"))] let oid = (*(*info).index).rd_node.relNode; #[cfg(feature = "pg16")] let oid = (*(*info).index).rd_locator.relNumber; diff --git a/src/index/am_build.rs b/src/index/am_build.rs index c452cede6..38652c3b4 100644 --- a/src/index/am_build.rs +++ b/src/index/am_build.rs @@ -1,14 +1,14 @@ #![allow(unsafe_op_in_unsafe_fn)] +use crate::index::am_setup::options; use crate::index::utils::from_datum; -use crate::ipc::client::ClientGuard; +use crate::ipc::ClientRpc; use crate::prelude::*; -use crate::{index::am_setup::options, ipc::client::Rpc}; use pgrx::pg_sys::{IndexBuildResult, IndexInfo, RelationData}; use service::prelude::*; pub struct Builder { - pub rpc: ClientGuard, + pub rpc: ClientRpc, pub heap_relation: *mut RelationData, pub index_info: *mut IndexInfo, pub result: *mut IndexBuildResult, @@ -18,14 +18,20 @@ pub unsafe fn build( index: pgrx::pg_sys::Relation, data: Option<(*mut RelationData, *mut IndexInfo, *mut IndexBuildResult)>, ) { - #[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))] + #[cfg(any(feature = "pg14", feature = "pg15"))] let oid = (*index).rd_node.relNode; #[cfg(feature = "pg16")] let oid = (*index).rd_locator.relNumber; let id = Handle::from_sys(oid); let options = options(index); - let mut rpc = crate::ipc::client::borrow_mut(); - rpc.create(id, options); + let mut rpc = check_client(crate::ipc::client()); + match rpc.create(id, options) { + Ok(()) => (), + Err(CreateError::Exist) => bad_service_exists(), + Err(CreateError::InvalidIndexOptions { reason }) => { + bad_service_invalid_index_options(&reason) + } + } if let Some((heap_relation, index_info, result)) = data { let mut builder = Builder { rpc, @@ -43,28 +49,6 @@ pub unsafe fn build( } } -#[cfg(feature = "pg12")] -#[pgrx::pg_guard] -unsafe extern "C" fn callback( - index_relation: pgrx::pg_sys::Relation, - htup: pgrx::pg_sys::HeapTuple, - values: *mut pgrx::pg_sys::Datum, - _is_null: *mut bool, - _tuple_is_alive: bool, - state: *mut std::os::raw::c_void, -) { - let ctid = &(*htup).t_self; - let oid = (*index_relation).rd_node.relNode; - let id = Handle::from_sys(oid); - let state = &mut *(state as *mut Builder); - let vector = from_datum(*values.add(0)); - let pointer = Pointer::from_sys(*ctid); - state.rpc.insert(id, vector, pointer); - (*state.result).heap_tuples += 1.0; - (*state.result).index_tuples += 1.0; -} - -#[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))] #[pgrx::pg_guard] unsafe extern "C" fn callback( index_relation: pgrx::pg_sys::Relation, @@ -74,7 +58,7 @@ unsafe extern "C" fn callback( _tuple_is_alive: bool, state: *mut std::os::raw::c_void, ) { - #[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15"))] + #[cfg(any(feature = "pg14", feature = "pg15"))] let oid = (*index_relation).rd_node.relNode; #[cfg(feature = "pg16")] let oid = (*index_relation).rd_locator.relNumber; @@ -82,7 +66,12 @@ unsafe extern "C" fn callback( let state = &mut *(state as *mut Builder); let vector = from_datum(*values.add(0)); let pointer = Pointer::from_sys(*ctid); - state.rpc.insert(id, vector, pointer); + match state.rpc.insert(id, vector, pointer) { + Ok(()) => (), + Err(InsertError::NotExist) => bad_service_not_exist(), + Err(InsertError::Upgrade) => bad_service_upgrade(), + Err(InsertError::InvalidVector) => bad_service_invalid_vector(), + } (*state.result).heap_tuples += 1.0; (*state.result).index_tuples += 1.0; } diff --git a/src/index/am_scan.rs b/src/index/am_scan.rs index 1bc812a14..6bbf09709 100644 --- a/src/index/am_scan.rs +++ b/src/index/am_scan.rs @@ -4,8 +4,7 @@ use crate::gucs::executing::search_options; use crate::gucs::planning::Mode; use crate::gucs::planning::SEARCH_MODE; use crate::index::utils::from_datum; -use crate::ipc::client::ClientGuard; -use crate::ipc::client::{Basic, Vbase}; +use crate::ipc::{ClientBasic, ClientVbase}; use crate::prelude::*; use pgrx::FromDatum; use service::prelude::*; @@ -17,11 +16,11 @@ pub enum Scanner { }, Basic { node: *mut pgrx::pg_sys::IndexScanState, - basic: ClientGuard, + basic: ClientBasic, }, Vbase { node: *mut pgrx::pg_sys::IndexScanState, - vbase: ClientGuard, + vbase: ClientVbase, }, } @@ -91,23 +90,35 @@ pub unsafe fn next_scan(scan: pgrx::pg_sys::IndexScanDesc) -> bool { let node = node.expect("Hook failed."); let vector = vector.as_ref().expect("Scan failed."); - #[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))] + #[cfg(any(feature = "pg14", feature = "pg15"))] let oid = (*(*scan).indexRelation).rd_node.relNode; #[cfg(feature = "pg16")] let oid = (*(*scan).indexRelation).rd_locator.relNumber; let id = Handle::from_sys(oid); - let rpc = crate::ipc::client::borrow_mut(); + let rpc = check_client(crate::ipc::client()); match SEARCH_MODE.get() { Mode::basic => { let opts = search_options(); - let basic = rpc.basic(id, vector.clone(), opts); + let basic = match rpc.basic(id, vector.clone(), opts) { + Ok(x) => x, + Err((_, BasicError::NotExist)) => bad_service_not_exist(), + Err((_, BasicError::Upgrade)) => bad_service_upgrade(), + Err((_, BasicError::InvalidVector)) => bad_service_invalid_vector(), + Err((_, BasicError::InvalidSearchOptions { reason: _ })) => unreachable!(), + }; *scanner = Scanner::Basic { node, basic }; } Mode::vbase => { let opts = search_options(); - let vbase = rpc.vbase(id, vector.clone(), opts); + let vbase = match rpc.vbase(id, vector.clone(), opts) { + Ok(x) => x, + Err((_, VbaseError::NotExist)) => bad_service_not_exist(), + Err((_, VbaseError::Upgrade)) => bad_service_upgrade(), + Err((_, VbaseError::InvalidVector)) => bad_service_invalid_vector(), + Err((_, VbaseError::InvalidSearchOptions { reason: _ })) => unreachable!(), + }; *scanner = Scanner::Vbase { node, vbase }; } } diff --git a/src/index/am_setup.rs b/src/index/am_setup.rs index 824e998e7..6e1b1bfda 100644 --- a/src/index/am_setup.rs +++ b/src/index/am_setup.rs @@ -67,8 +67,14 @@ pub unsafe fn convert_opfamily_to_distance(opfamily: pgrx::pg_sys::Oid) -> (Dist result = (Distance::Dot, Kind::F16); } else if operator == regoperatorin("vectors.<=>(vectors.vecf16,vectors.vecf16)") { result = (Distance::Cos, Kind::F16); + } else if operator == regoperatorin("vectors.<->(vectors.svector,vectors.svector)") { + result = (Distance::L2, Kind::SparseF32); + } else if operator == regoperatorin("vectors.<#>(vectors.svector,vectors.svector)") { + result = (Distance::Dot, Kind::SparseF32); + } else if operator == regoperatorin("vectors.<=>(vectors.svector,vectors.svector)") { + result = (Distance::Cos, Kind::SparseF32); } else { - SessionError::BadOptions2.friendly(); + bad_opclass(); }; pgrx::pg_sys::ReleaseCatCacheList(list); pgrx::pg_sys::ReleaseSysCache(tuple); @@ -85,7 +91,7 @@ pub unsafe fn options(index_relation: pgrx::pg_sys::Relation) -> IndexOptions { let attrs = (*(*index_relation).rd_att).attrs.as_slice(1); let attr = &attrs[0]; let typmod = Typmod::parse_from_i32(attr.type_mod()).unwrap(); - let dims = typmod.dims().ok_or(SessionError::BadOption1).friendly(); + let dims = check_column_dimensions(typmod.dims()).get(); // get other options let parsed = get_parsed_from_varlena((*index_relation).rd_options); IndexOptions { diff --git a/src/index/am_update.rs b/src/index/am_update.rs index 04647ac92..638db50ca 100644 --- a/src/index/am_update.rs +++ b/src/index/am_update.rs @@ -6,18 +6,32 @@ pub fn update_insert(handle: Handle, vector: DynamicVector, tid: pgrx::pg_sys::I callback_dirty(handle); let pointer = Pointer::from_sys(tid); - let mut rpc = crate::ipc::client::borrow_mut(); - rpc.insert(handle, vector, pointer); + let mut rpc = check_client(crate::ipc::client()); + + match rpc.insert(handle, vector, pointer) { + Ok(()) => (), + Err(InsertError::NotExist) => bad_service_not_exist(), + Err(InsertError::Upgrade) => bad_service_upgrade(), + Err(InsertError::InvalidVector) => bad_service_invalid_vector(), + } } pub fn update_delete(handle: Handle, f: impl Fn(Pointer) -> bool) { callback_dirty(handle); - let mut rpc_list = crate::ipc::client::borrow_mut().list(handle); - let mut rpc = crate::ipc::client::borrow_mut(); + let mut rpc_list = match check_client(crate::ipc::client()).list(handle) { + Ok(x) => x, + Err((_, ListError::NotExist)) => bad_service_not_exist(), + Err((_, ListError::Upgrade)) => bad_service_upgrade(), + }; + let mut rpc = check_client(crate::ipc::client()); while let Some(p) = rpc_list.next() { if f(p) { - rpc.delete(handle, p); + match rpc.delete(handle, p) { + Ok(()) => (), + Err(DeleteError::NotExist) => (), + Err(DeleteError::Upgrade) => (), + } } } rpc_list.leave(); diff --git a/src/index/compat.rs b/src/index/compat.rs index bf55a0456..c820f20fb 100644 --- a/src/index/compat.rs +++ b/src/index/compat.rs @@ -126,10 +126,10 @@ unsafe fn rewrite_opclass(istmt: *mut pgrx::pg_sys::IndexStmt) { if opclass_name.is_null() { continue; } + #[cfg(feature = "pg14")] + let opclass_ptr = (*(opclass_name as *mut pgrx::pg_sys::Value)).val.str_; #[cfg(any(feature = "pg15", feature = "pg16"))] let opclass_ptr = (*(opclass_name as *mut pgrx::pg_sys::String)).sval; - #[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14"))] - let opclass_ptr = (*(opclass_name as *mut pgrx::pg_sys::Value)).val.str_; let opclass = match CStr::from_ptr(opclass_ptr).to_str() { Ok("vector_l2_ops") => "vector_l2_ops", Ok("vector_ip_ops") => "vector_dot_ops", @@ -169,7 +169,6 @@ pub unsafe fn options_from_vec(vec: Vec<*mut pgrx::pg_sys::DefElem>) -> HashMap< options } -#[cfg(any(feature = "pg13", feature = "pg14", feature = "pg15", feature = "pg16"))] pub unsafe fn vec_from_list(l: *mut pgrx::pg_sys::List) -> Vec<*mut T> { let mut vec = Vec::new(); if l.is_null() { @@ -185,24 +184,6 @@ pub unsafe fn vec_from_list(l: *mut pgrx::pg_sys::List) -> Vec<*mut T> { vec } -#[cfg(feature = "pg12")] -pub unsafe fn vec_from_list(l: *mut pgrx::pg_sys::List) -> Vec<*mut T> { - let mut vec: Vec<*mut T> = Vec::new(); - if l.is_null() { - return vec; - } - unsafe { - let length = (*l).length; - let mut elem = (*l).head; - for _ in 0..length { - let e = (*elem).data.ptr_value as *mut T; - vec.push(e); - elem = (*elem).next; - } - } - vec -} - pub unsafe fn list_from_vec(vec: Vec<*mut T>) -> *mut pgrx::pg_sys::List { use std::ptr; if vec.is_empty() { diff --git a/src/index/functions.rs b/src/index/functions.rs index 8f8813276..bc88556b2 100644 --- a/src/index/functions.rs +++ b/src/index/functions.rs @@ -1,8 +1,4 @@ -use crate::ipc::client; - -#[pgrx::pg_extern(immutable, parallel_safe, strict)] +#[pgrx::pg_extern(volatile, strict)] fn _vectors_pgvectors_upgrade() { - let mut client = client::borrow_mut(); - client.upgrade(); - pgrx::warning!("pgvecto.rs is upgraded. Restart PostgreSQL to take effects."); + let _ = std::fs::remove_dir_all("pg_vectors"); } diff --git a/src/index/hook_executor.rs b/src/index/hook_executor.rs index 608712e4f..284a1071e 100644 --- a/src/index/hook_executor.rs +++ b/src/index/hook_executor.rs @@ -53,7 +53,7 @@ unsafe extern "C" fn rewrite_plan_state( } } } - #[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))] + #[cfg(any(feature = "pg14", feature = "pg15"))] { type PlanstateTreeWalker = unsafe extern "C" fn(*mut pgrx::pg_sys::PlanState, *mut libc::c_void) -> bool; diff --git a/src/index/hook_transaction.rs b/src/index/hook_transaction.rs index efc63aa33..530fd6a32 100644 --- a/src/index/hook_transaction.rs +++ b/src/index/hook_transaction.rs @@ -16,12 +16,14 @@ pub fn commit() { if pending_deletes.is_empty() && pending_dirty.is_empty() { return; } - let mut rpc = crate::ipc::client::borrow_mut(); + let Some(mut rpc) = crate::ipc::client() else { + return; + }; for handle in pending_dirty { - rpc.flush(handle); + let _ = rpc.flush(handle); } for handle in pending_deletes { - rpc.drop(handle); + let _ = rpc.drop(handle); } } @@ -31,13 +33,15 @@ pub fn abort() { if pending_deletes.is_empty() { return; } - let mut rpc = crate::ipc::client::borrow_mut(); + let Some(mut rpc) = crate::ipc::client() else { + return; + }; for handle in pending_deletes { - rpc.drop(handle); + let _ = rpc.drop(handle); } } -#[cfg(any(feature = "pg12", feature = "pg13", feature = "pg14", feature = "pg15"))] +#[cfg(any(feature = "pg14", feature = "pg15"))] fn pending_deletes(for_commit: bool) -> Vec { let mut ptr: *mut pgrx::pg_sys::RelFileNode = std::ptr::null_mut(); let n = unsafe { pgrx::pg_sys::smgrGetPendingDeletes(for_commit, &mut ptr as *mut _) }; diff --git a/src/index/hooks.rs b/src/index/hooks.rs index 960019351..066bfe165 100644 --- a/src/index/hooks.rs +++ b/src/index/hooks.rs @@ -20,7 +20,6 @@ unsafe extern "C" fn vectors_executor_start( } } -#[cfg(any(feature = "pg14", feature = "pg15", feature = "pg16"))] #[pgrx::pg_guard] unsafe extern "C" fn hook_pgvector_compatibility( pstmt: *mut pgrx::pg_sys::PlannedStmt, @@ -62,84 +61,6 @@ unsafe extern "C" fn hook_pgvector_compatibility( } } -#[cfg(feature = "pg13")] -#[pgrx::pg_guard] -unsafe extern "C" fn hook_pgvector_compatibility( - pstmt: *mut pgrx::pg_sys::PlannedStmt, - query_string: *const ::std::os::raw::c_char, - context: pgrx::pg_sys::ProcessUtilityContext, - params: pgrx::pg_sys::ParamListInfo, - query_env: *mut pgrx::pg_sys::QueryEnvironment, - dest: *mut pgrx::pg_sys::DestReceiver, - completion_tag: *mut pgrx::pg_sys::QueryCompletion, -) { - unsafe { - pgvector_stmt_rewrite(pstmt); - } - unsafe { - if let Some(prev_process_utility) = PREV_PROCESS_UTILITY { - prev_process_utility( - pstmt, - query_string, - context, - params, - query_env, - dest, - completion_tag, - ); - } else { - pgrx::pg_sys::standard_ProcessUtility( - pstmt, - query_string, - context, - params, - query_env, - dest, - completion_tag, - ); - } - } -} - -#[cfg(feature = "pg12")] -#[pgrx::pg_guard] -unsafe extern "C" fn hook_pgvector_compatibility( - pstmt: *mut pgrx::pg_sys::PlannedStmt, - query_string: *const ::std::os::raw::c_char, - context: pgrx::pg_sys::ProcessUtilityContext, - params: pgrx::pg_sys::ParamListInfo, - query_env: *mut pgrx::pg_sys::QueryEnvironment, - dest: *mut pgrx::pg_sys::DestReceiver, - completion_tag: *mut ::std::os::raw::c_char, -) { - unsafe { - pgvector_stmt_rewrite(pstmt); - } - unsafe { - if let Some(prev_process_utility) = PREV_PROCESS_UTILITY { - prev_process_utility( - pstmt, - query_string, - context, - params, - query_env, - dest, - completion_tag, - ); - } else { - pgrx::pg_sys::standard_ProcessUtility( - pstmt, - query_string, - context, - params, - query_env, - dest, - completion_tag, - ); - } - } -} - #[pgrx::pg_guard] unsafe extern "C" fn xact_callback(event: pgrx::pg_sys::XactEvent, _data: pgrx::void_mut_ptr) { match event { diff --git a/src/index/utils.rs b/src/index/utils.rs index 8183d3574..132fbe9ca 100644 --- a/src/index/utils.rs +++ b/src/index/utils.rs @@ -1,8 +1,9 @@ #![allow(unsafe_op_in_unsafe_fn)] +use crate::datatype::svecf32::SVecf32; use crate::datatype::vecf16::Vecf16; use crate::datatype::vecf32::Vecf32; -use service::prelude::DynamicVector; +use service::prelude::*; #[repr(C, align(8))] struct Header { @@ -18,6 +19,10 @@ pub unsafe fn from_datum(datum: pgrx::pg_sys::Datum) -> DynamicVector { let vector = match (*q.cast::
()).kind { 0 => DynamicVector::F32((*q.cast::()).data().to_vec()), 1 => DynamicVector::F16((*q.cast::()).data().to_vec()), + 2 => { + let svec = &*q.cast::(); + DynamicVector::SparseF32(SparseF32::from(svec.data())) + } _ => unreachable!(), }; if p != q { diff --git a/src/index/views.rs b/src/index/views.rs index 384c8267f..5b2d316e7 100644 --- a/src/index/views.rs +++ b/src/index/views.rs @@ -9,14 +9,14 @@ fn _vectors_index_stat( use service::index::IndexStat; let id = Handle::from_sys(oid); let mut res = PgHeapTuple::new_composite_type("vectors.vector_index_stat").unwrap(); - let mut rpc = crate::ipc::client::borrow_mut(); + let mut rpc = check_client(crate::ipc::client()); let stat = rpc.stat(id); match stat { - IndexStat::Normal { + Ok(IndexStat { indexing, options, segments, - } => { + }) => { res.set_by_name("idx_status", "NORMAL").unwrap(); res.set_by_name("idx_indexing", indexing).unwrap(); res.set_by_name( @@ -60,7 +60,10 @@ fn _vectors_index_stat( .unwrap(); res } - IndexStat::Upgrade => { + Err(StatError::NotExist) => { + bad_service_not_exist(); + } + Err(StatError::Upgrade) => { res.set_by_name("idx_status", "UPGRADE").unwrap(); res } diff --git a/src/ipc/client/mod.rs b/src/ipc/client/mod.rs deleted file mode 100644 index c1841fda0..000000000 --- a/src/ipc/client/mod.rs +++ /dev/null @@ -1,269 +0,0 @@ -use super::packet::*; -use super::transport::ClientSocket; -use crate::gucs::internal::{Transport, TRANSPORT}; -use crate::prelude::*; -use crate::utils::cells::PgRefCell; -use service::index::IndexOptions; -use service::index::IndexStat; -use service::index::SearchOptions; -use service::prelude::*; -use std::mem::ManuallyDrop; -use std::ops::Deref; -use std::ops::DerefMut; - -pub trait ClientLike: 'static { - fn from_socket(socket: ClientSocket) -> Self; - fn to_socket(self) -> ClientSocket; -} - -pub struct ClientGuard(pub ManuallyDrop); - -impl ClientGuard { - fn map(mut self) -> ClientGuard { - unsafe { - let t = ManuallyDrop::take(&mut self.0); - std::mem::forget(self); - ClientGuard::new(U::from_socket(t.to_socket())) - } - } -} - -impl Deref for ClientGuard { - type Target = T; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -impl DerefMut for ClientGuard { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 - } -} - -pub struct Rpc { - socket: ClientSocket, -} - -impl Rpc { - pub fn new(socket: ClientSocket) -> Self { - Self { socket } - } -} - -impl ClientGuard { - pub fn flush(&mut self, handle: Handle) { - let packet = RpcPacket::Flush { handle }; - self.socket.ok(packet).friendly(); - let flush::FlushPacket::Leave {} = self.socket.recv().friendly(); - } - pub fn drop(&mut self, handle: Handle) { - let packet = RpcPacket::Drop { handle }; - self.socket.ok(packet).friendly(); - let drop::DropPacket::Leave {} = self.socket.recv().friendly(); - } - pub fn create(&mut self, handle: Handle, options: IndexOptions) { - let packet = RpcPacket::Create { handle, options }; - self.socket.ok(packet).friendly(); - let create::CreatePacket::Leave {} = self.socket.recv().friendly(); - } - pub fn basic( - mut self, - handle: Handle, - vector: DynamicVector, - opts: SearchOptions, - ) -> ClientGuard { - let packet = RpcPacket::Basic { - handle, - vector, - opts, - }; - self.socket.ok(packet).friendly(); - let basic::BasicErrorPacket {} = self.socket.recv().friendly(); - ClientGuard::map(self) - } - pub fn delete(&mut self, handle: Handle, pointer: Pointer) { - let packet = RpcPacket::Delete { handle, pointer }; - self.socket.ok(packet).friendly(); - let delete::DeletePacket::Leave {} = self.socket.recv().friendly(); - } - pub fn insert(&mut self, handle: Handle, vector: DynamicVector, pointer: Pointer) { - let packet = RpcPacket::Insert { - handle, - vector, - pointer, - }; - self.socket.ok(packet).friendly(); - let insert::InsertPacket::Leave {} = self.socket.recv().friendly(); - } - pub fn stat(&mut self, handle: Handle) -> IndexStat { - let packet = RpcPacket::Stat { handle }; - self.socket.ok(packet).friendly(); - let stat::StatPacket::Leave { result } = self.socket.recv().friendly(); - result - } - pub fn vbase( - mut self, - handle: Handle, - vector: DynamicVector, - opts: SearchOptions, - ) -> ClientGuard { - let packet = RpcPacket::Vbase { - handle, - vector, - opts, - }; - self.socket.ok(packet).friendly(); - let vbase::VbaseErrorPacket {} = self.socket.recv().friendly(); - ClientGuard::map(self) - } - pub fn list(mut self, handle: Handle) -> ClientGuard { - let packet = RpcPacket::List { handle }; - self.socket.ok(packet).friendly(); - let list::ListErrorPacket {} = self.socket.recv().friendly(); - ClientGuard::map(self) - } - pub fn upgrade(&mut self) { - let packet = RpcPacket::Upgrade {}; - self.socket.ok(packet).friendly(); - let upgrade::UpgradePacket::Leave {} = self.socket.recv().friendly(); - } -} - -impl ClientLike for Rpc { - fn from_socket(socket: ClientSocket) -> Self { - Self { socket } - } - - fn to_socket(self) -> ClientSocket { - self.socket - } -} - -pub struct Vbase { - socket: ClientSocket, -} - -impl Vbase { - pub fn next(&mut self) -> Option { - let packet = vbase::VbasePacket::Next {}; - self.socket.ok(packet).friendly(); - let vbase::VbaseNextPacket { p } = self.socket.recv().friendly(); - p - } -} - -impl ClientGuard { - pub fn leave(mut self) -> ClientGuard { - let packet = vbase::VbasePacket::Leave {}; - self.socket.ok(packet).friendly(); - let vbase::VbaseLeavePacket {} = self.socket.recv().friendly(); - ClientGuard::map(self) - } -} - -impl ClientLike for Vbase { - fn from_socket(socket: ClientSocket) -> Self { - Self { socket } - } - - fn to_socket(self) -> ClientSocket { - self.socket - } -} - -pub struct Basic { - socket: ClientSocket, -} - -impl Basic { - pub fn next(&mut self) -> Option { - let packet = basic::BasicPacket::Next {}; - self.socket.ok(packet).friendly(); - let basic::BasicNextPacket { p } = self.socket.recv().friendly(); - p - } -} - -impl ClientGuard { - pub fn leave(mut self) -> ClientGuard { - let packet = basic::BasicPacket::Leave {}; - self.socket.ok(packet).friendly(); - let basic::BasicLeavePacket {} = self.socket.recv().friendly(); - ClientGuard::map(self) - } -} - -impl ClientLike for Basic { - fn from_socket(socket: ClientSocket) -> Self { - Self { socket } - } - - fn to_socket(self) -> ClientSocket { - self.socket - } -} - -pub struct List { - socket: ClientSocket, -} - -impl List { - pub fn next(&mut self) -> Option { - let packet = list::ListPacket::Next {}; - self.socket.ok(packet).friendly(); - let list::ListNextPacket { p } = self.socket.recv().friendly(); - p - } -} - -impl ClientGuard { - pub fn leave(mut self) -> ClientGuard { - let packet = list::ListPacket::Leave {}; - self.socket.ok(packet).friendly(); - let list::ListLeavePacket {} = self.socket.recv().friendly(); - ClientGuard::map(self) - } -} - -impl ClientLike for List { - fn from_socket(socket: ClientSocket) -> Self { - Self { socket } - } - - fn to_socket(self) -> ClientSocket { - self.socket - } -} - -static CLIENTS: PgRefCell> = unsafe { PgRefCell::new(Vec::new()) }; - -pub fn borrow_mut() -> ClientGuard { - let mut x = CLIENTS.borrow_mut(); - if let Some(socket) = x.pop() { - return ClientGuard::new(Rpc::new(socket)); - } - let socket = match TRANSPORT.get() { - Transport::unix => crate::ipc::connect_unix(), - Transport::mmap => crate::ipc::connect_mmap(), - }; - ClientGuard::new(Rpc::new(socket)) -} - -impl ClientGuard { - pub fn new(t: T) -> Self { - Self(ManuallyDrop::new(t)) - } -} - -impl Drop for ClientGuard { - fn drop(&mut self) { - let socket = unsafe { ManuallyDrop::take(&mut self.0).to_socket() }; - if !std::thread::panicking() && std::any::TypeId::of::() == std::any::TypeId::of::() - { - let mut x = CLIENTS.borrow_mut(); - x.push(socket); - } - } -} diff --git a/src/ipc/mod.rs b/src/ipc/mod.rs index 104ec9c1c..5c9d46cd7 100644 --- a/src/ipc/mod.rs +++ b/src/ipc/mod.rs @@ -1,55 +1,43 @@ -pub mod client; -mod packet; -pub mod server; pub mod transport; -use self::server::RpcHandler; +use self::transport::ClientSocket; +use self::transport::ServerSocket; +use crate::gucs::internal::{Transport, TRANSPORT}; +use crate::ipc::transport::Packet; use crate::prelude::*; +use crate::utils::cells::PgRefCell; use serde::{Deserialize, Serialize}; -use service::prelude::ServiceError; -use thiserror::Error; +use service::index::IndexOptions; +use service::index::IndexStat; +use service::index::SearchOptions; +use service::prelude::*; -#[derive(Debug, Clone, Error)] +#[derive(Debug, Clone)] pub enum ConnectionError { - #[error("\ -IPC connection is closed unexpected. -ADVICE: The error is raisen by background worker errors. \ -Please check the full PostgreSQL log to get more information. Please read `https://docs.pgvecto.rs/admin/configuration.html`.\ -")] - Unexpected, - #[error(transparent)] - Service(#[from] ServiceError), - #[error(transparent)] - Grace(#[from] GraceError), + ClosedConnection, + BadSerialization, + BadDeserialization, } -impl FriendlyError for ConnectionError {} - -#[derive(Debug, Clone, Error, Serialize, Deserialize)] -#[error("Client performs a graceful shutdown.")] -pub struct GraceError; - -impl FriendlyError for GraceError {} - -pub fn listen_unix() -> impl Iterator { +pub fn listen_unix() -> impl Iterator { std::iter::from_fn(move || { let socket = self::transport::ServerSocket::Unix(self::transport::unix::accept()); - Some(self::server::RpcHandler::new(socket)) + Some(self::ServerRpcHandler::new(socket)) }) } -pub fn listen_mmap() -> impl Iterator { +pub fn listen_mmap() -> impl Iterator { std::iter::from_fn(move || { let socket = self::transport::ServerSocket::Mmap(self::transport::mmap::accept()); - Some(self::server::RpcHandler::new(socket)) + Some(self::ServerRpcHandler::new(socket)) }) } -pub fn connect_unix() -> self::transport::ClientSocket { +pub fn connect_unix() -> ClientSocket { self::transport::ClientSocket::Unix(self::transport::unix::connect()) } -pub fn connect_mmap() -> self::transport::ClientSocket { +pub fn connect_mmap() -> ClientSocket { self::transport::ClientSocket::Mmap(self::transport::mmap::connect()) } @@ -57,3 +45,288 @@ pub fn init() { self::transport::mmap::init(); self::transport::unix::init(); } + +impl Drop for ClientRpc { + fn drop(&mut self) { + let socket = self.socket.take(); + if let Some(socket) = socket { + if !std::thread::panicking() { + let mut x = CLIENTS.borrow_mut(); + x.push(socket); + } + } + } +} + +pub struct ClientRpc { + pub socket: Option, +} + +impl ClientRpc { + fn new(socket: ClientSocket) -> Self { + Self { + socket: Some(socket), + } + } + fn _ok(&mut self, packet: U) -> Result<(), ConnectionError> { + self.socket.as_mut().unwrap().ok(packet) + } + fn _recv(&mut self) -> Result { + self.socket.as_mut().unwrap().recv() + } +} + +static CLIENTS: PgRefCell> = unsafe { PgRefCell::new(Vec::new()) }; + +pub fn client() -> Option { + if !crate::bgworker::is_started() { + return None; + } + let mut x = CLIENTS.borrow_mut(); + if let Some(socket) = x.pop() { + return Some(ClientRpc::new(socket)); + } + let socket = match TRANSPORT.get() { + Transport::unix => connect_unix(), + Transport::mmap => connect_mmap(), + }; + Some(ClientRpc::new(socket)) +} + +pub struct ServerRpcHandler { + socket: ServerSocket, +} + +impl ServerRpcHandler { + pub(super) fn new(socket: ServerSocket) -> Self { + Self { socket } + } +} + +macro_rules! define_packets { + (unary $name:ident($($p_name:ident: $p_ty:ty),*) -> $r:ty;) => { + paste::paste! { + #[derive(Debug, Serialize, Deserialize)] + pub struct [] { + pub result: Result<$r, [< $name:camel Error >]>, + } + } + }; + (stream $name:ident($($p_name:ident: $p_ty:ty),*) -> $r:ty;) => { + paste::paste! { + #[derive(Debug, Serialize, Deserialize)] + pub struct [] { + pub result: Result<(), [< $name:camel Error >]>, + } + + #[derive(Debug, Serialize, Deserialize)] + pub enum [] { + Next {}, + Leave {}, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct [] { + pub p: Option<$r>, + } + + #[derive(Debug, Serialize, Deserialize)] + pub struct [] {} + } + }; +} + +macro_rules! define_client_stuffs { + (unary $name:ident($($p_name:ident:$p_ty:ty),*) -> $r:ty;) => { + paste::paste! { + impl ClientRpc { + pub fn $name(&mut self, $($p_name:$p_ty),*) -> Result<$r, [< $name:camel Error >]> { + let packet = PacketRpc::[< $name:camel >] { $($p_name),* }; + check_connection(self._ok(packet)); + let [] { result } = check_connection(self._recv()); + result + } + } + } + }; + (stream $name:ident($($p_name:ident:$p_ty:ty),*) -> $r:ty;) => { + paste::paste! { + impl ClientRpc { + pub fn $name(mut self, $($p_name:$p_ty),*) -> Result<[], (Self, [< $name:camel Error >])> { + let packet = PacketRpc::[<$name:camel>] { $($p_name),* }; + check_connection(self._ok(packet)); + let [] { result } = check_connection(self._recv()); + if let Err(e) = result { + Err((self, e)) + } else { + Ok([] { + socket: self.socket.take() + }) + } + } + } + + pub struct [] { + socket: Option, + } + + impl [] { + fn _ok(&mut self, packet: U) -> Result<(), ConnectionError> { + self.socket.as_mut().unwrap().ok(packet) + } + fn _recv(&mut self) -> Result { + self.socket.as_mut().unwrap().recv() + } + } + + impl [] { + pub fn next(&mut self) -> Option<$r> { + let packet = []::Next {}; + check_connection(self._ok(packet)); + let [] { p } = check_connection(self._recv()); + p + } + pub fn leave(mut self) -> ClientRpc { + let packet = []::Leave {}; + check_connection(self._ok(packet)); + let [] {} = check_connection(self._recv()); + ClientRpc { socket: self.socket.take() } + } + } + } + }; +} + +macro_rules! define_server_stuffs { + (unary $name:ident($($p_name:ident:$p_ty:ty),*) -> $r:ty;) => { + paste::paste! { + pub struct [] { + socket: ServerSocket, + } + + impl [] { + pub fn leave(mut self, result: Result<$r, [<$name:camel Error>]>) -> Result { + let packet = [] { result }; + self.socket.ok(packet)?; + Ok(ServerRpcHandler { + socket: self.socket, + }) + } + } + } + }; + (stream $name:ident($($p_name:ident:$p_ty:ty),*) -> $r:ty;) => { + paste::paste! { + pub struct [] { + socket: ServerSocket, + } + + impl [] { + pub fn error_ok(mut self) -> Result<[], ConnectionError> { + self.socket.ok([] { result: Ok(()) })?; + Ok([] { + socket: self.socket, + }) + } + pub fn error_err(mut self, err: [<$name:camel Error>]) -> Result { + self.socket.ok([] { result: Err(err) })?; + Ok(ServerRpcHandler { + socket: self.socket, + }) + } + } + + pub struct [] { + socket: ServerSocket, + } + + impl [] { + pub fn handle(mut self) -> Result<[], ConnectionError> { + Ok(match self.socket.recv::<[]>()? { + []::Next {} => []::Next { + x: [] { + socket: self.socket, + }, + }, + []::Leave {} => { + self.socket.ok([] {})?; + []::Leave { + x: ServerRpcHandler { + socket: self.socket, + }, + } + } + }) + } + } + + pub enum [] { + Next { x: [] }, + Leave { x: ServerRpcHandler }, + } + + pub struct [] { + socket: ServerSocket, + } + + impl [] { + pub fn leave(mut self, p: Option<$r>) -> Result<[], ConnectionError> { + let packet = [] { p }; + self.socket.ok(packet)?; + Ok([] { + socket: self.socket, + }) + } + } + } + }; +} + +macro_rules! defines { + ( + $($kind:ident $name:ident($($p_name:ident:$p_ty:ty),*) -> $r:ty;)* + ) => { + $(define_packets!($kind $name($($p_name:$p_ty),*) -> $r;);)* + $(define_client_stuffs!($kind $name($($p_name:$p_ty),*) -> $r;);)* + $(define_server_stuffs!($kind $name($($p_name:$p_ty),*) -> $r;);)* + + paste::paste! { + #[derive(Debug, Serialize, Deserialize)] + pub enum PacketRpc { + $([<$name:camel>]{$($p_name:$p_ty),*},)* + } + + impl ServerRpcHandler { + pub fn handle(mut self) -> Result { + Ok(match self.socket.recv::()? { + $(PacketRpc::[<$name:camel>] { $($p_name),* } => ServerRpcHandle::[<$name:camel>] { + $($p_name),*, + x: [] { + socket: self.socket, + }, + },)* + }) + } + } + + pub enum ServerRpcHandle { + $([<$name:camel>] { + $($p_name:$p_ty),*, + x: [< Server $name:camel >], + }),* + } + } + }; +} + +defines! { + unary create(handle: Handle, options: IndexOptions) -> (); + unary drop(handle: Handle) -> (); + unary flush(handle: Handle) -> (); + unary insert(handle: Handle, vector: DynamicVector, pointer: Pointer) -> (); + unary delete(handle: Handle, pointer: Pointer) -> (); + stream basic(handle: Handle, vector: DynamicVector, opts: SearchOptions) -> Pointer; + stream vbase(handle: Handle, vector: DynamicVector, opts: SearchOptions) -> Pointer; + stream list(handle: Handle) -> Pointer; + unary stat(handle: Handle) -> IndexStat; +} diff --git a/src/ipc/packet/basic.rs b/src/ipc/packet/basic.rs deleted file mode 100644 index 3cb478101..000000000 --- a/src/ipc/packet/basic.rs +++ /dev/null @@ -1,19 +0,0 @@ -use serde::{Deserialize, Serialize}; -use service::prelude::*; - -#[derive(Debug, Serialize, Deserialize)] -pub struct BasicErrorPacket {} - -#[derive(Debug, Serialize, Deserialize)] -pub enum BasicPacket { - Next {}, - Leave {}, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct BasicNextPacket { - pub p: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct BasicLeavePacket {} diff --git a/src/ipc/packet/create.rs b/src/ipc/packet/create.rs deleted file mode 100644 index edb1afdf9..000000000 --- a/src/ipc/packet/create.rs +++ /dev/null @@ -1,6 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub enum CreatePacket { - Leave {}, -} diff --git a/src/ipc/packet/delete.rs b/src/ipc/packet/delete.rs deleted file mode 100644 index f950a4c84..000000000 --- a/src/ipc/packet/delete.rs +++ /dev/null @@ -1,6 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub enum DeletePacket { - Leave {}, -} diff --git a/src/ipc/packet/drop.rs b/src/ipc/packet/drop.rs deleted file mode 100644 index eb6ef0d69..000000000 --- a/src/ipc/packet/drop.rs +++ /dev/null @@ -1,6 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub enum DropPacket { - Leave {}, -} diff --git a/src/ipc/packet/flush.rs b/src/ipc/packet/flush.rs deleted file mode 100644 index f39543d0c..000000000 --- a/src/ipc/packet/flush.rs +++ /dev/null @@ -1,6 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub enum FlushPacket { - Leave {}, -} diff --git a/src/ipc/packet/insert.rs b/src/ipc/packet/insert.rs deleted file mode 100644 index b056d6c27..000000000 --- a/src/ipc/packet/insert.rs +++ /dev/null @@ -1,6 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub enum InsertPacket { - Leave {}, -} diff --git a/src/ipc/packet/list.rs b/src/ipc/packet/list.rs deleted file mode 100644 index 093ed5c97..000000000 --- a/src/ipc/packet/list.rs +++ /dev/null @@ -1,19 +0,0 @@ -use serde::{Deserialize, Serialize}; -use service::prelude::*; - -#[derive(Debug, Serialize, Deserialize)] -pub struct ListErrorPacket {} - -#[derive(Debug, Serialize, Deserialize)] -pub enum ListPacket { - Next {}, - Leave {}, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ListNextPacket { - pub p: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct ListLeavePacket {} diff --git a/src/ipc/packet/mod.rs b/src/ipc/packet/mod.rs deleted file mode 100644 index d9cb6a0a8..000000000 --- a/src/ipc/packet/mod.rs +++ /dev/null @@ -1,58 +0,0 @@ -pub mod basic; -pub mod create; -pub mod delete; -pub mod drop; -pub mod flush; -pub mod insert; -pub mod list; -pub mod stat; -pub mod upgrade; -pub mod vbase; - -use serde::{Deserialize, Serialize}; -use service::index::IndexOptions; -use service::index::SearchOptions; -use service::prelude::*; - -#[derive(Debug, Serialize, Deserialize)] -pub enum RpcPacket { - // transaction - Flush { - handle: Handle, - }, - Drop { - handle: Handle, - }, - Create { - handle: Handle, - options: IndexOptions, - }, - // instance - Insert { - handle: Handle, - vector: DynamicVector, - pointer: Pointer, - }, - Delete { - handle: Handle, - pointer: Pointer, - }, - Stat { - handle: Handle, - }, - Basic { - handle: Handle, - vector: DynamicVector, - opts: SearchOptions, - }, - Vbase { - handle: Handle, - vector: DynamicVector, - opts: SearchOptions, - }, - List { - handle: Handle, - }, - // admin - Upgrade {}, -} diff --git a/src/ipc/packet/stat.rs b/src/ipc/packet/stat.rs deleted file mode 100644 index 236602192..000000000 --- a/src/ipc/packet/stat.rs +++ /dev/null @@ -1,7 +0,0 @@ -use serde::{Deserialize, Serialize}; -use service::index::IndexStat; - -#[derive(Debug, Serialize, Deserialize)] -pub enum StatPacket { - Leave { result: IndexStat }, -} diff --git a/src/ipc/packet/upgrade.rs b/src/ipc/packet/upgrade.rs deleted file mode 100644 index 0a7ab9966..000000000 --- a/src/ipc/packet/upgrade.rs +++ /dev/null @@ -1,6 +0,0 @@ -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Serialize, Deserialize)] -pub enum UpgradePacket { - Leave {}, -} diff --git a/src/ipc/packet/vbase.rs b/src/ipc/packet/vbase.rs deleted file mode 100644 index 4663a9c8c..000000000 --- a/src/ipc/packet/vbase.rs +++ /dev/null @@ -1,19 +0,0 @@ -use serde::{Deserialize, Serialize}; -use service::prelude::*; - -#[derive(Debug, Serialize, Deserialize)] -pub struct VbaseErrorPacket {} - -#[derive(Debug, Serialize, Deserialize)] -pub enum VbasePacket { - Next {}, - Leave {}, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct VbaseNextPacket { - pub p: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct VbaseLeavePacket {} diff --git a/src/ipc/server/mod.rs b/src/ipc/server/mod.rs deleted file mode 100644 index a77d3c117..000000000 --- a/src/ipc/server/mod.rs +++ /dev/null @@ -1,449 +0,0 @@ -use super::packet::*; -use super::transport::ServerSocket; -use super::ConnectionError; -use service::index::IndexOptions; -use service::index::IndexStat; -use service::index::SearchOptions; -use service::prelude::*; - -pub struct RpcHandler { - socket: ServerSocket, -} - -impl RpcHandler { - pub(super) fn new(socket: ServerSocket) -> Self { - Self { socket } - } - pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { - RpcPacket::Flush { handle } => RpcHandle::Flush { - handle, - x: Flush { - socket: self.socket, - }, - }, - RpcPacket::Drop { handle } => RpcHandle::Drop { - handle, - x: Drop { - socket: self.socket, - }, - }, - RpcPacket::Create { handle, options } => RpcHandle::Create { - handle, - options, - x: Create { - socket: self.socket, - }, - }, - RpcPacket::Insert { - handle, - vector, - pointer, - } => RpcHandle::Insert { - handle, - vector, - pointer, - x: Insert { - socket: self.socket, - }, - }, - RpcPacket::Delete { handle, pointer } => RpcHandle::Delete { - handle, - pointer, - x: Delete { - socket: self.socket, - }, - }, - RpcPacket::Basic { - handle, - vector, - opts, - } => RpcHandle::Basic { - handle, - vector, - opts, - x: Basic { - socket: self.socket, - }, - }, - RpcPacket::Stat { handle } => RpcHandle::Stat { - handle, - x: Stat { - socket: self.socket, - }, - }, - RpcPacket::Vbase { - handle, - vector, - opts, - } => RpcHandle::Vbase { - handle, - vector, - opts, - x: Vbase { - socket: self.socket, - }, - }, - RpcPacket::List { handle } => RpcHandle::List { - handle, - x: List { - socket: self.socket, - }, - }, - RpcPacket::Upgrade {} => RpcHandle::Upgrade { - x: Upgrade { - socket: self.socket, - }, - }, - }) - } -} - -pub enum RpcHandle { - Flush { - handle: Handle, - x: Flush, - }, - Drop { - handle: Handle, - x: Drop, - }, - Create { - handle: Handle, - options: IndexOptions, - x: Create, - }, - Basic { - handle: Handle, - vector: DynamicVector, - opts: SearchOptions, - x: Basic, - }, - Insert { - handle: Handle, - vector: DynamicVector, - pointer: Pointer, - x: Insert, - }, - Delete { - handle: Handle, - pointer: Pointer, - x: Delete, - }, - Stat { - handle: Handle, - x: Stat, - }, - Vbase { - handle: Handle, - vector: DynamicVector, - opts: SearchOptions, - x: Vbase, - }, - List { - handle: Handle, - x: List, - }, - Upgrade { - x: Upgrade, - }, -} - -pub struct Flush { - socket: ServerSocket, -} - -impl Flush { - pub fn leave(mut self) -> Result { - let packet = flush::FlushPacket::Leave {}; - self.socket.ok(packet)?; - Ok(RpcHandler { - socket: self.socket, - }) - } - #[allow(dead_code)] - pub fn reset(mut self, err: ServiceError) -> Result { - self.socket.err(err) - } -} - -pub struct Drop { - socket: ServerSocket, -} - -impl Drop { - pub fn leave(mut self) -> Result { - let packet = drop::DropPacket::Leave {}; - self.socket.ok(packet)?; - Ok(RpcHandler { - socket: self.socket, - }) - } - #[allow(dead_code)] - pub fn reset(mut self, err: ServiceError) -> Result { - self.socket.err(err) - } -} - -pub struct Create { - socket: ServerSocket, -} - -impl Create { - pub fn leave(mut self) -> Result { - let packet = create::CreatePacket::Leave {}; - self.socket.ok(packet)?; - Ok(RpcHandler { - socket: self.socket, - }) - } - pub fn reset(mut self, err: ServiceError) -> Result { - self.socket.err(err) - } -} - -pub struct Insert { - socket: ServerSocket, -} - -impl Insert { - pub fn leave(mut self) -> Result { - let packet = insert::InsertPacket::Leave {}; - self.socket.ok(packet)?; - Ok(RpcHandler { - socket: self.socket, - }) - } - pub fn reset(mut self, err: ServiceError) -> Result { - self.socket.err(err) - } -} - -pub struct Delete { - socket: ServerSocket, -} - -impl Delete { - pub fn leave(mut self) -> Result { - let packet = delete::DeletePacket::Leave {}; - self.socket.ok(packet)?; - Ok(RpcHandler { - socket: self.socket, - }) - } - pub fn reset(mut self, err: ServiceError) -> Result { - self.socket.err(err) - } -} - -pub struct Basic { - socket: ServerSocket, -} - -impl Basic { - pub fn error(mut self) -> Result { - self.socket.ok(basic::BasicErrorPacket {})?; - Ok(BasicHandler { - socket: self.socket, - }) - } - pub fn reset(mut self, err: ServiceError) -> Result { - self.socket.err(err) - } -} - -pub struct BasicHandler { - socket: ServerSocket, -} - -impl BasicHandler { - pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { - basic::BasicPacket::Next {} => BasicHandle::Next { - x: BasicNext { - socket: self.socket, - }, - }, - basic::BasicPacket::Leave {} => { - self.socket.ok(basic::BasicLeavePacket {})?; - BasicHandle::Leave { - x: RpcHandler { - socket: self.socket, - }, - } - } - }) - } -} - -pub enum BasicHandle { - Next { x: BasicNext }, - Leave { x: RpcHandler }, -} - -pub struct BasicNext { - socket: ServerSocket, -} - -impl BasicNext { - pub fn leave(mut self, p: Option) -> Result { - let packet = basic::BasicNextPacket { p }; - self.socket.ok(packet)?; - Ok(BasicHandler { - socket: self.socket, - }) - } -} - -pub struct Stat { - socket: ServerSocket, -} - -impl Stat { - pub fn leave(mut self, result: IndexStat) -> Result { - let packet = stat::StatPacket::Leave { result }; - self.socket.ok(packet)?; - Ok(RpcHandler { - socket: self.socket, - }) - } - pub fn reset(mut self, err: ServiceError) -> Result { - self.socket.err(err) - } -} - -pub struct Vbase { - socket: ServerSocket, -} - -impl Vbase { - pub fn error(mut self) -> Result { - self.socket.ok(vbase::VbaseErrorPacket {})?; - Ok(VbaseHandler { - socket: self.socket, - }) - } - pub fn reset(mut self, err: ServiceError) -> Result { - self.socket.err(err) - } -} - -pub struct VbaseHandler { - socket: ServerSocket, -} - -impl VbaseHandler { - pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { - vbase::VbasePacket::Next {} => VbaseHandle::Next { - x: VbaseNext { - socket: self.socket, - }, - }, - vbase::VbasePacket::Leave {} => { - self.socket.ok(vbase::VbaseLeavePacket {})?; - VbaseHandle::Leave { - x: RpcHandler { - socket: self.socket, - }, - } - } - }) - } -} - -pub enum VbaseHandle { - Next { x: VbaseNext }, - Leave { x: RpcHandler }, -} - -pub struct VbaseNext { - socket: ServerSocket, -} - -impl VbaseNext { - pub fn leave(mut self, p: Option) -> Result { - let packet = vbase::VbaseNextPacket { p }; - self.socket.ok(packet)?; - Ok(VbaseHandler { - socket: self.socket, - }) - } -} - -pub struct List { - socket: ServerSocket, -} - -impl List { - pub fn error(mut self) -> Result { - self.socket.ok(list::ListErrorPacket {})?; - Ok(ListHandler { - socket: self.socket, - }) - } - pub fn reset(mut self, err: ServiceError) -> Result { - self.socket.err(err) - } -} - -pub struct ListHandler { - socket: ServerSocket, -} - -impl ListHandler { - pub fn handle(mut self) -> Result { - Ok(match self.socket.recv::()? { - list::ListPacket::Next {} => ListHandle::Next { - x: ListNext { - socket: self.socket, - }, - }, - list::ListPacket::Leave {} => { - self.socket.ok(list::ListLeavePacket {})?; - ListHandle::Leave { - x: RpcHandler { - socket: self.socket, - }, - } - } - }) - } -} - -pub enum ListHandle { - Next { x: ListNext }, - Leave { x: RpcHandler }, -} - -pub struct ListNext { - socket: ServerSocket, -} - -impl ListNext { - pub fn leave(mut self, p: Option) -> Result { - let packet = list::ListNextPacket { p }; - self.socket.ok(packet)?; - Ok(ListHandler { - socket: self.socket, - }) - } -} - -pub struct Upgrade { - socket: ServerSocket, -} - -impl Upgrade { - pub fn leave(mut self) -> Result { - let packet = upgrade::UpgradePacket::Leave {}; - self.socket.ok(packet)?; - Ok(RpcHandler { - socket: self.socket, - }) - } - #[allow(dead_code)] - pub fn reset(mut self, err: ServiceError) -> Result { - self.socket.err(err) - } -} diff --git a/src/ipc/transport/mmap.rs b/src/ipc/transport/mmap.rs index ea234620f..b7b4b107b 100644 --- a/src/ipc/transport/mmap.rs +++ b/src/ipc/transport/mmap.rs @@ -1,43 +1,58 @@ use super::ConnectionError; -use crate::utils::file_socket::FileSocket; -use crate::utils::os::{futex_wait, futex_wake, memfd_create, mmap_populate}; use rustix::fd::{AsFd, OwnedFd}; use rustix::fs::FlockOperation; +use send_fd::SendFd; use std::cell::UnsafeCell; use std::io::ErrorKind; use std::sync::atomic::{AtomicU32, Ordering}; use std::sync::OnceLock; +use std::time::Duration; const BUFFER_SIZE: usize = 512 * 1024; const SPIN_LIMIT: usize = 8; +const TIMEOUT: Duration = Duration::from_secs(15); -static CHANNEL: OnceLock = OnceLock::new(); +static CHANNEL: OnceLock = OnceLock::new(); pub fn init() { - CHANNEL.set(FileSocket::new().unwrap()).ok().unwrap(); + CHANNEL.set(SendFd::new().unwrap()).ok().unwrap(); } pub fn accept() -> Socket { let memfd = CHANNEL.get().unwrap().recv().unwrap(); rustix::fs::fcntl_lock(&memfd, FlockOperation::NonBlockingLockShared).unwrap(); - let addr = unsafe { mmap_populate(BUFFER_SIZE, &memfd).unwrap() }; + let memmap = unsafe { + memmap2::MmapOptions::new() + .len(BUFFER_SIZE) + .populate() + .map_mut(&memfd) + .unwrap() + }; Socket { is_server: true, - addr: addr as _, + addr: memmap.as_ptr().cast(), memfd, + _memmap: memmap, } } pub fn connect() -> Socket { - let memfd = memfd_create().unwrap(); + let memfd = memfd::memfd_create().unwrap(); rustix::fs::ftruncate(&memfd, BUFFER_SIZE as u64).unwrap(); rustix::fs::fcntl_lock(&memfd, FlockOperation::NonBlockingLockShared).unwrap(); CHANNEL.get().unwrap().send(memfd.as_fd()).unwrap(); - let addr = unsafe { mmap_populate(BUFFER_SIZE, &memfd).unwrap() }; + let memmap = unsafe { + memmap2::MmapOptions::new() + .len(BUFFER_SIZE) + .populate() + .map_mut(&memfd) + .unwrap() + }; Socket { is_server: false, - addr: addr as _, + addr: memmap.as_ptr().cast(), memfd, + _memmap: memmap, } } @@ -45,6 +60,7 @@ pub struct Socket { is_server: bool, addr: *const Channel, memfd: OwnedFd, + _memmap: memmap2::MmapMut, } unsafe impl Send for Socket {} @@ -123,17 +139,13 @@ impl Channel { { break; } - unsafe { - futex_wait(&self.futex, Y); - } + interprocess_atomic_wait::wait(&self.futex, Y, TIMEOUT); } Y => { if !test() { - return Err(ConnectionError::Unexpected); - } - unsafe { - futex_wait(&self.futex, Y); + return Err(ConnectionError::ClosedConnection); } + interprocess_atomic_wait::wait(&self.futex, Y, TIMEOUT); } _ => unsafe { std::hint::unreachable_unchecked() }, } @@ -154,9 +166,7 @@ impl Channel { (*self.bytes.get())[0..data.len()].copy_from_slice(data); } if X == self.futex.swap(T, Ordering::Release) { - unsafe { - futex_wake(&self.futex); - } + interprocess_atomic_wait::wake(&self.futex); } } unsafe fn server_recv(&self, test: impl Fn() -> bool) -> Result, ConnectionError> { @@ -182,17 +192,13 @@ impl Channel { { break; } - unsafe { - futex_wait(&self.futex, Y); - } + interprocess_atomic_wait::wait(&self.futex, Y, TIMEOUT); } Y => { if !test() { - return Err(ConnectionError::Unexpected); - } - unsafe { - futex_wait(&self.futex, Y); + return Err(ConnectionError::ClosedConnection); } + interprocess_atomic_wait::wait(&self.futex, Y, TIMEOUT); } _ => unsafe { std::hint::unreachable_unchecked() }, } @@ -213,9 +219,7 @@ impl Channel { (*self.bytes.get())[0..data.len()].copy_from_slice(data); } if X == self.futex.swap(T, Ordering::Release) { - unsafe { - futex_wake(&self.futex); - } + interprocess_atomic_wait::wake(&self.futex); } } } diff --git a/src/ipc/transport/mod.rs b/src/ipc/transport/mod.rs index d7004a1e9..3c6178f02 100644 --- a/src/ipc/transport/mod.rs +++ b/src/ipc/transport/mod.rs @@ -1,23 +1,21 @@ pub mod mmap; pub mod unix; -use super::{ConnectionError, GraceError}; +use super::ConnectionError; use serde::{Deserialize, Serialize}; -use service::prelude::ServiceError; -use std::fmt::Debug; -pub trait Bincode: Debug { - fn serialize(&self) -> Vec; - fn deserialize(_: &[u8]) -> Self; +pub trait Packet: Sized { + fn serialize(&self) -> Option>; + fn deserialize(_: &[u8]) -> Option; } -impl Deserialize<'a>> Bincode for T { - fn serialize(&self) -> Vec { - bincode::serialize(self).unwrap() +impl Deserialize<'a>> Packet for T { + fn serialize(&self) -> Option> { + bincode::serialize(self).ok() } - fn deserialize(bytes: &[u8]) -> Self { - bincode::deserialize(bytes).unwrap() + fn deserialize(bytes: &[u8]) -> Option { + bincode::deserialize(bytes).ok() } } @@ -32,66 +30,39 @@ pub enum ClientSocket { } impl ServerSocket { - pub fn ok(&mut self, packet: T) -> Result<(), ConnectionError> { - let mut buffer = vec![0u8]; - buffer.extend(packet.serialize()); + pub fn ok(&mut self, packet: T) -> Result<(), ConnectionError> { + let buffer = packet + .serialize() + .ok_or(ConnectionError::BadSerialization)?; match self { Self::Unix(x) => x.send(&buffer), Self::Mmap(x) => x.send(&buffer), } } - pub fn err(&mut self, packet: ServiceError) -> Result { - let mut buffer = vec![1u8]; - buffer.extend(Bincode::serialize(&packet)); - match self { - Self::Unix(x) => x.send(&buffer)?, - Self::Mmap(x) => x.send(&buffer)?, - } - Err(ConnectionError::Service(packet)) - } - pub fn recv(&mut self) -> Result { + pub fn recv(&mut self) -> Result { let buffer = match self { Self::Unix(x) => x.recv()?, Self::Mmap(x) => x.recv()?, }; - let c = &buffer[1..]; - match buffer[0] { - 0u8 => Ok(T::deserialize(c)), - 1u8 => Err(ConnectionError::Grace(bincode::deserialize(c).unwrap())), - _ => unreachable!(), - } + T::deserialize(&buffer).ok_or(ConnectionError::BadDeserialization) } } impl ClientSocket { - pub fn ok(&mut self, packet: T) -> Result<(), ConnectionError> { - let mut buffer = vec![0u8]; - buffer.extend(packet.serialize()); + pub fn ok(&mut self, packet: T) -> Result<(), ConnectionError> { + let buffer = packet + .serialize() + .ok_or(ConnectionError::BadSerialization)?; match self { Self::Unix(x) => x.send(&buffer), Self::Mmap(x) => x.send(&buffer), } } - #[allow(unused)] - pub fn err(&mut self, packet: GraceError) -> Result { - let mut buffer = vec![1u8]; - buffer.extend(Bincode::serialize(&packet)); - match self { - Self::Unix(x) => x.send(&buffer)?, - Self::Mmap(x) => x.send(&buffer)?, - } - Err(ConnectionError::Grace(packet)) - } - pub fn recv(&mut self) -> Result { + pub fn recv(&mut self) -> Result { let buffer = match self { Self::Unix(x) => x.recv()?, Self::Mmap(x) => x.recv()?, }; - let c = &buffer[1..]; - match buffer[0] { - 0u8 => Ok(T::deserialize(c)), - 1u8 => Err(ConnectionError::Service(bincode::deserialize(c).unwrap())), - _ => unreachable!(), - } + T::deserialize(&buffer).ok_or(ConnectionError::BadDeserialization) } } diff --git a/src/ipc/transport/unix.rs b/src/ipc/transport/unix.rs index a15bc2c54..7fd9adbac 100644 --- a/src/ipc/transport/unix.rs +++ b/src/ipc/transport/unix.rs @@ -1,15 +1,15 @@ use super::ConnectionError; -use crate::utils::file_socket::FileSocket; use byteorder::{ReadBytesExt, WriteBytesExt}; use rustix::fd::AsFd; +use send_fd::SendFd; use std::io::{Read, Write}; use std::os::unix::net::UnixStream; use std::sync::OnceLock; -static CHANNEL: OnceLock = OnceLock::new(); +static CHANNEL: OnceLock = OnceLock::new(); pub fn init() { - CHANNEL.set(FileSocket::new().unwrap()).ok().unwrap(); + CHANNEL.set(SendFd::new().unwrap()).ok().unwrap(); } pub fn accept() -> Socket { @@ -32,7 +32,7 @@ macro_rules! resolve_closed { ($t: expr) => { match $t { Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => { - return Err(ConnectionError::Unexpected) + return Err(ConnectionError::ClosedConnection) } Err(e) => panic!("{}", e), Ok(e) => e, diff --git a/src/lib.rs b/src/lib.rs index 56452c95e..7ca3d35f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,7 +21,7 @@ pgrx::extension_sql_file!("./sql/finalize.sql", finalize); unsafe extern "C" fn _PG_init() { use crate::prelude::*; if unsafe { pgrx::pg_sys::IsUnderPostmaster } { - SessionError::BadInit.friendly(); + bad_init(); } unsafe { detect::initialize(); diff --git a/src/prelude/error.rs b/src/prelude/error.rs index e06fa0ec9..e28e05bc5 100644 --- a/src/prelude/error.rs +++ b/src/prelude/error.rs @@ -1,75 +1,143 @@ -use service::prelude::ServiceError; -use std::fmt::Display; -use thiserror::Error; +use crate::ipc::{ClientRpc, ConnectionError}; +use pgrx::error; +use std::num::NonZeroU16; -pub trait FriendlyError: Display { - fn friendly(&self) -> ! { - panic!("pgvecto.rs: {}", self); +pub fn bad_init() -> ! { + error!("\ +pgvecto.rs: pgvecto.rs must be loaded via shared_preload_libraries. +ADVICE: If you encounter this error for your first use of pgvecto.rs, \ +please read `https://docs.pgvecto.rs/getting-started/installation.html`. \ +You should edit `shared_preload_libraries` in `postgresql.conf` to include `vectors.so`, \ +or simply run the command `psql -U postgres -c 'ALTER SYSTEM SET shared_preload_libraries = \"vectors.so\"'`."); +} + +pub fn check_type_dimensions(dimensions: Option) -> NonZeroU16 { + match dimensions { + None => { + error!( + "\ +pgvecto.rs: Modifier of the type is invalid. +ADVICE: Check if modifier of the type is an integer among 1 and 65535." + ) + } + Some(x) => x, + } +} + +pub fn check_value_dimensions(dimensions: usize) -> NonZeroU16 { + match u16::try_from(dimensions) + .and_then(NonZeroU16::try_from) + .ok() + { + None => { + error!( + "\ +pgvecto.rs: Dimensions of the vector is invalid. +ADVICE: Check if dimensions of the vector are among 1 and 65535." + ) + } + Some(x) => x, } } -impl FriendlyError for ServiceError {} +pub fn bad_literal(hint: &str) -> ! { + error!( + "\ +pgvecto.rs: Bad literal. +INFORMATION: hint = {hint}" + ); +} -pub trait Friendly { - fn friendly(self) -> T; +#[inline(always)] +pub fn check_matched_dimensions(left_dimensions: usize, right_dimensions: usize) -> usize { + if left_dimensions != right_dimensions { + error!( + "\ +pgvecto.rs: Operands of the operator differs in dimensions or scalar type. +INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}", + ) + } + left_dimensions } -impl Friendly for Result { - fn friendly(self) -> T { - match self { - Ok(x) => x, - Err(e) => e.friendly(), - } +#[inline(always)] +pub fn check_column_dimensions(dimensions: Option) -> NonZeroU16 { + match dimensions { + None => error!( + "\ +pgvecto.rs: Dimensions type modifier of a vector column is needed for building the index.", + ), + Some(x) => x, } } -#[must_use] -#[derive(Debug, Error)] -#[rustfmt::skip] -pub enum SessionError { - #[error("\ -pgvecto.rs must be loaded via shared_preload_libraries. -ADVICE: If you encounter this error for your first use of pgvecto.rs, \ -please read `https://docs.pgvecto.rs/getting-started/installation.html`. \ -You should edit `shared_preload_libraries` in `postgresql.conf` to include `vectors.so`, \ -or simply run the command `psql -U postgres -c 'ALTER SYSTEM SET shared_preload_libraries = \"vectors.so\"'`.\ -")] - BadInit, - #[error("\ -Bad literal. -INFORMATION: hint = {hint}\ -")] - BadLiteral { - hint: String, - }, - #[error("\ -Dimensions type modifier of a vector column is needed for building the index.\ -")] - BadOption1, - #[error("\ -Indexes can only be built on built-in distance functions. +pub fn bad_opclass() -> ! { + error!( + "\ +pgvecto.rs: Indexes can only be built on built-in distance functions. ADVICE: If you want pgvecto.rs to support more distance functions, \ -visit `https://github.com/tensorchord/pgvecto.rs/issues` and contribute your ideas.\ -")] - BadOptions2, - #[error("\ -Modifier of the type is invalid. -ADVICE: Check if modifier of the type is an integer among 1 and 65535.\ -")] - BadTypeDimensions, - #[error("\ -Dimensions of the vector is invalid. -ADVICE: Check if dimensions of the vector are among 1 and 65535.\ -")] - BadValueDimensions, - #[error("\ -Operands of the operator differs in dimensions or scalar type. -INFORMATION: left_dimensions = {left_dimensions}, right_dimensions = {right_dimensions}\ -")] - Unmatched { - left_dimensions: u16, - right_dimensions: u16, - }, +visit `https://github.com/tensorchord/pgvecto.rs/issues` and contribute your ideas." + ); } -impl FriendlyError for SessionError {} +pub fn bad_service_not_exist() -> ! { + error!( + "\ +pgvecto.rs: The index is not existing in the background worker. +ADVICE: Drop or rebuild the index.\ + " + ); +} + +pub fn check_connection(result: Result) -> T { + match result { + Err(_) => error!( + "\ +pgvecto.rs: Indexes can only be built on built-in distance functions. +ADVICE: If you want pgvecto.rs to support more distance functions, \ +visit `https://github.com/tensorchord/pgvecto.rs/issues` and contribute your ideas." + ), + Ok(x) => x, + } +} + +pub fn check_client(option: Option) -> ClientRpc { + match option { + None => error!( + "\ +pgvecto.rs: The extension is upgraded so all index files are outdated. +ADVICE: Delete all index files. Please read `https://docs.pgvecto.rs/admin/upgrading.html`" + ), + Some(x) => x, + } +} + +pub fn bad_service_upgrade() -> ! { + error!( + "\ +pgvecto.rs: The extension is upgraded so this index is outdated. +ADVICE: Rebuild the index. Please read `https://docs.pgvecto.rs/admin/upgrading.html`." + ) +} + +pub fn bad_service_exists() -> ! { + error!( + "\ +pgvecto.rs: The index is already existing in the background worker." + ) +} + +pub fn bad_service_invalid_index_options(reason: &str) -> ! { + error!( + "\ +pgvecto.rs: The given index option is invalid. +INFORMATION: reason = {reason:?}" + ) +} + +pub fn bad_service_invalid_vector() -> ! { + error!( + "\ +pgvecto.rs: The dimension of a vector does not matched that in a vector index column." + ) +} diff --git a/src/prelude/mod.rs b/src/prelude/mod.rs index 617b2d329..af364e499 100644 --- a/src/prelude/mod.rs +++ b/src/prelude/mod.rs @@ -1,5 +1,5 @@ mod error; mod sys; -pub use error::{Friendly, FriendlyError, SessionError}; +pub use error::*; pub use sys::{FromSys, IntoSys}; diff --git a/src/sql/bootstrap.sql b/src/sql/bootstrap.sql index 488b2fd76..d858eddc9 100644 --- a/src/sql/bootstrap.sql +++ b/src/sql/bootstrap.sql @@ -4,6 +4,7 @@ CREATE TYPE vector; CREATE TYPE vecf16; +CREATE TYPE svector; CREATE TYPE vector_index_stat; -- bootstrap end diff --git a/src/sql/finalize.sql b/src/sql/finalize.sql index ddbd39129..ac79f9b24 100644 --- a/src/sql/finalize.sql +++ b/src/sql/finalize.sql @@ -7,6 +7,7 @@ CREATE TYPE vector ( OUTPUT = _vectors_vecf32_out, RECEIVE = _vectors_vecf32_recv, SEND = _vectors_vecf32_send, + SUBSCRIPT = _vectors_vecf32_subscript, TYPMOD_IN = _vectors_typmod_in, TYPMOD_OUT = _vectors_typmod_out, STORAGE = EXTERNAL, @@ -14,18 +15,12 @@ CREATE TYPE vector ( ALIGNMENT = double ); -DO $$ -BEGIN - IF current_setting('server_version_num')::int >= 140000 THEN - ALTER TYPE vector SET (SUBSCRIPT = _vectors_vecf32_subscript); - END IF; -END $$; - CREATE TYPE vecf16 ( INPUT = _vectors_vecf16_in, OUTPUT = _vectors_vecf16_out, RECEIVE = _vectors_vecf16_recv, SEND = _vectors_vecf16_send, + SUBSCRIPT = _vectors_vecf16_subscript, TYPMOD_IN = _vectors_typmod_in, TYPMOD_OUT = _vectors_typmod_out, STORAGE = EXTERNAL, @@ -33,12 +28,18 @@ CREATE TYPE vecf16 ( ALIGNMENT = double ); -DO $$ -BEGIN - IF current_setting('server_version_num')::int >= 140000 THEN - ALTER TYPE vecf16 SET (SUBSCRIPT = _vectors_vecf16_subscript); - END IF; -END $$; +CREATE TYPE svector ( + INPUT = _vectors_svecf32_in, + OUTPUT = _vectors_svecf32_out, + RECEIVE = _vectors_svecf32_recv, + SEND = _vectors_svecf32_send, + SUBSCRIPT = _vectors_svecf32_subscript, + TYPMOD_IN = _vectors_typmod_in, + TYPMOD_OUT = _vectors_typmod_out, + STORAGE = EXTERNAL, + INTERNALLENGTH = VARIABLE, + ALIGNMENT = double +); CREATE TYPE vector_index_stat AS ( idx_status TEXT, @@ -67,6 +68,13 @@ CREATE OPERATOR + ( COMMUTATOR = + ); +CREATE OPERATOR + ( + PROCEDURE = _vectors_svecf32_operator_add, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = + +); + CREATE OPERATOR - ( PROCEDURE = _vectors_vecf32_operator_minus, LEFTARG = vector, @@ -79,6 +87,12 @@ CREATE OPERATOR - ( RIGHTARG = vecf16 ); +CREATE OPERATOR - ( + PROCEDURE = _vectors_svecf32_operator_minus, + LEFTARG = svector, + RIGHTARG = svector +); + CREATE OPERATOR = ( PROCEDURE = _vectors_vecf32_operator_eq, LEFTARG = vector, @@ -99,6 +113,16 @@ CREATE OPERATOR = ( JOIN = eqjoinsel ); +CREATE OPERATOR = ( + PROCEDURE = _vectors_svecf32_operator_eq, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = =, + NEGATOR = <>, + RESTRICT = eqsel, + JOIN = eqjoinsel +); + CREATE OPERATOR <> ( PROCEDURE = _vectors_vecf32_operator_neq, LEFTARG = vector, @@ -119,6 +143,16 @@ CREATE OPERATOR <> ( JOIN = eqjoinsel ); +CREATE OPERATOR <> ( + PROCEDURE = _vectors_svecf32_operator_neq, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <>, + NEGATOR = =, + RESTRICT = eqsel, + JOIN = eqjoinsel +); + CREATE OPERATOR < ( PROCEDURE = _vectors_vecf32_operator_lt, LEFTARG = vector, @@ -139,6 +173,16 @@ CREATE OPERATOR < ( JOIN = scalarltjoinsel ); +CREATE OPERATOR < ( + PROCEDURE = _vectors_svecf32_operator_lt, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = >, + NEGATOR = >=, + RESTRICT = scalarltsel, + JOIN = scalarltjoinsel +); + CREATE OPERATOR > ( PROCEDURE = _vectors_vecf32_operator_gt, LEFTARG = vector, @@ -159,6 +203,17 @@ CREATE OPERATOR > ( JOIN = scalargtjoinsel ); +CREATE OPERATOR > ( + PROCEDURE = _vectors_svecf32_operator_gt, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <, + NEGATOR = <=, + RESTRICT = scalargtsel, + JOIN = scalargtjoinsel +); + + CREATE OPERATOR <= ( PROCEDURE = _vectors_vecf32_operator_lte, LEFTARG = vector, @@ -179,6 +234,16 @@ CREATE OPERATOR <= ( JOIN = scalarltjoinsel ); +CREATE OPERATOR <= ( + PROCEDURE = _vectors_svecf32_operator_lte, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = >=, + NEGATOR = >, + RESTRICT = scalarltsel, + JOIN = scalarltjoinsel +); + CREATE OPERATOR >= ( PROCEDURE = _vectors_vecf32_operator_gte, LEFTARG = vector, @@ -199,6 +264,16 @@ CREATE OPERATOR >= ( JOIN = scalargtjoinsel ); +CREATE OPERATOR >= ( + PROCEDURE = _vectors_svecf32_operator_gte, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <=, + NEGATOR = <, + RESTRICT = scalargtsel, + JOIN = scalargtjoinsel +); + CREATE OPERATOR <-> ( PROCEDURE = _vectors_vecf32_operator_l2, LEFTARG = vector, @@ -213,6 +288,14 @@ CREATE OPERATOR <-> ( COMMUTATOR = <-> ); +CREATE OPERATOR <-> ( + PROCEDURE = _vectors_svecf32_operator_l2, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <-> +); + + CREATE OPERATOR <#> ( PROCEDURE = _vectors_vecf32_operator_dot, LEFTARG = vector, @@ -227,6 +310,13 @@ CREATE OPERATOR <#> ( COMMUTATOR = <#> ); +CREATE OPERATOR <#> ( + PROCEDURE = _vectors_svecf32_operator_dot, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <#> +); + CREATE OPERATOR <=> ( PROCEDURE = _vectors_vecf32_operator_cosine, LEFTARG = vector, @@ -241,10 +331,20 @@ CREATE OPERATOR <=> ( COMMUTATOR = <=> ); +CREATE OPERATOR <=> ( + PROCEDURE = _vectors_svecf32_operator_cosine, + LEFTARG = svector, + RIGHTARG = svector, + COMMUTATOR = <=> +); + -- List of functions CREATE FUNCTION pgvectors_upgrade() RETURNS void -IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_pgvectors_upgrade_wrapper'; +STRICT LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_pgvectors_upgrade_wrapper'; + +CREATE FUNCTION to_svector(dims INT, indices INT[], vals real[]) RETURNS svector +IMMUTABLE STRICT PARALLEL SAFE LANGUAGE c AS 'MODULE_PATHNAME', '_vectors_to_svector_wrapper'; -- List of casts @@ -257,6 +357,15 @@ CREATE CAST (vector AS real[]) CREATE CAST (vector AS vecf16) WITH FUNCTION _vectors_cast_vecf32_to_vecf16(vector, integer, boolean); +CREATE CAST (vecf16 AS vector) + WITH FUNCTION _vectors_cast_vecf16_to_vecf32(vecf16, integer, boolean); + +CREATE CAST (vector AS svector) + WITH FUNCTION _vectors_cast_vecf32_to_svecf32(vector, integer, boolean); + +CREATE CAST (svector AS vector) + WITH FUNCTION _vectors_cast_svecf32_to_vecf32(svector, integer, boolean); + -- List of access methods CREATE ACCESS METHOD vectors TYPE INDEX HANDLER _vectors_amhandler; @@ -302,6 +411,18 @@ CREATE OPERATOR CLASS vecf16_cos_ops FOR TYPE vecf16 USING vectors AS OPERATOR 1 <=> (vecf16, vecf16) FOR ORDER BY float_ops; +CREATE OPERATOR CLASS svector_l2_ops + FOR TYPE svector USING vectors AS + OPERATOR 1 <-> (svector, svector) FOR ORDER BY float_ops; + +CREATE OPERATOR CLASS svector_dot_ops + FOR TYPE svector USING vectors AS + OPERATOR 1 <#> (svector, svector) FOR ORDER BY float_ops; + +CREATE OPERATOR CLASS svector_cos_ops + FOR TYPE svector USING vectors AS + OPERATOR 1 <=> (svector, svector) FOR ORDER BY float_ops; + -- List of views CREATE VIEW pg_vector_index_stat AS diff --git a/src/utils/mod.rs b/src/utils/mod.rs index bcad5f77c..967246145 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,4 +1,2 @@ pub mod cells; -pub mod file_socket; -pub mod os; pub mod parse; diff --git a/src/utils/os.rs b/src/utils/os.rs deleted file mode 100644 index 431efc0f6..000000000 --- a/src/utils/os.rs +++ /dev/null @@ -1,190 +0,0 @@ -use rustix::fd::{AsFd, OwnedFd}; -use rustix::mm::{MapFlags, ProtFlags}; -use std::sync::atomic::AtomicU32; - -#[cfg(target_os = "linux")] -pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) { - const FUTEX_TIMEOUT: libc::timespec = libc::timespec { - tv_sec: 15, - tv_nsec: 0, - }; - unsafe { - libc::syscall( - libc::SYS_futex, - futex.as_ptr(), - libc::FUTEX_WAIT, - value, - &FUTEX_TIMEOUT, - ); - } -} - -#[cfg(target_os = "linux")] -pub unsafe fn futex_wake(futex: &AtomicU32) { - unsafe { - libc::syscall(libc::SYS_futex, futex.as_ptr(), libc::FUTEX_WAKE, i32::MAX); - } -} - -#[cfg(target_os = "linux")] -pub fn memfd_create() -> std::io::Result { - if detect::linux::detect_memfd() { - use rustix::fs::MemfdFlags; - Ok(rustix::fs::memfd_create( - format!(".memfd.VECTORS.{:x}", std::process::id()), - MemfdFlags::empty(), - )?) - } else { - use rustix::fs::Mode; - use rustix::fs::OFlags; - // POSIX fcntl locking do not support shmem, so we use a regular file here. - // reference: https://man7.org/linux/man-pages/man3/fcntl.3p.html - let name = format!( - ".shm.VECTORS.{:x}.{:x}", - std::process::id(), - rand::random::() - ); - let fd = rustix::fs::open( - &name, - OFlags::RDWR | OFlags::CREATE | OFlags::EXCL, - Mode::RUSR | Mode::WUSR, - )?; - rustix::fs::unlink(&name)?; - Ok(fd) - } -} - -#[cfg(target_os = "linux")] -pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> { - use std::ptr::null_mut; - unsafe { - Ok(rustix::mm::mmap( - null_mut(), - len, - ProtFlags::READ | ProtFlags::WRITE, - MapFlags::SHARED | MapFlags::POPULATE, - fd, - 0, - )?) - } -} - -#[cfg(target_os = "macos")] -pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) { - const ULOCK_TIMEOUT: u32 = 15_000_000; - unsafe { - ulock_sys::__ulock_wait( - ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, - futex.as_ptr().cast(), - value as _, - ULOCK_TIMEOUT, - ); - } -} - -#[cfg(target_os = "macos")] -pub unsafe fn futex_wake(futex: &AtomicU32) { - unsafe { - ulock_sys::__ulock_wake( - ulock_sys::darwin19::UL_COMPARE_AND_WAIT_SHARED, - futex.as_ptr().cast(), - 0, - ); - } -} - -#[cfg(target_os = "macos")] -pub fn memfd_create() -> std::io::Result { - use rustix::fs::Mode; - use rustix::fs::OFlags; - // POSIX fcntl locking do not support shmem, so we use a regular file here. - // reference: https://man7.org/linux/man-pages/man3/fcntl.3p.html - let name = format!( - ".shm.VECTORS.{:x}.{:x}", - std::process::id(), - rand::random::() - ); - let fd = rustix::fs::open( - &name, - OFlags::RDWR | OFlags::CREATE | OFlags::EXCL, - Mode::RUSR | Mode::WUSR, - )?; - rustix::fs::unlink(&name)?; - Ok(fd) -} - -#[cfg(target_os = "macos")] -pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> { - use std::ptr::null_mut; - unsafe { - Ok(rustix::mm::mmap( - null_mut(), - len, - ProtFlags::READ | ProtFlags::WRITE, - MapFlags::SHARED, - fd, - 0, - )?) - } -} - -#[cfg(target_os = "freebsd")] -pub unsafe fn futex_wait(futex: &AtomicU32, value: u32) { - let ptr: *const AtomicU32 = futex; - unsafe { - libc::_umtx_op( - ptr as *mut libc::c_void, - libc::UMTX_OP_WAIT_UINT, - value as libc::c_ulong, - core::ptr::null_mut(), - core::ptr::null_mut(), - ); - }; -} - -#[cfg(target_os = "freebsd")] -pub unsafe fn futex_wake(futex: &AtomicU32) { - let ptr: *const AtomicU32 = futex; - unsafe { - libc::_umtx_op( - ptr as *mut libc::c_void, - libc::UMTX_OP_WAKE, - i32::MAX as libc::c_ulong, - core::ptr::null_mut(), - core::ptr::null_mut(), - ); - }; -} - -#[cfg(target_os = "freebsd")] -pub fn memfd_create() -> std::io::Result { - use rustix::fs::Mode; - use rustix::fs::OFlags; - let name = format!( - ".shm.VECTORS.{:x}.{:x}", - std::process::id(), - rand::random::() - ); - let fd = rustix::fs::open( - &name, - OFlags::RDWR | OFlags::CREATE | OFlags::EXCL, - Mode::RUSR | Mode::WUSR, - )?; - rustix::fs::unlink(&name)?; - Ok(fd) -} - -#[cfg(target_os = "freebsd")] -pub unsafe fn mmap_populate(len: usize, fd: impl AsFd) -> std::io::Result<*mut libc::c_void> { - use std::ptr::null_mut; - unsafe { - Ok(rustix::mm::mmap( - null_mut(), - len, - ProtFlags::READ | ProtFlags::WRITE, - MapFlags::SHARED, - fd, - 0, - )?) - } -} diff --git a/tests/sqllogictest/error.slt b/tests/sqllogictest/error.slt index 890028302..f88cff095 100644 --- a/tests/sqllogictest/error.slt +++ b/tests/sqllogictest/error.slt @@ -7,8 +7,8 @@ CREATE TABLE t (val vector(3)); statement ok CREATE INDEX ON t USING vectors (val vector_l2_ops); -statement error The given vector is invalid for input. +statement error The dimension of a vector does not matched that in a vector index column. INSERT INTO t (val) VALUES ('[0, 1, 2, 3]'); -statement error The given vector is invalid for input. +statement error The dimension of a vector does not matched that in a vector index column. SELECT * FROM t ORDER BY val <-> '[0, 1, 2, 3]'; diff --git a/tests/sqllogictest/index.slt b/tests/sqllogictest/index.slt index e2b279ec9..e718cfa1b 100644 --- a/tests/sqllogictest/index.slt +++ b/tests/sqllogictest/index.slt @@ -37,4 +37,8 @@ SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '[0.5,0.5,0.5]' limit 10) query I SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0.5,0.5,0.5]' limit 10) t2; ---- -10 \ No newline at end of file +10 + +statement ok +---- +DROP TABLE t; \ No newline at end of file diff --git a/tests/sqllogictest/sparse.slt b/tests/sqllogictest/sparse.slt new file mode 100644 index 000000000..a29c62463 --- /dev/null +++ b/tests/sqllogictest/sparse.slt @@ -0,0 +1,50 @@ +statement ok +SET search_path TO pg_temp, vectors; + +statement ok +CREATE TABLE t (val svector(6)); + +statement ok +INSERT INTO t (val) SELECT ARRAY[0, random(), 0, 0, random(), random()]::real[]::vector::svector FROM generate_series(1, 1000); + +statement ok +CREATE INDEX ON t USING vectors (val svector_l2_ops) +WITH (options = "[indexing.hnsw]"); + +statement ok +CREATE INDEX ON t USING vectors (val svector_dot_ops) +WITH (options = "[indexing.hnsw]"); + +statement ok +CREATE INDEX ON t USING vectors (val svector_cos_ops) +WITH (options = "[indexing.ivf]"); + + +query I +SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <-> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2; +---- +10 + +query I +SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <=> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2; +---- +10 + +query I +SELECT COUNT(1) FROM (SELECT 1 FROM t ORDER BY val <#> '[0.5,0.5,0.5,0.5,0.5,0.5]'::svector limit 10) t2; +---- +10 + +statement ok +DROP TABLE t; + +query I +SELECT to_svector(5, '{1,2}', '{1,2}'); +---- +[0, 1, 2, 0, 0] + +statement error Lengths of index and value are not matched. +SELECT to_svector(5, '{1,2,3}', '{1,2}'); + +statement error Duplicated index. +SELECT to_svector(5, '{1,1}', '{1,2}'); diff --git a/tests/sqllogictest/svector_binary.slt b/tests/sqllogictest/svector_binary.slt new file mode 100644 index 000000000..27f86b4e6 --- /dev/null +++ b/tests/sqllogictest/svector_binary.slt @@ -0,0 +1,37 @@ +statement ok +SET search_path TO pg_temp, vectors; + +statement ok +CREATE TABLE t (id bigserial, val svector); + +statement ok +INSERT INTO t (val) SELECT NULL FROM generate_series(1, 1000); + +statement ok +INSERT INTO t (val) SELECT ARRAY[random()]::real[]::vector::svector FROM generate_series(1, 1000); + +statement ok +INSERT INTO t (val) SELECT ARRAY[random(), random()]::real[]::vector::svector FROM generate_series(1, 1000); + +statement ok +INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[]::vector::svector FROM generate_series(1, 1000); + +statement ok +COPY t TO '/tmp/data.bin' WITH (FORMAT binary); + +statement ok +CREATE TABLE t2 (id bigserial, val svector); + +statement ok +COPY t2 FROM '/tmp/data.bin' WITH (FORMAT binary); + +query I +SELECT SUM(((t.val = t2.val) OR (t.val IS NULL and t2.val IS NULL))::int) FROM t FULL OUTER JOIN t2 ON t.id = t2.id; +---- +4000 + +statement ok +DROP TABLE t; + +statement ok +DROP TABLE t2; diff --git a/tests/sqllogictest/svector_storage.slt b/tests/sqllogictest/svector_storage.slt new file mode 100644 index 000000000..91c51d1be --- /dev/null +++ b/tests/sqllogictest/svector_storage.slt @@ -0,0 +1,20 @@ +statement ok +SET search_path TO pg_temp, vectors; + +statement ok +CREATE TABLE t (val svector); + +statement ok +INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[]::vector::svector FROM generate_series(1, 1000); + +statement ok +ALTER TABLE t ALTER COLUMN val SET STORAGE PLAIN; + +statement ok +ALTER TABLE t ALTER COLUMN val SET STORAGE EXTERNAL; + +statement ok +ALTER TABLE t ALTER COLUMN val SET STORAGE EXTENDED; + +statement ok +ALTER TABLE t ALTER COLUMN val SET STORAGE MAIN; diff --git a/tests/sqllogictest/svector_subscript.slt b/tests/sqllogictest/svector_subscript.slt new file mode 100644 index 000000000..ad683b75a --- /dev/null +++ b/tests/sqllogictest/svector_subscript.slt @@ -0,0 +1,88 @@ +statement ok +SET search_path TO pg_temp, vectors; + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3:6]; +---- +[3, 4, 5] + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:4]; +---- +[0, 1, 2, 3] + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[5:]; +---- +[5, 6, 7] + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[1:8]; +---- +[1, 2, 3, 4, 5, 6, 7] + +statement error type svector does only support one subscript +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3:3][1:1]; + +statement error type svector does only support slice fetch +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[3]; + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[5:4]; +---- +NULL + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[9:]; +---- +NULL + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:0]; +---- +NULL + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:-1]; +---- +NULL + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:NULL]; +---- +NULL + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:8]; +---- +NULL + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[1:NULL]; +---- +NULL + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[NULL:]; +---- +NULL + +query I +SELECT ('[0, 1, 2, 3, 4, 5, 6, 7]'::svector)[:NULL]; +---- +NULL + +query I +SELECT ('[0, 0, 2, 0, 4, 0, 0, 7]'::svector)[3:7]; +---- +[0, 4, 0, 0] + +query I +SELECT ('[0, 0, 2, 0, 4, 0, 0, 7]'::svector)[5:7]; +---- +[0, 0] + +query I +SELECT ('[0, 0, 0, 0, 0, 0, 0, 0]'::svector)[5:7]; +---- +[0, 0] \ No newline at end of file diff --git a/tests/sqllogictest/test.sh b/tests/sqllogictest/test.sh index 160befdf6..a92660c53 100755 --- a/tests/sqllogictest/test.sh +++ b/tests/sqllogictest/test.sh @@ -2,7 +2,3 @@ set -e sqllogictest -d $USER $(dirname $0)/*.slt - -if [ "$(psql -tAqX -c "SHOW server_version_num")" -ge 140000 ]; then - sqllogictest -d $USER $(dirname $0)/pg14/*.slt -fi diff --git a/tests/sqllogictest/vecf16_binary.slt b/tests/sqllogictest/vecf16_binary.slt index 2824a5877..de8f82767 100644 --- a/tests/sqllogictest/vecf16_binary.slt +++ b/tests/sqllogictest/vecf16_binary.slt @@ -11,10 +11,10 @@ statement ok INSERT INTO t (val) SELECT ARRAY[random()]::real[]::vector::vecf16 FROM generate_series(1, 1000); statement ok -INSERT INTO t (val) SELECT ARRAY[random(), random()]::real[]::real[]::vector::vecf16 FROM generate_series(1, 1000); +INSERT INTO t (val) SELECT ARRAY[random(), random()]::real[]::vector::vecf16 FROM generate_series(1, 1000); statement ok -INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[]::real[]::vector::vecf16 FROM generate_series(1, 1000); +INSERT INTO t (val) SELECT ARRAY[random(), random(), random()]::real[]::vector::vecf16 FROM generate_series(1, 1000); statement ok COPY t TO '/tmp/data.bin' WITH (FORMAT binary); diff --git a/tests/sqllogictest/pg14/vecf16_subscript.slt b/tests/sqllogictest/vecf16_subscript.slt similarity index 100% rename from tests/sqllogictest/pg14/vecf16_subscript.slt rename to tests/sqllogictest/vecf16_subscript.slt diff --git a/tests/sqllogictest/pg14/vector_subscript.slt b/tests/sqllogictest/vector_subscript.slt similarity index 100% rename from tests/sqllogictest/pg14/vector_subscript.slt rename to tests/sqllogictest/vector_subscript.slt