diff --git a/Makefile b/Makefile index 6fcf95dc7d..bf2d876b7f 100644 --- a/Makefile +++ b/Makefile @@ -216,6 +216,25 @@ test-python-universal-postgres-online: not test_snowflake" \ sdk/python/tests + test-python-universal-pgvector-online: + PYTHONPATH='.' \ + FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.pgvector_repo_configuration \ + PYTEST_PLUGINS=sdk.python.tests.integration.feature_repos.universal.online_store.postgres \ + python -m pytest -n 8 --integration \ + -k "not test_universal_cli and \ + not test_go_feature_server and \ + not test_feature_logging and \ + not test_reorder_columns and \ + not test_logged_features_validation and \ + not test_lambda_materialization_consistency and \ + not test_offline_write and \ + not test_push_features_to_offline_store and \ + not gcs_registry and \ + not s3_registry and \ + not test_universal_types and \ + not test_snowflake" \ + sdk/python/tests + test-python-universal-mysql-online: PYTHONPATH='.' \ FULL_REPO_CONFIGS_MODULE=sdk.python.feast.infra.online_stores.contrib.mysql_repo_configuration \ diff --git a/docs/reference/online-stores/postgres.md b/docs/reference/online-stores/postgres.md index 3885867dd2..277494868c 100644 --- a/docs/reference/online-stores/postgres.md +++ b/docs/reference/online-stores/postgres.md @@ -30,6 +30,8 @@ online_store: sslkey_path: /path/to/client-key.pem sslcert_path: /path/to/client-cert.pem sslrootcert_path: /path/to/server-ca.pem + pgvector_enabled: false + vector_len: 512 ``` {% endcode %} @@ -60,3 +62,29 @@ Below is a matrix indicating which functionality is supported by the Postgres on | collocated by entity key | no | To compare this set of functionality against other online stores, please see the full [functionality matrix](overview.md#functionality-matrix). + +## PGVector +The Postgres online store supports the use of [PGVector](https://pgvector.dev/) for storing feature values. +To enable PGVector, set `pgvector_enabled: true` in the online store configuration. +The `vector_len` parameter can be used to specify the length of the vector. The default value is 512. + +Then you can use `retrieve_online_documents` to retrieve the top k closest vectors to a query vector. + +{% code title="python" %} +```python +from feast import FeatureStore +from feast.infra.online_stores.postgres import retrieve_online_documents + +feature_store = FeatureStore(repo_path=".") + +query_vector = [0.1, 0.2, 0.3, 0.4, 0.5] +top_k = 5 + +feature_values = retrieve_online_documents( + feature_store=feature_store, + feature_view_name="document_fv:embedding_float", + query_vector=query_vector, + top_k=top_k, +) +``` +{% endcode %} diff --git a/sdk/python/feast/infra/key_encoding_utils.py b/sdk/python/feast/infra/key_encoding_utils.py index e50e438c3d..ca834f1917 100644 --- a/sdk/python/feast/infra/key_encoding_utils.py +++ b/sdk/python/feast/infra/key_encoding_utils.py @@ -74,8 +74,13 @@ def serialize_entity_key( return b"".join(output) -def get_val_str(val): - accept_value_types = ["float_list_val", "double_list_val", "int_list_val"] +def get_list_val_str(val): + accept_value_types = [ + "float_list_val", + "double_list_val", + "int32_list_val", + "int64_list_val", + ] for accept_type in accept_value_types: if val.HasField(accept_type): return str(getattr(val, accept_type).val) diff --git a/sdk/python/feast/infra/online_stores/contrib/pgvector_repo_configuration.py b/sdk/python/feast/infra/online_stores/contrib/pgvector_repo_configuration.py new file mode 100644 index 0000000000..26b0561315 --- /dev/null +++ b/sdk/python/feast/infra/online_stores/contrib/pgvector_repo_configuration.py @@ -0,0 +1,12 @@ +from tests.integration.feature_repos.integration_test_repo_config import ( + IntegrationTestRepoConfig, +) +from tests.integration.feature_repos.universal.online_store.postgres import ( + PGVectorOnlineStoreCreator, +) + +FULL_REPO_CONFIGS = [ + IntegrationTestRepoConfig( + online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator + ), +] diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres.py b/sdk/python/feast/infra/online_stores/contrib/postgres.py index 2dcb618783..2890f60746 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres.py @@ -2,7 +2,7 @@ import logging from collections import defaultdict from datetime import datetime -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple import psycopg2 import pytz @@ -12,7 +12,7 @@ from feast import Entity from feast.feature_view import FeatureView -from feast.infra.key_encoding_utils import get_val_str, serialize_entity_key +from feast.infra.key_encoding_utils import get_list_val_str, serialize_entity_key from feast.infra.online_stores.online_store import OnlineStore from feast.infra.utils.postgres.connection_utils import _get_conn, _get_connection_pool from feast.infra.utils.postgres.postgres_config import ConnectionType, PostgreSQLConfig @@ -74,19 +74,18 @@ def online_write_batch( created_ts = _to_naive_utc(created_ts) for feature_name, val in values.items(): - val_str: Union[str, bytes] + vector_val = None if ( - "pgvector_enabled" in config.online_config - and config.online_config["pgvector_enabled"] + "pgvector_enabled" in config.online_store + and config.online_store.pgvector_enabled ): - val_str = get_val_str(val) - else: - val_str = val.SerializeToString() + vector_val = get_list_val_str(val) insert_values.append( ( entity_key_bin, feature_name, - val_str, + val.SerializeToString(), + vector_val, timestamp, created_ts, ) @@ -100,11 +99,12 @@ def online_write_batch( sql.SQL( """ INSERT INTO {} - (entity_key, feature_name, value, event_ts, created_ts) + (entity_key, feature_name, value, vector_value, event_ts, created_ts) VALUES %s ON CONFLICT (entity_key, feature_name) DO UPDATE SET value = EXCLUDED.value, + vector_value = EXCLUDED.vector_value, event_ts = EXCLUDED.event_ts, created_ts = EXCLUDED.created_ts; """, @@ -226,12 +226,14 @@ def update( for table in tables_to_keep: table_name = _table_id(project, table) - value_type = "BYTEA" if ( - "pgvector_enabled" in config.online_config - and config.online_config["pgvector_enabled"] + "pgvector_enabled" in config.online_store + and config.online_store.pgvector_enabled ): - value_type = f'vector({config.online_config["vector_len"]})' + vector_value_type = f"vector({config.online_store.vector_len})" + else: + # keep the vector_value_type as BYTEA if pgvector is not enabled, to maintain compatibility + vector_value_type = "BYTEA" cur.execute( sql.SQL( """ @@ -239,7 +241,8 @@ def update( ( entity_key BYTEA, feature_name TEXT, - value {}, + value BYTEA, + vector_value {} NULL, event_ts TIMESTAMPTZ, created_ts TIMESTAMPTZ, PRIMARY KEY(entity_key, feature_name) @@ -248,7 +251,7 @@ def update( """ ).format( sql.Identifier(table_name), - sql.SQL(value_type), + sql.SQL(vector_value_type), sql.Identifier(f"{table_name}_ek"), sql.Identifier(table_name), ) @@ -294,6 +297,14 @@ def retrieve_online_documents( """ project = config.project + if ( + "pgvector_enabled" not in config.online_store + or not config.online_store.pgvector_enabled + ): + raise ValueError( + "pgvector is not enabled in the online store configuration" + ) + # Convert the embedding to a string to be used in postgres vector search query_embedding_str = f"[{','.join(str(el) for el in embedding)}]" @@ -311,8 +322,8 @@ def retrieve_online_documents( SELECT entity_key, feature_name, - value, - value <-> %s as distance, + vector_value, + vector_value <-> %s as distance, event_ts FROM {table_name} WHERE feature_name = {feature_name} ORDER BY distance @@ -327,13 +338,13 @@ def retrieve_online_documents( ) rows = cur.fetchall() - for entity_key, feature_name, value, distance, event_ts in rows: + for entity_key, feature_name, vector_value, distance, event_ts in rows: # TODO Deserialize entity_key to return the entity in response # entity_key_proto = EntityKeyProto() # entity_key_proto_bin = bytes(entity_key) # TODO Convert to List[float] for value type proto - feature_value_proto = ValueProto(string_val=value) + feature_value_proto = ValueProto(string_val=vector_value) distance_value_proto = ValueProto(float_val=distance) result.append((event_ts, feature_value_proto, distance_value_proto)) diff --git a/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py b/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py index 6e4ca3f950..ea975ec808 100644 --- a/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py +++ b/sdk/python/feast/infra/online_stores/contrib/postgres_repo_configuration.py @@ -2,7 +2,6 @@ IntegrationTestRepoConfig, ) from tests.integration.feature_repos.universal.online_store.postgres import ( - PGVectorOnlineStoreCreator, PostgresOnlineStoreCreator, ) @@ -10,9 +9,4 @@ IntegrationTestRepoConfig( online_store="postgres", online_store_creator=PostgresOnlineStoreCreator ), - IntegrationTestRepoConfig( - online_store="pgvector", online_store_creator=PGVectorOnlineStoreCreator - ), ] - -AVAILABLE_ONLINE_STORES = {"pgvector": PGVectorOnlineStoreCreator} diff --git a/sdk/python/tests/integration/feature_repos/universal/online_store/init.sql b/sdk/python/tests/integration/feature_repos/universal/online_store/init.sql new file mode 100644 index 0000000000..64f04f61ad --- /dev/null +++ b/sdk/python/tests/integration/feature_repos/universal/online_store/init.sql @@ -0,0 +1 @@ +CREATE EXTENSION IF NOT EXISTS vector; \ No newline at end of file diff --git a/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py b/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py index 58e7af9c46..7b4156fffe 100644 --- a/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py +++ b/sdk/python/tests/integration/feature_repos/universal/online_store/postgres.py @@ -1,3 +1,4 @@ +import os from typing import Dict from testcontainers.core.container import DockerContainer @@ -37,12 +38,17 @@ def teardown(self): class PGVectorOnlineStoreCreator(OnlineStoreCreator): def __init__(self, project_name: str, **kwargs): super().__init__(project_name) + script_directory = os.path.dirname(os.path.abspath(__file__)) self.container = ( DockerContainer("pgvector/pgvector:pg16") .with_env("POSTGRES_USER", "root") .with_env("POSTGRES_PASSWORD", "test") .with_env("POSTGRES_DB", "test") .with_exposed_ports(5432) + .with_volume_mapping( + os.path.join(script_directory, "init.sql"), + "/docker-entrypoint-initdb.d/init.sql", + ) ) def create_online_store(self) -> Dict[str, str]: @@ -51,8 +57,10 @@ def create_online_store(self) -> Dict[str, str]: wait_for_logs( container=self.container, predicate=log_string_to_wait_for, timeout=10 ) - command = "psql -h localhost -p 5432 -U root -d test -c 'CREATE EXTENSION IF NOT EXISTS vector;'" - self.container.exec(command) + init_log_string_to_wait_for = "PostgreSQL init process complete" + wait_for_logs( + container=self.container, predicate=init_log_string_to_wait_for, timeout=10 + ) return { "host": "localhost", "type": "postgres",