Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scalar Quantization #1110

Merged
merged 6 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions integration/test_collection_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from integration.conftest import _sanitize_collection_name
from weaviate.collections.classes.config import (
_BQConfig,
_SQConfig,
_CollectionConfig,
_CollectionConfigSimple,
_PQConfig,
Expand Down Expand Up @@ -481,6 +482,23 @@ def test_hnsw_with_bq(collection_factory: CollectionFactory) -> None:
assert isinstance(config.vector_index_config.quantizer, _BQConfig)


def test_hnsw_with_sq(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
vector_index_config=Configure.VectorIndex.hnsw(
vector_cache_max_objects=5,
quantizer=Configure.VectorIndex.Quantizer.sq(rescore_limit=10, training_limit=1000000),
),
)
if collection._connection._weaviate_version.is_lower_than(1, 26, 0):
pytest.skip("SQ+HNSW is not supported in Weaviate versions lower than 1.26.0")

config = collection.config.get()
assert config.vector_index_type == VectorIndexType.HNSW
assert config.vector_index_config is not None
assert isinstance(config.vector_index_config, _VectorIndexConfigHNSW)
assert isinstance(config.vector_index_config.quantizer, _SQConfig)


def test_update_flat(collection_factory: CollectionFactory) -> None:
collection = collection_factory(
vector_index_config=Configure.VectorIndex.flat(
Expand Down
13 changes: 13 additions & 0 deletions test/collection/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,19 @@ def test_vector_config_hnsw_bq() -> None:
assert vi_dict["bq"]["rescoreLimit"] == 123


def test_vector_config_hnsw_sq() -> None:
vector_index = Configure.VectorIndex.hnsw(
ef_construction=128,
quantizer=Configure.VectorIndex.Quantizer.sq(rescore_limit=123, training_limit=5012),
)

vi_dict = vector_index._to_dict()

assert vi_dict["efConstruction"] == 128
assert vi_dict["sq"]["rescoreLimit"] == 123
assert vi_dict["sq"]["trainingLimit"] == 5012


def test_vector_config_flat_pq() -> None:
vector_index = Configure.VectorIndex.flat(
distance_metric=VectorDistances.DOT,
Expand Down
70 changes: 68 additions & 2 deletions weaviate/collections/classes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,16 @@ def quantizer_name() -> str:
return "bq"


class _SQConfigCreate(_QuantizerConfigCreate):
cache: Optional[bool]
rescoreLimit: Optional[int]
trainingLimit: Optional[int]

@staticmethod
def quantizer_name() -> str:
return "sq"


class _PQConfigUpdate(_QuantizerConfigUpdate):
bitCompression: Optional[bool] = Field(default=None)
centroids: Optional[int]
Expand All @@ -299,6 +309,16 @@ def quantizer_name() -> str:
return "bq"


class _SQConfigUpdate(_QuantizerConfigUpdate):
enabled: Optional[bool]
rescoreLimit: Optional[int]
trainingLimit: Optional[int]

@staticmethod
def quantizer_name() -> str:
return "sq"


class _ShardingConfigCreate(_ConfigCreateModel):
virtualPerPhysical: Optional[int]
desiredCount: Optional[int]
Expand Down Expand Up @@ -1111,19 +1131,29 @@ class _BQConfig(_ConfigBase):
rescore_limit: int


@dataclass
class _SQConfig(_ConfigBase):
cache: Optional[bool]
rescore_limit: int
training_limit: int


BQConfig = _BQConfig
SQConfig = _SQConfig


@dataclass
class _VectorIndexConfig(_ConfigBase):
quantizer: Optional[Union[PQConfig, BQConfig]]
quantizer: Optional[Union[PQConfig, BQConfig, SQConfig]]

def to_dict(self) -> Dict[str, Any]:
out = super().to_dict()
if isinstance(self.quantizer, _PQConfig):
out["pq"] = {**out.pop("quantizer"), "enabled": True}
elif isinstance(self.quantizer, _BQConfig):
out["bq"] = {**out.pop("quantizer"), "enabled": True}
elif isinstance(self.quantizer, _SQConfig):
out["sq"] = {**out.pop("quantizer"), "enabled": True}
return out


Expand Down Expand Up @@ -1614,6 +1644,25 @@ def bq(
rescoreLimit=rescore_limit,
)

@staticmethod
def sq(
cache: Optional[bool] = None,
rescore_limit: Optional[int] = None,
training_limit: Optional[int] = None,
) -> _SQConfigCreate:
"""Create a `_SQConfigCreate` object to be used when defining the scalar quantization (SQ) configuration of Weaviate.

Use this method when defining the `quantizer` argument in the `vector_index` configuration. Note that the arguments have no effect for HNSW.

Arguments:
See [the docs](https://weaviate.io/developers/weaviate/concepts/vector-index#binary-quantization) for a more detailed view!
""" # noqa: D417 (missing argument descriptions in the docstring)
return _SQConfigCreate(
cache=cache,
rescoreLimit=rescore_limit,
trainingLimit=training_limit,
)


class _VectorIndex:
Quantizer = _VectorIndexQuantizer
Expand Down Expand Up @@ -1867,6 +1916,23 @@ def bq(rescore_limit: Optional[int] = None) -> _BQConfigUpdate:
""" # noqa: D417 (missing argument descriptions in the docstring)
return _BQConfigUpdate(rescoreLimit=rescore_limit)

@staticmethod
def sq(
rescore_limit: Optional[int] = None,
training_limit: Optional[int] = None,
enabled: bool = True,
) -> _SQConfigUpdate:
"""Create a `_SQConfigUpdate` object to be used when updating the scalar quantization (SQ) configuration of Weaviate.

Use this method when defining the `quantizer` argument in the `vector_index` configuration in `collection.update()`.

Arguments:
See [the docs](https://weaviate.io/developers/weaviate/concepts/vector-index#hnsw-with-compression) for a more detailed view!
""" # noqa: D417 (missing argument descriptions in the docstring)
return _SQConfigUpdate(
enabled=enabled, rescoreLimit=rescore_limit, trainingLimit=training_limit
)


class _VectorIndexUpdate:
Quantizer = _VectorIndexQuantizerUpdate
Expand All @@ -1879,7 +1945,7 @@ def hnsw(
ef: Optional[int] = None,
flat_search_cutoff: Optional[int] = None,
vector_cache_max_objects: Optional[int] = None,
quantizer: Optional[Union[_PQConfigUpdate, _BQConfigUpdate]] = None,
quantizer: Optional[Union[_PQConfigUpdate, _BQConfigUpdate, _SQConfigUpdate]] = None,
) -> _VectorIndexConfigHNSWUpdate:
"""Create an `_VectorIndexConfigHNSWUpdate` object to update the configuration of the HNSW vector index.

Expand Down
14 changes: 12 additions & 2 deletions weaviate/collections/classes/config_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from weaviate.collections.classes.config import (
_BQConfig,
_SQConfig,
_CollectionConfig,
_CollectionConfigSimple,
_NamedVectorConfig,
Expand Down Expand Up @@ -99,14 +100,23 @@ def __get_vector_index_type(schema: Dict[str, Any]) -> Optional[VectorIndexType]
return None


def __get_quantizer_config(config: Dict[str, Any]) -> Optional[Union[_PQConfig, _BQConfig]]:
quantizer: Optional[Union[_PQConfig, _BQConfig]] = None
def __get_quantizer_config(
config: Dict[str, Any]
) -> Optional[Union[_PQConfig, _BQConfig, _SQConfig]]:
quantizer: Optional[Union[_PQConfig, _BQConfig, _SQConfig]] = None
if "bq" in config and config["bq"]["enabled"]:
# values are not present for bq+hnsw
quantizer = _BQConfig(
cache=config["bq"].get("cache"),
rescore_limit=config["bq"].get("rescoreLimit"),
)
elif "sq" in config and config["sq"]["enabled"]:
# values are not present for bq+hnsw
quantizer = _SQConfig(
cache=config["sq"].get("cache"),
rescore_limit=config["sq"].get("rescoreLimit"),
training_limit=config["sq"].get("trainingLimit"),
)
elif "pq" in config and config["pq"].get("enabled"):
quantizer = _PQConfig(
internal_bit_compression=config["pq"].get("bitCompression"),
Expand Down