diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index dcad135d8d..4b782a25d9 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -376,6 +376,10 @@ def get_table_column_names_and_types( """ raise NotImplementedError + def get_table_query_string(self) -> str: + """Returns a string that can directly be used to reference this table in SQL""" + raise NotImplementedError + class KafkaSource(DataSource): def validate(self, config: RepoConfig): diff --git a/sdk/python/feast/driver_test_data.py b/sdk/python/feast/driver_test_data.py index ea0921bf04..36603118b3 100644 --- a/sdk/python/feast/driver_test_data.py +++ b/sdk/python/feast/driver_test_data.py @@ -5,7 +5,9 @@ import pandas as pd from pytz import FixedOffset, timezone, utc -from feast.infra.provider import DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL +from feast.infra.offline_stores.offline_utils import ( + DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, +) class EventTimestampType(Enum): diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index bdcba93792..1202d4df49 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -185,3 +185,18 @@ def __init__(self): class RedshiftQueryError(Exception): def __init__(self, details): super().__init__(f"Redshift SQL Query failed to finish. Details: {details}") + + +class EntityTimestampInferenceException(Exception): + def __init__(self, expected_column_name: str): + super().__init__( + f"Please provide an entity_df with a column named {expected_column_name} representing the time of events." + ) + + +class InvalidEntityType(Exception): + def __init__(self, entity_type: type): + super().__init__( + f"The entity dataframe you have provided must be a Pandas DataFrame or a SQL query, " + f"but we found: {entity_type} " + ) diff --git a/sdk/python/feast/infra/aws.py b/sdk/python/feast/infra/aws.py index f182bbbcee..5318c9c81d 100644 --- a/sdk/python/feast/infra/aws.py +++ b/sdk/python/feast/infra/aws.py @@ -7,7 +7,7 @@ from feast import FeatureTable from feast.entity import Entity from feast.feature_view import FeatureView -from feast.infra.offline_stores.helpers import get_offline_store_from_config +from feast.infra.offline_stores.offline_utils import get_offline_store_from_config from feast.infra.online_stores.helpers import get_online_store_from_config from feast.infra.provider import ( Provider, diff --git a/sdk/python/feast/infra/gcp.py b/sdk/python/feast/infra/gcp.py index 2662a6e54f..2c679216ca 100644 --- a/sdk/python/feast/infra/gcp.py +++ b/sdk/python/feast/infra/gcp.py @@ -7,7 +7,7 @@ from feast import FeatureTable from feast.entity import Entity from feast.feature_view import FeatureView -from feast.infra.offline_stores.helpers import get_offline_store_from_config +from feast.infra.offline_stores.offline_utils import get_offline_store_from_config from feast.infra.online_stores.helpers import get_online_store_from_config from feast.infra.provider import ( Provider, diff --git a/sdk/python/feast/infra/local.py b/sdk/python/feast/infra/local.py index f677c84672..32a526dcb0 100644 --- a/sdk/python/feast/infra/local.py +++ b/sdk/python/feast/infra/local.py @@ -8,7 +8,7 @@ from feast import FeatureTable from feast.entity import Entity from feast.feature_view import FeatureView -from feast.infra.offline_stores.helpers import get_offline_store_from_config +from feast.infra.offline_stores.offline_utils import get_offline_store_from_config from feast.infra.online_stores.helpers import get_online_store_from_config from feast.infra.provider import ( Provider, diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 16dc8e950c..5fa0114133 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -1,30 +1,24 @@ -import time import uuid -from dataclasses import asdict, dataclass from datetime import date, datetime, timedelta -from typing import List, Optional, Set, Union +from typing import Dict, List, Optional, Union +import numpy as np import pandas import pyarrow -from jinja2 import BaseLoader, Environment -from pandas import Timestamp from pydantic import StrictStr from pydantic.typing import Literal from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed -from feast import errors from feast.data_source import DataSource from feast.errors import ( BigQueryJobCancelled, BigQueryJobStillRunning, FeastProviderLoginError, + InvalidEntityType, ) from feast.feature_view import FeatureView +from feast.infra.offline_stores import offline_utils from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob -from feast.infra.provider import ( - DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, - _get_requested_feature_views_to_features_dict, -) from feast.registry import Registry from feast.repo_config import FeastConfigBaseModel, RepoConfig @@ -34,7 +28,7 @@ from google.api_core.exceptions import NotFound from google.auth.exceptions import DefaultCredentialsError from google.cloud import bigquery - from google.cloud.bigquery import Client, Table + from google.cloud.bigquery import Client except ImportError as e: from feast.errors import FeastExtrasDependencyImportError @@ -108,134 +102,44 @@ def get_historical_features( assert isinstance(config.offline_store, BigQueryOfflineStoreConfig) client = _get_bigquery_client(project=config.offline_store.project_id) - expected_join_keys = _get_join_keys(project, feature_views, registry) assert isinstance(config.offline_store, BigQueryOfflineStoreConfig) - table = _upload_entity_df_into_bigquery( - client=client, - project=config.project, - dataset_name=config.offline_store.dataset, - dataset_project=client.project, - entity_df=entity_df, + table_reference = _get_table_reference_for_new_entity( + client, client.project, config.offline_store.dataset ) - entity_df_event_timestamp_col = _infer_event_timestamp_from_bigquery_query( - table.schema + entity_schema = _upload_entity_df_and_get_entity_schema( + client=client, table_name=table_reference, entity_df=entity_df, ) - _assert_expected_columns_in_bigquery( - expected_join_keys, entity_df_event_timestamp_col, table.schema, + + entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema ) - # Build a query context containing all information required to template the BigQuery SQL query - query_context = get_feature_view_query_context( - feature_refs, - feature_views, - registry, - project, - full_feature_names=full_feature_names, + expected_join_keys = offline_utils.get_expected_join_keys( + project, feature_views, registry ) - # Infer min and max timestamps from entity_df to limit data read in BigQuery SQL query - min_timestamp, max_timestamp = _get_entity_df_timestamp_bounds( - client, str(table.reference), entity_df_event_timestamp_col + offline_utils.assert_expected_columns_in_entity_df( + entity_schema, expected_join_keys, entity_df_event_timestamp_col + ) + + # Build a query context containing all information required to template the BigQuery SQL query + query_context = offline_utils.get_feature_view_query_context( + feature_refs, feature_views, registry, project, ) # Generate the BigQuery SQL query from the query context - query = build_point_in_time_query( + query = offline_utils.build_point_in_time_query( query_context, - min_timestamp=min_timestamp, - max_timestamp=max_timestamp, - left_table_query_string=str(table.reference), + left_table_query_string=table_reference, entity_df_event_timestamp_col=entity_df_event_timestamp_col, + query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, full_feature_names=full_feature_names, ) - job = BigQueryRetrievalJob(query=query, client=client, config=config) - return job - - -def _assert_expected_columns_in_dataframe( - join_keys: Set[str], entity_df_event_timestamp_col: str, entity_df: pandas.DataFrame -): - entity_df_columns = set(entity_df.columns.values) - expected_columns = join_keys.copy() - expected_columns.add(entity_df_event_timestamp_col) - - missing_keys = expected_columns - entity_df_columns - - if len(missing_keys) != 0: - raise errors.FeastEntityDFMissingColumnsError(expected_columns, missing_keys) - - -def _assert_expected_columns_in_bigquery( - join_keys: Set[str], entity_df_event_timestamp_col: str, table_schema -): - entity_columns = set() - for schema_field in table_schema: - entity_columns.add(schema_field.name) - - expected_columns = join_keys.copy() - expected_columns.add(entity_df_event_timestamp_col) - - missing_keys = expected_columns - entity_columns - - if len(missing_keys) != 0: - raise errors.FeastEntityDFMissingColumnsError(expected_columns, missing_keys) - - -def _get_join_keys( - project: str, feature_views: List[FeatureView], registry: Registry -) -> Set[str]: - join_keys = set() - for feature_view in feature_views: - entities = feature_view.entities - for entity_name in entities: - entity = registry.get_entity(entity_name, project) - join_keys.add(entity.join_key) - return join_keys - - -def _infer_event_timestamp_from_bigquery_query(table_schema) -> str: - if any( - schema_field.name == DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL - for schema_field in table_schema - ): - return DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL - else: - datetime_columns = list( - filter( - lambda schema_field: schema_field.field_type == "TIMESTAMP", - table_schema, - ) - ) - if len(datetime_columns) == 1: - print( - f"Using {datetime_columns[0].name} as the event timestamp. To specify a column explicitly, please name it {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL}." - ) - return datetime_columns[0].name - else: - raise ValueError( - f"Please provide an entity_df with a column named {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL} representing the time of events." - ) - - -def _infer_event_timestamp_from_dataframe(entity_df: pandas.DataFrame) -> str: - if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in entity_df.columns: - return DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL - else: - datetime_columns = entity_df.select_dtypes( - include=["datetime", "datetimetz"] - ).columns - if len(datetime_columns) == 1: - print( - f"Using {datetime_columns[0]} as the event timestamp. To specify a column explicitly, please name it {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL}." - ) - return datetime_columns[0] - else: - raise ValueError( - f"Please provide an entity_df with a column named {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL} representing the time of events." - ) + return BigQueryRetrievalJob(query=query, client=client, config=config) class BigQueryRetrievalJob(RetrievalJob): @@ -340,24 +244,8 @@ def _wait_until_done(job_id): raise bq_job.exception() -@dataclass(frozen=True) -class FeatureViewQueryContext: - """Context object used to template a BigQuery point-in-time SQL query""" - - name: str - ttl: int - entities: List[str] - features: List[str] # feature reference format - table_ref: str - event_timestamp_column: str - created_timestamp_column: Optional[str] - query: str - table_subquery: str - entity_selections: List[str] - - -def _get_table_id_for_new_entity( - client: Client, project: str, dataset_name: str, dataset_project: str +def _get_table_reference_for_new_entity( + client: Client, dataset_project: str, dataset_name: str ) -> str: """Gets the table_id for the new entity to be uploaded.""" @@ -371,25 +259,24 @@ def _get_table_id_for_new_entity( # Only create the dataset if it does not exist client.create_dataset(dataset, exists_ok=True) - return f"{dataset_project}.{dataset_name}.entity_df_{project}_{int(time.time())}" + table_name = offline_utils.get_temp_entity_table_name() + return f"{dataset_project}.{dataset_name}.{table_name}" -def _upload_entity_df_into_bigquery( - client: Client, - project: str, - dataset_name: str, - dataset_project: str, - entity_df: Union[pandas.DataFrame, str], -) -> Table: - """Uploads a Pandas entity dataframe into a BigQuery table and returns the resulting table""" - table_id = _get_table_id_for_new_entity( - client, project, dataset_name, dataset_project - ) +def _upload_entity_df_and_get_entity_schema( + client: Client, table_name: str, entity_df: Union[pandas.DataFrame, str], +) -> Dict[str, np.dtype]: + """Uploads a Pandas entity dataframe into a BigQuery table and returns the resulting table""" if type(entity_df) is str: - job = client.query(f"CREATE TABLE {table_id} AS ({entity_df})") + job = client.query(f"CREATE TABLE {table_name} AS ({entity_df})") block_until_done(client, job) + + limited_entity_df = ( + client.query(f"SELECT * FROM {table_name} LIMIT 1").result().to_dataframe() + ) + entity_schema = dict(zip(limited_entity_df.columns, limited_entity_df.dtypes)) elif isinstance(entity_df, pandas.DataFrame): # Drop the index so that we dont have unnecessary columns entity_df.reset_index(drop=True, inplace=True) @@ -397,132 +284,20 @@ def _upload_entity_df_into_bigquery( # Upload the dataframe into BigQuery, creating a temporary table job_config = bigquery.LoadJobConfig() job = client.load_table_from_dataframe( - entity_df, table_id, job_config=job_config + entity_df, table_name, job_config=job_config ) block_until_done(client, job) + + entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) else: - raise ValueError( - f"The entity dataframe you have provided must be a Pandas DataFrame or BigQuery SQL query, " - f"but we found: {type(entity_df)} " - ) + raise InvalidEntityType(type(entity_df)) # Ensure that the table expires after some time - table = client.get_table(table=table_id) + table = client.get_table(table=table_name) table.expires = datetime.utcnow() + timedelta(minutes=30) client.update_table(table, ["expires"]) - return table - - -def _get_entity_df_timestamp_bounds( - client: Client, entity_df_bq_table: str, event_timestamp_col: str, -): - - boundary_df = ( - client.query( - f""" - SELECT - MIN({event_timestamp_col}) AS min_timestamp, - MAX({event_timestamp_col}) AS max_timestamp - FROM {entity_df_bq_table} - """ - ) - .result() - .to_dataframe() - ) - - min_timestamp = boundary_df.loc[0, "min_timestamp"] - max_timestamp = boundary_df.loc[0, "max_timestamp"] - return min_timestamp, max_timestamp - - -def get_feature_view_query_context( - feature_refs: List[str], - feature_views: List[FeatureView], - registry: Registry, - project: str, - full_feature_names: bool = False, -) -> List[FeatureViewQueryContext]: - """Build a query context containing all information required to template a BigQuery point-in-time SQL query""" - - feature_views_to_feature_map = _get_requested_feature_views_to_features_dict( - feature_refs, feature_views - ) - - query_context = [] - for feature_view, features in feature_views_to_feature_map.items(): - join_keys = [] - entity_selections = [] - reverse_field_mapping = { - v: k for k, v in feature_view.input.field_mapping.items() - } - for entity_name in feature_view.entities: - entity = registry.get_entity(entity_name, project) - join_keys.append(entity.join_key) - join_key_column = reverse_field_mapping.get( - entity.join_key, entity.join_key - ) - entity_selections.append(f"{join_key_column} AS {entity.join_key}") - - if isinstance(feature_view.ttl, timedelta): - ttl_seconds = int(feature_view.ttl.total_seconds()) - else: - ttl_seconds = 0 - - assert isinstance(feature_view.input, BigQuerySource) - - event_timestamp_column = feature_view.input.event_timestamp_column - created_timestamp_column = feature_view.input.created_timestamp_column - - context = FeatureViewQueryContext( - name=feature_view.name, - ttl=ttl_seconds, - entities=join_keys, - features=features, - table_ref=feature_view.input.table_ref, - event_timestamp_column=reverse_field_mapping.get( - event_timestamp_column, event_timestamp_column - ), - created_timestamp_column=reverse_field_mapping.get( - created_timestamp_column, created_timestamp_column - ), - # TODO: Make created column optional and not hardcoded - query=feature_view.input.query, - table_subquery=feature_view.input.get_table_query_string(), - entity_selections=entity_selections, - ) - query_context.append(context) - return query_context - - -def build_point_in_time_query( - feature_view_query_contexts: List[FeatureViewQueryContext], - min_timestamp: Timestamp, - max_timestamp: Timestamp, - left_table_query_string: str, - entity_df_event_timestamp_col: str, - full_feature_names: bool = False, -): - """Build point-in-time query between each feature view table and the entity dataframe""" - template = Environment(loader=BaseLoader()).from_string( - source=SINGLE_FEATURE_VIEW_POINT_IN_TIME_JOIN - ) - - # Add additional fields to dict - template_context = { - "min_timestamp": min_timestamp, - "max_timestamp": max_timestamp, - "left_table_query_string": left_table_query_string, - "entity_df_event_timestamp_col": entity_df_event_timestamp_col, - "unique_entity_keys": set( - [entity for fv in feature_view_query_contexts for entity in fv.entities] - ), - "featureviews": [asdict(context) for context in feature_view_query_contexts], - "full_feature_names": full_feature_names, - } - - query = template.render(template_context) - return query + return entity_schema def _get_bigquery_client(project: Optional[str] = None): @@ -550,21 +325,23 @@ def _get_bigquery_client(project: Optional[str] = None): # * Precompute ROW_NUMBER() so that it doesn't have to be recomputed for every query on entity_dataframe # * Create temporary tables instead of keeping all tables in memory -SINGLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """ +# Note: Keep this in sync with sdk/python/feast/infra/offline_stores/redshift.py:MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN + +MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """ /* Compute a deterministic hash for the `left_table_query_string` that will be used throughout all the logic as the field to GROUP BY the data */ WITH entity_dataframe AS ( SELECT *, - {{entity_df_event_timestamp_col}} AS entity_timestamp, + {{entity_df_event_timestamp_col}} AS entity_timestamp {% for featureview in featureviews %} - CONCAT( + ,CONCAT( {% for entity in featureview.entities %} CAST({{entity}} AS STRING), {% endfor %} CAST({{entity_df_event_timestamp_col}} AS STRING) - ) AS {{featureview.name}}__entity_row_unique_id, + ) AS {{featureview.name}}__entity_row_unique_id {% endfor %} FROM {{ left_table_query_string }} ), @@ -606,9 +383,9 @@ def _get_bigquery_client(project: Optional[str] = None): {{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{feature}}{% else %}{{ feature }}{% endif %}{% if loop.last %}{% else %}, {% endif %} {% endfor %} FROM {{ featureview.table_subquery }} - WHERE {{ featureview.event_timestamp_column }} <= '{{max_timestamp}}' + WHERE {{ featureview.event_timestamp_column }} <= (SELECT MAX(entity_timestamp) FROM entity_dataframe) {% if featureview.ttl == 0 %}{% else %} - AND {{ featureview.event_timestamp_column }} >= Timestamp_sub('{{min_timestamp}}', interval {{ featureview.ttl }} second) + AND {{ featureview.event_timestamp_column }} >= Timestamp_sub((SELECT MIN(entity_timestamp) FROM entity_dataframe), interval {{ featureview.ttl }} second) {% endif %} ), @@ -642,7 +419,7 @@ def _get_bigquery_client(project: Optional[str] = None): SELECT {{featureview.name}}__entity_row_unique_id, event_timestamp, - MAX(created_timestamp) as created_timestamp, + MAX(created_timestamp) as created_timestamp FROM {{ featureview.name }}__base GROUP BY {{featureview.name}}__entity_row_unique_id, event_timestamp ), @@ -698,9 +475,9 @@ def _get_bigquery_client(project: Optional[str] = None): {% for featureview in featureviews %} LEFT JOIN ( SELECT - {{featureview.name}}__entity_row_unique_id, + {{featureview.name}}__entity_row_unique_id {% for feature in featureview.features %} - {% if full_feature_names %}{{ featureview.name }}__{{feature}}{% else %}{{ feature }}{% endif %}, + ,{% if full_feature_names %}{{ featureview.name }}__{{feature}}{% else %}{{ feature }}{% endif %} {% endfor %} FROM {{ featureview.name }}__cleaned ) USING ({{featureview.name}}__entity_row_unique_id) diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 74153acaee..590ba7f3b7 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -11,8 +11,10 @@ from feast.errors import FeastJoinKeysDuringMaterialization from feast.feature_view import FeatureView from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob -from feast.infra.provider import ( +from feast.infra.offline_stores.offline_utils import ( DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, +) +from feast.infra.provider import ( _get_requested_feature_views_to_features_dict, _run_field_mapping, ) diff --git a/sdk/python/feast/infra/offline_stores/helpers.py b/sdk/python/feast/infra/offline_stores/helpers.py deleted file mode 100644 index dff604c7ed..0000000000 --- a/sdk/python/feast/infra/offline_stores/helpers.py +++ /dev/null @@ -1,31 +0,0 @@ -import importlib -from typing import Any - -from feast import errors -from feast.infra.offline_stores.offline_store import OfflineStore - - -def get_offline_store_from_config(offline_store_config: Any,) -> OfflineStore: - """Get the offline store from offline store config""" - - module_name = offline_store_config.__module__ - qualified_name = type(offline_store_config).__name__ - store_class_name = qualified_name.replace("Config", "") - try: - module = importlib.import_module(module_name) - except Exception as e: - # The original exception can be anything - either module not found, - # or any other kind of error happening during the module import time. - # So we should include the original error as well in the stack trace. - raise errors.FeastModuleImportError(module_name, "OfflineStore") from e - - # Try getting the provider class definition - try: - offline_store_class = getattr(module, store_class_name) - except AttributeError: - # This can only be one type of error, when class_name attribute does not exist in the module - # So we don't have to include the original exception here - raise errors.FeastClassImportError( - module_name, store_class_name, class_type="OfflineStore" - ) from None - return offline_store_class() diff --git a/sdk/python/feast/infra/offline_stores/offline_utils.py b/sdk/python/feast/infra/offline_stores/offline_utils.py new file mode 100644 index 0000000000..304bdc8e91 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/offline_utils.py @@ -0,0 +1,198 @@ +import importlib +import uuid +from dataclasses import asdict, dataclass +from datetime import timedelta +from typing import Any, Dict, List, Optional, Set, Tuple + +import numpy as np +import pandas as pd +from jinja2 import BaseLoader, Environment +from pandas import Timestamp + +import feast +from feast.errors import ( + EntityTimestampInferenceException, + FeastClassImportError, + FeastEntityDFMissingColumnsError, + FeastModuleImportError, +) +from feast.infra.offline_stores.offline_store import OfflineStore +from feast.infra.provider import _get_requested_feature_views_to_features_dict +from feast.registry import Registry + +DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL = "event_timestamp" + + +def infer_event_timestamp_from_entity_df(entity_schema: Dict[str, np.dtype]) -> str: + if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in entity_schema.keys(): + return DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL + + datetime_columns = [ + column + for column, dtype in entity_schema.items() + if pd.core.dtypes.common.is_datetime64_any_dtype(dtype) + ] + + if len(datetime_columns) == 1: + print( + f"Using {datetime_columns[0]} as the event timestamp. To specify a column explicitly, please name it {DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL}." + ) + return datetime_columns[0] + else: + raise EntityTimestampInferenceException(DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL) + + +def assert_expected_columns_in_entity_df( + entity_schema: Dict[str, np.dtype], + join_keys: Set[str], + entity_df_event_timestamp_col: str, +): + entity_columns = set(entity_schema.keys()) + expected_columns = join_keys | {entity_df_event_timestamp_col} + missing_keys = expected_columns - entity_columns + + if len(missing_keys) != 0: + raise FeastEntityDFMissingColumnsError(expected_columns, missing_keys) + + +def get_expected_join_keys( + project: str, feature_views: List["feast.FeatureView"], registry: Registry +) -> Set[str]: + join_keys = set() + for feature_view in feature_views: + entities = feature_view.entities + for entity_name in entities: + entity = registry.get_entity(entity_name, project) + join_keys.add(entity.join_key) + return join_keys + + +def get_entity_df_timestamp_bounds( + entity_df: pd.DataFrame, event_timestamp_col: str +) -> Tuple[Timestamp, Timestamp]: + event_timestamp_series = entity_df[event_timestamp_col] + return event_timestamp_series.min(), event_timestamp_series.max() + + +@dataclass(frozen=True) +class FeatureViewQueryContext: + """Context object used to template a BigQuery and Redshift point-in-time SQL query""" + + name: str + ttl: int + entities: List[str] + features: List[str] # feature reference format + event_timestamp_column: str + created_timestamp_column: Optional[str] + table_subquery: str + entity_selections: List[str] + + +def get_feature_view_query_context( + feature_refs: List[str], + feature_views: List["feast.FeatureView"], + registry: Registry, + project: str, +) -> List[FeatureViewQueryContext]: + """Build a query context containing all information required to template a BigQuery and Redshift point-in-time SQL query""" + + feature_views_to_feature_map = _get_requested_feature_views_to_features_dict( + feature_refs, feature_views + ) + + query_context = [] + for feature_view, features in feature_views_to_feature_map.items(): + join_keys = [] + entity_selections = [] + reverse_field_mapping = { + v: k for k, v in feature_view.input.field_mapping.items() + } + for entity_name in feature_view.entities: + entity = registry.get_entity(entity_name, project) + join_keys.append(entity.join_key) + join_key_column = reverse_field_mapping.get( + entity.join_key, entity.join_key + ) + entity_selections.append(f"{join_key_column} AS {entity.join_key}") + + if isinstance(feature_view.ttl, timedelta): + ttl_seconds = int(feature_view.ttl.total_seconds()) + else: + ttl_seconds = 0 + + event_timestamp_column = feature_view.input.event_timestamp_column + created_timestamp_column = feature_view.input.created_timestamp_column + + context = FeatureViewQueryContext( + name=feature_view.name, + ttl=ttl_seconds, + entities=join_keys, + features=features, + event_timestamp_column=reverse_field_mapping.get( + event_timestamp_column, event_timestamp_column + ), + created_timestamp_column=reverse_field_mapping.get( + created_timestamp_column, created_timestamp_column + ), + # TODO: Make created column optional and not hardcoded + table_subquery=feature_view.input.get_table_query_string(), + entity_selections=entity_selections, + ) + query_context.append(context) + return query_context + + +def build_point_in_time_query( + feature_view_query_contexts: List[FeatureViewQueryContext], + left_table_query_string: str, + entity_df_event_timestamp_col: str, + query_template: str, + full_feature_names: bool = False, +): + """Build point-in-time query between each feature view table and the entity dataframe for Bigquery and Redshift""" + template = Environment(loader=BaseLoader()).from_string(source=query_template) + + # Add additional fields to dict + template_context = { + "left_table_query_string": left_table_query_string, + "entity_df_event_timestamp_col": entity_df_event_timestamp_col, + "unique_entity_keys": set( + [entity for fv in feature_view_query_contexts for entity in fv.entities] + ), + "featureviews": [asdict(context) for context in feature_view_query_contexts], + "full_feature_names": full_feature_names, + } + + query = template.render(template_context) + return query + + +def get_temp_entity_table_name() -> str: + """Returns a random table name for uploading the entity dataframe""" + return "feast_entity_df_" + uuid.uuid4().hex + + +def get_offline_store_from_config(offline_store_config: Any,) -> OfflineStore: + """Get the offline store from offline store config""" + + module_name = offline_store_config.__module__ + qualified_name = type(offline_store_config).__name__ + store_class_name = qualified_name.replace("Config", "") + try: + module = importlib.import_module(module_name) + except Exception as e: + # The original exception can be anything - either module not found, + # or any other kind of error happening during the module import time. + # So we should include the original error as well in the stack trace. + raise FeastModuleImportError(module_name, "OfflineStore") from e + + # Try getting the provider class definition + try: + offline_store_class = getattr(module, store_class_name) + except AttributeError: + # This can only be one type of error, when class_name attribute does not exist in the module + # So we don't have to include the original exception here + raise FeastClassImportError( + module_name, store_class_name, class_type="OfflineStore" + ) from None + return offline_store_class() diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 9204ff00be..00fa6727e7 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -1,7 +1,8 @@ import uuid from datetime import datetime -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union +import numpy as np import pandas as pd import pyarrow as pa from pydantic import StrictStr @@ -9,7 +10,9 @@ from feast import RedshiftSource from feast.data_source import DataSource +from feast.errors import InvalidEntityType from feast.feature_view import FeatureView +from feast.infra.offline_stores import offline_utils from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob from feast.infra.utils import aws_utils from feast.registry import Registry @@ -103,11 +106,67 @@ def get_historical_features( project: str, full_feature_names: bool = False, ) -> RetrievalJob: - pass + assert isinstance(config.offline_store, RedshiftOfflineStoreConfig) + + redshift_client = aws_utils.get_redshift_data_client( + config.offline_store.region + ) + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + + table_name = offline_utils.get_temp_entity_table_name() + + entity_schema = _upload_entity_df_and_get_entity_schema( + entity_df, redshift_client, config, s3_resource, table_name + ) + + entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema + ) + + expected_join_keys = offline_utils.get_expected_join_keys( + project, feature_views, registry + ) + + offline_utils.assert_expected_columns_in_entity_df( + entity_schema, expected_join_keys, entity_df_event_timestamp_col + ) + + # Build a query context containing all information required to template the Redshift SQL query + query_context = offline_utils.get_feature_view_query_context( + feature_refs, feature_views, registry, project, + ) + + # Generate the Redshift SQL query from the query context + query = offline_utils.build_point_in_time_query( + query_context, + left_table_query_string=table_name, + entity_df_event_timestamp_col=entity_df_event_timestamp_col, + query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, + full_feature_names=full_feature_names, + ) + + return RedshiftRetrievalJob( + query=query, + redshift_client=redshift_client, + s3_resource=s3_resource, + config=config, + drop_columns=["entity_timestamp"] + + [ + f"{feature_view.name}__entity_row_unique_id" + for feature_view in feature_views + ], + ) class RedshiftRetrievalJob(RetrievalJob): - def __init__(self, query: str, redshift_client, s3_resource, config: RepoConfig): + def __init__( + self, + query: str, + redshift_client, + s3_resource, + config: RepoConfig, + drop_columns: Optional[List[str]] = None, + ): """Initialize RedshiftRetrievalJob object. Args: @@ -115,6 +174,8 @@ def __init__(self, query: str, redshift_client, s3_resource, config: RepoConfig) redshift_client: boto3 redshift-data client s3_resource: boto3 s3 resource object config: Feast repo config + drop_columns: Optionally a list of columns to drop before unloading to S3. + This is a convenient field, since "SELECT ... EXCEPT col" isn't supported in Redshift. """ self.query = query self._redshift_client = redshift_client @@ -125,6 +186,7 @@ def __init__(self, query: str, redshift_client, s3_resource, config: RepoConfig) + "/unload/" + str(uuid.uuid4()) ) + self._drop_columns = drop_columns def to_df(self) -> pd.DataFrame: return aws_utils.unload_redshift_query_to_df( @@ -136,6 +198,7 @@ def to_df(self) -> pd.DataFrame: self._s3_path, self._config.offline_store.iam_role, self.query, + self._drop_columns, ) def to_arrow(self) -> pa.Table: @@ -148,6 +211,7 @@ def to_arrow(self) -> pa.Table: self._s3_path, self._config.offline_store.iam_role, self.query, + self._drop_columns, ) def to_s3(self) -> str: @@ -160,15 +224,227 @@ def to_s3(self) -> str: self._s3_path, self._config.offline_store.iam_role, self.query, + self._drop_columns, ) return self._s3_path def to_redshift(self, table_name: str) -> None: """ Save dataset as a new Redshift table """ + query = f'CREATE TABLE "{table_name}" AS ({self.query});\n' + if self._drop_columns is not None: + for column in self._drop_columns: + query += f"ALTER TABLE {table_name} DROP COLUMN {column};\n" + aws_utils.execute_redshift_statement( self._redshift_client, self._config.offline_store.cluster_id, self._config.offline_store.database, self._config.offline_store.user, - f'CREATE TABLE "{table_name}" AS ({self.query})', + query, + ) + + +def _upload_entity_df_and_get_entity_schema( + entity_df: Union[pd.DataFrame, str], + redshift_client, + config: RepoConfig, + s3_resource, + table_name: str, +) -> Dict[str, np.dtype]: + if isinstance(entity_df, pd.DataFrame): + # If the entity_df is a pandas dataframe, upload it to Redshift + # and construct the schema from the original entity_df dataframe + aws_utils.upload_df_to_redshift( + redshift_client, + config.offline_store.cluster_id, + config.offline_store.database, + config.offline_store.user, + s3_resource, + f"{config.offline_store.s3_staging_location}/entity_df/{table_name}.parquet", + config.offline_store.iam_role, + table_name, + entity_df, ) + return dict(zip(entity_df.columns, entity_df.dtypes)) + elif isinstance(entity_df, str): + # If the entity_df is a string (SQL query), create a Redshift table out of it, + # get pandas dataframe consisting of 1 row (LIMIT 1) and generate the schema out of it + aws_utils.execute_redshift_statement( + redshift_client, + config.offline_store.cluster_id, + config.offline_store.database, + config.offline_store.user, + f"CREATE TABLE {table_name} AS ({entity_df})", + ) + limited_entity_df = RedshiftRetrievalJob( + f"SELECT * FROM {table_name} LIMIT 1", redshift_client, s3_resource, config + ).to_df() + return dict(zip(limited_entity_df.columns, limited_entity_df.dtypes)) + else: + raise InvalidEntityType(type(entity_df)) + + +# This query is based on sdk/python/feast/infra/offline_stores/bigquery.py:MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN +# There are couple of changes from BigQuery: +# 1. Use VARCHAR instead of STRING type +# 2. Use "t - x * interval '1' second" instead of "Timestamp_sub(...)" +# 3. Replace `SELECT * EXCEPT (...)` with `SELECT *`, because `EXCEPT` is not supported by Redshift. +# Instead, we drop the column later after creating the table out of the query. +# We need to keep this query in sync with BigQuery. + +MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """ +/* + Compute a deterministic hash for the `left_table_query_string` that will be used throughout + all the logic as the field to GROUP BY the data +*/ +WITH entity_dataframe AS ( + SELECT *, + {{entity_df_event_timestamp_col}} AS entity_timestamp + {% for featureview in featureviews %} + ,CONCAT( + {% for entity in featureview.entities %} + CAST({{entity}} AS VARCHAR), + {% endfor %} + CAST({{entity_df_event_timestamp_col}} AS VARCHAR) + ) AS {{featureview.name}}__entity_row_unique_id + {% endfor %} + FROM {{ left_table_query_string }} +), + +{% for featureview in featureviews %} + +{{ featureview.name }}__entity_dataframe AS ( + SELECT + {{ featureview.entities | join(', ')}}, + entity_timestamp, + {{featureview.name}}__entity_row_unique_id + FROM entity_dataframe + GROUP BY {{ featureview.entities | join(', ')}}, entity_timestamp, {{featureview.name}}__entity_row_unique_id +), + +/* + This query template performs the point-in-time correctness join for a single feature set table + to the provided entity table. + + 1. We first join the current feature_view to the entity dataframe that has been passed. + This JOIN has the following logic: + - For each row of the entity dataframe, only keep the rows where the `event_timestamp_column` + is less than the one provided in the entity dataframe + - If there a TTL for the current feature_view, also keep the rows where the `event_timestamp_column` + is higher the the one provided minus the TTL + - For each row, Join on the entity key and retrieve the `entity_row_unique_id` that has been + computed previously + + The output of this CTE will contain all the necessary information and already filtered out most + of the data that is not relevant. +*/ + +{{ featureview.name }}__subquery AS ( + SELECT + {{ featureview.event_timestamp_column }} as event_timestamp, + {{ featureview.created_timestamp_column ~ ' as created_timestamp,' if featureview.created_timestamp_column else '' }} + {{ featureview.entity_selections | join(', ')}}, + {% for feature in featureview.features %} + {{ feature }} as {% if full_feature_names %}{{ featureview.name }}__{{feature}}{% else %}{{ feature }}{% endif %}{% if loop.last %}{% else %}, {% endif %} + {% endfor %} + FROM {{ featureview.table_subquery }} + WHERE {{ featureview.event_timestamp_column }} <= (SELECT MAX(entity_timestamp) FROM entity_dataframe) + {% if featureview.ttl == 0 %}{% else %} + AND {{ featureview.event_timestamp_column }} >= (SELECT MIN(entity_timestamp) FROM entity_dataframe) - {{ featureview.ttl }} * interval '1' second + {% endif %} +), + +{{ featureview.name }}__base AS ( + SELECT + subquery.*, + entity_dataframe.entity_timestamp, + entity_dataframe.{{featureview.name}}__entity_row_unique_id + FROM {{ featureview.name }}__subquery AS subquery + INNER JOIN {{ featureview.name }}__entity_dataframe AS entity_dataframe + ON TRUE + AND subquery.event_timestamp <= entity_dataframe.entity_timestamp + + {% if featureview.ttl == 0 %}{% else %} + AND subquery.event_timestamp >= entity_dataframe.entity_timestamp - {{ featureview.ttl }} * interval '1' second + {% endif %} + + {% for entity in featureview.entities %} + AND subquery.{{ entity }} = entity_dataframe.{{ entity }} + {% endfor %} +), + +/* + 2. If the `created_timestamp_column` has been set, we need to + deduplicate the data first. This is done by calculating the + `MAX(created_at_timestamp)` for each event_timestamp. + We then join the data on the next CTE +*/ +{% if featureview.created_timestamp_column %} +{{ featureview.name }}__dedup AS ( + SELECT + {{featureview.name}}__entity_row_unique_id, + event_timestamp, + MAX(created_timestamp) as created_timestamp + FROM {{ featureview.name }}__base + GROUP BY {{featureview.name}}__entity_row_unique_id, event_timestamp +), +{% endif %} + +/* + 3. The data has been filtered during the first CTE "*__base" + Thus we only need to compute the latest timestamp of each feature. +*/ +{{ featureview.name }}__latest AS ( + SELECT + {{featureview.name}}__entity_row_unique_id, + MAX(event_timestamp) AS event_timestamp + {% if featureview.created_timestamp_column %} + ,ANY_VALUE(created_timestamp) AS created_timestamp + {% endif %} + + FROM {{ featureview.name }}__base + {% if featureview.created_timestamp_column %} + INNER JOIN {{ featureview.name }}__dedup + USING ({{featureview.name}}__entity_row_unique_id, event_timestamp, created_timestamp) + {% endif %} + + GROUP BY {{featureview.name}}__entity_row_unique_id +), + +/* + 4. Once we know the latest value of each feature for a given timestamp, + we can join again the data back to the original "base" dataset +*/ +{{ featureview.name }}__cleaned AS ( + SELECT base.* + FROM {{ featureview.name }}__base as base + INNER JOIN {{ featureview.name }}__latest + USING( + {{featureview.name}}__entity_row_unique_id, + event_timestamp + {% if featureview.created_timestamp_column %} + ,created_timestamp + {% endif %} + ) +){% if loop.last %}{% else %}, {% endif %} + + +{% endfor %} +/* + Joins the outputs of multiple time travel joins to a single table. + The entity_dataframe dataset being our source of truth here. + */ + +SELECT * +FROM entity_dataframe +{% for featureview in featureviews %} +LEFT JOIN ( + SELECT + {{featureview.name}}__entity_row_unique_id + {% for feature in featureview.features %} + ,{% if full_feature_names %}{{ featureview.name }}__{{feature}}{% else %}{{ feature }}{% endif %} + {% endfor %} + FROM {{ featureview.name }}__cleaned +) USING ({{featureview.name}}__entity_row_unique_id) +{% endfor %} +""" diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index f778032c17..40b0659d20 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -18,8 +18,6 @@ from feast.repo_config import RepoConfig from feast.type_map import python_value_to_proto_value -DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL = "event_timestamp" - class Provider(abc.ABC): @abc.abstractmethod diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index 235f427b76..7e6c8849dd 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -1,7 +1,8 @@ +import contextlib import os import tempfile import uuid -from typing import Tuple +from typing import Generator, List, Optional, Tuple import pandas as pd import pyarrow as pa @@ -174,6 +175,9 @@ def upload_df_to_redshift( """ bucket, key = get_bucket_and_key(s3_path) + # Drop the index so that we dont have unnecessary columns + df.reset_index(drop=True, inplace=True) + # Convert Pandas DataFrame into PyArrow table and compile the Redshift table schema table = pa.Table.from_pandas(df) column_names, column_types = [], [] @@ -207,6 +211,49 @@ def upload_df_to_redshift( s3_resource.Object(bucket, key).delete() +@contextlib.contextmanager +def temporarily_upload_df_to_redshift( + redshift_data_client, + cluster_id: str, + database: str, + user: str, + s3_resource, + s3_path: str, + iam_role: str, + table_name: str, + df: pd.DataFrame, +) -> Generator[None, None, None]: + """Uploads a Pandas DataFrame to Redshift as a new table with cleanup logic. + + This is essentially the same as upload_df_to_redshift (check out its docstring for full details), + but unlike it this method is a generator and should be used with `with` block. For example: + + >>> with temporarily_upload_df_to_redshift(...): + >>> # Use `table_name` table in Redshift here + >>> # `table_name` will not exist at this point, since it's cleaned up by the `with` block + + """ + # Upload the dataframe to Redshift + upload_df_to_redshift( + redshift_data_client, + cluster_id, + database, + user, + s3_resource, + s3_path, + iam_role, + table_name, + df, + ) + + yield + + # Clean up the uploaded Redshift table + execute_redshift_statement( + redshift_data_client, cluster_id, database, user, f"DROP TABLE {table_name}", + ) + + def download_s3_directory(s3_resource, bucket: str, key: str, local_dir: str): """ Download the S3 directory to a local disk """ bucket_obj = s3_resource.Bucket(bucket) @@ -236,17 +283,31 @@ def execute_redshift_query_and_unload_to_s3( s3_path: str, iam_role: str, query: str, + drop_columns: Optional[List[str]] = None, ) -> None: - """ Unload Redshift Query results to S3 """ + """Unload Redshift Query results to S3 + + Args: + redshift_data_client: Redshift Data API Service client + cluster_id: Redshift Cluster Identifier + database: Redshift Database Name + user: Redshift username + s3_path: S3 directory where the unloaded data is written + iam_role: IAM Role for Redshift to assume during the UNLOAD command. + The role must grant permission to write to the S3 location. + query: The SQL query to execute + drop_columns: Optionally a list of columns to drop before unloading to S3. + This is a convenient field, since "SELECT ... EXCEPT col" isn't supported in Redshift. + + """ # Run the query, unload the results to S3 unique_table_name = "_" + str(uuid.uuid4()).replace("-", "") - unload_query = f""" - CREATE TEMPORARY TABLE {unique_table_name} AS ({query}); - UNLOAD ('SELECT * FROM {unique_table_name}') TO '{s3_path}/' IAM_ROLE '{iam_role}' PARQUET - """ - execute_redshift_statement( - redshift_data_client, cluster_id, database, user, unload_query - ) + query = f"CREATE TEMPORARY TABLE {unique_table_name} AS ({query});\n" + if drop_columns is not None: + for column in drop_columns: + query += f"ALTER TABLE {unique_table_name} DROP COLUMN {column};\n" + query += f"UNLOAD ('SELECT * FROM {unique_table_name}') TO '{s3_path}/' IAM_ROLE '{iam_role}' PARQUET" + execute_redshift_statement(redshift_data_client, cluster_id, database, user, query) def unload_redshift_query_to_pa( @@ -258,12 +319,20 @@ def unload_redshift_query_to_pa( s3_path: str, iam_role: str, query: str, + drop_columns: Optional[List[str]] = None, ) -> pa.Table: """ Unload Redshift Query results to S3 and get the results in PyArrow Table format """ bucket, key = get_bucket_and_key(s3_path) execute_redshift_query_and_unload_to_s3( - redshift_data_client, cluster_id, database, user, s3_path, iam_role, query + redshift_data_client, + cluster_id, + database, + user, + s3_path, + iam_role, + query, + drop_columns, ) with tempfile.TemporaryDirectory() as temp_dir: @@ -281,6 +350,7 @@ def unload_redshift_query_to_df( s3_path: str, iam_role: str, query: str, + drop_columns: Optional[List[str]] = None, ) -> pd.DataFrame: """ Unload Redshift Query results to S3 and get the results in Pandas DataFrame format """ table = unload_redshift_query_to_pa( @@ -292,5 +362,6 @@ def unload_redshift_query_to_df( s3_path, iam_role, query, + drop_columns, ) return table.to_pandas() diff --git a/sdk/python/tests/integration/materialization/test_offline_online_store_consistency.py b/sdk/python/tests/integration/materialization/test_offline_online_store_consistency.py index a8efdfcfb5..e0acadf7c9 100644 --- a/sdk/python/tests/integration/materialization/test_offline_online_store_consistency.py +++ b/sdk/python/tests/integration/materialization/test_offline_online_store_consistency.py @@ -323,7 +323,6 @@ def check_offline_and_online_features( event_timestamp: datetime, expected_value: Optional[float], full_feature_names: bool, - check_offline_store: bool = True, ) -> None: # Check online store response_dict = fs.get_online_features( @@ -344,32 +343,28 @@ def check_offline_and_online_features( assert response_dict["value"][0] is None # Check offline store - if check_offline_store: - df = fs.get_historical_features( - entity_df=pd.DataFrame.from_dict( - {"driver_id": [driver_id], "event_timestamp": [event_timestamp]} - ), - features=[f"{fv.name}:value"], - full_feature_names=full_feature_names, - ).to_df() - - if full_feature_names: - if expected_value: - assert abs(df.to_dict()[f"{fv.name}__value"][0] - expected_value) < 1e-6 - else: - assert math.isnan(df.to_dict()[f"{fv.name}__value"][0]) + df = fs.get_historical_features( + entity_df=pd.DataFrame.from_dict( + {"driver_id": [driver_id], "event_timestamp": [event_timestamp]} + ), + features=[f"{fv.name}:value"], + full_feature_names=full_feature_names, + ).to_df() + + if full_feature_names: + if expected_value: + assert abs(df.to_dict()[f"{fv.name}__value"][0] - expected_value) < 1e-6 + else: + assert math.isnan(df.to_dict()[f"{fv.name}__value"][0]) + else: + if expected_value: + assert abs(df.to_dict()["value"][0] - expected_value) < 1e-6 else: - if expected_value: - assert abs(df.to_dict()["value"][0] - expected_value) < 1e-6 - else: - assert math.isnan(df.to_dict()["value"][0]) + assert math.isnan(df.to_dict()["value"][0]) def run_offline_online_store_consistency_test( - fs: FeatureStore, - fv: FeatureView, - full_feature_names: bool, - check_offline_store: bool = True, + fs: FeatureStore, fv: FeatureView, full_feature_names: bool, ) -> None: now = datetime.utcnow() # Run materialize() @@ -386,7 +381,6 @@ def run_offline_online_store_consistency_test( event_timestamp=end_date, expected_value=0.3, full_feature_names=full_feature_names, - check_offline_store=check_offline_store, ) check_offline_and_online_features( @@ -396,7 +390,6 @@ def run_offline_online_store_consistency_test( event_timestamp=end_date, expected_value=None, full_feature_names=full_feature_names, - check_offline_store=check_offline_store, ) # check prior value for materialize_incremental() @@ -407,7 +400,6 @@ def run_offline_online_store_consistency_test( event_timestamp=end_date, expected_value=4, full_feature_names=full_feature_names, - check_offline_store=check_offline_store, ) # run materialize_incremental() @@ -421,7 +413,6 @@ def run_offline_online_store_consistency_test( event_timestamp=now, expected_value=5, full_feature_names=full_feature_names, - check_offline_store=check_offline_store, ) @@ -460,8 +451,7 @@ def test_redshift_offline_online_store_consistency( source_type: str, full_feature_names: bool ): with prep_redshift_fs_and_fv(source_type) as (fs, fv): - # TODO: remove check_offline_store parameter once Redshift's get_historical_features is implemented - run_offline_online_store_consistency_test(fs, fv, full_feature_names, False) + run_offline_online_store_consistency_test(fs, fv, full_feature_names) @pytest.mark.parametrize("full_feature_names", [True, False]) diff --git a/sdk/python/tests/integration/offline_store/test_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_historical_retrieval.py index af043359e4..3786080837 100644 --- a/sdk/python/tests/integration/offline_store/test_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_historical_retrieval.py @@ -14,18 +14,28 @@ from pytz import utc import feast.driver_test_data as driver_data -from feast import BigQuerySource, FeatureService, FileSource, RepoConfig, errors, utils +from feast import ( + BigQuerySource, + FeatureService, + FileSource, + RedshiftSource, + RepoConfig, + errors, + utils, +) from feast.entity import Entity from feast.errors import FeatureNameCollisionError from feast.feature import Feature from feast.feature_store import FeatureStore, _validate_feature_refs from feast.feature_view import FeatureView -from feast.infra.offline_stores.bigquery import ( - BigQueryOfflineStoreConfig, - _get_entity_df_timestamp_bounds, +from feast.infra.offline_stores.bigquery import BigQueryOfflineStoreConfig +from feast.infra.offline_stores.offline_utils import ( + DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, ) +from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig +from feast.infra.online_stores.dynamodb import DynamoDBOnlineStoreConfig from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig -from feast.infra.provider import DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL +from feast.infra.utils import aws_utils from feast.value_type import ValueType np.random.seed(0) @@ -440,7 +450,7 @@ def test_historical_features_from_bigquery_sources( customer_source = BigQuerySource( table_ref=customer_table_id, event_timestamp_column="datetime", - created_timestamp_column="", + created_timestamp_column="created", ) customer_fv = create_customer_daily_profile_feature_view(customer_source) @@ -602,10 +612,10 @@ def test_historical_features_from_bigquery_sources( # Make sure that custom dataset name is being used from the offline_store config if provider_type == "gcp_custom_offline_config": - assertpy.assert_that(job_from_df.query).contains("foo.entity_df") + assertpy.assert_that(job_from_df.query).contains("foo.feast_entity_df") else: assertpy.assert_that(job_from_df.query).contains( - f"{bigquery_dataset}.entity_df" + f"{bigquery_dataset}.feast_entity_df" ) start_time = datetime.utcnow() @@ -638,28 +648,282 @@ def test_historical_features_from_bigquery_sources( @pytest.mark.integration -def test_timestamp_bound_inference_from_entity_df_using_bigquery(): - start_date = datetime.now().replace(microsecond=0, second=0, minute=0) - (_, _, _, entity_df, start_date) = generate_entities( - start_date, infer_event_timestamp_col=True +@pytest.mark.parametrize( + "provider_type", ["local", "aws"], +) +@pytest.mark.parametrize( + "infer_event_timestamp_col", [False, True], +) +@pytest.mark.parametrize( + "full_feature_names", [False, True], +) +def test_historical_features_from_redshift_sources( + provider_type, infer_event_timestamp_col, capsys, full_feature_names +): + client = aws_utils.get_redshift_data_client("us-west-2") + s3 = aws_utils.get_s3_resource("us-west-2") + + offline_store = RedshiftOfflineStoreConfig( + cluster_id="feast-integration-tests", + region="us-west-2", + user="admin", + database="feast", + s3_staging_location="s3://feast-integration-tests/redshift/tests/ingestion", + iam_role="arn:aws:iam::402087665549:role/redshift_s3_access_role", ) - table_id = f"foo.table_id_{int(time.time_ns())}_{random.randint(1000, 9999)}" - stage_orders_bigquery(entity_df, table_id) + start_date = datetime.now().replace(microsecond=0, second=0, minute=0) + ( + customer_entities, + driver_entities, + end_date, + orders_df, + start_date, + ) = generate_entities(start_date, infer_event_timestamp_col) - client = bigquery.Client() - table = client.get_table(table=table_id) + redshift_table_prefix = ( + f"test_hist_retrieval_{int(time.time_ns())}_{random.randint(1000, 9999)}" + ) + + # Stage orders_df to Redshift + table_name = f"{redshift_table_prefix}_orders" + entity_df_query = f"SELECT * FROM {table_name}" + orders_context = aws_utils.temporarily_upload_df_to_redshift( + client, + offline_store.cluster_id, + offline_store.database, + offline_store.user, + s3, + f"{offline_store.s3_staging_location}/copy/{table_name}.parquet", + offline_store.iam_role, + table_name, + orders_df, + ) - # Ensure that the table expires after some time - table.expires = datetime.utcnow() + timedelta(minutes=30) - client.update_table(table, ["expires"]) + # Stage driver_df to Redshift + driver_df = driver_data.create_driver_hourly_stats_df( + driver_entities, start_date, end_date + ) + driver_table_name = f"{redshift_table_prefix}_driver_hourly" + driver_context = aws_utils.temporarily_upload_df_to_redshift( + client, + offline_store.cluster_id, + offline_store.database, + offline_store.user, + s3, + f"{offline_store.s3_staging_location}/copy/{driver_table_name}.parquet", + offline_store.iam_role, + driver_table_name, + driver_df, + ) - min_timestamp, max_timestamp = _get_entity_df_timestamp_bounds( - client, str(table.reference), "e_ts" + # Stage customer_df to Redshift + customer_df = driver_data.create_customer_daily_profile_df( + customer_entities, start_date, end_date + ) + customer_table_name = f"{redshift_table_prefix}_customer_profile" + customer_context = aws_utils.temporarily_upload_df_to_redshift( + client, + offline_store.cluster_id, + offline_store.database, + offline_store.user, + s3, + f"{offline_store.s3_staging_location}/copy/{customer_table_name}.parquet", + offline_store.iam_role, + customer_table_name, + customer_df, ) - assert min_timestamp.astimezone("UTC") == min(entity_df["e_ts"]).astimezone("UTC") - assert max_timestamp.astimezone("UTC") == max(entity_df["e_ts"]).astimezone("UTC") + with orders_context, driver_context, customer_context, TemporaryDirectory() as temp_dir: + driver_source = RedshiftSource( + table=driver_table_name, + event_timestamp_column="datetime", + created_timestamp_column="created", + ) + driver_fv = create_driver_hourly_stats_feature_view(driver_source) + + customer_source = RedshiftSource( + table=customer_table_name, + event_timestamp_column="datetime", + created_timestamp_column="created", + ) + customer_fv = create_customer_daily_profile_feature_view(customer_source) + + driver = Entity(name="driver", join_key="driver_id", value_type=ValueType.INT64) + customer = Entity(name="customer_id", value_type=ValueType.INT64) + + if provider_type == "local": + store = FeatureStore( + config=RepoConfig( + registry=os.path.join(temp_dir, "registry.db"), + project="default", + provider="local", + online_store=SqliteOnlineStoreConfig( + path=os.path.join(temp_dir, "online_store.db"), + ), + offline_store=offline_store, + ) + ) + elif provider_type == "aws": + store = FeatureStore( + config=RepoConfig( + registry=os.path.join(temp_dir, "registry.db"), + project="".join( + random.choices(string.ascii_uppercase + string.digits, k=10) + ), + provider="aws", + online_store=DynamoDBOnlineStoreConfig(region="us-west-2"), + offline_store=offline_store, + ) + ) + else: + raise Exception("Invalid provider used as part of test configuration") + + store.apply([driver, customer, driver_fv, customer_fv]) + + event_timestamp = ( + DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL + if DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL in orders_df.columns + else "e_ts" + ) + expected_df = get_expected_training_df( + customer_df, + customer_fv, + driver_df, + driver_fv, + orders_df, + event_timestamp, + full_feature_names, + ) + + job_from_sql = store.get_historical_features( + entity_df=entity_df_query, + features=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ], + full_feature_names=full_feature_names, + ) + + start_time = datetime.utcnow() + actual_df_from_sql_entities = job_from_sql.to_df() + end_time = datetime.utcnow() + with capsys.disabled(): + print( + str( + f"\nTime to execute job_from_sql.to_df() = '{(end_time - start_time)}'" + ) + ) + + assert sorted(expected_df.columns) == sorted( + actual_df_from_sql_entities.columns + ) + assert_frame_equal( + expected_df.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ).reset_index(drop=True), + actual_df_from_sql_entities[expected_df.columns] + .sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"]) + .reset_index(drop=True), + check_dtype=False, + ) + + table_from_sql_entities = job_from_sql.to_arrow() + assert_frame_equal( + actual_df_from_sql_entities.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ).reset_index(drop=True), + table_from_sql_entities.to_pandas() + .sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"]) + .reset_index(drop=True), + ) + + timestamp_column = ( + "e_ts" + if infer_event_timestamp_col + else DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL + ) + + entity_df_query_with_invalid_join_key = ( + f"select order_id, driver_id, customer_id as customer, " + f"order_is_success, {timestamp_column} FROM {table_name}" + ) + # Rename the join key; this should now raise an error. + assertpy.assert_that(store.get_historical_features).raises( + errors.FeastEntityDFMissingColumnsError + ).when_called_with( + entity_df=entity_df_query_with_invalid_join_key, + features=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ], + ) + + job_from_df = store.get_historical_features( + entity_df=orders_df, + features=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ], + full_feature_names=full_feature_names, + ) + + # Rename the join key; this should now raise an error. + orders_df_with_invalid_join_key = orders_df.rename( + {"customer_id": "customer"}, axis="columns" + ) + assertpy.assert_that(store.get_historical_features).raises( + errors.FeastEntityDFMissingColumnsError + ).when_called_with( + entity_df=orders_df_with_invalid_join_key, + features=[ + "driver_stats:conv_rate", + "driver_stats:avg_daily_trips", + "customer_profile:current_balance", + "customer_profile:avg_passenger_count", + "customer_profile:lifetime_trip_count", + ], + ) + + start_time = datetime.utcnow() + actual_df_from_df_entities = job_from_df.to_df() + end_time = datetime.utcnow() + with capsys.disabled(): + print( + str( + f"Time to execute job_from_df.to_df() = '{(end_time - start_time)}'\n" + ) + ) + + assert sorted(expected_df.columns) == sorted(actual_df_from_df_entities.columns) + assert_frame_equal( + expected_df.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ).reset_index(drop=True), + actual_df_from_df_entities[expected_df.columns] + .sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"]) + .reset_index(drop=True), + check_dtype=False, + ) + + table_from_df_entities = job_from_df.to_arrow() + assert_frame_equal( + actual_df_from_df_entities.sort_values( + by=[event_timestamp, "order_id", "driver_id", "customer_id"] + ).reset_index(drop=True), + table_from_df_entities.to_pandas() + .sort_values(by=[event_timestamp, "order_id", "driver_id", "customer_id"]) + .reset_index(drop=True), + ) def test_feature_name_collision_on_historical_retrieval():