-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add grpc unit test run, expand testing of VectorFactory (#326)
## Problem There is a small number of unit tests that should only get run when grpc dependencies are installed. These were previously omitted from CI by mistake. ## Solution Add a step to run these grpc steps. Make it conditional on the `use_grpc` test matrix param. ## Type of Change - [x] Infrastructure change (CI configs, etc)
- Loading branch information
Showing
13 changed files
with
449 additions
and
237 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from ..utils.constants import REQUIRED_VECTOR_FIELDS, OPTIONAL_VECTOR_FIELDS | ||
|
||
class VectorDictionaryMissingKeysError(ValueError): | ||
def __init__(self, item): | ||
message = f"Vector dictionary is missing required fields: {list(REQUIRED_VECTOR_FIELDS - set(item.keys()))}" | ||
super().__init__(message) | ||
|
||
class VectorDictionaryExcessKeysError(ValueError): | ||
def __init__(self, item): | ||
invalid_keys = list(set(item.keys()) - (REQUIRED_VECTOR_FIELDS | OPTIONAL_VECTOR_FIELDS)) | ||
message = f"Found excess keys in the vector dictionary: {invalid_keys}. The allowed keys are: {list(REQUIRED_VECTOR_FIELDS | OPTIONAL_VECTOR_FIELDS)}" | ||
super().__init__(message) | ||
|
||
class VectorTupleLengthError(ValueError): | ||
def __init__(self, item): | ||
message = f"Found a tuple of length {len(item)} which is not supported. Vectors can be represented as tuples either the form (id, values, metadata) or (id, values). To pass sparse values please use either dicts or Vector objects as inputs." | ||
super().__init__(message) | ||
|
||
class SparseValuesTypeError(ValueError, TypeError): | ||
def __init__(self): | ||
message = "Found unexpected data in column `sparse_values`. Expected format is `'sparse_values': {'indices': List[int], 'values': List[float]}`." | ||
super().__init__(message) | ||
|
||
class SparseValuesMissingKeysError(ValueError): | ||
def __init__(self, sparse_values_dict): | ||
message = f"Missing required keys in data in column `sparse_values`. Expected format is `'sparse_values': {{'indices': List[int], 'values': List[float]}}`. Found keys {list(sparse_values_dict.keys())}" | ||
super().__init__(message) | ||
|
||
class SparseValuesDictionaryExpectedError(ValueError, TypeError): | ||
def __init__(self, sparse_values_dict): | ||
message = f"Column `sparse_values` is expected to be a dictionary, found {type(sparse_values_dict)}" | ||
super().__init__(message) | ||
|
||
class MetadataDictionaryExpectedError(ValueError, TypeError): | ||
def __init__(self, item): | ||
message = f"Column `metadata` is expected to be a dictionary, found {type(item['metadata'])}" | ||
super().__init__(message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import numbers | ||
|
||
from collections.abc import Mapping | ||
from typing import Union, Dict | ||
|
||
from ..utils import convert_to_list | ||
|
||
from .errors import ( | ||
SparseValuesTypeError, | ||
SparseValuesMissingKeysError, | ||
SparseValuesDictionaryExpectedError | ||
) | ||
|
||
from pinecone.core.client.models import ( | ||
SparseValues | ||
) | ||
|
||
class SparseValuesFactory: | ||
@staticmethod | ||
def build(input: Union[Dict, SparseValues]) -> SparseValues: | ||
if input is None: | ||
return input | ||
if isinstance(input, SparseValues): | ||
return input | ||
if not isinstance(input, Mapping): | ||
raise SparseValuesDictionaryExpectedError(input) | ||
if not {"indices", "values"}.issubset(input): | ||
raise SparseValuesMissingKeysError(input) | ||
|
||
indices = SparseValuesFactory._convert_to_list(input.get("indices"), int) | ||
values = SparseValuesFactory._convert_to_list(input.get("values"), float) | ||
|
||
if len(indices) != len(values): | ||
raise ValueError("Sparse values indices and values must have the same length") | ||
|
||
try: | ||
return SparseValues(indices=indices, values=values) | ||
except TypeError as e: | ||
raise SparseValuesTypeError() from e | ||
|
||
@staticmethod | ||
def _convert_to_list(input, expected_type): | ||
try: | ||
converted = convert_to_list(input) | ||
except TypeError as e: | ||
raise SparseValuesTypeError() from e | ||
|
||
SparseValuesFactory._validate_list_items_type(converted, expected_type) | ||
return converted | ||
|
||
@staticmethod | ||
def _validate_list_items_type(input, expected_type): | ||
if len(input) > 0 and not isinstance(input[0], expected_type): | ||
raise SparseValuesTypeError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import numbers | ||
|
||
from collections.abc import Mapping | ||
from typing import Union, Dict | ||
|
||
from ..utils import convert_to_list | ||
|
||
from ..data import ( | ||
SparseValuesTypeError, | ||
SparseValuesMissingKeysError, | ||
SparseValuesDictionaryExpectedError | ||
) | ||
|
||
from pinecone.core.grpc.protos.vector_service_pb2 import ( | ||
SparseValues as GRPCSparseValues, | ||
) | ||
from pinecone import ( | ||
SparseValues as NonGRPCSparseValues | ||
) | ||
|
||
class SparseValuesFactory: | ||
@staticmethod | ||
def build(input: Union[Dict, GRPCSparseValues, NonGRPCSparseValues]) -> GRPCSparseValues: | ||
if input is None: | ||
return input | ||
if isinstance(input, GRPCSparseValues): | ||
return input | ||
if isinstance(input, NonGRPCSparseValues): | ||
return GRPCSparseValues(indices=input.indices, values=input.values) | ||
if not isinstance(input, Mapping): | ||
raise SparseValuesDictionaryExpectedError(input) | ||
if not {"indices", "values"}.issubset(input): | ||
raise SparseValuesMissingKeysError(input) | ||
|
||
indices = SparseValuesFactory._convert_to_list(input.get("indices"), int) | ||
values = SparseValuesFactory._convert_to_list(input.get("values"), float) | ||
|
||
if len(indices) != len(values): | ||
raise ValueError("Sparse values indices and values must have the same length") | ||
|
||
try: | ||
return GRPCSparseValues(indices=indices, values=values) | ||
except TypeError as e: | ||
raise SparseValuesTypeError() from e | ||
|
||
@staticmethod | ||
def _convert_to_list(input, expected_type): | ||
try: | ||
converted = convert_to_list(input) | ||
except TypeError as e: | ||
raise SparseValuesTypeError() from e | ||
|
||
SparseValuesFactory._validate_list_items_type(converted, expected_type) | ||
return converted | ||
|
||
@staticmethod | ||
def _validate_list_items_type(input, expected_type): | ||
if len(input) > 0 and not isinstance(input[0], expected_type): | ||
raise SparseValuesTypeError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,18 @@ | ||
from ..exceptions import ListConversionException | ||
|
||
def convert_to_list(obj): | ||
class_name = obj.__class__.__name__ | ||
|
||
if class_name == 'list': | ||
return obj | ||
elif hasattr(obj, 'tolist') and callable(getattr(obj, 'tolist')): | ||
return obj.tolist() | ||
elif obj is None or isinstance(obj, str) or isinstance(obj, dict): | ||
# The string and dictionary classes in python can be passed to list() | ||
# but they're not going to yield sensible results for our use case. | ||
raise ListConversionException(f"Expected a list or list-like data structure, but got: {obj}") | ||
else: | ||
return list(obj) | ||
try: | ||
return list(obj) | ||
except Exception as e: | ||
raise ListConversionException(f"Expected a list or list-like data structure, but got: {obj}") from e |
Oops, something went wrong.