Skip to content

Commit

Permalink
implement to_remote_storage method
Browse files Browse the repository at this point in the history
Signed-off-by: niklasvm <niklasvm@gmail.com>
  • Loading branch information
niklasvm committed Sep 3, 2022
1 parent b4ef834 commit 844bb83
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import tempfile
import uuid
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Expand All @@ -13,6 +15,7 @@
from pyspark import SparkConf
from pyspark.sql import SparkSession
from pytz import utc
from sdk.python.feast.infra.utils import aws_utils

from feast import FeatureView, OnDemandFeatureView
from feast.data_source import DataSource
Expand Down Expand Up @@ -46,6 +49,12 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel):
""" Configuration overlay for the spark session """
# sparksession is not serializable and we dont want to pass it around as an argument

staging_location: Optional[StrictStr] = None
""" Remote path for batch materialization jobs"""

region: Optional[StrictStr] = None
""" AWS Region if applicable for s3-based staging locations"""


class SparkOfflineStore(OfflineStore):
@staticmethod
Expand Down Expand Up @@ -105,6 +114,7 @@ def pull_latest_from_table_or_query(
return SparkRetrievalJob(
spark_session=spark_session,
query=query,
config=config,
full_feature_names=False,
on_demand_feature_views=None,
)
Expand All @@ -129,6 +139,7 @@ def get_historical_features(
"Some functionality may still be unstable so functionality can change in the future.",
RuntimeWarning,
)

spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=config.offline_store
)
Expand Down Expand Up @@ -192,6 +203,7 @@ def get_historical_features(
min_event_timestamp=entity_df_event_timestamp_range[0],
max_event_timestamp=entity_df_event_timestamp_range[1],
),
config=config,
)

@staticmethod
Expand Down Expand Up @@ -286,7 +298,10 @@ def pull_all_from_table_or_query(
"""

return SparkRetrievalJob(
spark_session=spark_session, query=query, full_feature_names=False
spark_session=spark_session,
query=query,
full_feature_names=False,
config=config,
)


Expand All @@ -296,6 +311,7 @@ def __init__(
spark_session: SparkSession,
query: str,
full_feature_names: bool,
config: RepoConfig,
on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None,
metadata: Optional[RetrievalMetadata] = None,
):
Expand All @@ -305,6 +321,7 @@ def __init__(
self._full_feature_names = full_feature_names
self._on_demand_feature_views = on_demand_feature_views or []
self._metadata = metadata
self._config = config

@property
def full_feature_names(self) -> bool:
Expand Down Expand Up @@ -342,6 +359,53 @@ def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
raise ValueError("Cannot persist, table_name is not defined")
self.to_spark_df().createOrReplaceTempView(table_name)

def supports_remote_storage_export(self) -> bool:
return self._config.offline_store.staging_location is not None

def to_remote_storage(self) -> List[str]:
"""Currently only works for local and s3-based staging locations"""
if self.supports_remote_storage_export():

sdf: pyspark.sql.DataFrame = self.to_spark_df()

if self._config.offline_store.staging_location.startswith("file://"):
local_file_staging_location = os.path.abspath(
self._config.offline_store.staging_location
)

# write to staging location
output_uri = os.path.join(
str(local_file_staging_location), str(uuid.uuid4())
)
sdf.write.parquet(output_uri)

return _list_files_in_folder(output_uri)
elif self._config.offline_store.staging_location.startswith("s3://"):

spark_compatible_s3_staging_location = (
self._config.offline_store.staging_location.replace(
"s3://", "s3a://"
)
)

# write to staging location
output_uri = os.path.join(
str(spark_compatible_s3_staging_location), str(uuid.uuid4())
)
sdf.write.parquet(output_uri)

return aws_utils.list_s3_files(
self._config.offline_store.region, output_uri
)

else:
raise NotImplementedError(
"to_remote_storage is only implemented for file:// and s3:// uri schemes"
)

else:
raise NotImplementedError()

@property
def metadata(self) -> Optional[RetrievalMetadata]:
"""
Expand Down Expand Up @@ -444,6 +508,17 @@ def _format_datetime(t: datetime) -> str:
return dt


def _list_files_in_folder(folder):
"""List full filenames in a folder"""
files = []
for file in os.listdir(folder):
filename = os.path.join(folder, file)
if os.path.isfile(filename):
files.append(filename)

return files


def _cast_data_frame(
df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame
) -> pyspark.sql.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def create_offline_store_config(self):
self.spark_offline_store_config = SparkOfflineStoreConfig()
self.spark_offline_store_config.type = "spark"
self.spark_offline_store_config.spark_conf = self.spark_conf
self.spark_offline_store_config.staging_location = "file://" + str(
tempfile.TemporaryDirectory()
)
self.spark_offline_store_config.region = "eu-west-1"
return self.spark_offline_store_config

def create_data_source(
Expand Down

0 comments on commit 844bb83

Please sign in to comment.