Skip to content

Commit

Permalink
Downgrade pyspark dependencies, add more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Khor Shu Heng <khor.heng@gojek.com>
  • Loading branch information
khorshuheng committed Oct 15, 2020
1 parent 27deb7b commit a739121
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 22 deletions.
9 changes: 4 additions & 5 deletions sdk/python/feast/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@
_write_partitioned_table_from_source,
)
from feast.online_response import OnlineResponse, _infer_online_entity_rows
from feast.serving.ServingService_pb2 import (
GetFeastServingInfoRequest,
GetOnlineFeaturesRequestV2,
)
from feast.pyspark.abc import RetrievalJob
from feast.pyspark.launcher import (
start_historical_feature_retrieval_job,
start_historical_feature_retrieval_spark_session,
)
from feast.serving.ServingService_pb2 import (
GetFeastServingInfoRequest,
GetOnlineFeaturesRequestV2,
)
from feast.serving.ServingService_pb2_grpc import ServingServiceStub

_logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -816,7 +816,6 @@ def get_online_features(
response = OnlineResponse(response)
return response


def get_historical_features(
self,
feature_refs: List[str],
Expand Down
9 changes: 4 additions & 5 deletions sdk/python/feast/pyspark/launcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import pathlib
from typing import TYPE_CHECKING, List, Union

Expand Down Expand Up @@ -124,10 +123,10 @@ def start_historical_feature_retrieval_job(
job_id: str,
) -> RetrievalJob:
launcher = resolve_launcher(client._config)
retrieval_job_pyspark_script = os.path.join(
pathlib.Path(__file__).parent.absolute(),
"pyspark",
"historical_feature_retrieval_job.py",
retrieval_job_pyspark_script = str(
pathlib.Path(__file__).parent.absolute()
/ "pyspark"
/ "historical_feature_retrieval_job.py"
)
return launcher.historical_feature_retrieval(
pyspark_script=retrieval_job_pyspark_script,
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/requirements-ci.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pytest-lazy-fixture==0.6.3
pytest-mock
pytest-timeout
pytest-ordering==0.6.*
pyspark==3.*
pyspark==2.4.2
pandas~=1.0.0
mock==2.0.0
pandavro==1.5.*
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ flake8
black==19.10b0
boto3
moto
pyspark==3.*
pyspark-stubs==3.*
pyspark==2.4.2
pyspark-stubs==2.4.0.post9
114 changes: 105 additions & 9 deletions sdk/python/tests/test_historical_feature_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from google.protobuf.duration_pb2 import Duration
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import (
BooleanType,
DoubleType,
IntegerType,
StructField,
Expand Down Expand Up @@ -65,7 +66,7 @@ def spark():
spark_session.stop()


@pytest.fixture(scope="module")
@pytest.fixture()
def server():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
Core.add_CoreServiceServicer_to_server(CoreServicer(), server)
Expand All @@ -75,17 +76,17 @@ def server():
server.stop(0)


@pytest.fixture(scope="module")
@pytest.fixture()
def client(server):
return Client(core_url=f"localhost:{free_port}")


@pytest.fixture(scope="module")
@pytest.fixture()
def driver_entity(client):
return client.apply_entity(Entity("driver_id", "description", ValueType.INT32))


@pytest.fixture(scope="module")
@pytest.fixture()
def customer_entity(client):
return client.apply_entity(Entity("customer_id", "description", ValueType.INT32))

Expand All @@ -100,14 +101,15 @@ def create_temp_parquet_file(
return temp_dir, f"file://{file_path}"


@pytest.fixture(scope="module")
@pytest.fixture()
def transactions_feature_table(spark, client):
schema = StructType(
[
StructField("customer_id", IntegerType()),
StructField("event_timestamp", TimestampType()),
StructField("created_timestamp", TimestampType()),
StructField("total_transactions", DoubleType()),
StructField("is_vip", BooleanType()),
]
)
df_data = [
Expand All @@ -116,30 +118,35 @@ def transactions_feature_table(spark, client):
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
50.0,
True,
),
(
1001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=2),
100.0,
True,
),
(
2001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
400.0,
False,
),
(
1001,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=1),
200.0,
False,
),
(
1001,
datetime(year=2020, month=9, day=4),
datetime(year=2020, month=9, day=1),
300.0,
False,
),
]
temp_dir, file_uri = create_temp_parquet_file(
Expand All @@ -148,15 +155,18 @@ def transactions_feature_table(spark, client):
file_source = FileSource(
"event_timestamp", "created_timestamp", "parquet", file_uri
)
features = [Feature("total_transactions", ValueType.DOUBLE)]
features = [
Feature("total_transactions", ValueType.DOUBLE),
Feature("is_vip", ValueType.BOOL),
]
feature_table = FeatureTable(
"transactions", ["customer_id"], features, batch_source=file_source
)
yield client.apply_feature_table(feature_table)
shutil.rmtree(temp_dir)


@pytest.fixture(scope="module")
@pytest.fixture()
def bookings_feature_table(spark, client):
schema = StructType(
[
Expand Down Expand Up @@ -201,6 +211,51 @@ def bookings_feature_table(spark, client):
shutil.rmtree(temp_dir)


@pytest.fixture()
def bookings_feature_table_with_mapping(spark, client):
schema = StructType(
[
StructField("id", IntegerType()),
StructField("datetime", TimestampType()),
StructField("created_datetime", TimestampType()),
StructField("total_completed_bookings", IntegerType()),
]
)
df_data = [
(
8001,
datetime(year=2020, month=9, day=1),
datetime(year=2020, month=9, day=1),
100,
),
(
8001,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2),
150,
),
(
8002,
datetime(year=2020, month=9, day=2),
datetime(year=2020, month=9, day=2),
200,
),
]
temp_dir, file_uri = create_temp_parquet_file(spark, "bookings", schema, df_data)

file_source = FileSource(
"datetime", "created_datetime", "parquet", file_uri, {"id": "driver_id"}
)
features = [Feature("total_completed_bookings", ValueType.INT32)]
max_age = Duration()
max_age.FromSeconds(86400)
feature_table = FeatureTable(
"bookings", ["driver_id"], features, batch_source=file_source, max_age=max_age
)
yield client.apply_feature_table(feature_table)
shutil.rmtree(temp_dir)


def test_historical_feature_retrieval_from_local_spark_session(
spark,
client,
Expand All @@ -227,12 +282,12 @@ def test_historical_feature_retrieval_from_local_spark_session(
temp_dir, file_uri = create_temp_parquet_file(
spark, "customer_driver_pair", schema, df_data
)
entity_source = FileSource(
customer_driver_pairs_source = FileSource(
"event_timestamp", "created_timestamp", "parquet", file_uri
)
joined_df = client.get_historical_features_df(
["transactions:total_transactions", "bookings:total_completed_bookings"],
entity_source,
customer_driver_pairs_source,
)
expected_joined_df_schema = StructType(
[
Expand All @@ -257,3 +312,44 @@ def test_historical_feature_retrieval_from_local_spark_session(
)
assert_dataframe_equal(joined_df, expected_joined_df)
shutil.rmtree(temp_dir)


def test_historical_feature_retrieval_with_field_mappings_from_local_spark_session(
spark, client, driver_entity, bookings_feature_table_with_mapping,
):
schema = StructType(
[
StructField("driver_id", IntegerType()),
StructField("event_timestamp", TimestampType()),
]
)
df_data = [
(8001, datetime(year=2020, month=9, day=1)),
(8001, datetime(year=2020, month=9, day=2)),
(8002, datetime(year=2020, month=9, day=1)),
]
temp_dir, file_uri = create_temp_parquet_file(spark, "drivers", schema, df_data)
entity_source = FileSource(
"event_timestamp", "created_timestamp", "parquet", file_uri
)
joined_df = client.get_historical_features_df(
["bookings:total_completed_bookings"], entity_source,
)
expected_joined_df_schema = StructType(
[
StructField("driver_id", IntegerType()),
StructField("event_timestamp", TimestampType()),
StructField("bookings__total_completed_bookings", IntegerType()),
]
)
expected_joined_df_data = [
(8001, datetime(year=2020, month=9, day=1), 100),
(8001, datetime(year=2020, month=9, day=2), 150),
(8002, datetime(year=2020, month=9, day=1), None),
]
expected_joined_df = spark.createDataFrame(
spark.sparkContext.parallelize(expected_joined_df_data),
expected_joined_df_schema,
)
assert_dataframe_equal(joined_df, expected_joined_df)
shutil.rmtree(temp_dir)

0 comments on commit a739121

Please sign in to comment.