From c6ace5594870e67ffe034b268d31eda26d71bc79 Mon Sep 17 00:00:00 2001 From: Kevin Zhang Date: Fri, 1 Apr 2022 12:54:54 -0700 Subject: [PATCH 1/3] Update snowflake source Signed-off-by: Kevin Zhang --- protos/feast/core/DataSource.proto | 3 +++ .../infra/offline_stores/snowflake_source.py | 25 +++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/protos/feast/core/DataSource.proto b/protos/feast/core/DataSource.proto index 8fe84274a1..07b2f978b8 100644 --- a/protos/feast/core/DataSource.proto +++ b/protos/feast/core/DataSource.proto @@ -161,6 +161,9 @@ message DataSource { // Snowflake schema name string database = 4; + + // Snowflake warehouse name + string warehouse = 5; } // Defines configuration for custom third-party data sources. diff --git a/sdk/python/feast/infra/offline_stores/snowflake_source.py b/sdk/python/feast/infra/offline_stores/snowflake_source.py index 40868ef64d..5ee7999fbc 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake_source.py +++ b/sdk/python/feast/infra/offline_stores/snowflake_source.py @@ -16,6 +16,7 @@ class SnowflakeSource(DataSource): def __init__( self, database: Optional[str] = None, + warehouse: Optional[str] = None, schema: Optional[str] = None, table: Optional[str] = None, query: Optional[str] = None, @@ -33,6 +34,7 @@ def __init__( Args: database (optional): Snowflake database where the features are stored. + warehouse (optional): Snowflake warehouse where the database is stored. schema (optional): Snowflake schema in which the table is located. table (optional): Snowflake table where the features are stored. event_timestamp_column (optional): Event timestamp column used for point in @@ -55,7 +57,7 @@ def __init__( _schema = "PUBLIC" if (database and table and not schema) else schema self.snowflake_options = SnowflakeOptions( - database=database, schema=_schema, table=table, query=query + database=database, schema=_schema, table=table, query=query, warehouse=warehouse ) # If no name, use the table as the default name @@ -152,6 +154,11 @@ def query(self): """Returns the snowflake options of this snowflake source.""" return self.snowflake_options.query + @property + def warehouse(self): + """Returns the warehouse of this snowflake source.""" + return self.snowflake_options.warehouse + def to_proto(self) -> DataSourceProto: """ Converts a SnowflakeSource object to its protobuf representation. @@ -239,11 +246,13 @@ def __init__( schema: Optional[str], table: Optional[str], query: Optional[str], + warehouse: Optional[str], ): self._database = database self._schema = schema self._table = table self._query = query + self._warehouse = warehouse @property def query(self): @@ -285,6 +294,16 @@ def table(self, table): """Sets the table ref of this snowflake table.""" self._table = table + @property + def warehouse(self): + """Returns the warehouse name of this snowflake table.""" + return self._warehouse + + @table.setter + def warehouse(self, warehouse): + """Sets the warehouse name of this snowflake table.""" + self._warehouse = warehouse + @classmethod def from_proto(cls, snowflake_options_proto: DataSourceProto.SnowflakeOptions): """ @@ -301,6 +320,7 @@ def from_proto(cls, snowflake_options_proto: DataSourceProto.SnowflakeOptions): schema=snowflake_options_proto.schema, table=snowflake_options_proto.table, query=snowflake_options_proto.query, + warehouse=snowflake_options_proto.warehouse, ) return snowflake_options @@ -317,6 +337,7 @@ def to_proto(self) -> DataSourceProto.SnowflakeOptions: schema=self.schema, table=self.table, query=self.query, + warehouse=self.warehouse, ) return snowflake_options_proto @@ -329,7 +350,7 @@ class SavedDatasetSnowflakeStorage(SavedDatasetStorage): def __init__(self, table_ref: str): self.snowflake_options = SnowflakeOptions( - database=None, schema=None, table=table_ref, query=None + database=None, schema=None, table=table_ref, query=None, warehouse=None ) @staticmethod From b96a3ab2653c72904c885f3fc287137c07fb7b27 Mon Sep 17 00:00:00 2001 From: Kevin Zhang Date: Mon, 4 Apr 2022 16:02:08 -0700 Subject: [PATCH 2/3] Fix snowflake Signed-off-by: Kevin Zhang --- sdk/python/feast/infra/offline_stores/snowflake.py | 6 ++++++ sdk/python/feast/infra/offline_stores/snowflake_source.py | 2 ++ sdk/python/feast/templates/snowflake/bootstrap.py | 6 +++++- sdk/python/feast/templates/snowflake/driver_repo.py | 1 + .../feature_repos/universal/data_sources/snowflake.py | 1 + 5 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index cc346251a8..968055fcee 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -128,6 +128,9 @@ def pull_latest_from_table_or_query( + '"' ) + if data_source.snowflake_options.warehouse: + config.offline_store.warehouse = data_source.snowflake_options.warehouse + snowflake_conn = get_snowflake_conn(config.offline_store) query = f""" @@ -173,6 +176,9 @@ def pull_all_from_table_or_query( + '"' ) + if data_source.snowflake_options.warehouse: + config.offline_store.warehouse = data_source.snowflake_options.warehouse + snowflake_conn = get_snowflake_conn(config.offline_store) start_date = start_date.astimezone(tz=utc) diff --git a/sdk/python/feast/infra/offline_stores/snowflake_source.py b/sdk/python/feast/infra/offline_stores/snowflake_source.py index 5ee7999fbc..c15e5fe1a3 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake_source.py +++ b/sdk/python/feast/infra/offline_stores/snowflake_source.py @@ -101,6 +101,7 @@ def from_proto(data_source: DataSourceProto): database=data_source.snowflake_options.database, schema=data_source.snowflake_options.schema, table=data_source.snowflake_options.table, + warehouse=data_source.snowflake_options.warehouse, event_timestamp_column=data_source.event_timestamp_column, created_timestamp_column=data_source.created_timestamp_column, date_partition_column=data_source.date_partition_column, @@ -126,6 +127,7 @@ def __eq__(self, other): and self.snowflake_options.schema == other.snowflake_options.schema and self.snowflake_options.table == other.snowflake_options.table and self.snowflake_options.query == other.snowflake_options.query + and self.snowflake_options.warehouse == other.snowflake_options.warehouse and self.event_timestamp_column == other.event_timestamp_column and self.created_timestamp_column == other.created_timestamp_column and self.field_mapping == other.field_mapping diff --git a/sdk/python/feast/templates/snowflake/bootstrap.py b/sdk/python/feast/templates/snowflake/bootstrap.py index 3712651a5d..6acbcab77d 100644 --- a/sdk/python/feast/templates/snowflake/bootstrap.py +++ b/sdk/python/feast/templates/snowflake/bootstrap.py @@ -68,7 +68,7 @@ def bootstrap(): repo_path = pathlib.Path(__file__).parent.absolute() config_file = repo_path / "feature_store.yaml" - + driver_file = repo_path / "driver_repo.py" replace_str_in_file( config_file, "SNOWFLAKE_DEPLOYMENT_URL", snowflake_deployment_url ) @@ -78,6 +78,10 @@ def bootstrap(): replace_str_in_file(config_file, "SNOWFLAKE_WAREHOUSE", snowflake_warehouse) replace_str_in_file(config_file, "SNOWFLAKE_DATABASE", snowflake_database) + replace_str_in_file(driver_file, "SNOWFLAKE_WAREHOUSE", snowflake_warehouse) + + + def replace_str_in_file(file_path, match_str, sub_str): with open(file_path, "r") as f: diff --git a/sdk/python/feast/templates/snowflake/driver_repo.py b/sdk/python/feast/templates/snowflake/driver_repo.py index a63c6cb503..0ecdad7f05 100644 --- a/sdk/python/feast/templates/snowflake/driver_repo.py +++ b/sdk/python/feast/templates/snowflake/driver_repo.py @@ -24,6 +24,7 @@ # The Snowflake table where features can be found database=yaml.safe_load(open("feature_store.yaml"))["offline_store"]["database"], table=f"{project_name}_feast_driver_hourly_stats", + warehouse="SNOWFLAKE_WAREHOUSE", # The event timestamp is used for point-in-time joins and for ensuring only # features within the TTL are returned event_timestamp_column="event_timestamp", diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py index f76656f5b7..05cdea82f0 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py @@ -57,6 +57,7 @@ def create_data_source( created_timestamp_column=created_timestamp_column, date_partition_column="", field_mapping=field_mapping or {"ts_1": "ts"}, + warehouse=self.offline_store_config.warehouse, ) def create_saved_dataset_destination(self) -> SavedDatasetSnowflakeStorage: From 9ff0ac10d51987e8771eb775a18fac2a42fdb9af Mon Sep 17 00:00:00 2001 From: Kevin Zhang Date: Mon, 4 Apr 2022 16:04:24 -0700 Subject: [PATCH 3/3] Fix Signed-off-by: Kevin Zhang --- sdk/python/feast/infra/offline_stores/snowflake_source.py | 8 ++++++-- sdk/python/feast/templates/snowflake/bootstrap.py | 2 -- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/snowflake_source.py b/sdk/python/feast/infra/offline_stores/snowflake_source.py index c15e5fe1a3..f094d2b329 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake_source.py +++ b/sdk/python/feast/infra/offline_stores/snowflake_source.py @@ -57,7 +57,11 @@ def __init__( _schema = "PUBLIC" if (database and table and not schema) else schema self.snowflake_options = SnowflakeOptions( - database=database, schema=_schema, table=table, query=query, warehouse=warehouse + database=database, + schema=_schema, + table=table, + query=query, + warehouse=warehouse, ) # If no name, use the table as the default name @@ -301,7 +305,7 @@ def warehouse(self): """Returns the warehouse name of this snowflake table.""" return self._warehouse - @table.setter + @warehouse.setter def warehouse(self, warehouse): """Sets the warehouse name of this snowflake table.""" self._warehouse = warehouse diff --git a/sdk/python/feast/templates/snowflake/bootstrap.py b/sdk/python/feast/templates/snowflake/bootstrap.py index 6acbcab77d..194ba08c08 100644 --- a/sdk/python/feast/templates/snowflake/bootstrap.py +++ b/sdk/python/feast/templates/snowflake/bootstrap.py @@ -81,8 +81,6 @@ def bootstrap(): replace_str_in_file(driver_file, "SNOWFLAKE_WAREHOUSE", snowflake_warehouse) - - def replace_str_in_file(file_path, match_str, sub_str): with open(file_path, "r") as f: contents = f.read()