Skip to content

Commit

Permalink
fix: Pgvector patch (#4103)
Browse files Browse the repository at this point in the history
  • Loading branch information
HaoXuAI authored Apr 16, 2024
1 parent 504e40e commit 5c4a9c5
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 30 deletions.
19 changes: 19 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,25 @@ test-python-universal-postgres-online:
not test_snowflake" \
sdk/python/tests

test-python-universal-pgvector-online:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.pgvector_repo_configuration \
PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.postgres \
python -m pytest -n 8 --integration \
-k "not test_universal_cli and \
not test_go_feature_server and \
not test_feature_logging and \
not test_reorder_columns and \
not test_logged_features_validation and \
not test_lambda_materialization_consistency and \
not test_offline_write and \
not test_push_features_to_offline_store and \
not gcs_registry and \
not s3_registry and \
not test_universal_types and \
not test_snowflake" \
sdk/python/tests

test-python-universal-mysql-online:
PYTHONPATH='.' \
FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.mysql_repo_configuration \
Expand Down
28 changes: 28 additions & 0 deletions docs/reference/online-stores/postgres.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ online_store:
sslkey_path: /path/to/client-key.pem
sslcert_path: /path/to/client-cert.pem
sslrootcert_path: /path/to/server-ca.pem
pgvector_enabled: false
vector_len: 512
```
{% endcode %}
Expand Down Expand Up @@ -60,3 +62,29 @@ Below is a matrix indicating which functionality is supported by the Postgres on
| collocated by entity key | no |
To compare this set of functionality against other online stores, please see the full [functionality matrix](overview.md#functionality-matrix).
## PGVector
The Postgres online store supports the use of [PGVector](https://pgvector.dev/) for storing feature values.
To enable PGVector, set `pgvector_enabled: true` in the online store configuration.
The `vector_len` parameter can be used to specify the length of the vector. The default value is 512.

Then you can use `retrieve_online_documents` to retrieve the top k closest vectors to a query vector.

{% code title="python" %}
```python
from feast import FeatureStore
from feast.infra.online_stores.postgres import retrieve_online_documents
feature_store = FeatureStore(repo_path=".")
query_vector = [0.1, 0.2, 0.3, 0.4, 0.5]
top_k = 5
feature_values = retrieve_online_documents(
feature_store=feature_store,
feature_view_name="document_fv:embedding_float",
query_vector=query_vector,
top_k=top_k,
)
```
{% endcode %}
9 changes: 7 additions & 2 deletions sdk/python/feast/infra/key_encoding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,13 @@ def serialize_entity_key(
return b"".join(output)


def get_val_str(val):
accept_value_types = ["float_list_val", "double_list_val", "int_list_val"]
def get_list_val_str(val):
accept_value_types = [
"float_list_val",
"double_list_val",
"int32_list_val",
"int64_list_val",
]
for accept_type in accept_value_types:
if val.HasField(accept_type):
return str(getattr(val, accept_type).val)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from tests.integration.feature_repos.integration_test_repo_config import (
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.universal.online_store.postgres import (
PGVectorOnlineStoreCreator,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(
online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator
),
]
51 changes: 31 additions & 20 deletions sdk/python/feast/infra/online_stores/contrib/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import defaultdict
from datetime import datetime
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple

import psycopg2
import pytz
Expand All @@ -12,7 +12,7 @@

from feast import Entity
from feast.feature_view import FeatureView
from feast.infra.key_encoding_utils import get_val_str, serialize_entity_key
from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool
from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig
Expand Down Expand Up @@ -74,19 +74,18 @@ def online_write_batch(
created_ts = _to_naive_utc(created_ts)

for feature_name, val in values.items():
val_str: Union[str, bytes]
vector_val = None
if (
"pgvector_enabled" in config.online_config
and config.online_config["pgvector_enabled"]
"pgvector_enabled" in config.online_store
and config.online_store.pgvector_enabled
):
val_str = get_val_str(val)
else:
val_str = val.SerializeToString()
vector_val = get_list_val_str(val)
insert_values.append(
(
entity_key_bin,
feature_name,
val_str,
val.SerializeToString(),
vector_val,
timestamp,
created_ts,
)
Expand All @@ -100,11 +99,12 @@ def online_write_batch(
sql.SQL(
"""
INSERT INTO {}
(entity_key, feature_name, value, event_ts, created_ts)
(entity_key, feature_name, value, vector_value, event_ts, created_ts)
VALUES %s
ON CONFLICT (entity_key, feature_name) DO
UPDATE SET
value = EXCLUDED.value,
vector_value = EXCLUDED.vector_value,
event_ts = EXCLUDED.event_ts,
created_ts = EXCLUDED.created_ts;
""",
Expand Down Expand Up @@ -226,20 +226,23 @@ def update(

for table in tables_to_keep:
table_name = _table_id(project, table)
value_type = "BYTEA"
if (
"pgvector_enabled" in config.online_config
and config.online_config["pgvector_enabled"]
"pgvector_enabled" in config.online_store
and config.online_store.pgvector_enabled
):
value_type = f'vector({config.online_config["vector_len"]})'
vector_value_type = f"vector({config.online_store.vector_len})"
else:
# keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility
vector_value_type = "BYTEA"
cur.execute(
sql.SQL(
"""
CREATE TABLE IF NOT EXISTS {}
(
entity_key BYTEA,
feature_name TEXT,
value {},
value BYTEA,
vector_value {} NULL,
event_ts TIMESTAMPTZ,
created_ts TIMESTAMPTZ,
PRIMARY KEY(entity_key, feature_name)
Expand All @@ -248,7 +251,7 @@ def update(
"""
).format(
sql.Identifier(table_name),
sql.SQL(value_type),
sql.SQL(vector_value_type),
sql.Identifier(f"{table_name}_ek"),
sql.Identifier(table_name),
)
Expand Down Expand Up @@ -294,6 +297,14 @@ def retrieve_online_documents(
"""
project = config.project

if (
"pgvector_enabled" not in config.online_store
or not config.online_store.pgvector_enabled
):
raise ValueError(
"pgvector is not enabled in the online store configuration"
)

# Convert the embedding to a string to be used in postgres vector search
query_embedding_str = f"[{','.join(str(el) for el in embedding)}]"

Expand All @@ -311,8 +322,8 @@ def retrieve_online_documents(
SELECT
entity_key,
feature_name,
value,
value <-> %s as distance,
vector_value,
vector_value <-> %s as distance,
event_ts FROM {table_name}
WHERE feature_name = {feature_name}
ORDER BY distance
Expand All @@ -327,13 +338,13 @@ def retrieve_online_documents(
)
rows = cur.fetchall()

for entity_key, feature_name, value, distance, event_ts in rows:
for entity_key, feature_name, vector_value, distance, event_ts in rows:
# TODO Deserialize entity_key to return the entity in response
# entity_key_proto = EntityKeyProto()
# entity_key_proto_bin = bytes(entity_key)

# TODO Convert to List[float] for value type proto
feature_value_proto = ValueProto(string_val=value)
feature_value_proto = ValueProto(string_val=vector_value)

distance_value_proto = ValueProto(float_val=distance)
result.append((event_ts, feature_value_proto, distance_value_proto))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,11 @@
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.universal.online_store.postgres import (
PGVectorOnlineStoreCreator,
PostgresOnlineStoreCreator,
)

FULL_REPO_CONFIGS = [
IntegrationTestRepoConfig(
online_store="postgres", online_store_creator=PostgresOnlineStoreCreator
),
IntegrationTestRepoConfig(
online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator
),
]

AVAILABLE_ONLINE_STORES = {"pgvector": PGVectorOnlineStoreCreator}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
CREATE EXTENSION IF NOT EXISTS vector;
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Dict

from testcontainers.core.container import DockerContainer
Expand Down Expand Up @@ -37,12 +38,17 @@ def teardown(self):
class PGVectorOnlineStoreCreator(OnlineStoreCreator):
def __init__(self, project_name: str, **kwargs):
super().__init__(project_name)
script_directory = os.path.dirname(os.path.abspath(__file__))
self.container = (
DockerContainer("pgvector/pgvector:pg16")
.with_env("POSTGRES_USER", "root")
.with_env("POSTGRES_PASSWORD", "test")
.with_env("POSTGRES_DB", "test")
.with_exposed_ports(5432)
.with_volume_mapping(
os.path.join(script_directory, "init.sql"),
"/docker-entrypoint-initdb.d/init.sql",
)
)

def create_online_store(self) -> Dict[str, str]:
Expand All @@ -51,8 +57,10 @@ def create_online_store(self) -> Dict[str, str]:
wait_for_logs(
container=self.container, predicate=log_string_to_wait_for, timeout=10
)
command = "psql -h localhost -p 5432 -U root -d test -c 'CREATE EXTENSION IF NOT EXISTS vector;'"
self.container.exec(command)
init_log_string_to_wait_for = "PostgreSQL init process complete"
wait_for_logs(
container=self.container, predicate=init_log_string_to_wait_for, timeout=10
)
return {
"host": "localhost",
"type": "postgres",
Expand Down

0 comments on commit 5c4a9c5

Please sign in to comment.