diff --git a/sdk/python/feast/job.py b/sdk/python/feast/job.py index 21b08224ba..049cd3613c 100644 --- a/sdk/python/feast/job.py +++ b/sdk/python/feast/job.py @@ -1,4 +1,3 @@ -import tempfile import time from datetime import datetime, timedelta from typing import List @@ -6,7 +5,6 @@ import fastavro import pandas as pd -from google.cloud import storage from google.protobuf.json_format import MessageToJson from feast.core.CoreService_pb2 import ListIngestionJobsRequest @@ -23,6 +21,7 @@ from feast.serving.ServingService_pb2 import Job as JobProto from feast.serving.ServingService_pb2_grpc import ServingServiceStub from feast.source import Source +from feast.staging.staging_strategy import StagingStrategy # Maximum no of seconds to wait until the retrieval jobs status is DONE in Feast # Currently set to the maximum query execution time limit in BigQuery @@ -47,8 +46,7 @@ def __init__( """ self.job_proto = job_proto self.serving_stub = serving_stub - # TODO: abstract away GCP depedency - self.gcs_client = storage.Client(project=None) + self.staging_strategy = StagingStrategy() @property def id(self): @@ -126,16 +124,7 @@ def result(self, timeout_sec: int = DEFAULT_TIMEOUT_SEC): """ uris = self.get_avro_files(timeout_sec) for file_uri in uris: - if file_uri.scheme == "gs": - file_obj = tempfile.TemporaryFile() - self.gcs_client.download_blob_to_file(file_uri.geturl(), file_obj) - elif file_uri.scheme == "file": - file_obj = open(file_uri.path, "rb") - else: - raise Exception( - f"Could not identify file URI {file_uri}. Only gs:// and file:// supported" - ) - + file_obj = self.staging_strategy.execute_file_download(file_uri) file_obj.seek(0) avro_reader = fastavro.reader(file_obj) diff --git a/sdk/python/feast/loaders/file.py b/sdk/python/feast/loaders/file.py index 52cc8ae7dc..1ced7186c8 100644 --- a/sdk/python/feast/loaders/file.py +++ b/sdk/python/feast/loaders/file.py @@ -13,18 +13,18 @@ # limitations under the License. import os -import re import shutil import tempfile import uuid from datetime import datetime from typing import List, Optional, Tuple, Union -from urllib.parse import ParseResult, urlparse +from urllib.parse import urlparse import pandas as pd -from google.cloud import storage from pandavro import to_avro +from feast.staging.staging_strategy import StagingStrategy + def export_source_to_staging_location( source: Union[pd.DataFrame, str], staging_location_uri: str @@ -58,6 +58,7 @@ def export_source_to_staging_location( remote staging location. """ + staging_strategy = StagingStrategy() uri = urlparse(staging_location_uri) # Prepare Avro file to be exported to staging location @@ -66,26 +67,21 @@ def export_source_to_staging_location( uri_path = None # type: Optional[str] if uri.scheme == "file": uri_path = uri.path - # Remote gs staging location provided by serving dir_path, file_name, source_path = export_dataframe_to_local( df=source, dir_path=uri_path ) - elif urlparse(source).scheme in ["", "file"]: - # Local file provided as a source - dir_path = None - file_name = os.path.basename(source) - source_path = os.path.abspath( - os.path.join(urlparse(source).netloc, urlparse(source).path) - ) - elif urlparse(source).scheme == "gs": - # Google Cloud Storage path provided - input_source_uri = urlparse(source) - if "*" in source: - # Wildcard path - return _get_files(bucket=input_source_uri.hostname, uri=input_source_uri) + elif isinstance(source, str): + if urlparse(source).scheme in ["", "file"]: + # Local file provided as a source + dir_path = None + file_name = os.path.basename(source) + source_path = os.path.abspath( + os.path.join(urlparse(source).netloc, urlparse(source).path) + ) else: - return [source] + # gs, s3 file provided as a source. + return staging_strategy.execute_get_source_files(source) else: raise Exception( f"Only string and DataFrame types are allowed as a " @@ -93,20 +89,12 @@ def export_source_to_staging_location( ) # Push data to required staging location - if uri.scheme == "gs": - # Staging location is a Google Cloud Storage path - upload_file_to_gcs( - source_path, uri.hostname, str(uri.path).strip("/") + "/" + file_name - ) - elif uri.scheme == "file": - # Staging location is a file path - # Used for end-to-end test - pass - else: - raise Exception( - f"Staging location {staging_location_uri} does not have a " - f"valid URI. Only gs:// and file:// uri scheme are supported." - ) + staging_strategy.execute_file_upload( + uri.scheme, + source_path, + uri.hostname, + str(uri.path).strip("/") + "/" + file_name, + ) # Clean up, remove local staging file if dir_path and isinstance(source, pd.DataFrame) and len(str(dir_path)) > 4: @@ -160,70 +148,6 @@ def export_dataframe_to_local( return dir_path, file_name, dest_path -def upload_file_to_gcs(local_path: str, bucket: str, remote_path: str) -> None: - """ - Upload a file from the local file system to Google Cloud Storage (GCS). - - Args: - local_path (str): - Local filesystem path of file to upload. - - bucket (str): - GCS bucket destination to upload to. - - remote_path (str): - Path within GCS bucket to upload file to, includes file name. - - Returns: - None: - None - """ - - storage_client = storage.Client(project=None) - bucket = storage_client.get_bucket(bucket) - blob = bucket.blob(remote_path) - blob.upload_from_filename(local_path) - - -def _get_files(bucket: str, uri: ParseResult) -> List[str]: - """ - List all available files within a Google storage bucket that matches a wild - card path. - - Args: - bucket (str): - Google Storage bucket to reference. - - uri (urllib.parse.ParseResult): - Wild card uri path containing the "*" character. - Example: - * gs://feast/staging_location/* - * gs://feast/staging_location/file_*.avro - - Returns: - List[str]: - List of all available files matching the wildcard path. - """ - - storage_client = storage.Client(project=None) - bucket = storage_client.get_bucket(bucket) - path = uri.path - - if "*" in path: - regex = re.compile(path.replace("*", ".*?").strip("/")) - blob_list = bucket.list_blobs( - prefix=path.strip("/").split("*")[0], delimiter="/" - ) - # File path should not be in path (file path must be longer than path) - return [ - f"{uri.scheme}://{uri.hostname}/{file}" - for file in [x.name for x in blob_list] - if re.match(regex, file) and file not in path - ] - else: - raise Exception(f"{path} is not a wildcard path") - - def _get_file_name() -> str: """ Create a random file name. diff --git a/sdk/python/feast/staging/__init__.py b/sdk/python/feast/staging/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/python/feast/staging/staging_strategy.py b/sdk/python/feast/staging/staging_strategy.py new file mode 100644 index 0000000000..f87335f1d9 --- /dev/null +++ b/sdk/python/feast/staging/staging_strategy.py @@ -0,0 +1,162 @@ +import re +from abc import ABC, ABCMeta, abstractmethod +from enum import Enum +from tempfile import TemporaryFile +from typing import List +from urllib.parse import ParseResult, urlparse + +import boto3 +from google.cloud import storage + + +class PROTOCOL(Enum): + GS = "gs" + S3 = "s3" + LOCAL_FILE = "file" + + +class StagingStrategy: + def __init__(self): + self._protocol_dict = dict() + + def execute_file_download(self, file_uri: ParseResult) -> TemporaryFile: + protocol = self._get_staging_protocol(file_uri.scheme) + return protocol.download_file(file_uri) + + def execute_get_source_files(self, source: str) -> List[str]: + uri = urlparse(source) + if "*" in uri.path: + protocol = self._get_staging_protocol(uri.scheme) + return protocol.list_files(bucket=uri.hostname, uri=uri) + elif PROTOCOL(uri.scheme) in [PROTOCOL.S3, PROTOCOL.GS]: + return [source] + else: + raise Exception( + f"Could not identify file protocol {uri.scheme}. Only gs:// and file:// and s3:// supported" + ) + + def execute_file_upload( + self, scheme: str, local_path: str, bucket: str, remote_path: str + ): + protocol = self._get_staging_protocol(scheme) + return protocol.upload_file(local_path, bucket, remote_path) + + def _get_staging_protocol(self, protocol): + if protocol in self._protocol_dict: + return self._protocol_dict[protocol] + else: + if PROTOCOL(protocol) == PROTOCOL.GS: + self._protocol_dict[protocol] = GCSProtocol() + elif PROTOCOL(protocol) == PROTOCOL.S3: + self._protocol_dict[protocol] = S3Protocol() + elif PROTOCOL(protocol) == PROTOCOL.LOCAL_FILE: + self._protocol_dict[protocol] = LocalFSProtocol() + else: + raise Exception( + f"Could not identify file protocol {protocol}. Only gs:// and file:// and s3:// supported" + ) + return self._protocol_dict[protocol] + + +class AbstractStagingProtocol(ABC): + + __metaclass__ = ABCMeta + + @abstractmethod + def __init__(self): + pass + + @abstractmethod + def download_file(self, uri: ParseResult) -> TemporaryFile: + pass + + @abstractmethod + def list_files(self, bucket: str, uri: ParseResult) -> List[str]: + pass + + @abstractmethod + def upload_file(self, local_path: str, bucket: str, remote_path: str): + pass + + +class GCSProtocol(AbstractStagingProtocol): + def __init__(self): + self.gcs_client = storage.Client(project=None) + + def download_file(self, uri: ParseResult) -> TemporaryFile: + url = uri.geturl() + file_obj = TemporaryFile() + self.gcs_client.download_blob_to_file(url, file_obj) + return file_obj + + def list_files(self, bucket: str, uri: ParseResult) -> List[str]: + bucket = self.gcs_client.get_bucket(bucket) + path = uri.path + + if "*" in path: + regex = re.compile(path.replace("*", ".*?").strip("/")) + blob_list = bucket.list_blobs( + prefix=path.strip("/").split("*")[0], delimiter="/" + ) + # File path should not be in path (file path must be longer than path) + return [ + f"{uri.scheme}://{uri.hostname}/{file}" + for file in [x.name for x in blob_list] + if re.match(regex, file) and file not in path + ] + else: + raise Exception(f"{path} is not a wildcard path") + + def upload_file(self, local_path: str, bucket: str, remote_path: str): + bucket = self.gcs_client.get_bucket(bucket) + blob = bucket.blob(remote_path) + blob.upload_from_filename(local_path) + + +class S3Protocol(AbstractStagingProtocol): + def __init__(self): + self.s3_client = boto3.client("s3") + + def download_file(self, uri: ParseResult) -> TemporaryFile: + url = uri.path[1:] # removing leading / from the path + bucket = uri.hostname + file_obj = TemporaryFile() + self.s3_client.download_fileobj(bucket, url, file_obj) + return file_obj + + def list_files(self, bucket: str, uri: ParseResult) -> List[str]: + path = uri.path + + if "*" in path: + regex = re.compile(path.replace("*", ".*?").strip("/")) + blob_list = self.s3_client.list_objects( + Bucket=bucket, Prefix=path.strip("/").split("*")[0], Delimiter="/" + ) + # File path should not be in path (file path must be longer than path) + return [ + f"{uri.scheme}://{uri.hostname}/{file}" + for file in [x["Key"] for x in blob_list["Contents"]] + if re.match(regex, file) and file not in path + ] + else: + raise Exception(f"{path} is not a wildcard path") + + def upload_file(self, local_path: str, bucket: str, remote_path: str): + with open(local_path, "rb") as file: + self.s3_client.upload_fileobj(file, bucket, remote_path) + + +class LocalFSProtocol(AbstractStagingProtocol): + def __init__(self): + pass + + def download_file(self, file_uri: ParseResult) -> TemporaryFile: + url = file_uri.path + file_obj = open(url, "rb") + return file_obj + + def list_files(self, bucket: str, uri: ParseResult) -> List[str]: + raise NotImplementedError("list file not implemented for Local file") + + def upload_file(self, local_path: str, bucket: str, remote_path: str): + pass # For test cases diff --git a/sdk/python/requirements-ci.txt b/sdk/python/requirements-ci.txt index 45aff4788b..3d09606668 100644 --- a/sdk/python/requirements-ci.txt +++ b/sdk/python/requirements-ci.txt @@ -10,4 +10,5 @@ pytest-timeout pytest-ordering==0.6.* pandas==0.* mock==2.0.0 -pandavro==1.5.* \ No newline at end of file +pandavro==1.5.* +moto \ No newline at end of file diff --git a/sdk/python/requirements-dev.txt b/sdk/python/requirements-dev.txt index f24141fb49..ca341d001b 100644 --- a/sdk/python/requirements-dev.txt +++ b/sdk/python/requirements-dev.txt @@ -35,4 +35,6 @@ mypy mypy-protobuf pre-commit flake8 -black \ No newline at end of file +black +boto3 +moto \ No newline at end of file diff --git a/sdk/python/setup.py b/sdk/python/setup.py index 69ea44a187..0e8edf2b0e 100644 --- a/sdk/python/setup.py +++ b/sdk/python/setup.py @@ -46,6 +46,7 @@ "numpy", "google", "confluent_kafka", + 'boto3' ] # README file from Feast repo root directory diff --git a/sdk/python/tests/loaders/__init__.py b/sdk/python/tests/loaders/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/python/tests/loaders/test_file.py b/sdk/python/tests/loaders/test_file.py new file mode 100644 index 0000000000..d9f6b907b9 --- /dev/null +++ b/sdk/python/tests/loaders/test_file.py @@ -0,0 +1,77 @@ +import tempfile +from unittest.mock import patch +from urllib.parse import urlparse + +import boto3 +import fastavro +import pandas as pd +import pandavro +from moto import mock_s3 +from pandas.testing import assert_frame_equal +from pytest import fixture + +from feast.loaders.file import export_source_to_staging_location + +BUCKET = "test_bucket" +FOLDER_NAME = "test_folder" +FILE_NAME = "test.avro" + +LOCAL_FILE = "file://tmp/tmp" +S3_LOCATION = f"s3://{BUCKET}/{FOLDER_NAME}" + +TEST_DATA_FRAME = pd.DataFrame( + { + "driver": [1001, 1002, 1003], + "transaction": [1001, 1002, 1003], + "driver_id": [1001, 1002, 1003], + } +) + + +@fixture +def avro_data_path(): + final_results = tempfile.mktemp() + pandavro.to_avro(file_path_or_buffer=final_results, df=TEST_DATA_FRAME) + return final_results + + +@patch("feast.loaders.file._get_file_name", return_value=FILE_NAME) +def test_export_source_to_staging_location_local_file_should_pass(get_file_name): + source = export_source_to_staging_location(TEST_DATA_FRAME, LOCAL_FILE) + assert source == [f"{LOCAL_FILE}/{FILE_NAME}"] + assert get_file_name.call_count == 1 + + +@mock_s3 +@patch("feast.loaders.file._get_file_name", return_value=FILE_NAME) +def test_export_source_to_staging_location_dataframe_to_s3_should_pass(get_file_name): + s3_client = boto3.client("s3") + s3_client.create_bucket(Bucket=BUCKET) + source = export_source_to_staging_location(TEST_DATA_FRAME, S3_LOCATION) + file_obj = tempfile.TemporaryFile() + uri = urlparse(source[0]) + s3_client.download_fileobj(uri.hostname, uri.path[1:], file_obj) + file_obj.seek(0) + avro_reader = fastavro.reader(file_obj) + retrived_df = pd.DataFrame.from_records([r for r in avro_reader]) + assert_frame_equal(retrived_df, TEST_DATA_FRAME) + assert get_file_name.call_count == 1 + + +def test_export_source_to_staging_location_s3_file_as_source_should_pass(): + source = export_source_to_staging_location(S3_LOCATION, None) + assert source == [S3_LOCATION] + + +@mock_s3 +def test_export_source_to_staging_location_s3_wildcard_as_source_should_pass( + avro_data_path, +): + s3_client = boto3.client("s3") + s3_client.create_bucket(Bucket=BUCKET) + with open(avro_data_path, "rb") as data: + s3_client.upload_fileobj(data, BUCKET, f"{FOLDER_NAME}/file1.avro") + with open(avro_data_path, "rb") as data: + s3_client.upload_fileobj(data, BUCKET, f"{FOLDER_NAME}/file2.avro") + sources = export_source_to_staging_location(f"{S3_LOCATION}/*", None) + assert sources == [f"{S3_LOCATION}/file1.avro", f"{S3_LOCATION}/file2.avro"] diff --git a/sdk/python/tests/test_job.py b/sdk/python/tests/test_job.py new file mode 100644 index 0000000000..97a0013714 --- /dev/null +++ b/sdk/python/tests/test_job.py @@ -0,0 +1,128 @@ +import tempfile + +import boto3 +import grpc +import pandas as pd +import pandavro +import pytest +from moto import mock_s3 +from pandas.testing import assert_frame_equal +from pytest import fixture, raises + +import feast.serving.ServingService_pb2_grpc as Serving +from feast.job import RetrievalJob, JobProto +from feast.serving.ServingService_pb2 import DataFormat, GetJobResponse +from feast.serving.ServingService_pb2 import Job as BatchRetrievalJob +from feast.serving.ServingService_pb2 import JobStatus, JobType + +BUCKET = "test_bucket" + +TEST_DATA_FRAME = pd.DataFrame( + { + "driver": [1001, 1002, 1003], + "transaction": [1001, 1002, 1003], + "driver_id": [1001, 1002, 1003], + } +) + + +class TestRetrievalJob: + @fixture + def retrieve_job(self): + + serving_service_stub = Serving.ServingServiceStub(grpc.insecure_channel("")) + job_proto = JobProto( + id="123", + type=JobType.JOB_TYPE_DOWNLOAD, + status=JobStatus.JOB_STATUS_RUNNING, + ) + return RetrievalJob(job_proto, serving_service_stub) + + @fixture + def avro_data_path(self): + final_results = tempfile.mktemp() + pandavro.to_avro(file_path_or_buffer=final_results, df=TEST_DATA_FRAME) + return final_results + + def test_to_dataframe_local_file_staging_should_pass( + self, retrieve_job, avro_data_path, mocker + ): + mocker.patch.object( + retrieve_job.serving_stub, + "GetJob", + return_value=GetJobResponse( + job=BatchRetrievalJob( + id="123", + type=JobType.JOB_TYPE_DOWNLOAD, + status=JobStatus.JOB_STATUS_DONE, + file_uris=[f"file://{avro_data_path}"], + data_format=DataFormat.DATA_FORMAT_AVRO, + ) + ), + ) + retrived_df = retrieve_job.to_dataframe() + assert_frame_equal(TEST_DATA_FRAME, retrived_df) + + @mock_s3 + def test_to_dataframe_s3_file_staging_should_pass( + self, retrieve_job, avro_data_path, mocker + ): + s3_client = boto3.client("s3") + target = "test_proj/test_features.avro" + s3_client.create_bucket(Bucket=BUCKET) + with open(avro_data_path, "rb") as data: + s3_client.upload_fileobj(data, BUCKET, target) + + mocker.patch.object( + retrieve_job.serving_stub, + "GetJob", + return_value=GetJobResponse( + job=BatchRetrievalJob( + id="123", + type=JobType.JOB_TYPE_DOWNLOAD, + status=JobStatus.JOB_STATUS_DONE, + file_uris=[f"s3://{BUCKET}/{target}"], + data_format=DataFormat.DATA_FORMAT_AVRO, + ) + ), + ) + retrived_df = retrieve_job.to_dataframe() + assert_frame_equal(TEST_DATA_FRAME, retrived_df) + + @pytest.mark.parametrize( + "job_proto,exception", + [ + ( + GetJobResponse( + job=BatchRetrievalJob( + id="123", + type=JobType.JOB_TYPE_DOWNLOAD, + status=JobStatus.JOB_STATUS_DONE, + data_format=DataFormat.DATA_FORMAT_AVRO, + error="Testing job failure", + ) + ), + Exception, + ), + ( + GetJobResponse( + job=BatchRetrievalJob( + id="123", + type=JobType.JOB_TYPE_DOWNLOAD, + status=JobStatus.JOB_STATUS_DONE, + data_format=DataFormat.DATA_FORMAT_INVALID, + ) + ), + Exception, + ), + ], + ids=["when_retrieve_job_fails", "when_data_format_is_not_avro"], + ) + def test_to_dataframe_s3_file_staging_should_raise( + self, retrieve_job, mocker, job_proto, exception + ): + mocker.patch.object( + retrieve_job.serving_stub, "GetJob", return_value=job_proto, + ) + with raises(exception): + retrieve_job.to_dataframe() diff --git a/serving/src/main/java/feast/serving/service/OnlineServingService.java b/serving/src/main/java/feast/serving/service/OnlineServingService.java index 7339a19c1a..bb73e34f51 100644 --- a/serving/src/main/java/feast/serving/service/OnlineServingService.java +++ b/serving/src/main/java/feast/serving/service/OnlineServingService.java @@ -16,6 +16,7 @@ */ package feast.serving.service; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.protobuf.Duration; import feast.proto.serving.ServingAPIProto.*; @@ -75,6 +76,9 @@ public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequest requ // feature set request. List> featureRows = retriever.getOnlineFeatures(entityRows, featureSetRequests); + if (scope != null) { + scope.span().log(ImmutableMap.of("event", "featureRows", "value", featureRows)); + } // For each feature set request, read the feature rows returned by the retriever, and // populate the featureValuesMap with the feature values corresponding to that entity row.