Skip to content

Commit

Permalink
feat: Enable other distance metrics for Vector DB and Update docs (#4170
Browse files Browse the repository at this point in the history
)

* updated PGVector docs

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* adding distance metric to arguments and defaulting to L2

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* linter

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* testing other distance metric

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* updated default

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* linter

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* fixed some copy

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* updated

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
  • Loading branch information
franciscojavierarceo authored May 7, 2024
1 parent 67bea4c commit ba9f4ef
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 8 deletions.
10 changes: 8 additions & 2 deletions docs/reference/online-stores/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,16 @@ To compare this set of functionality against other online stores, please see the
## PGVector
The Postgres online store supports the use of [PGVector](https://github.com/pgvector/pgvector) for storing feature values.
To enable PGVector, set `pgvector_enabled: true` in the online store configuration.
To enable PGVector, set `pgvector_enabled: true` in the online store configuration.

The `vector_len` parameter can be used to specify the length of the vector. The default value is 512.

Then you can use `retrieve_online_documents` to retrieve the top k closest vectors to a query vector.
Please make sure to follow the instructions in the repository, which, as the time of this writing, requires you to
run `CREATE EXTENSION vector;` in the database.


Then you can use `retrieve_online_documents` to retrieve the top k closest vectors to a query vector.
For the Retrieval Augmented Generation (RAG) use-case, you have to embed the query prior to passing the query vector.

{% code title="python" %}
```python
Expand Down
7 changes: 7 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,7 @@ def retrieve_online_documents(
feature: str,
query: Union[str, List[float]],
top_k: int,
distance_metric: str,
) -> OnlineResponse:
"""
Retrieves the top k closest document features. Note, embeddings are a subset of features.
Expand All @@ -1750,18 +1751,21 @@ def retrieve_online_documents(
references must have format "feature_view:feature", e.g, "document_fv:document_embeddings".
query: The query to retrieve the closest document features for.
top_k: The number of closest document features to retrieve.
distance_metric: The distance metric to use for retrieval.
"""
return self._retrieve_online_documents(
feature=feature,
query=query,
top_k=top_k,
distance_metric=distance_metric,
)

def _retrieve_online_documents(
self,
feature: str,
query: Union[str, List[float]],
top_k: int,
distance_metric: str = "L2",
):
if isinstance(query, str):
raise ValueError(
Expand All @@ -1783,6 +1787,7 @@ def _retrieve_online_documents(
requested_feature,
query,
top_k,
distance_metric,
)

# TODO Refactor to better way of populating result
Expand Down Expand Up @@ -2025,6 +2030,7 @@ def _retrieve_from_online_store(
requested_feature: str,
query: List[float],
top_k: int,
distance_metric: str,
) -> List[Tuple[Timestamp, "FieldStatus.ValueType", Value, Value, Value]]:
"""
Search and return document features from the online document store.
Expand All @@ -2035,6 +2041,7 @@ def _retrieve_from_online_store(
requested_feature=requested_feature,
query=query,
top_k=top_k,
distance_metric=distance_metric,
)

read_row_protos = []
Expand Down
18 changes: 17 additions & 1 deletion sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from feast.repo_config import RepoConfig
from feast.usage import log_exceptions_and_usage

SUPPORTED_DISTANCE_METRICS_DICT = {
"cosine": "<=>",
"L1": "<+>",
"L2": "<->",
"inner_product": "<#>",
}


class PostgreSQLOnlineStoreConfig(PostgreSQLConfig):
type: Literal["postgres"] = "postgres"
Expand Down Expand Up @@ -276,6 +283,7 @@ def retrieve_online_documents(
requested_feature: str,
embedding: List[float],
top_k: int,
distance_metric: str = "L2",
) -> List[
Tuple[
Optional[datetime],
Expand All @@ -292,6 +300,7 @@ def retrieve_online_documents(
requested_feature: The requested feature as the column to search
embedding: The query embedding to search for
top_k: The number of items to return
distance_metric: The distance metric to use for the search.G
Returns:
List of tuples containing the event timestamp and the document feature
Expand All @@ -303,6 +312,12 @@ def retrieve_online_documents(
"pgvector is not enabled in the online store configuration"
)

if distance_metric not in SUPPORTED_DISTANCE_METRICS_DICT:
raise ValueError(
f"Distance metric {distance_metric} is not supported. Supported distance metrics are {SUPPORTED_DISTANCE_METRICS_DICT.keys()}"
)

distance_metric_sql = SUPPORTED_DISTANCE_METRICS_DICT[distance_metric]
# Convert the embedding to a string to be used in postgres vector search
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

Expand All @@ -327,13 +342,14 @@ def retrieve_online_documents(
feature_name,
value,
vector_value,
vector_value <-> %s as distance,
vector_value {distance_metric_sql} %s as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
LIMIT {top_k};
"""
).format(
distance_metric_sql=distance_metric_sql,
table_name=sql.Identifier(table_name),
feature_name=sql.Literal(requested_feature),
top_k=sql.Literal(top_k),
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/feast/infra/online_stores/online_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def retrieve_online_documents(
table: The feature view whose feature values should be read.
requested_feature: The name of the feature whose embeddings should be used for retrieval.
embedding: The embeddings to use for retrieval.
top_k: The number of nearest neighbors to retrieve.
top_k: The number of documents to retrieve.
Returns:
object: A list of top k closest documents to the specified embedding. Each item in the list is a tuple
Expand Down
8 changes: 7 additions & 1 deletion sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,12 +196,18 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
distance_metric: str,
) -> List:
set_usage_attribute("provider", self.__class__.__name__)
result = []
if self.online_store:
result = self.online_store.retrieve_online_documents(
config, table, requested_feature, query, top_k
config,
table,
requested_feature,
query,
top_k,
distance_metric,
)
return result

Expand Down
5 changes: 3 additions & 2 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
distance_metric: str = "L2",
) -> List[
Tuple[
Optional[datetime],
Expand All @@ -312,14 +313,14 @@ def retrieve_online_documents(
]
]:
"""
Searches for the top-k nearest neighbors of the given document in the online document store.
Searches for the top-k most similar documents in the online document store.
Args:
config: The config for the current feature store.
table: The feature view whose embeddings should be searched.
requested_feature: the requested document feature name.
query: The query embedding to search for.
top_k: The number of nearest neighbors to return.
top_k: The number of documents to return.
Returns:
A list of dictionaries, where each dictionary contains the document feature.
Expand Down
1 change: 1 addition & 0 deletions sdk/python/tests/foo_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def retrieve_online_documents(
requested_feature: str,
query: List[float],
top_k: int,
distance_metric: str,
) -> List[
Tuple[
Optional[datetime],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,25 @@ def test_retrieve_online_documents(environment, fake_document_data):
fs.write_to_online_store("item_embeddings", df)

documents = fs.retrieve_online_documents(
feature="item_embeddings:embedding_float", query=[1.0, 2.0], top_k=2
feature="item_embeddings:embedding_float",
query=[1.0, 2.0],
top_k=2,
distance_metric="L2",
).to_dict()
assert len(documents["embedding_float"]) == 2

documents = fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
query=[1.0, 2.0],
top_k=2,
distance_metric="L1",
).to_dict()
assert len(documents["embedding_float"]) == 2

with pytest.raises(ValueError):
fs.retrieve_online_documents(
feature="item_embeddings:embedding_float",
query=[1.0, 2.0],
top_k=2,
distance_metric="wrong",
).to_dict()

0 comments on commit ba9f4ef

Please sign in to comment.