Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete tables #1916

Merged
merged 3 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 81 additions & 58 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import uuid
from datetime import date, datetime, timedelta
from typing import Dict, List, Optional, Union
from typing import Callable, ContextManager, Dict, Iterator, List, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -122,38 +123,47 @@ def get_historical_features(
client, client.project, config.offline_store.dataset
)

entity_schema = _upload_entity_df_and_get_entity_schema(
client=client, table_name=table_reference, entity_df=entity_df,
)
@contextlib.contextmanager
def query_generator() -> Iterator[str]:
entity_schema = _upload_entity_df_and_get_entity_schema(
client=client, table_name=table_reference, entity_df=entity_df,
)

entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
entity_schema
)
entity_df_event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df(
entity_schema
)

expected_join_keys = offline_utils.get_expected_join_keys(
project, feature_views, registry
)
expected_join_keys = offline_utils.get_expected_join_keys(
project, feature_views, registry
)

offline_utils.assert_expected_columns_in_entity_df(
entity_schema, expected_join_keys, entity_df_event_timestamp_col
)
offline_utils.assert_expected_columns_in_entity_df(
entity_schema, expected_join_keys, entity_df_event_timestamp_col
)

# Build a query context containing all information required to template the BigQuery SQL query
query_context = offline_utils.get_feature_view_query_context(
feature_refs, feature_views, registry, project,
)
# Build a query context containing all information required to template the BigQuery SQL query
query_context = offline_utils.get_feature_view_query_context(
feature_refs, feature_views, registry, project,
)

# Generate the BigQuery SQL query from the query context
query = offline_utils.build_point_in_time_query(
query_context,
left_table_query_string=table_reference,
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
full_feature_names=full_feature_names,
)
# Generate the BigQuery SQL query from the query context
query = offline_utils.build_point_in_time_query(
query_context,
left_table_query_string=table_reference,
entity_df_event_timestamp_col=entity_df_event_timestamp_col,
query_template=MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN,
full_feature_names=full_feature_names,
)

try:
yield query
finally:
# Asynchronously clean up the uploaded Bigquery table, which will expire
felixwang9817 marked this conversation as resolved.
Show resolved Hide resolved
# if cleanup fails
client.delete_table(table=table_reference, not_found_ok=True)

return BigQueryRetrievalJob(
query=query,
query=query_generator,
client=client,
config=config,
full_feature_names=full_feature_names,
Expand All @@ -166,13 +176,22 @@ def get_historical_features(
class BigQueryRetrievalJob(RetrievalJob):
def __init__(
self,
query: str,
query: Union[str, Callable[[], ContextManager[str]]],
client: bigquery.Client,
config: RepoConfig,
full_feature_names: bool,
on_demand_feature_views: Optional[List[OnDemandFeatureView]],
):
self.query = query
if not isinstance(query, str):
self._query_generator = query
else:

@contextlib.contextmanager
def query_generator() -> Iterator[str]:
assert isinstance(query, str)
yield query

self._query_generator = query_generator
self.client = client
self.config = config
self._full_feature_names = full_feature_names
Expand All @@ -187,15 +206,16 @@ def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]:
return self._on_demand_feature_views

def _to_df_internal(self) -> pd.DataFrame:
# TODO: Ideally only start this job when the user runs "get_historical_features", not when they run to_df()
df = self.client.query(self.query).to_dataframe(create_bqstorage_client=True)
return df
with self._query_generator() as query:
df = self.client.query(query).to_dataframe(create_bqstorage_client=True)
return df

def to_sql(self) -> str:
"""
Returns the SQL query that will be executed in BigQuery to build the historical feature table.
"""
return self.query
with self._query_generator() as query:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: technically the query generated here will not be the exact query used to build the historical feature table, since the randomly generated table name will be different. I couldn't think of a better way to handle to_sql given that we switched to a context manager - open to suggestions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm in that case we run think we should either return a sql strinv with the table name a parameter that canbe filled the, or we should accept it as a param forthr to_sql method. As it stands a full SQL string that does is not expected to execute correctly seems like a bug

return query

def to_bigquery(
self,
Expand All @@ -215,36 +235,39 @@ def to_bigquery(
Returns:
Returns the destination table name or returns None if job_config.dry_run is True.
"""
with self._query_generator() as query:
if not job_config:
today = date.today().strftime("%Y%m%d")
rand_id = str(uuid.uuid4())[:7]
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
job_config = bigquery.QueryJobConfig(destination=path)

if not job_config.dry_run and self.on_demand_feature_views is not None:
job = _write_pyarrow_table_to_bq(
self.client, self.to_arrow(), job_config.destination
)
job.result()
print(f"Done writing to '{job_config.destination}'.")
return str(job_config.destination)

bq_job = self.client.query(query, job_config=job_config)

if job_config.dry_run:
print(
"This query will process {} bytes.".format(
bq_job.total_bytes_processed
)
)
return None

block_until_done(client=self.client, bq_job=bq_job, timeout=timeout)

if not job_config:
today = date.today().strftime("%Y%m%d")
rand_id = str(uuid.uuid4())[:7]
path = f"{self.client.project}.{self.config.offline_store.dataset}.historical_{today}_{rand_id}"
job_config = bigquery.QueryJobConfig(destination=path)

if not job_config.dry_run and self.on_demand_feature_views is not None:
job = _write_pyarrow_table_to_bq(
self.client, self.to_arrow(), job_config.destination
)
job.result()
print(f"Done writing to '{job_config.destination}'.")
return str(job_config.destination)

bq_job = self.client.query(self.query, job_config=job_config)

if job_config.dry_run:
print(
"This query will process {} bytes.".format(bq_job.total_bytes_processed)
)
return None

block_until_done(client=self.client, bq_job=bq_job, timeout=timeout)

print(f"Done writing to '{job_config.destination}'.")
return str(job_config.destination)

def _to_arrow_internal(self) -> pyarrow.Table:
return self.client.query(self.query).to_arrow()
with self._query_generator() as query:
return self.client.query(query).to_arrow()


def block_until_done(
Expand Down Expand Up @@ -325,13 +348,13 @@ def _upload_entity_df_and_get_entity_schema(
limited_entity_df = (
client.query(f"SELECT * FROM {table_name} LIMIT 1").result().to_dataframe()
)

entity_schema = dict(zip(limited_entity_df.columns, limited_entity_df.dtypes))
elif isinstance(entity_df, pd.DataFrame):
# Drop the index so that we dont have unnecessary columns
entity_df.reset_index(drop=True, inplace=True)
job = _write_df_to_bq(client, entity_df, table_name)
block_until_done(client, job)

entity_schema = dict(zip(entity_df.columns, entity_df.dtypes))
else:
raise InvalidEntityType(type(entity_df))
Expand Down
21 changes: 11 additions & 10 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,17 @@ def query_generator() -> Iterator[str]:
full_feature_names=full_feature_names,
)

yield query

# Clean up the uploaded Redshift table
aws_utils.execute_redshift_statement(
redshift_client,
config.offline_store.cluster_id,
config.offline_store.database,
config.offline_store.user,
f"DROP TABLE {table_name}",
)
try:
yield query
finally:
# Always clean up the uploaded Redshift table
aws_utils.execute_redshift_statement(
redshift_client,
config.offline_store.cluster_id,
config.offline_store.database,
config.offline_store.user,
f"DROP TABLE IF EXISTS {table_name}",
)

return RedshiftRetrievalJob(
query=query_generator,
Expand Down