Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feast SDK integration for historical feature retrieval using Spark #1054

Merged
merged 3 commits into from
Oct 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 119 additions & 3 deletions sdk/python/feast/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import logging
import multiprocessing
import shutil
import uuid
from itertools import groupby
from typing import Any, Dict, List, Optional, Union

import grpc
Expand All @@ -30,6 +32,8 @@
CONFIG_SERVING_ENABLE_SSL_KEY,
CONFIG_SERVING_SERVER_SSL_CERT_KEY,
CONFIG_SERVING_URL_KEY,
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT,
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION,
FEAST_DEFAULT_OPTIONS,
)
from feast.core.CoreService_pb2 import (
Expand Down Expand Up @@ -70,6 +74,11 @@
_write_partitioned_table_from_source,
)
from feast.online_response import OnlineResponse, _infer_online_entity_rows
from feast.pyspark.abc import RetrievalJob
from feast.pyspark.launcher import (
start_historical_feature_retrieval_job,
start_historical_feature_retrieval_spark_session,
)
from feast.serving.ServingService_pb2 import (
GetFeastServingInfoRequest,
GetOnlineFeaturesRequestV2,
Expand Down Expand Up @@ -723,7 +732,6 @@ def get_online_features(
) -> OnlineResponse:
"""
Retrieves the latest online feature data from Feast Serving.

Args:
feature_refs: List of feature references that will be returned for each entity.
Each feature reference should have the following format:
Expand All @@ -733,12 +741,10 @@ def get_online_features(
entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.
project: Optionally specify the the project override. If specified, uses given project for retrieval.
Overrides the projects specified in Feature References if also are specified.

Returns:
GetOnlineFeaturesResponse containing the feature data in records.
Each EntityRow provided will yield one record, which contains
data fields with data value and field status metadata (if included).

Examples:
>>> from feast import Client
>>>
Expand Down Expand Up @@ -767,3 +773,113 @@ def get_online_features(

response = OnlineResponse(response)
return response

def get_historical_features(
self,
feature_refs: List[str],
entity_source: Union[FileSource, BigQuerySource],
project: str = None,
) -> RetrievalJob:
"""
Launch a historical feature retrieval job.

Args:
feature_refs: List of feature references that will be returned for each entity.
Each feature reference should have the following format:
"feature_table:feature" where "feature_table" & "feature" refer to
the feature and feature table names respectively.
entity_source (Union[FileSource, BigQuerySource]): Source for the entity rows.
The user needs to make sure that the source is accessible from the Spark cluster
that will be used for the retrieval job.
project: Specifies the project that contains the feature tables
which the requested features belong to.

Returns:
Returns a retrieval job object that can be used to monitor retrieval
progress asynchronously, and can be used to materialize the
results.

Examples:
>>> from feast import Client
>>> from datetime import datetime
>>> feast_client = Client(core_url="localhost:6565")
>>> feature_refs = ["bookings:bookings_7d", "bookings:booking_14d"]
>>> entity_source = FileSource("event_timestamp", "parquet", "gs://some-bucket/customer")
>>> feature_retrieval_job = feast_client.get_historical_features(
>>> feature_refs, entity_source, project="my_project")
>>> output_file_uri = feature_retrieval_job.get_output_file_uri()
"gs://some-bucket/output/
"""
feature_tables = self._get_feature_tables_from_feature_refs(
feature_refs, project
)
output_location = self._config.get(
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION
)
output_format = self._config.get(CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT)
job_id = f"historical-feature-{str(uuid.uuid4())}"

return start_historical_feature_retrieval_job(
self, entity_source, feature_tables, output_format, output_location, job_id
)

def get_historical_features_df(
self,
feature_refs: List[str],
entity_source: Union[FileSource, BigQuerySource],
project: str = None,
):
"""
Launch a historical feature retrieval job.

Args:
feature_refs: List of feature references that will be returned for each entity.
Each feature reference should have the following format:
"feature_table:feature" where "feature_table" & "feature" refer to
the feature and feature table names respectively.
entity_source (Union[FileSource, BigQuerySource]): Source for the entity rows.
The user needs to make sure that the source is accessible from the Spark cluster
that will be used for the retrieval job.
project: Specifies the project that contains the feature tables
which the requested features belong to.

Returns:
Returns the historical feature retrieval result in the form of Spark dataframe.

Examples:
>>> from feast import Client
>>> from datetime import datetime
>>> from pyspark.sql import SparkSession
>>> spark = SparkSession.builder.getOrCreate()
>>> feast_client = Client(core_url="localhost:6565")
>>> feature_refs = ["bookings:bookings_7d", "bookings:booking_14d"]
>>> entity_source = FileSource("event_timestamp", "parquet", "gs://some-bucket/customer")
>>> df = feast_client.get_historical_features(
>>> feature_refs, entity_source, project="my_project")
"""
feature_tables = self._get_feature_tables_from_feature_refs(
feature_refs, project
)
return start_historical_feature_retrieval_spark_session(
self, entity_source, feature_tables
)

def _get_feature_tables_from_feature_refs(
self, feature_refs: List[str], project: Optional[str]
):
feature_refs_grouped_by_table = [
(feature_table_name, list(grouped_feature_refs))
for feature_table_name, grouped_feature_refs in groupby(
feature_refs, lambda x: x.split(":")[0]
)
]

feature_tables = []
for feature_table_name, grouped_feature_refs in feature_refs_grouped_by_table:
feature_table = self.get_feature_table(feature_table_name, project)
feature_names = [f.split(":")[-1] for f in grouped_feature_refs]
feature_table.features = [
f for f in feature_table.features if f.name in feature_names
]
feature_tables.append(feature_table)
return feature_tables
14 changes: 14 additions & 0 deletions sdk/python/feast/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ class AuthProvider(Enum):
CONFIG_TIMEOUT_KEY = "timeout"
CONFIG_MAX_WAIT_INTERVAL_KEY = "max_wait_interval"

# Spark Job Config
CONFIG_SPARK_LAUNCHER = "spark_launcher" # standalone, dataproc, emr

CONFIG_SPARK_STANDALONE_MASTER = "spark_standalone_master"

CONFIG_SPARK_DATAPROC_CLUSTER_NAME = "dataproc_cluster_name"
CONFIG_SPARK_DATAPROC_PROJECT = "dataproc_project"
CONFIG_SPARK_DATAPROC_REGION = "dataproc_region"
CONFIG_SPARK_DATAPROC_STAGING_LOCATION = "dataproc_staging_location"

CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT = "historical_feature_output_format"
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION = "historical_feature_output_location"


# Configuration option default values
FEAST_DEFAULT_OPTIONS = {
# Default Feast project to use
Expand Down
193 changes: 193 additions & 0 deletions sdk/python/feast/pyspark/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import abc
from typing import Dict, List


class SparkJobFailure(Exception):
"""
Job submission failed, encountered error during execution, or timeout
"""

pass


class SparkJob(abc.ABC):
"""
Base class for all spark jobs
"""

@abc.abstractmethod
def get_id(self) -> str:
"""
Getter for the job id. The job id must be unique for each spark job submission.

Returns:
str: Job id.
"""
raise NotImplementedError


class RetrievalJob(SparkJob):
"""
Container for the historical feature retrieval job result
"""

@abc.abstractmethod
def get_output_file_uri(self, timeout_sec=None):
"""
Get output file uri to the result file. This method will block until the
job succeeded, or if the job didn't execute successfully within timeout.

Args:
timeout_sec (int):
Max no of seconds to wait until job is done. If "timeout_sec"
is exceeded or if the job fails, an exception will be raised.

Raises:
SparkJobFailure:
The spark job submission failed, encountered error during execution,
or timeout.

Returns:
str: file uri to the result file.
"""
raise NotImplementedError


class IngestionJob(SparkJob):
pass


class JobLauncher(abc.ABC):
"""
Submits spark jobs to a spark cluster. Currently supports only historical feature retrieval jobs.
"""

@abc.abstractmethod
def historical_feature_retrieval(
self,
pyspark_script: str,
entity_source_conf: Dict,
feature_tables_sources_conf: List[Dict],
feature_tables_conf: List[Dict],
destination_conf: Dict,
job_id: str,
**kwargs,
) -> RetrievalJob:
"""
Submits a historical feature retrieval job to a Spark cluster.

Args:
pyspark_script (str): Local file path to the pyspark script for historical feature
retrieval.
entity_source_conf (Dict): Entity data source configuration.
feature_tables_sources_conf (List[Dict]): List of feature tables data sources configurations.
feature_tables_conf (List[Dict]): List of feature table specification.
The order of the feature table must correspond to that of feature_tables_sources.
destination_conf (Dict): Retrieval job output destination.
job_id (str): A job id that is unique for each job submission.

Raises:
SparkJobFailure: The spark job submission failed, encountered error
during execution, or timeout.

Examples:
>>> # Entity source from file
>>> entity_source_conf = {
"file": {
"format": "parquet",
"path": "gs://some-gcs-bucket/customer",
"event_timestamp_column": "event_timestamp",
"options": {
"mergeSchema": "true"
} # Optional. Options to be passed to Spark while reading the dataframe from source.
"field_mapping": {
"id": "customer_id"
} # Optional. Map the columns, where the key is the original column name and the value is the new column name.

}
}

>>> # Entity source from BigQuery
>>> entity_source_conf = {
"bq": {
"project": "gcp_project_id",
"dataset": "bq_dataset",
"table": "customer",
"event_timestamp_column": "event_timestamp",
}
}

>>> feature_table_sources_conf = [
{
"bq": {
"project": "gcp_project_id",
"dataset": "bq_dataset",
"table": "customer_transactions",
"event_timestamp_column": "event_timestamp",
"created_timestamp_column": "created_timestamp" # This field is mandatory for feature tables.
}
},

{
"file": {
"format": "parquet",
"path": "gs://some-gcs-bucket/customer_profile",
"event_timestamp_column": "event_timestamp",
"created_timestamp_column": "created_timestamp",
"options": {
"mergeSchema": "true"
}
}
},
]


>>> feature_tables_conf = [
{
"name": "customer_transactions",
"entities": [
{
"name": "customer
"type": "int32"
}
],
"features": [
{
"name": "total_transactions"
"type": "double"
},
{
"name": "total_discounts"
"type": "double"
}
],
"max_age": 86400 # In seconds.
},

{
"name": "customer_profile",
"entities": [
{
"name": "customer
"type": "int32"
}
],
"features": [
{
"name": "is_vip"
"type": "bool"
}
],

}
]

>>> destination_conf = {
"format": "parquet",
"path": "gs://some-gcs-bucket/retrieval_output"
}

Returns:
str: file uri to the result file.
"""
raise NotImplementedError
Loading