diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 3e1317626a..e1e53d4128 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -1,6 +1,7 @@ +import contextlib import uuid from datetime import date, datetime, timedelta -from typing import Dict, List, Optional, Union +from typing import Callable, ContextManager, Dict, Iterator, List, Optional, Union import numpy as np import pandas as pd @@ -122,38 +123,47 @@ def get_historical_features( client, client.project, config.offline_store.dataset ) - entity_schema = _upload_entity_df_and_get_entity_schema( - client=client, table_name=table_reference, entity_df=entity_df, - ) + @contextlib.contextmanager + def query_generator() -> Iterator[str]: + entity_schema = _upload_entity_df_and_get_entity_schema( + client=client, table_name=table_reference, entity_df=entity_df, + ) - entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( - entity_schema - ) + entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema + ) - expected_join_keys = offline_utils.get_expected_join_keys( - project, feature_views, registry - ) + expected_join_keys = offline_utils.get_expected_join_keys( + project, feature_views, registry + ) - offline_utils.assert_expected_columns_in_entity_df( - entity_schema, expected_join_keys, entity_df_event_timestamp_col - ) + offline_utils.assert_expected_columns_in_entity_df( + entity_schema, expected_join_keys, entity_df_event_timestamp_col + ) - # Build a query context containing all information required to template the BigQuery SQL query - query_context = offline_utils.get_feature_view_query_context( - feature_refs, feature_views, registry, project, - ) + # Build a query context containing all information required to template the BigQuery SQL query + query_context = offline_utils.get_feature_view_query_context( + feature_refs, feature_views, registry, project, + ) - # Generate the BigQuery SQL query from the query context - query = offline_utils.build_point_in_time_query( - query_context, - left_table_query_string=table_reference, - entity_df_event_timestamp_col=entity_df_event_timestamp_col, - query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, - full_feature_names=full_feature_names, - ) + # Generate the BigQuery SQL query from the query context + query = offline_utils.build_point_in_time_query( + query_context, + left_table_query_string=table_reference, + entity_df_event_timestamp_col=entity_df_event_timestamp_col, + query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN, + full_feature_names=full_feature_names, + ) + + try: + yield query + finally: + # Asynchronously clean up the uploaded Bigquery table, which will expire + # if cleanup fails + client.delete_table(table=table_reference, not_found_ok=True) return BigQueryRetrievalJob( - query=query, + query=query_generator, client=client, config=config, full_feature_names=full_feature_names, @@ -166,13 +176,22 @@ def get_historical_features( class BigQueryRetrievalJob(RetrievalJob): def __init__( self, - query: str, + query: Union[str, Callable[[], ContextManager[str]]], client: bigquery.Client, config: RepoConfig, full_feature_names: bool, on_demand_feature_views: Optional[List[OnDemandFeatureView]], ): - self.query = query + if not isinstance(query, str): + self._query_generator = query + else: + + @contextlib.contextmanager + def query_generator() -> Iterator[str]: + assert isinstance(query, str) + yield query + + self._query_generator = query_generator self.client = client self.config = config self._full_feature_names = full_feature_names @@ -187,15 +206,16 @@ def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: return self._on_demand_feature_views def _to_df_internal(self) -> pd.DataFrame: - # TODO: Ideally only start this job when the user runs "get_historical_features", not when they run to_df() - df = self.client.query(self.query).to_dataframe(create_bqstorage_client=True) - return df + with self._query_generator() as query: + df = self.client.query(query).to_dataframe(create_bqstorage_client=True) + return df def to_sql(self) -> str: """ Returns the SQL query that will be executed in BigQuery to build the historical feature table. """ - return self.query + with self._query_generator() as query: + return query def to_bigquery( self, @@ -215,36 +235,39 @@ def to_bigquery( Returns: Returns the destination table name or returns None if job_config.dry_run is True. """ + with self._query_generator() as query: + if not job_config: + today = date.today().strftime("%Y%m%d") + rand_id = str(uuid.uuid4())[:7] + path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}" + job_config = bigquery.QueryJobConfig(destination=path) + + if not job_config.dry_run and self.on_demand_feature_views is not None: + job = _write_pyarrow_table_to_bq( + self.client, self.to_arrow(), job_config.destination + ) + job.result() + print(f"Done writing to '{job_config.destination}'.") + return str(job_config.destination) + + bq_job = self.client.query(query, job_config=job_config) + + if job_config.dry_run: + print( + "This query will process {} bytes.".format( + bq_job.total_bytes_processed + ) + ) + return None + + block_until_done(client=self.client, bq_job=bq_job, timeout=timeout) - if not job_config: - today = date.today().strftime("%Y%m%d") - rand_id = str(uuid.uuid4())[:7] - path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}" - job_config = bigquery.QueryJobConfig(destination=path) - - if not job_config.dry_run and self.on_demand_feature_views is not None: - job = _write_pyarrow_table_to_bq( - self.client, self.to_arrow(), job_config.destination - ) - job.result() print(f"Done writing to '{job_config.destination}'.") return str(job_config.destination) - bq_job = self.client.query(self.query, job_config=job_config) - - if job_config.dry_run: - print( - "This query will process {} bytes.".format(bq_job.total_bytes_processed) - ) - return None - - block_until_done(client=self.client, bq_job=bq_job, timeout=timeout) - - print(f"Done writing to '{job_config.destination}'.") - return str(job_config.destination) - def _to_arrow_internal(self) -> pyarrow.Table: - return self.client.query(self.query).to_arrow() + with self._query_generator() as query: + return self.client.query(query).to_arrow() def block_until_done( @@ -325,13 +348,13 @@ def _upload_entity_df_and_get_entity_schema( limited_entity_df = ( client.query(f"SELECT * FROM {table_name} LIMIT 1").result().to_dataframe() ) + entity_schema = dict(zip(limited_entity_df.columns, limited_entity_df.dtypes)) elif isinstance(entity_df, pd.DataFrame): # Drop the index so that we dont have unnecessary columns entity_df.reset_index(drop=True, inplace=True) job = _write_df_to_bq(client, entity_df, table_name) block_until_done(client, job) - entity_schema = dict(zip(entity_df.columns, entity_df.dtypes)) else: raise InvalidEntityType(type(entity_df)) diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 62a5933777..1b17ae15b0 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -153,16 +153,17 @@ def query_generator() -> Iterator[str]: full_feature_names=full_feature_names, ) - yield query - - # Clean up the uploaded Redshift table - aws_utils.execute_redshift_statement( - redshift_client, - config.offline_store.cluster_id, - config.offline_store.database, - config.offline_store.user, - f"DROP TABLE {table_name}", - ) + try: + yield query + finally: + # Always clean up the uploaded Redshift table + aws_utils.execute_redshift_statement( + redshift_client, + config.offline_store.cluster_id, + config.offline_store.database, + config.offline_store.user, + f"DROP TABLE IF EXISTS {table_name}", + ) return RedshiftRetrievalJob( query=query_generator,