Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add snowflake online store #2902

Merged
merged 7 commits into from
Jul 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions sdk/python/feast/infra/online_stores/snowflake.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import itertools
import os
from binascii import hexlify
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple

import pandas as pd
import pytz
from pydantic import Field
from pydantic.schema import Literal

from feast import Entity, FeatureView
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.snowflake_utils import get_snowflake_conn, write_pandas_binary
from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto
from feast.protos.feast.types.Value_pb2 import Value as ValueProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.usage import log_exceptions_and_usage


class SnowflakeOnlineStoreConfig(FeastConfigBaseModel):
""" Online store config for Snowflake """

type: Literal["snowflake.online"] = "snowflake.online"
""" Online store type selector"""

config_path: Optional[str] = (
Path(os.environ["HOME"]) / ".snowsql/config"
).__str__()
""" Snowflake config path -- absolute path required (Can't use ~)"""

account: Optional[str] = None
""" Snowflake deployment identifier -- drop .snowflakecomputing.com"""

user: Optional[str] = None
""" Snowflake user name """

password: Optional[str] = None
""" Snowflake password """

role: Optional[str] = None
""" Snowflake role name"""

warehouse: Optional[str] = None
""" Snowflake warehouse name """

database: Optional[str] = None
""" Snowflake database name """

schema_: Optional[str] = Field("PUBLIC", alias="schema")
""" Snowflake schema name """

class Config:
allow_population_by_field_name = True


class SnowflakeOnlineStore(OnlineStore):
sfc-gh-madkins marked this conversation as resolved.
Show resolved Hide resolved
@log_exceptions_and_usage(online_store="snowflake")
def online_write_batch(
self,
config: RepoConfig,
table: FeatureView,
data: List[
Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]
],
progress: Optional[Callable[[int], Any]],
) -> None:
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)

dfs = [None] * len(data)
for i, (entity_key, values, timestamp, created_ts) in enumerate(data):

sfc-gh-madkins marked this conversation as resolved.
Show resolved Hide resolved
df = pd.DataFrame(
columns=[
"entity_feature_key",
"entity_key",
"feature_name",
"value",
"event_ts",
"created_ts",
],
index=range(0, len(values)),
)

timestamp = _to_naive_utc(timestamp)
if created_ts is not None:
created_ts = _to_naive_utc(created_ts)

for j, (feature_name, val) in enumerate(values.items()):
df.loc[j, "entity_feature_key"] = serialize_entity_key(
sfc-gh-madkins marked this conversation as resolved.
Show resolved Hide resolved
entity_key
) + bytes(feature_name, encoding="utf-8")
df.loc[j, "entity_key"] = serialize_entity_key(entity_key)
df.loc[j, "feature_name"] = feature_name
df.loc[j, "value"] = val.SerializeToString()
df.loc[j, "event_ts"] = timestamp
df.loc[j, "created_ts"] = created_ts

dfs[i] = df
if progress:
sfc-gh-madkins marked this conversation as resolved.
Show resolved Hide resolved
progress(1)

if dfs:
agg_df = pd.concat(dfs)

with get_snowflake_conn(config.online_store, autocommit=False) as conn:
sfc-gh-madkins marked this conversation as resolved.
Show resolved Hide resolved

sfc-gh-madkins marked this conversation as resolved.
Show resolved Hide resolved
write_pandas_binary(conn, agg_df, f"{config.project}_{table.name}")

query = f"""
INSERT OVERWRITE INTO "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}"
SELECT
"entity_feature_key",
"entity_key",
"feature_name",
"value",
"event_ts",
"created_ts"
FROM
(SELECT
*,
ROW_NUMBER() OVER(PARTITION BY "entity_key","feature_name" ORDER BY "event_ts" DESC, "created_ts" DESC) AS "_feast_row"
FROM
"{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}")
WHERE
"_feast_row" = 1;
"""

conn.cursor().execute(query)

return None

@log_exceptions_and_usage(online_store="snowflake")
def online_read(
self,
config: RepoConfig,
table: FeatureView,
entity_keys: List[EntityKeyProto],
requested_features: List[str],
) -> List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]]:
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)

result: List[Tuple[Optional[datetime], Optional[Dict[str, ValueProto]]]] = []

with get_snowflake_conn(config.online_store) as conn:

df = (
conn.cursor()
.execute(
f"""
SELECT
"entity_key", "feature_name", "value", "event_ts"
FROM
"{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}"
WHERE
"entity_feature_key" IN ({','.join([('TO_BINARY('+hexlify(serialize_entity_key(combo[0])+bytes(combo[1], encoding='utf-8')).__str__()[1:]+")") for combo in itertools.product(entity_keys,requested_features)])})
""",
)
.fetch_pandas_all()
)

for entity_key in entity_keys:
entity_key_bin = serialize_entity_key(entity_key)
res = {}
res_ts = None
for index, row in df[df["entity_key"] == entity_key_bin].iterrows():
val = ValueProto()
val.ParseFromString(row["value"])
res[row["feature_name"]] = val
res_ts = row["event_ts"].to_pydatetime()

if not res:
result.append((None, None))
else:
result.append((res_ts, res))
return result

@log_exceptions_and_usage(online_store="snowflake")
def update(
self,
config: RepoConfig,
tables_to_delete: Sequence[FeatureView],
tables_to_keep: Sequence[FeatureView],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
partial: bool,
):
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)

with get_snowflake_conn(config.online_store) as conn:

for table in tables_to_keep:

conn.cursor().execute(
f"""CREATE TABLE IF NOT EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}" (
"entity_feature_key" BINARY,
"entity_key" BINARY,
"feature_name" VARCHAR,
"value" BINARY,
"event_ts" TIMESTAMP,
"created_ts" TIMESTAMP
)"""
)

for table in tables_to_delete:

conn.cursor().execute(
f'DROP TABLE IF EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}"'
)

def teardown(
self,
config: RepoConfig,
tables: Sequence[FeatureView],
entities: Sequence[Entity],
):
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)

with get_snowflake_conn(config.online_store) as conn:

for table in tables:
query = f'DROP TABLE IF EXISTS "{config.online_store.database}"."{config.online_store.schema_}"."{config.project}_{table.name}"'
conn.cursor().execute(query)


def _to_naive_utc(ts: datetime):
if ts.tzinfo is None:
return ts
else:
return ts.astimezone(pytz.utc).replace(tzinfo=None)
Loading