Skip to content

Commit

Permalink
chore: Snowflake code cleanup (#3118)
Browse files Browse the repository at this point in the history
Signed-off-by: Miles Adkins <miles.adkins@snowflake.com>

Signed-off-by: Miles Adkins <miles.adkins@snowflake.com>
  • Loading branch information
sfc-gh-madkins authored Aug 23, 2022
1 parent b4d0f6d commit 3910a9c
Show file tree
Hide file tree
Showing 10 changed files with 139 additions and 165 deletions.
89 changes: 41 additions & 48 deletions sdk/python/feast/infra/materialization/snowflake_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from feast.infra.utils.snowflake.snowflake_utils import (
_run_snowflake_field_mapping,
assert_snowflake_feature_names,
execute_snowflake_statement,
get_snowflake_conn,
get_snowflake_materialization_config,
package_snowpark_zip,
)
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
Expand Down Expand Up @@ -125,30 +125,32 @@ def update(
)
click.echo()

conn_config = get_snowflake_materialization_config(self.repo_config)

stage_context = f'"{conn_config.database}"."{conn_config.schema_}"'
stage_context = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"'
stage_path = f'{stage_context}."feast_{project}"'
with get_snowflake_conn(conn_config) as conn:
cur = conn.cursor()
with get_snowflake_conn(self.repo_config.batch_engine) as conn:
query = f"SHOW STAGES IN {stage_context}"
cursor = execute_snowflake_statement(conn, query)
stage_list = pd.DataFrame(
cursor.fetchall(),
columns=[column.name for column in cursor.description],
)

# if the stage already exists,
# assumes that the materialization functions have been deployed
cur.execute(f"SHOW STAGES IN {stage_context}")
stage_list = pd.DataFrame(
cur.fetchall(), columns=[column.name for column in cur.description]
)
if f"feast_{project}" in stage_list["name"].tolist():
click.echo(
f"Materialization functions for {Style.BRIGHT + Fore.GREEN}{project}{Style.RESET_ALL} already exists"
)
click.echo()
return None

cur.execute(f"CREATE STAGE {stage_path}")
query = f"CREATE STAGE {stage_path}"
execute_snowflake_statement(conn, query)

copy_path, zip_path = package_snowpark_zip(project)
cur.execute(f"PUT file://{zip_path} @{stage_path}")
query = f"PUT file://{zip_path} @{stage_path}"
execute_snowflake_statement(conn, query)

shutil.rmtree(copy_path)

# Execute snowflake python udf creation functions
Expand All @@ -159,8 +161,8 @@ def update(
sqlCommands = sqlFile.split(";")
for command in sqlCommands:
command = command.replace("STAGE_HOLDER", f"{stage_path}")
command = command.replace("PROJECT_NAME", f"{project}")
cur.execute(command)
query = command.replace("PROJECT_NAME", f"{project}")
execute_snowflake_statement(conn, query)

return None

Expand All @@ -170,15 +172,11 @@ def teardown_infra(
fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]],
entities: Sequence[Entity],
):
conn_config = get_snowflake_materialization_config(self.repo_config)

stage_path = (
f'"{conn_config.database}"."{conn_config.schema_}"."feast_{project}"'
)
with get_snowflake_conn(conn_config) as conn:
cur = conn.cursor()

cur.execute(f"DROP STAGE IF EXISTS {stage_path}")
stage_path = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"."feast_{project}"'
with get_snowflake_conn(self.repo_config.batch_engine) as conn:
query = f"DROP STAGE IF EXISTS {stage_path}"
execute_snowflake_statement(conn, query)

# Execute snowflake python udf deletion functions
sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/snowpark/snowflake_python_udfs_deletion.sql"
Expand All @@ -187,8 +185,8 @@ def teardown_infra(

sqlCommands = sqlFile.split(";")
for command in sqlCommands:
command = command.replace("PROJECT_NAME", f"{project}")
cur.execute(command)
query = command.replace("PROJECT_NAME", f"{project}")
execute_snowflake_statement(conn, query)

return None

Expand Down Expand Up @@ -239,8 +237,6 @@ def _materialize_one(
feature_view, FeatureView
), "Snowflake can only materialize FeatureView & BatchFeatureView feature view types."

repo_config = self.repo_config

entities = []
for entity_name in feature_view.entities:
entities.append(registry.get_entity(entity_name, project))
Expand All @@ -256,7 +252,7 @@ def _materialize_one(

try:
offline_job = self.offline_store.pull_latest_from_table_or_query(
config=repo_config,
config=self.repo_config,
data_source=feature_view.batch_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
Expand All @@ -274,22 +270,22 @@ def _materialize_one(
)

fv_to_proto_sql = self.generate_snowflake_materialization_query(
repo_config,
self.repo_config,
fv_latest_mapped_values_sql,
feature_view,
project,
)

if repo_config.online_store.type == "snowflake.online":
if self.repo_config.online_store.type == "snowflake.online":
self.materialize_to_snowflake_online_store(
repo_config,
self.repo_config,
fv_to_proto_sql,
feature_view,
project,
)
else:
self.materialize_to_external_online_store(
repo_config,
self.repo_config,
fv_to_proto_sql,
feature_view,
tqdm_builder,
Expand Down Expand Up @@ -373,9 +369,7 @@ def materialize_to_snowflake_online_store(
) -> None:
assert_snowflake_feature_names(feature_view)

conn_config = get_snowflake_materialization_config(repo_config)

online_table = f"""{repo_config.online_store.database}"."{repo_config.online_store.schema_}"."[online-transient] {project}_{feature_view.name}"""
online_table = f"""{repo_config .online_store.database}"."{repo_config.online_store.schema_}"."[online-transient] {project}_{feature_view.name}"""

feature_names_str = '", "'.join(
[feature.name for feature in feature_view.features]
Expand All @@ -386,7 +380,7 @@ def materialize_to_snowflake_online_store(
else:
fv_created_str = None

fv_to_online = f"""
query = f"""
MERGE INTO "{online_table}" online_table
USING (
SELECT
Expand Down Expand Up @@ -420,14 +414,12 @@ def materialize_to_snowflake_online_store(
)
"""

with get_snowflake_conn(conn_config) as conn:
cur = conn.cursor()
cur.execute(fv_to_online)
with get_snowflake_conn(repo_config.batch_engine) as conn:
query_id = execute_snowflake_statement(conn, query).sfqid

query_id = cur.sfqid
click.echo(
f"Snowflake Query ID: {Style.BRIGHT + Fore.GREEN}{query_id}{Style.RESET_ALL}"
)
click.echo(
f"Snowflake Query ID: {Style.BRIGHT + Fore.GREEN}{query_id}{Style.RESET_ALL}"
)
return None

def materialize_to_external_online_store(
Expand All @@ -437,15 +429,16 @@ def materialize_to_external_online_store(
feature_view: Union[StreamFeatureView, FeatureView],
tqdm_builder: Callable[[int], tqdm],
) -> None:
conn_config = get_snowflake_materialization_config(repo_config)

feature_names = [feature.name for feature in feature_view.features]

with get_snowflake_conn(conn_config) as conn:
cur = conn.cursor()
cur.execute(materialization_sql)
for i, df in enumerate(cur.fetch_pandas_batches()):
click.echo(f"Snowflake: Processing ResultSet Batch #{i+1}")
with get_snowflake_conn(repo_config.batch_engine) as conn:
query = materialization_sql
cursor = execute_snowflake_statement(conn, query)
for i, df in enumerate(cursor.fetch_pandas_batches()):
click.echo(
f"Snowflake: Processing Materialization ResultSet Batch #{i+1}"
)

entity_keys = (
df["entity_key"].apply(EntityKeyProto.FromString).to_numpy()
Expand Down
21 changes: 11 additions & 10 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,21 +485,22 @@ def to_remote_storage(self) -> List[str]:
table = f"temporary_{uuid.uuid4().hex}"
self.to_snowflake(table)

copy_into_query = f"""copy into '{self.config.offline_store.blob_export_location}/{table}' from "{self.config.offline_store.database}"."{self.config.offline_store.schema_}"."{table}"\n
storage_integration = {self.config.offline_store.storage_integration_name}\n
file_format = (TYPE = PARQUET)\n
DETAILED_OUTPUT = TRUE\n
HEADER = TRUE;\n
query = f"""
COPY INTO '{self.config.offline_store.blob_export_location}/{table}' FROM "{self.config.offline_store.database}"."{self.config.offline_store.schema_}"."{table}"\n
STORAGE_INTEGRATION = {self.config.offline_store.storage_integration_name}\n
FILE_FORMAT = (TYPE = PARQUET)
DETAILED_OUTPUT = TRUE
HEADER = TRUE
"""
cursor = execute_snowflake_statement(self.snowflake_conn, query)

cursor = execute_snowflake_statement(self.snowflake_conn, copy_into_query)
all_rows = (
cursor.fetchall()
) # This may be need pagination at some point in the future.
file_name_column_index = [
idx for idx, rm in enumerate(cursor.description) if rm.name == "FILE_NAME"
][0]
return [f"{self.export_path}/{row[file_name_column_index]}" for row in all_rows]
return [
f"{self.export_path}/{row[file_name_column_index]}"
for row in cursor.fetchall()
]


def _get_entity_schema(
Expand Down
57 changes: 27 additions & 30 deletions sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ def get_table_column_names_and_types(
Args:
config: A RepoConfig describing the feature repo
"""

from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig
from feast.infra.utils.snowflake.snowflake_utils import (
execute_snowflake_statement,
Expand All @@ -217,23 +216,26 @@ def get_table_column_names_and_types(

assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig)

snowflake_conn = get_snowflake_conn(config.offline_store)

query = f"SELECT * FROM {self.get_table_query_string()} LIMIT 5"

result_cur = execute_snowflake_statement(snowflake_conn, query)
with get_snowflake_conn(config.offline_store) as conn:
query = f"SELECT * FROM {self.get_table_query_string()} LIMIT 5"
cursor = execute_snowflake_statement(conn, query)

metadata = [
{
"column_name": column.name,
"type_code": column.type_code,
"precision": column.precision,
"scale": column.scale,
"is_nullable": column.is_nullable,
"snowflake_type": None,
}
for column in cursor.description
]

metadata = [
{
"column_name": column.name,
"type_code": column.type_code,
"precision": column.precision,
"scale": column.scale,
"is_nullable": column.is_nullable,
"snowflake_type": None,
}
for column in result_cur.description
]
if cursor.fetch_pandas_all().empty:
raise DataSourceNotFoundException(
"The following source:\n" + query + "\n ... is empty"
)

for row in metadata:
if row["type_code"] == 0:
Expand All @@ -244,12 +246,12 @@ def get_table_column_names_and_types(
row["snowflake_type"] = "NUMBER64"
else:
column = row["column_name"]
query = f'SELECT MAX("{column}") AS "{column}" FROM {self.get_table_query_string()}'

result = execute_snowflake_statement(
snowflake_conn, query
).fetch_pandas_all()

with get_snowflake_conn(config.offline_store) as conn:
query = f'SELECT MAX("{column}") AS "{column}" FROM {self.get_table_query_string()}'
result = execute_snowflake_statement(
conn, query
).fetch_pandas_all()
if (
result.dtypes[column].name
in python_int_to_snowflake_type_map
Expand Down Expand Up @@ -277,14 +279,9 @@ def get_table_column_names_and_types(
f"The following Snowflake Column is not supported: {row['column_name']} (type_code: {row['type_code']})"
)

if not result_cur.fetch_pandas_all().empty:
return [
(column["column_name"], column["snowflake_type"]) for column in metadata
]
else:
raise DataSourceNotFoundException(
"The following source:\n" + query + "\n ... is empty"
)
return [
(column["column_name"], column["snowflake_type"]) for column in metadata
]


snowflake_type_code_map = {
Expand Down
Loading

0 comments on commit 3910a9c

Please sign in to comment.