diff --git a/sdk/python/feast/client.py b/sdk/python/feast/client.py index 22e5dbc28e..f31bca1d9d 100644 --- a/sdk/python/feast/client.py +++ b/sdk/python/feast/client.py @@ -14,6 +14,8 @@ import logging import multiprocessing import shutil +import uuid +from itertools import groupby from typing import Any, Dict, List, Optional, Union import grpc @@ -30,6 +32,8 @@ CONFIG_SERVING_ENABLE_SSL_KEY, CONFIG_SERVING_SERVER_SSL_CERT_KEY, CONFIG_SERVING_URL_KEY, + CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT, + CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION, FEAST_DEFAULT_OPTIONS, ) from feast.core.CoreService_pb2 import ( @@ -70,6 +74,11 @@ _write_partitioned_table_from_source, ) from feast.online_response import OnlineResponse, _infer_online_entity_rows +from feast.pyspark.abc import RetrievalJob +from feast.pyspark.launcher import ( + start_historical_feature_retrieval_job, + start_historical_feature_retrieval_spark_session, +) from feast.serving.ServingService_pb2 import ( GetFeastServingInfoRequest, GetOnlineFeaturesRequestV2, @@ -723,7 +732,6 @@ def get_online_features( ) -> OnlineResponse: """ Retrieves the latest online feature data from Feast Serving. - Args: feature_refs: List of feature references that will be returned for each entity. Each feature reference should have the following format: @@ -733,12 +741,10 @@ def get_online_features( entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair. project: Optionally specify the the project override. If specified, uses given project for retrieval. Overrides the projects specified in Feature References if also are specified. - Returns: GetOnlineFeaturesResponse containing the feature data in records. Each EntityRow provided will yield one record, which contains data fields with data value and field status metadata (if included). - Examples: >>> from feast import Client >>> @@ -767,3 +773,113 @@ def get_online_features( response = OnlineResponse(response) return response + + def get_historical_features( + self, + feature_refs: List[str], + entity_source: Union[FileSource, BigQuerySource], + project: str = None, + ) -> RetrievalJob: + """ + Launch a historical feature retrieval job. + + Args: + feature_refs: List of feature references that will be returned for each entity. + Each feature reference should have the following format: + "feature_table:feature" where "feature_table" & "feature" refer to + the feature and feature table names respectively. + entity_source (Union[FileSource, BigQuerySource]): Source for the entity rows. + The user needs to make sure that the source is accessible from the Spark cluster + that will be used for the retrieval job. + project: Specifies the project that contains the feature tables + which the requested features belong to. + + Returns: + Returns a retrieval job object that can be used to monitor retrieval + progress asynchronously, and can be used to materialize the + results. + + Examples: + >>> from feast import Client + >>> from datetime import datetime + >>> feast_client = Client(core_url="localhost:6565") + >>> feature_refs = ["bookings:bookings_7d", "bookings:booking_14d"] + >>> entity_source = FileSource("event_timestamp", "parquet", "gs://some-bucket/customer") + >>> feature_retrieval_job = feast_client.get_historical_features( + >>> feature_refs, entity_source, project="my_project") + >>> output_file_uri = feature_retrieval_job.get_output_file_uri() + "gs://some-bucket/output/ + """ + feature_tables = self._get_feature_tables_from_feature_refs( + feature_refs, project + ) + output_location = self._config.get( + CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION + ) + output_format = self._config.get(CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT) + job_id = f"historical-feature-{str(uuid.uuid4())}" + + return start_historical_feature_retrieval_job( + self, entity_source, feature_tables, output_format, output_location, job_id + ) + + def get_historical_features_df( + self, + feature_refs: List[str], + entity_source: Union[FileSource, BigQuerySource], + project: str = None, + ): + """ + Launch a historical feature retrieval job. + + Args: + feature_refs: List of feature references that will be returned for each entity. + Each feature reference should have the following format: + "feature_table:feature" where "feature_table" & "feature" refer to + the feature and feature table names respectively. + entity_source (Union[FileSource, BigQuerySource]): Source for the entity rows. + The user needs to make sure that the source is accessible from the Spark cluster + that will be used for the retrieval job. + project: Specifies the project that contains the feature tables + which the requested features belong to. + + Returns: + Returns the historical feature retrieval result in the form of Spark dataframe. + + Examples: + >>> from feast import Client + >>> from datetime import datetime + >>> from pyspark.sql import SparkSession + >>> spark = SparkSession.builder.getOrCreate() + >>> feast_client = Client(core_url="localhost:6565") + >>> feature_refs = ["bookings:bookings_7d", "bookings:booking_14d"] + >>> entity_source = FileSource("event_timestamp", "parquet", "gs://some-bucket/customer") + >>> df = feast_client.get_historical_features( + >>> feature_refs, entity_source, project="my_project") + """ + feature_tables = self._get_feature_tables_from_feature_refs( + feature_refs, project + ) + return start_historical_feature_retrieval_spark_session( + self, entity_source, feature_tables + ) + + def _get_feature_tables_from_feature_refs( + self, feature_refs: List[str], project: Optional[str] + ): + feature_refs_grouped_by_table = [ + (feature_table_name, list(grouped_feature_refs)) + for feature_table_name, grouped_feature_refs in groupby( + feature_refs, lambda x: x.split(":")[0] + ) + ] + + feature_tables = [] + for feature_table_name, grouped_feature_refs in feature_refs_grouped_by_table: + feature_table = self.get_feature_table(feature_table_name, project) + feature_names = [f.split(":")[-1] for f in grouped_feature_refs] + feature_table.features = [ + f for f in feature_table.features if f.name in feature_names + ] + feature_tables.append(feature_table) + return feature_tables diff --git a/sdk/python/feast/constants.py b/sdk/python/feast/constants.py index 7ed0d273fe..0161c8e1e3 100644 --- a/sdk/python/feast/constants.py +++ b/sdk/python/feast/constants.py @@ -64,6 +64,20 @@ class AuthProvider(Enum): CONFIG_TIMEOUT_KEY = "timeout" CONFIG_MAX_WAIT_INTERVAL_KEY = "max_wait_interval" +# Spark Job Config +CONFIG_SPARK_LAUNCHER = "spark_launcher" # standalone, dataproc, emr + +CONFIG_SPARK_STANDALONE_MASTER = "spark_standalone_master" + +CONFIG_SPARK_DATAPROC_CLUSTER_NAME = "dataproc_cluster_name" +CONFIG_SPARK_DATAPROC_PROJECT = "dataproc_project" +CONFIG_SPARK_DATAPROC_REGION = "dataproc_region" +CONFIG_SPARK_DATAPROC_STAGING_LOCATION = "dataproc_staging_location" + +CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT = "historical_feature_output_format" +CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION = "historical_feature_output_location" + + # Configuration option default values FEAST_DEFAULT_OPTIONS = { # Default Feast project to use diff --git a/sdk/python/feast/pyspark/abc.py b/sdk/python/feast/pyspark/abc.py new file mode 100644 index 0000000000..0428a448fc --- /dev/null +++ b/sdk/python/feast/pyspark/abc.py @@ -0,0 +1,193 @@ +import abc +from typing import Dict, List + + +class SparkJobFailure(Exception): + """ + Job submission failed, encountered error during execution, or timeout + """ + + pass + + +class SparkJob(abc.ABC): + """ + Base class for all spark jobs + """ + + @abc.abstractmethod + def get_id(self) -> str: + """ + Getter for the job id. The job id must be unique for each spark job submission. + + Returns: + str: Job id. + """ + raise NotImplementedError + + +class RetrievalJob(SparkJob): + """ + Container for the historical feature retrieval job result + """ + + @abc.abstractmethod + def get_output_file_uri(self, timeout_sec=None): + """ + Get output file uri to the result file. This method will block until the + job succeeded, or if the job didn't execute successfully within timeout. + + Args: + timeout_sec (int): + Max no of seconds to wait until job is done. If "timeout_sec" + is exceeded or if the job fails, an exception will be raised. + + Raises: + SparkJobFailure: + The spark job submission failed, encountered error during execution, + or timeout. + + Returns: + str: file uri to the result file. + """ + raise NotImplementedError + + +class IngestionJob(SparkJob): + pass + + +class JobLauncher(abc.ABC): + """ + Submits spark jobs to a spark cluster. Currently supports only historical feature retrieval jobs. + """ + + @abc.abstractmethod + def historical_feature_retrieval( + self, + pyspark_script: str, + entity_source_conf: Dict, + feature_tables_sources_conf: List[Dict], + feature_tables_conf: List[Dict], + destination_conf: Dict, + job_id: str, + **kwargs, + ) -> RetrievalJob: + """ + Submits a historical feature retrieval job to a Spark cluster. + + Args: + pyspark_script (str): Local file path to the pyspark script for historical feature + retrieval. + entity_source_conf (Dict): Entity data source configuration. + feature_tables_sources_conf (List[Dict]): List of feature tables data sources configurations. + feature_tables_conf (List[Dict]): List of feature table specification. + The order of the feature table must correspond to that of feature_tables_sources. + destination_conf (Dict): Retrieval job output destination. + job_id (str): A job id that is unique for each job submission. + + Raises: + SparkJobFailure: The spark job submission failed, encountered error + during execution, or timeout. + + Examples: + >>> # Entity source from file + >>> entity_source_conf = { + "file": { + "format": "parquet", + "path": "gs://some-gcs-bucket/customer", + "event_timestamp_column": "event_timestamp", + "options": { + "mergeSchema": "true" + } # Optional. Options to be passed to Spark while reading the dataframe from source. + "field_mapping": { + "id": "customer_id" + } # Optional. Map the columns, where the key is the original column name and the value is the new column name. + + } + } + + >>> # Entity source from BigQuery + >>> entity_source_conf = { + "bq": { + "project": "gcp_project_id", + "dataset": "bq_dataset", + "table": "customer", + "event_timestamp_column": "event_timestamp", + } + } + + >>> feature_table_sources_conf = [ + { + "bq": { + "project": "gcp_project_id", + "dataset": "bq_dataset", + "table": "customer_transactions", + "event_timestamp_column": "event_timestamp", + "created_timestamp_column": "created_timestamp" # This field is mandatory for feature tables. + } + }, + + { + "file": { + "format": "parquet", + "path": "gs://some-gcs-bucket/customer_profile", + "event_timestamp_column": "event_timestamp", + "created_timestamp_column": "created_timestamp", + "options": { + "mergeSchema": "true" + } + } + }, + ] + + + >>> feature_tables_conf = [ + { + "name": "customer_transactions", + "entities": [ + { + "name": "customer + "type": "int32" + } + ], + "features": [ + { + "name": "total_transactions" + "type": "double" + }, + { + "name": "total_discounts" + "type": "double" + } + ], + "max_age": 86400 # In seconds. + }, + + { + "name": "customer_profile", + "entities": [ + { + "name": "customer + "type": "int32" + } + ], + "features": [ + { + "name": "is_vip" + "type": "bool" + } + ], + + } + ] + + >>> destination_conf = { + "format": "parquet", + "path": "gs://some-gcs-bucket/retrieval_output" + } + + Returns: + str: file uri to the result file. + """ + raise NotImplementedError diff --git a/sdk/python/feast/pyspark/historical_feature_retrieval_job.py b/sdk/python/feast/pyspark/historical_feature_retrieval_job.py index 591323cbd4..e65c200896 100644 --- a/sdk/python/feast/pyspark/historical_feature_retrieval_job.py +++ b/sdk/python/feast/pyspark/historical_feature_retrieval_job.py @@ -2,7 +2,7 @@ import argparse import json from datetime import timedelta -from typing import Dict, List, NamedTuple, Optional +from typing import Any, Dict, List, NamedTuple, Optional from pyspark.sql import DataFrame, SparkSession, Window from pyspark.sql.functions import col, expr, monotonically_increasing_id, row_number @@ -104,7 +104,7 @@ def spark_read_options(self) -> Dict[str, str]: return self.options -class BQSource(Source): +class BigQuerySource(Source): """ Big query datasource, which depends on spark bigquery connector (https://github.com/GoogleCloudDataproc/spark-bigquery-connector). @@ -147,21 +147,21 @@ def spark_path(self) -> str: def _source_from_dict(dct: Dict) -> Source: if "file" in dct.keys(): return FileSource( - dct["format"], - dct["path"], - dct["event_timestamp_column"], - dct.get("created_timestamp_column"), - dct.get("field_mapping"), - dct.get("options"), + dct["file"]["format"], + dct["file"]["path"], + dct["file"]["event_timestamp_column"], + dct["file"].get("created_timestamp_column"), + dct["file"].get("field_mapping"), + dct["file"].get("options"), ) else: - return BQSource( - dct["project"], - dct["dataset"], - dct["table"], - dct.get("field_mapping", {}), - dct["event_timestamp_column"], - dct.get("created_timestamp_column"), + return BigQuerySource( + dct["bq"]["project"], + dct["bq"]["dataset"], + dct["bq"]["table"], + dct["bq"].get("field_mapping", {}), + dct["bq"]["event_timestamp_column"], + dct["bq"].get("created_timestamp_column"), ) @@ -610,19 +610,19 @@ def _read_and_verify_feature_table_df_from_source( def retrieve_historical_features( spark: SparkSession, - entity_source: Source, - feature_tables_sources: List[Source], - feature_tables: List[FeatureTable], + entity_source_conf: Dict, + feature_tables_sources_conf: List[Dict], + feature_tables_conf: List[Dict], ) -> DataFrame: - """Retrieve historical features based on given configurations. + """Retrieve historical features based on given configurations. The argument can be either Args: spark (SparkSession): Spark session. - entity_source (Source): Entity data source, which describe where and how to retrieve the Spark dataframe + entity_source_conf (Dict): Entity data source, which describe where and how to retrieve the Spark dataframe representing the entities. - feature_tables_sources (Source): List of feature tables data sources, which describe where and how to + feature_tables_sources_conf (Dict): List of feature tables data sources, which describe where and how to retrieve the feature table representing the feature tables. - feature_tables (List[FeatureTable]): List of feature table specification. The specification describes which + feature_tables_conf (List[Dict]): List of feature table specification. The specification describes which features should be present in the final join result, as well as the maximum age. The order of the feature table must correspond to that of feature_tables_sources. @@ -634,39 +634,45 @@ def retrieve_historical_features( SchemaError: If either the entity or feature table has missing columns or wrong column types. Example: - >>> entity_source = FileSource( - format="csv", - path="file:///some_dir/customer_driver_pairs.csv"), - options={"inferSchema": "true", "header": "true"}, - field_mapping={"id": "driver_id"} - ) - - >>> feature_tables_sources = [ - FileSource( - format="parquet", - path="gs://some_bucket/bookings.parquet"), - field_mapping={"id": "driver_id"} - ), - FileSource( - format="avro", - path="s3://some_bucket/transactions.avro"), - ) + >>> entity_source_conf = { + "format": "csv", + "path": "file:///some_dir/customer_driver_pairs.csv"), + "options": {"inferSchema": "true", "header": "true"}, + "field_mapping": {"id": "driver_id"} + } + + >>> feature_tables_sources_conf = [ + { + "format": "parquet", + "path": "gs://some_bucket/bookings.parquet"), + "field_mapping": {"id": "driver_id"} + }, + { + "format": "avro", + "path": "s3://some_bucket/transactions.avro"), + } ] - >>> feature_tables = [ - FeatureTable( - name="bookings", - entities=[Field("driver_id", "int32")], - features=[Field("completed_bookings", "int32")], - ), - FeatureTable( - name="transactions", - entities=[Field("customer_id", "int32")], - features=[Field("total_transactions", "double")], - max_age=172800 - ), + >>> feature_tables_conf = [ + { + "name": "bookings", + "entities": [{"name": "driver_id", "type": "int32"}], + "features": [{"name": "completed_bookings", "type": "int32"}], + }, + { + "name": "transactions", + "entities": [{"name": "customer_id", "type": "int32"}], + "features": [{"name": "total_transactions", "type": "double"}], + "max_age": 172800 + }, ] """ + feature_tables = [_feature_table_from_dict(dct) for dct in feature_tables_conf] + feature_tables_sources = [ + _source_from_dict(dct) for dct in feature_tables_sources_conf + ] + entity_source = _source_from_dict(entity_source_conf) + entity_df = _read_and_verify_entity_df_from_source(spark, entity_source) feature_table_dfs = [ @@ -722,14 +728,15 @@ def retrieve_historical_features( def start_job( spark: SparkSession, - entity_source: Source, - feature_tables_sources: List[Source], - feature_tables: List[FeatureTable], - destination: FileDestination, + entity_source_conf: Dict, + feature_tables_sources_conf: List[Dict], + feature_tables_conf: List[Dict], + destination_conf: Dict, ): result = retrieve_historical_features( - spark, entity_source, feature_tables_sources, feature_tables + spark, entity_source_conf, feature_tables_sources_conf, feature_tables_conf ) + destination = FileDestination(**destination_conf) result.write.format(destination.format).mode("overwrite").save(destination.path) @@ -752,14 +759,28 @@ def _get_args(): return parser.parse_args() +def _feature_table_from_dict(dct: Dict[str, Any]) -> FeatureTable: + return FeatureTable( + name=dct["name"], + entities=[Field(**e) for e in dct["entities"]], + features=[Field(**f) for f in dct["features"]], + max_age=dct.get("max_age"), + project=dct.get("project"), + ) + + if __name__ == "__main__": spark = SparkSession.builder.getOrCreate() args = _get_args() - feature_tables = [FeatureTable(**dct) for dct in json.loads(args.feature_tables)] - feature_tables_sources = [ - _source_from_dict(dct) for dct in json.loads(args.feature_tables_source) - ] - entity_source = _source_from_dict(json.loads(args.entity_source)) - destination = FileDestination(**json.loads(args.destination)) - start_job(spark, entity_source, feature_tables_sources, feature_tables, destination) + feature_tables_conf = json.loads(args.feature_tables) + feature_tables_sources_conf = json.loads(args.feature_tables_source) + entity_source_conf = json.loads(args.entity_source) + destination_conf = json.loads(args.destination) + start_job( + spark, + entity_source_conf, + feature_tables_sources_conf, + feature_tables_conf, + destination_conf, + ) spark.stop() diff --git a/sdk/python/feast/pyspark/launcher.py b/sdk/python/feast/pyspark/launcher.py new file mode 100644 index 0000000000..31f2971175 --- /dev/null +++ b/sdk/python/feast/pyspark/launcher.py @@ -0,0 +1,144 @@ +import pathlib +from typing import TYPE_CHECKING, List, Union + +from feast.config import Config +from feast.constants import ( + CONFIG_SPARK_DATAPROC_CLUSTER_NAME, + CONFIG_SPARK_DATAPROC_PROJECT, + CONFIG_SPARK_DATAPROC_REGION, + CONFIG_SPARK_DATAPROC_STAGING_LOCATION, + CONFIG_SPARK_LAUNCHER, + CONFIG_SPARK_STANDALONE_MASTER, +) +from feast.data_source import BigQuerySource, DataSource, FileSource +from feast.feature_table import FeatureTable +from feast.pyspark.abc import JobLauncher, RetrievalJob +from feast.value_type import ValueType + +if TYPE_CHECKING: + from feast.client import Client + + +def _standalone_launcher(config: Config) -> JobLauncher: + from feast.pyspark.launchers import standalone + + return standalone.StandaloneClusterLauncher( + config.get(CONFIG_SPARK_STANDALONE_MASTER) + ) + + +def _dataproc_launcher(config: Config) -> JobLauncher: + from feast.pyspark.launchers import gcloud + + return gcloud.DataprocClusterLauncher( + config.get(CONFIG_SPARK_DATAPROC_CLUSTER_NAME), + config.get(CONFIG_SPARK_DATAPROC_STAGING_LOCATION), + config.get(CONFIG_SPARK_DATAPROC_REGION), + config.get(CONFIG_SPARK_DATAPROC_PROJECT), + ) + + +_launchers = {"standalone": _standalone_launcher, "dataproc": _dataproc_launcher} + + +def resolve_launcher(config: Config) -> JobLauncher: + return _launchers[config.get(CONFIG_SPARK_LAUNCHER)](config) + + +_SOURCES = { + FileSource: ("file", "file_options", {"path": "file_url", "format": "file_format"}), + BigQuerySource: ("bq", "bigquery_options", {"table_ref": "table_ref"}), +} + + +def source_to_argument(source: DataSource): + common_properties = { + "field_mapping": dict(source.field_mapping), + "event_timestamp_column": source.event_timestamp_column, + "created_timestamp_column": source.created_timestamp_column, + "date_partition_column": source.date_partition_column, + } + + kind, option_field, extra_properties = _SOURCES[type(source)] + + properties = { + **common_properties, + **{ + k: getattr(getattr(source, option_field), ref) + for k, ref in extra_properties.items() + }, + } + + return {kind: properties} + + +def feature_table_to_argument(client: "Client", feature_table: FeatureTable): + return { + "features": [ + {"name": f.name, "type": ValueType(f.dtype).name} + for f in feature_table.features + ], + "project": "default", + "name": feature_table.name, + "entities": [ + {"name": n, "type": client.get_entity(n).value_type} + for n in feature_table.entities + ], + "max_age": feature_table.max_age.ToSeconds() if feature_table.max_age else None, + } + + +def start_historical_feature_retrieval_spark_session( + client: "Client", + entity_source: Union[FileSource, BigQuerySource], + feature_tables: List[FeatureTable], +): + from pyspark.sql import SparkSession + + from feast.pyspark.historical_feature_retrieval_job import ( + retrieve_historical_features, + ) + + spark_session = SparkSession.builder.getOrCreate() + return retrieve_historical_features( + spark=spark_session, + entity_source_conf=source_to_argument(entity_source), + feature_tables_sources_conf=[ + source_to_argument(feature_table.batch_source) + for feature_table in feature_tables + ], + feature_tables_conf=[ + feature_table_to_argument(client, feature_table) + for feature_table in feature_tables + ], + ) + + +def start_historical_feature_retrieval_job( + client: "Client", + entity_source: Union[FileSource, BigQuerySource], + feature_tables: List[FeatureTable], + output_format: str, + output_path: str, + job_id: str, +) -> RetrievalJob: + launcher = resolve_launcher(client._config) + retrieval_job_pyspark_script = str( + pathlib.Path(__file__).parent.absolute() + / "pyspark" + / "historical_feature_retrieval_job.py" + ) + return launcher.historical_feature_retrieval( + pyspark_script=retrieval_job_pyspark_script, + entity_source_conf=source_to_argument(entity_source), + feature_tables_sources_conf=[ + source_to_argument(feature_table.batch_source) + for feature_table in feature_tables + ], + feature_tables_conf=[ + feature_table_to_argument(client, feature_table) + for feature_table in feature_tables + ], + destination_conf={"format": output_format, "path": output_path}, + job_id=job_id, + ) diff --git a/sdk/python/feast/pyspark/launchers.py b/sdk/python/feast/pyspark/launchers.py deleted file mode 100644 index a258b871a5..0000000000 --- a/sdk/python/feast/pyspark/launchers.py +++ /dev/null @@ -1,401 +0,0 @@ -import abc -import json -import os -import subprocess -from typing import Dict, List -from urllib.parse import urlparse - - -class SparkJobFailure(Exception): - """ - Job submission failed, encountered error during execution, or timeout - """ - - pass - - -class RetrievalJob(abc.ABC): - """ - Container for the historical feature retrieval job result - """ - - @abc.abstractmethod - def get_id(self) -> str: - """ - Getter for the job id. The job id must be unique for each spark job submission. - - Returns: - str: Job id. - """ - raise NotImplementedError - - @abc.abstractmethod - def get_output_file_uri(self, timeout_sec=None): - """ - Get output file uri to the result file. This method will block until the - job succeeded, or if the job didn't execute successfully within timeout. - - Args: - timeout_sec (int): - Max no of seconds to wait until job is done. If "timeout_sec" - is exceeded or if the job fails, an exception will be raised. - - Raises: - SparkJobFailure: - The spark job submission failed, encountered error during execution, - or timeout. - - Returns: - str: file uri to the result file. - """ - raise NotImplementedError - - -class StandaloneClusterRetrievalJob(RetrievalJob): - """ - Historical feature retrieval job result for a standalone spark cluster - """ - - def __init__(self, job_id: str, process: subprocess.Popen, output_file_uri: str): - """ - This is the returned historical feature retrieval job result for StandaloneClusterLauncher. - - Args: - job_id (str): Historical feature retrieval job id. - process (subprocess.Popen): Pyspark driver process, spawned by the launcher. - output_file_uri (str): Uri to the historical feature retrieval job output file. - """ - self.job_id = job_id - self._process = process - self._output_file_uri = output_file_uri - - def get_id(self) -> str: - return self.job_id - - def get_output_file_uri(self, timeout_sec: int = None): - with self._process as p: - try: - p.wait(timeout_sec) - except Exception: - p.kill() - raise SparkJobFailure("Timeout waiting for subprocess to return") - if self._process.returncode != 0: - stderr = "" if self._process.stderr is None else self._process.stderr.read() - stdout = "" if self._process.stdout is None else self._process.stdout.read() - - raise SparkJobFailure( - f"Non zero return code: {self._process.returncode}. stderr: {stderr} stdout: {stdout}" - ) - - -class DataprocRetrievalJob(RetrievalJob): - """ - Historical feature retrieval job result for a Dataproc cluster - """ - - def __init__(self, job_id, operation, output_file_uri): - """ - This is the returned historical feature retrieval job result for DataprocClusterLauncher. - - Args: - job_id (str): Historical feature retrieval job id. - operation (google.api.core.operation.Operation): A Future for the spark job result, - returned by the dataproc client. - output_file_uri (str): Uri to the historical feature retrieval job output file. - """ - self.job_id = job_id - self._operation = operation - self._output_file_uri = output_file_uri - - def get_id(self) -> str: - return self.job_id - - def get_output_file_uri(self, timeout_sec=None): - try: - self._operation.result(timeout_sec) - except Exception as err: - raise SparkJobFailure(err) - return self._output_file_uri - - -class JobLauncher(abc.ABC): - """ - Submits spark jobs to a spark cluster. Currently supports only historical feature retrieval jobs. - """ - - @abc.abstractmethod - def historical_feature_retrieval( - self, - pyspark_script: str, - entity_source_conf: Dict, - feature_tables_sources_conf: List[Dict], - feature_tables_conf: List[Dict], - destination_conf: Dict, - job_id: str, - **kwargs, - ) -> RetrievalJob: - """ - Submits a historical feature retrieval job to a Spark cluster. - - Args: - pyspark_script (str): Local file path to the pyspark script for historical feature - retrieval. - entity_source_conf (Dict): Entity data source configuration. - feature_tables_sources_conf (List[Dict]): List of feature tables data sources configurations. - feature_tables_conf (List[Dict]): List of feature table specification. - The order of the feature table must correspond to that of feature_tables_sources. - destination_conf (Dict): Retrieval job output destination. - job_id (str): A job id that is unique for each job submission. - - Raises: - SparkJobFailure: The spark job submission failed, encountered error - during execution, or timeout. - - Examples: - >>> # Entity source from file - >>> entity_source_conf = { - "file": { - "format": "parquet", - "path": "gs://some-gcs-bucket/customer", - "event_timestamp_column": "event_timestamp", - "options": { - "mergeSchema": "true" - } # Optional. Options to be passed to Spark while reading the dataframe from source. - "field_mapping": { - "id": "customer_id" - } # Optional. Map the columns, where the key is the original column name and the value is the new column name. - - } - } - - >>> # Entity source from BigQuery - >>> entity_source_conf = { - "bq": { - "project": "gcp_project_id", - "dataset": "bq_dataset", - "table": "customer", - "event_timestamp_column": "event_timestamp", - } - } - - >>> feature_table_sources_conf = [ - { - "bq": { - "project": "gcp_project_id", - "dataset": "bq_dataset", - "table": "customer_transactions", - "event_timestamp_column": "event_timestamp", - "created_timestamp_column": "created_timestamp" # This field is mandatory for feature tables. - } - }, - - { - "file": { - "format": "parquet", - "path": "gs://some-gcs-bucket/customer_profile", - "event_timestamp_column": "event_timestamp", - "created_timestamp_column": "created_timestamp", - "options": { - "mergeSchema": "true" - } - } - }, - ] - - - >>> feature_tables_conf = [ - { - "name": "customer_transactions", - "entities": [ - { - "name": "customer - "type": "int32" - } - ], - "features": [ - { - "name": "total_transactions" - "type": "double" - }, - { - "name": "total_discounts" - "type": "double" - } - ], - "max_age": 86400 # In seconds. - }, - - { - "name": "customer_profile", - "entities": [ - { - "name": "customer - "type": "int32" - } - ], - "features": [ - { - "name": "is_vip" - "type": "bool" - } - ], - - } - ] - - >>> destination_conf = { - "format": "parquet", - "path": "gs://some-gcs-bucket/retrieval_output" - } - - Returns: - str: file uri to the result file. - """ - raise NotImplementedError - - -class StandaloneClusterLauncher(JobLauncher): - """ - Submits jobs to a standalone Spark cluster in client mode. - """ - - def __init__(self, master_url: str, spark_home: str = None): - """ - This launcher executes the spark-submit script in a subprocess. The subprocess - will run until the Pyspark driver exits. - - Args: - master_url (str): - Spark cluster url. Must start with spark://. - spark_home (str): - Local file path to Spark installation directory. If not provided, - the environmental variable `SPARK_HOME` will be used instead. - """ - self.master_url = master_url - self.spark_home = spark_home if spark_home else os.getenv("SPARK_HOME") - - @property - def spark_submit_script_path(self): - return os.path.join(self.spark_home, "bin/spark-submit") - - def historical_feature_retrieval( - self, - pyspark_script: str, - entity_source_conf: Dict, - feature_tables_sources_conf: List[Dict], - feature_tables_conf: List[Dict], - destination_conf: Dict, - job_id: str, - **kwargs, - ) -> RetrievalJob: - - submission_cmd = [ - self.spark_submit_script_path, - "--master", - self.master_url, - "--name", - job_id, - pyspark_script, - "--feature-tables", - json.dumps(feature_tables_conf), - "--feature-tables-sources", - json.dumps(feature_tables_sources_conf), - "--entity-source", - json.dumps(entity_source_conf), - "--destination", - json.dumps(destination_conf), - ] - - process = subprocess.Popen(submission_cmd, shell=True) - output_file = destination_conf["path"] - return StandaloneClusterRetrievalJob(job_id, process, output_file) - - -class DataprocClusterLauncher(JobLauncher): - """ - Submits jobs to an existing Dataproc cluster. Depends on google-cloud-dataproc and - google-cloud-storage, which are optional dependencies that the user has to installed in - addition to the Feast SDK. - """ - - def __init__( - self, cluster_name: str, staging_location: str, region: str, project_id: str, - ): - """ - Initialize a dataproc job controller client, used internally for job submission and result - retrieval. - - Args: - cluster_name (str): - Dataproc cluster name. - staging_location (str): - GCS directory for the storage of files generated by the launcher, such as the pyspark scripts. - region (str): - Dataproc cluster region. - project_id (str: - GCP project id for the dataproc cluster. - """ - from google.cloud import dataproc_v1 - - self.cluster_name = cluster_name - - scheme, self.staging_bucket, self.remote_path, _, _, _ = urlparse( - staging_location - ) - if scheme != "gs": - raise ValueError( - "Only GCS staging location is supported for DataprocLauncher." - ) - self.project_id = project_id - self.region = region - self.job_client = dataproc_v1.JobControllerClient( - client_options={"api_endpoint": f"{region}-dataproc.googleapis.com:443"} - ) - - def _stage_files(self, pyspark_script: str, job_id: str) -> str: - from google.cloud import storage - - client = storage.Client() - bucket = client.get_bucket(self.staging_bucket) - blob_path = os.path.join( - self.remote_path, job_id, os.path.basename(pyspark_script), - ) - blob = bucket.blob(blob_path) - blob.upload_from_filename(pyspark_script) - - return f"gs://{self.staging_bucket}/{blob_path}" - - def historical_feature_retrieval( - self, - pyspark_script: str, - entity_source_conf: Dict, - feature_tables_sources_conf: List[Dict], - feature_tables_conf: List[Dict], - destination_conf: Dict, - job_id: str, - **kwargs, - ) -> RetrievalJob: - - pyspark_gcs = self._stage_files(pyspark_script, job_id) - job = { - "reference": {"job_id": job_id}, - "placement": {"cluster_name": self.cluster_name}, - "pyspark_job": { - "main_python_file_uri": pyspark_gcs, - "args": [ - "--feature-tables", - json.dumps(feature_tables_conf), - "--feature-tables-sources", - json.dumps(feature_tables_sources_conf), - "--entity-source", - json.dumps(entity_source_conf), - "--destination", - json.dumps(destination_conf), - ], - }, - } - operation = self.job_client.submit_job_as_operation( - request={"project_id": self.project_id, "region": self.region, "job": job} - ) - output_file = destination_conf["path"] - return DataprocRetrievalJob(job_id, operation, output_file) diff --git a/sdk/python/feast/pyspark/launchers/__init__.py b/sdk/python/feast/pyspark/launchers/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/python/feast/pyspark/launchers/aws/__init__.py b/sdk/python/feast/pyspark/launchers/aws/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/sdk/python/feast/pyspark/launchers/gcloud/__init__.py b/sdk/python/feast/pyspark/launchers/gcloud/__init__.py new file mode 100644 index 0000000000..6f14542761 --- /dev/null +++ b/sdk/python/feast/pyspark/launchers/gcloud/__init__.py @@ -0,0 +1,3 @@ +from .dataproc import DataprocClusterLauncher, DataprocRetrievalJob + +__all__ = ["DataprocRetrievalJob", "DataprocClusterLauncher"] diff --git a/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py b/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py new file mode 100644 index 0000000000..efa0db413d --- /dev/null +++ b/sdk/python/feast/pyspark/launchers/gcloud/dataproc.py @@ -0,0 +1,125 @@ +import json +import os +from typing import Dict, List +from urllib.parse import urlparse + +from google.cloud import dataproc_v1, storage + +from feast.pyspark.abc import JobLauncher, RetrievalJob, SparkJobFailure + + +class DataprocRetrievalJob(RetrievalJob): + """ + Historical feature retrieval job result for a Dataproc cluster + """ + + def __init__(self, job_id, operation, output_file_uri): + """ + This is the returned historical feature retrieval job result for DataprocClusterLauncher. + + Args: + job_id (str): Historical feature retrieval job id. + operation (google.api.core.operation.Operation): A Future for the spark job result, + returned by the dataproc client. + output_file_uri (str): Uri to the historical feature retrieval job output file. + """ + self.job_id = job_id + self._operation = operation + self._output_file_uri = output_file_uri + + def get_id(self) -> str: + return self.job_id + + def get_output_file_uri(self, timeout_sec=None): + try: + self._operation.result(timeout_sec) + except Exception as err: + raise SparkJobFailure(err) + return self._output_file_uri + + +class DataprocClusterLauncher(JobLauncher): + """ + Submits jobs to an existing Dataproc cluster. Depends on google-cloud-dataproc and + google-cloud-storage, which are optional dependencies that the user has to installed in + addition to the Feast SDK. + """ + + def __init__( + self, cluster_name: str, staging_location: str, region: str, project_id: str, + ): + """ + Initialize a dataproc job controller client, used internally for job submission and result + retrieval. + + Args: + cluster_name (str): + Dataproc cluster name. + staging_location (str): + GCS directory for the storage of files generated by the launcher, such as the pyspark scripts. + region (str): + Dataproc cluster region. + project_id (str: + GCP project id for the dataproc cluster. + """ + + self.cluster_name = cluster_name + + scheme, self.staging_bucket, self.remote_path, _, _, _ = urlparse( + staging_location + ) + if scheme != "gs": + raise ValueError( + "Only GCS staging location is supported for DataprocLauncher." + ) + self.project_id = project_id + self.region = region + self.job_client = dataproc_v1.JobControllerClient( + client_options={"api_endpoint": f"{region}-dataproc.googleapis.com:443"} + ) + + def _stage_files(self, pyspark_script: str, job_id: str) -> str: + client = storage.Client() + bucket = client.get_bucket(self.staging_bucket) + blob_path = os.path.join( + self.remote_path, job_id, os.path.basename(pyspark_script), + ) + blob = bucket.blob(blob_path) + blob.upload_from_filename(pyspark_script) + + return f"gs://{self.staging_bucket}/{blob_path}" + + def historical_feature_retrieval( + self, + pyspark_script: str, + entity_source_conf: Dict, + feature_tables_sources_conf: List[Dict], + feature_tables_conf: List[Dict], + destination_conf: Dict, + job_id: str, + **kwargs, + ) -> RetrievalJob: + + pyspark_gcs = self._stage_files(pyspark_script, job_id) + job = { + "reference": {"job_id": job_id}, + "placement": {"cluster_name": self.cluster_name}, + "pyspark_job": { + "main_python_file_uri": pyspark_gcs, + "args": [ + "--feature-tables", + json.dumps(feature_tables_conf), + "--feature-tables-sources", + json.dumps(feature_tables_sources_conf), + "--entity-source", + json.dumps(entity_source_conf), + "--destination", + json.dumps(destination_conf), + ], + }, + } + operation = self.job_client.submit_job_as_operation( + request={"project_id": self.project_id, "region": self.region, "job": job} + ) + output_file = destination_conf["path"] + return DataprocRetrievalJob(job_id, operation, output_file) diff --git a/sdk/python/feast/pyspark/launchers/standalone/__init__.py b/sdk/python/feast/pyspark/launchers/standalone/__init__.py new file mode 100644 index 0000000000..1c44e5497f --- /dev/null +++ b/sdk/python/feast/pyspark/launchers/standalone/__init__.py @@ -0,0 +1,3 @@ +from .local import StandaloneClusterLauncher, StandaloneClusterRetrievalJob + +__all__ = ["StandaloneClusterRetrievalJob", "StandaloneClusterLauncher"] diff --git a/sdk/python/feast/pyspark/launchers/standalone/local.py b/sdk/python/feast/pyspark/launchers/standalone/local.py new file mode 100644 index 0000000000..35bd00f751 --- /dev/null +++ b/sdk/python/feast/pyspark/launchers/standalone/local.py @@ -0,0 +1,100 @@ +import json +import os +import subprocess +from typing import Dict, List + +from feast.pyspark.abc import JobLauncher, RetrievalJob, SparkJobFailure + + +class StandaloneClusterRetrievalJob(RetrievalJob): + """ + Historical feature retrieval job result for a standalone spark cluster + """ + + def __init__(self, job_id: str, process: subprocess.Popen, output_file_uri: str): + """ + This is the returned historical feature retrieval job result for StandaloneClusterLauncher. + + Args: + job_id (str): Historical feature retrieval job id. + process (subprocess.Popen): Pyspark driver process, spawned by the launcher. + output_file_uri (str): Uri to the historical feature retrieval job output file. + """ + self.job_id = job_id + self._process = process + self._output_file_uri = output_file_uri + + def get_id(self) -> str: + return self.job_id + + def get_output_file_uri(self, timeout_sec: int = None): + with self._process as p: + try: + p.wait(timeout_sec) + except Exception: + p.kill() + raise SparkJobFailure("Timeout waiting for subprocess to return") + if self._process.returncode != 0: + stderr = "" if self._process.stderr is None else self._process.stderr.read() + stdout = "" if self._process.stdout is None else self._process.stdout.read() + + raise SparkJobFailure( + f"Non zero return code: {self._process.returncode}. stderr: {stderr} stdout: {stdout}" + ) + + +class StandaloneClusterLauncher(JobLauncher): + """ + Submits jobs to a standalone Spark cluster in client mode. + """ + + def __init__(self, master_url: str, spark_home: str = None): + """ + This launcher executes the spark-submit script in a subprocess. The subprocess + will run until the Pyspark driver exits. + + Args: + master_url (str): + Spark cluster url. Must start with spark://. + spark_home (str): + Local file path to Spark installation directory. If not provided, + the environmental variable `SPARK_HOME` will be used instead. + """ + self.master_url = master_url + self.spark_home = spark_home if spark_home else os.getenv("SPARK_HOME") + + @property + def spark_submit_script_path(self): + return os.path.join(self.spark_home, "bin/spark-submit") + + def historical_feature_retrieval( + self, + pyspark_script: str, + entity_source_conf: Dict, + feature_tables_sources_conf: List[Dict], + feature_tables_conf: List[Dict], + destination_conf: Dict, + job_id: str, + **kwargs, + ) -> RetrievalJob: + + submission_cmd = [ + self.spark_submit_script_path, + "--master", + self.master_url, + "--name", + job_id, + pyspark_script, + "--feature-tables", + json.dumps(feature_tables_conf), + "--feature-tables-sources", + json.dumps(feature_tables_sources_conf), + "--entity-source", + json.dumps(entity_source_conf), + "--destination", + json.dumps(destination_conf), + ] + + process = subprocess.Popen(submission_cmd, shell=True) + output_file = destination_conf["path"] + return StandaloneClusterRetrievalJob(job_id, process, output_file) diff --git a/sdk/python/requirements-ci.txt b/sdk/python/requirements-ci.txt index 69c0be5a05..62b335e557 100644 --- a/sdk/python/requirements-ci.txt +++ b/sdk/python/requirements-ci.txt @@ -9,7 +9,7 @@ pytest-lazy-fixture==0.6.3 pytest-mock pytest-timeout pytest-ordering==0.6.* -pyspark==3.* +pyspark==2.4.2 pandas~=1.0.0 mock==2.0.0 pandavro==1.5.* diff --git a/sdk/python/requirements-dev.txt b/sdk/python/requirements-dev.txt index 113de1ab82..ac7813ec79 100644 --- a/sdk/python/requirements-dev.txt +++ b/sdk/python/requirements-dev.txt @@ -39,5 +39,5 @@ flake8 black==19.10b0 boto3 moto -pyspark==3.* -pyspark-stubs==3.* +pyspark==2.4.2 +pyspark-stubs==2.4.0.post9 diff --git a/sdk/python/tests/feast_core_server.py b/sdk/python/tests/feast_core_server.py index f66830d7a4..85f09175bf 100644 --- a/sdk/python/tests/feast_core_server.py +++ b/sdk/python/tests/feast_core_server.py @@ -11,7 +11,11 @@ ApplyEntityResponse, ApplyFeatureTableRequest, ApplyFeatureTableResponse, + GetEntityRequest, + GetEntityResponse, GetFeastCoreVersionResponse, + GetFeatureTableRequest, + GetFeatureTableResponse, ListEntitiesRequest, ListEntitiesResponse, ListFeatureTablesRequest, @@ -66,6 +70,14 @@ def __init__(self): def GetFeastCoreVersion(self, request, context): return GetFeastCoreVersionResponse(version="0.10.0") + def GetFeatureTable(self, request: GetFeatureTableRequest, context): + filtered_table = [ + table + for table in self._feature_tables.values() + if table.spec.name == request.name + ] + return GetFeatureTableResponse(table=filtered_table[0]) + def ListFeatureTables(self, request: ListFeatureTablesRequest, context): filtered_feature_table_response = list(self._feature_tables.values()) @@ -93,6 +105,14 @@ def ApplyFeatureTable(self, request: ApplyFeatureTableRequest, context): return ApplyFeatureTableResponse(table=applied_feature_table,) + def GetEntity(self, request: GetEntityRequest, context): + filtered_entities = [ + entity + for entity in self._entities.values() + if entity.spec.name == request.name + ] + return GetEntityResponse(entity=filtered_entities[0]) + def ListEntities(self, request: ListEntitiesRequest, context): filtered_entities_response = list(self._entities.values()) diff --git a/sdk/python/tests/test_as_of_join.py b/sdk/python/tests/test_as_of_join.py index b084b87a56..a6288855c7 100644 --- a/sdk/python/tests/test_as_of_join.py +++ b/sdk/python/tests/test_as_of_join.py @@ -18,7 +18,6 @@ from feast.pyspark.historical_feature_retrieval_job import ( FeatureTable, Field, - FileSource, SchemaError, as_of_join, join_entity_to_feature_tables, @@ -29,9 +28,7 @@ @pytest.yield_fixture(scope="module") def spark(pytestconfig): spark_session = ( - SparkSession.builder.appName("Batch Retrieval Test") - .master("local") - .getOrCreate() + SparkSession.builder.appName("As of join test").master("local").getOrCreate() ) yield spark_session spark_session.stop() @@ -583,37 +580,43 @@ def test_multiple_join( def test_historical_feature_retrieval(spark: SparkSession): test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data") - entity_source = FileSource( - format="csv", - path=f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", - event_timestamp_column="event_timestamp", - options={"inferSchema": "true", "header": "true"}, - ) - booking_source = FileSource( - format="csv", - path=f"file://{path.join(test_data_dir, 'bookings.csv')}", - event_timestamp_column="event_timestamp", - created_timestamp_column="created_timestamp", - options={"inferSchema": "true", "header": "true"}, - ) - transaction_source = FileSource( - format="csv", - path=f"file://{path.join(test_data_dir, 'transactions.csv')}", - event_timestamp_column="event_timestamp", - created_timestamp_column="created_timestamp", - options={"inferSchema": "true", "header": "true"}, - ) - booking_table = FeatureTable( - name="bookings", - entities=[Field("driver_id", "int32")], - features=[Field("completed_bookings", "int32")], - ) - transaction_table = FeatureTable( - name="transactions", - entities=[Field("customer_id", "int32")], - features=[Field("daily_transactions", "double")], - max_age=86400, - ) + entity_source = { + "file": { + "format": "csv", + "path": f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", + "event_timestamp_column": "event_timestamp", + "options": {"inferSchema": "true", "header": "true"}, + } + } + booking_source = { + "file": { + "format": "csv", + "path": f"file://{path.join(test_data_dir, 'bookings.csv')}", + "event_timestamp_column": "event_timestamp", + "created_timestamp_column": "created_timestamp", + "options": {"inferSchema": "true", "header": "true"}, + } + } + transaction_source = { + "file": { + "format": "csv", + "path": f"file://{path.join(test_data_dir, 'transactions.csv')}", + "event_timestamp_column": "event_timestamp", + "created_timestamp_column": "created_timestamp", + "options": {"inferSchema": "true", "header": "true"}, + } + } + booking_table = { + "name": "bookings", + "entities": [{"name": "driver_id", "type": "int32"}], + "features": [{"name": "completed_bookings", "type": "int32"}], + } + transaction_table = { + "name": "transactions", + "entities": [{"name": "customer_id", "type": "int32"}], + "features": [{"name": "daily_transactions", "type": "double"}], + "max_age": 86400, + } joined_df = retrieve_historical_features( spark, @@ -648,25 +651,29 @@ def test_historical_feature_retrieval(spark: SparkSession): def test_historical_feature_retrieval_with_mapping(spark: SparkSession): test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data") - entity_source = FileSource( - format="csv", - path=f"file://{path.join(test_data_dir, 'column_mapping_test_entity.csv')}", - event_timestamp_column="event_timestamp", - field_mapping={"id": "customer_id"}, - options={"inferSchema": "true", "header": "true"}, - ) - booking_source = FileSource( - format="csv", - path=f"file://{path.join(test_data_dir, 'column_mapping_test_feature.csv')}", - event_timestamp_column="datetime", - created_timestamp_column="created_datetime", - options={"inferSchema": "true", "header": "true"}, - ) - booking_table = FeatureTable( - name="bookings", - entities=[Field("customer_id", "int32")], - features=[Field("total_bookings", "int32")], - ) + entity_source = { + "file": { + "format": "csv", + "path": f"file://{path.join(test_data_dir, 'column_mapping_test_entity.csv')}", + "event_timestamp_column": "event_timestamp", + "field_mapping": {"id": "customer_id"}, + "options": {"inferSchema": "true", "header": "true"}, + } + } + booking_source = { + "file": { + "format": "csv", + "path": f"file://{path.join(test_data_dir, 'column_mapping_test_feature.csv')}", + "event_timestamp_column": "datetime", + "created_timestamp_column": "created_datetime", + "options": {"inferSchema": "true", "header": "true"}, + } + } + booking_table = { + "name": "bookings", + "entities": [{"name": "customer_id", "type": "int32"}], + "features": [{"name": "total_bookings", "type": "int32"}], + } joined_df = retrieve_historical_features( spark, entity_source, [booking_source], [booking_table], @@ -714,25 +721,29 @@ def test_large_historical_feature_retrieval( spark.sparkContext.parallelize(expected_join_data), expected_join_data_schema ) - entity_source = FileSource( - format="csv", - path=f"file://{large_entity_csv_file}", - event_timestamp_column="event_timestamp", - field_mapping={"id": "customer_id"}, - options={"inferSchema": "true", "header": "true"}, - ) - feature_source = FileSource( - format="csv", - path=f"file://{large_feature_csv_file}", - event_timestamp_column="event_timestamp", - created_timestamp_column="created_timestamp", - options={"inferSchema": "true", "header": "true"}, - ) - feature_table = FeatureTable( - name="feature", - entities=[Field("customer_id", "int32")], - features=[Field("total_bookings", "int32")], - ) + entity_source = { + "file": { + "format": "csv", + "path": f"file://{large_entity_csv_file}", + "event_timestamp_column": "event_timestamp", + "field_mapping": {"id": "customer_id"}, + "options": {"inferSchema": "true", "header": "true"}, + } + } + feature_source = { + "file": { + "format": "csv", + "path": f"file://{large_feature_csv_file}", + "event_timestamp_column": "event_timestamp", + "created_timestamp_column": "created_timestamp", + "options": {"inferSchema": "true", "header": "true"}, + } + } + feature_table = { + "name": "feature", + "entities": [{"name": "customer_id", "type": "int32"}], + "features": [{"name": "total_bookings", "type": "int32"}], + } joined_df = retrieve_historical_features( spark, entity_source, [feature_source], [feature_table] @@ -742,54 +753,64 @@ def test_large_historical_feature_retrieval( def test_historical_feature_retrieval_with_schema_errors(spark: SparkSession): test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data") - entity_source = FileSource( - format="csv", - path=f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", - event_timestamp_column="event_timestamp", - options={"inferSchema": "true", "header": "true"}, - ) - entity_source_missing_timestamp = FileSource( - format="csv", - path=f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", - event_timestamp_column="datetime", - options={"inferSchema": "true", "header": "true"}, - ) - entity_source_missing_entity = FileSource( - format="csv", - path=f"file://{path.join(test_data_dir, 'customers.csv')}", - event_timestamp_column="event_timestamp", - options={"inferSchema": "true", "header": "true"}, - ) - - booking_source = FileSource( - format="csv", - path=f"file://{path.join(test_data_dir, 'bookings.csv')}", - event_timestamp_column="event_timestamp", - created_timestamp_column="created_timestamp", - options={"inferSchema": "true", "header": "true"}, - ) - booking_source_missing_timestamp = FileSource( - format="csv", - path=f"file://{path.join(test_data_dir, 'bookings.csv')}", - event_timestamp_column="datetime", - created_timestamp_column="created_datetime", - options={"inferSchema": "true", "header": "true"}, - ) - booking_table = FeatureTable( - name="bookings", - entities=[Field("driver_id", "int32")], - features=[Field("completed_bookings", "int32")], - ) - booking_table_missing_features = FeatureTable( - name="bookings", - entities=[Field("driver_id", "int32")], - features=[Field("nonexist_feature", "int32")], - ) - booking_table_wrong_column_type = FeatureTable( - name="bookings", - entities=[Field("driver_id", "string")], - features=[Field("completed_bookings", "int32")], - ) + entity_source = { + "file": { + "format": "csv", + "path": f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", + "event_timestamp_column": "event_timestamp", + "options": {"inferSchema": "true", "header": "true"}, + } + } + entity_source_missing_timestamp = { + "file": { + "format": "csv", + "path": f"file://{path.join(test_data_dir, 'customer_driver_pairs.csv')}", + "event_timestamp_column": "datetime", + "options": {"inferSchema": "true", "header": "true"}, + } + } + entity_source_missing_entity = { + "file": { + "format": "csv", + "path": f"file://{path.join(test_data_dir, 'customers.csv')}", + "event_timestamp_column": "event_timestamp", + "options": {"inferSchema": "true", "header": "true"}, + } + } + + booking_source = { + "file": { + "format": "csv", + "path": f"file://{path.join(test_data_dir, 'bookings.csv')}", + "event_timestamp_column": "event_timestamp", + "created_timestamp_column": "created_timestamp", + "options": {"inferSchema": "true", "header": "true"}, + } + } + booking_source_missing_timestamp = { + "file": { + "format": "csv", + "path": f"file://{path.join(test_data_dir, 'bookings.csv')}", + "event_timestamp_column": "datetime", + "created_timestamp_column": "created_datetime", + "options": {"inferSchema": "true", "header": "true"}, + } + } + booking_table = { + "name": "bookings", + "entities": [{"name": "driver_id", "type": "int32"}], + "features": [{"name": "completed_bookings", "type": "int32"}], + } + booking_table_missing_features = { + "name": "bookings", + "entities": [{"name": "driver_id", "type": "int32"}], + "features": [{"name": "nonexist_feature", "type": "int32"}], + } + booking_table_wrong_column_type = { + "name": "bookings", + "entities": [{"name": "driver_id", "type": "string"}], + "features": [{"name": "completed_bookings", "type": "int32"}], + } with pytest.raises(SchemaError): retrieve_historical_features( diff --git a/sdk/python/tests/test_historical_feature_retrieval.py b/sdk/python/tests/test_historical_feature_retrieval.py new file mode 100644 index 0000000000..d9dbd9f4fc --- /dev/null +++ b/sdk/python/tests/test_historical_feature_retrieval.py @@ -0,0 +1,355 @@ +import os +import shutil +import socket +import tempfile +from concurrent import futures +from contextlib import closing +from datetime import datetime +from typing import List, Tuple + +import grpc +import pytest +from google.protobuf.duration_pb2 import Duration +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import ( + BooleanType, + DoubleType, + IntegerType, + StructField, + StructType, + TimestampType, +) + +from feast import Client, Entity, Feature, FeatureTable, FileSource, ValueType +from feast.core import CoreService_pb2_grpc as Core +from tests.feast_core_server import CoreServicer + + +def find_free_port(): + with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] + + +free_port = find_free_port() + + +def assert_dataframe_equal(left: DataFrame, right: DataFrame): + is_column_equal = set(left.columns) == set(right.columns) + + if not is_column_equal: + print(f"Column not equal. Left: {left.columns}, Right: {right.columns}") + assert is_column_equal + + is_content_equal = ( + left.exceptAll(right).count() == 0 and right.exceptAll(left).count() == 0 + ) + if not is_content_equal: + print("Rows are different.") + print("Left:") + left.show() + print("Right:") + right.show() + + assert is_content_equal + + +@pytest.yield_fixture(scope="module") +def spark(): + spark_session = ( + SparkSession.builder.appName("Historical Feature Retrieval Test") + .master("local") + .getOrCreate() + ) + yield spark_session + spark_session.stop() + + +@pytest.fixture() +def server(): + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + Core.add_CoreServiceServicer_to_server(CoreServicer(), server) + server.add_insecure_port(f"[::]:{free_port}") + server.start() + yield server + server.stop(0) + + +@pytest.fixture() +def client(server): + return Client(core_url=f"localhost:{free_port}") + + +@pytest.fixture() +def driver_entity(client): + return client.apply_entity(Entity("driver_id", "description", ValueType.INT32)) + + +@pytest.fixture() +def customer_entity(client): + return client.apply_entity(Entity("customer_id", "description", ValueType.INT32)) + + +def create_temp_parquet_file( + spark: SparkSession, filename, schema: StructType, data: List[Tuple] +) -> Tuple[str, str]: + temp_dir = tempfile.mkdtemp() + file_path = os.path.join(temp_dir, filename) + df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.write.parquet(file_path) + return temp_dir, f"file://{file_path}" + + +@pytest.fixture() +def transactions_feature_table(spark, client): + schema = StructType( + [ + StructField("customer_id", IntegerType()), + StructField("event_timestamp", TimestampType()), + StructField("created_timestamp", TimestampType()), + StructField("total_transactions", DoubleType()), + StructField("is_vip", BooleanType()), + ] + ) + df_data = [ + ( + 1001, + datetime(year=2020, month=9, day=1), + datetime(year=2020, month=9, day=1), + 50.0, + True, + ), + ( + 1001, + datetime(year=2020, month=9, day=1), + datetime(year=2020, month=9, day=2), + 100.0, + True, + ), + ( + 2001, + datetime(year=2020, month=9, day=1), + datetime(year=2020, month=9, day=1), + 400.0, + False, + ), + ( + 1001, + datetime(year=2020, month=9, day=2), + datetime(year=2020, month=9, day=1), + 200.0, + False, + ), + ( + 1001, + datetime(year=2020, month=9, day=4), + datetime(year=2020, month=9, day=1), + 300.0, + False, + ), + ] + temp_dir, file_uri = create_temp_parquet_file( + spark, "transactions", schema, df_data + ) + file_source = FileSource( + "event_timestamp", "created_timestamp", "parquet", file_uri + ) + features = [ + Feature("total_transactions", ValueType.DOUBLE), + Feature("is_vip", ValueType.BOOL), + ] + feature_table = FeatureTable( + "transactions", ["customer_id"], features, batch_source=file_source + ) + yield client.apply_feature_table(feature_table) + shutil.rmtree(temp_dir) + + +@pytest.fixture() +def bookings_feature_table(spark, client): + schema = StructType( + [ + StructField("driver_id", IntegerType()), + StructField("event_timestamp", TimestampType()), + StructField("created_timestamp", TimestampType()), + StructField("total_completed_bookings", IntegerType()), + ] + ) + df_data = [ + ( + 8001, + datetime(year=2020, month=9, day=1), + datetime(year=2020, month=9, day=1), + 100, + ), + ( + 8001, + datetime(year=2020, month=9, day=2), + datetime(year=2020, month=9, day=2), + 150, + ), + ( + 8002, + datetime(year=2020, month=9, day=2), + datetime(year=2020, month=9, day=2), + 200, + ), + ] + temp_dir, file_uri = create_temp_parquet_file(spark, "bookings", schema, df_data) + + file_source = FileSource( + "event_timestamp", "created_timestamp", "parquet", file_uri + ) + features = [Feature("total_completed_bookings", ValueType.INT32)] + max_age = Duration() + max_age.FromSeconds(86400) + feature_table = FeatureTable( + "bookings", ["driver_id"], features, batch_source=file_source, max_age=max_age + ) + yield client.apply_feature_table(feature_table) + shutil.rmtree(temp_dir) + + +@pytest.fixture() +def bookings_feature_table_with_mapping(spark, client): + schema = StructType( + [ + StructField("id", IntegerType()), + StructField("datetime", TimestampType()), + StructField("created_datetime", TimestampType()), + StructField("total_completed_bookings", IntegerType()), + ] + ) + df_data = [ + ( + 8001, + datetime(year=2020, month=9, day=1), + datetime(year=2020, month=9, day=1), + 100, + ), + ( + 8001, + datetime(year=2020, month=9, day=2), + datetime(year=2020, month=9, day=2), + 150, + ), + ( + 8002, + datetime(year=2020, month=9, day=2), + datetime(year=2020, month=9, day=2), + 200, + ), + ] + temp_dir, file_uri = create_temp_parquet_file(spark, "bookings", schema, df_data) + + file_source = FileSource( + "datetime", "created_datetime", "parquet", file_uri, {"id": "driver_id"} + ) + features = [Feature("total_completed_bookings", ValueType.INT32)] + max_age = Duration() + max_age.FromSeconds(86400) + feature_table = FeatureTable( + "bookings", ["driver_id"], features, batch_source=file_source, max_age=max_age + ) + yield client.apply_feature_table(feature_table) + shutil.rmtree(temp_dir) + + +def test_historical_feature_retrieval_from_local_spark_session( + spark, + client, + driver_entity, + customer_entity, + bookings_feature_table, + transactions_feature_table, +): + schema = StructType( + [ + StructField("customer_id", IntegerType()), + StructField("driver_id", IntegerType()), + StructField("event_timestamp", TimestampType()), + ] + ) + df_data = [ + (1001, 8001, datetime(year=2020, month=9, day=1),), + (2001, 8001, datetime(year=2020, month=9, day=2),), + (2001, 8002, datetime(year=2020, month=9, day=1),), + (1001, 8001, datetime(year=2020, month=9, day=2),), + (1001, 8001, datetime(year=2020, month=9, day=3),), + (1001, 8001, datetime(year=2020, month=9, day=4),), + ] + temp_dir, file_uri = create_temp_parquet_file( + spark, "customer_driver_pair", schema, df_data + ) + customer_driver_pairs_source = FileSource( + "event_timestamp", "created_timestamp", "parquet", file_uri + ) + joined_df = client.get_historical_features_df( + ["transactions:total_transactions", "bookings:total_completed_bookings"], + customer_driver_pairs_source, + ) + expected_joined_df_schema = StructType( + [ + StructField("customer_id", IntegerType()), + StructField("driver_id", IntegerType()), + StructField("event_timestamp", TimestampType()), + StructField("transactions__total_transactions", DoubleType()), + StructField("bookings__total_completed_bookings", IntegerType()), + ] + ) + expected_joined_df_data = [ + (1001, 8001, datetime(year=2020, month=9, day=1), 100.0, 100), + (2001, 8001, datetime(year=2020, month=9, day=2), 400.0, 150), + (2001, 8002, datetime(year=2020, month=9, day=1), 400.0, None), + (1001, 8001, datetime(year=2020, month=9, day=2), 200.0, 150), + (1001, 8001, datetime(year=2020, month=9, day=3), 200.0, 150), + (1001, 8001, datetime(year=2020, month=9, day=4), 300.0, None), + ] + expected_joined_df = spark.createDataFrame( + spark.sparkContext.parallelize(expected_joined_df_data), + expected_joined_df_schema, + ) + assert_dataframe_equal(joined_df, expected_joined_df) + shutil.rmtree(temp_dir) + + +def test_historical_feature_retrieval_with_field_mappings_from_local_spark_session( + spark, client, driver_entity, bookings_feature_table_with_mapping, +): + schema = StructType( + [ + StructField("driver_id", IntegerType()), + StructField("event_timestamp", TimestampType()), + ] + ) + df_data = [ + (8001, datetime(year=2020, month=9, day=1)), + (8001, datetime(year=2020, month=9, day=2)), + (8002, datetime(year=2020, month=9, day=1)), + ] + temp_dir, file_uri = create_temp_parquet_file(spark, "drivers", schema, df_data) + entity_source = FileSource( + "event_timestamp", "created_timestamp", "parquet", file_uri + ) + joined_df = client.get_historical_features_df( + ["bookings:total_completed_bookings"], entity_source, + ) + expected_joined_df_schema = StructType( + [ + StructField("driver_id", IntegerType()), + StructField("event_timestamp", TimestampType()), + StructField("bookings__total_completed_bookings", IntegerType()), + ] + ) + expected_joined_df_data = [ + (8001, datetime(year=2020, month=9, day=1), 100), + (8001, datetime(year=2020, month=9, day=2), 150), + (8002, datetime(year=2020, month=9, day=1), None), + ] + expected_joined_df = spark.createDataFrame( + spark.sparkContext.parallelize(expected_joined_df_data), + expected_joined_df_schema, + ) + assert_dataframe_equal(joined_df, expected_joined_df) + shutil.rmtree(temp_dir)