Skip to content

Commit

Permalink
Some cleanups
Browse files Browse the repository at this point in the history
Signed-off-by: Terence <terencelimxp@gmail.com>
  • Loading branch information
terryyylim committed Oct 15, 2020
1 parent bd88ec0 commit 2c266af
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 108 deletions.
45 changes: 2 additions & 43 deletions sdk/python/feast/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import multiprocessing
import shutil
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Union

import grpc
import pandas as pd
Expand Down Expand Up @@ -69,14 +69,12 @@
_write_non_partitioned_table_from_source,
_write_partitioned_table_from_source,
)
from feast.online_response import OnlineResponse
from feast.online_response import OnlineResponse, _infer_online_entity_rows
from feast.serving.ServingService_pb2 import (
GetFeastServingInfoRequest,
GetOnlineFeaturesRequestV2,
)
from feast.serving.ServingService_pb2_grpc import ServingServiceStub
from feast.type_map import _python_value_to_proto_value, python_type_to_feast_value_type
from feast.types.Value_pb2 import Value as Value

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -769,42 +767,3 @@ def get_online_features(

response = OnlineResponse(response)
return response


def _infer_online_entity_rows(
entity_rows: List[Dict[str, Any]]
) -> List[GetOnlineFeaturesRequestV2.EntityRow]:
"""
Builds a list of EntityRow protos from Python native type format passed by user.
Args:
entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.
Returns:
A list of EntityRow protos parsed from args.
"""

entity_rows_dicts = cast(List[Dict[str, Any]], entity_rows)
entity_row_list = []
entity_type_map = dict()

for entity in entity_rows_dicts:
fields = {}
for key, value in entity.items():
# Allow for feast.types.Value
if isinstance(value, Value):
proto_value = value
else:
# Infer the specific type for this row
current_dtype = python_type_to_feast_value_type(name=key, value=value)

if key not in entity_type_map:
entity_type_map[key] = current_dtype
else:
if current_dtype != entity_type_map[key]:
raise TypeError(
f"Input entity {key} has mixed types, {current_dtype} and {entity_type_map[key]}. That is not allowed. "
)
proto_value = _python_value_to_proto_value(current_dtype, value)
fields[key] = proto_value
entity_row_list.append(GetOnlineFeaturesRequestV2.EntityRow(fields=fields))
return entity_row_list
53 changes: 50 additions & 3 deletions sdk/python/feast/online_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, List
from typing import Any, Dict, List, cast

from feast.serving.ServingService_pb2 import GetOnlineFeaturesResponse
from feast.type_map import feast_value_type_to_python_type
from feast.serving.ServingService_pb2 import (
GetOnlineFeaturesRequestV2,
GetOnlineFeaturesResponse,
)
from feast.type_map import (
_python_value_to_proto_value,
feast_value_type_to_python_type,
python_type_to_feast_value_type,
)
from feast.types.Value_pb2 import Value as Value


class OnlineResponse:
Expand Down Expand Up @@ -52,3 +60,42 @@ def to_dict(self) -> Dict[str, Any]:
features_dict[feature].append(native_type_value)

return features_dict


def _infer_online_entity_rows(
entity_rows: List[Dict[str, Any]]
) -> List[GetOnlineFeaturesRequestV2.EntityRow]:
"""
Builds a list of EntityRow protos from Python native type format passed by user.
Args:
entity_rows: A list of dictionaries where each key-value is an entity-name, entity-value pair.
Returns:
A list of EntityRow protos parsed from args.
"""

entity_rows_dicts = cast(List[Dict[str, Any]], entity_rows)
entity_row_list = []
entity_type_map = dict()

for entity in entity_rows_dicts:
fields = {}
for key, value in entity.items():
# Allow for feast.types.Value
if isinstance(value, Value):
proto_value = value
else:
# Infer the specific type for this row
current_dtype = python_type_to_feast_value_type(name=key, value=value)

if key not in entity_type_map:
entity_type_map[key] = current_dtype
else:
if current_dtype != entity_type_map[key]:
raise TypeError(
f"Input entity {key} has mixed types, {current_dtype} and {entity_type_map[key]}. That is not allowed. "
)
proto_value = _python_value_to_proto_value(current_dtype, value)
fields[key] = proto_value
entity_row_list.append(GetOnlineFeaturesRequestV2.EntityRow(fields=fields))
return entity_row_list
121 changes: 59 additions & 62 deletions sdk/python/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,29 @@ def non_partitioned_df(self):
}
)

@pytest.fixture
def get_online_features_fields_statuses(self):
ROW_COUNT = 100
fields_statuses_tuple_list = []
for row_number in range(0, ROW_COUNT):
fields_statuses_tuple_list.append(
(
{
"driver_id": ValueProto.Value(int64_val=row_number),
"driver:age": ValueProto.Value(int64_val=1),
"driver:rating": ValueProto.Value(string_val="9"),
"driver:null_value": ValueProto.Value(),
},
{
"driver_id": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
"driver:age": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
"driver:rating": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
"driver:null_value": GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE,
},
)
)
return fields_statuses_tuple_list

@pytest.mark.parametrize(
"mocked_client",
[lazy_fixture("mock_client"), lazy_fixture("secure_mock_client")],
Expand Down Expand Up @@ -549,19 +572,15 @@ def test_ingest_csv(self, mocked_client, mocker, tmp_path):
"secure_mock_client_with_auth",
],
)
def test_get_online_features(self, mocked_client, auth_metadata, mocker):
def test_get_online_features(
self, mocked_client, auth_metadata, mocker, get_online_features_fields_statuses
):
ROW_COUNT = 100

mocked_client._serving_service_stub = Serving.ServingServiceStub(
grpc.insecure_channel("")
)

def int_val(x):
return ValueProto.Value(int64_val=x)

def string_val(x):
return ValueProto.Value(string_val=x)

request = GetOnlineFeaturesRequestV2(project="driver_project")
request.features.extend(
[
Expand All @@ -573,28 +592,18 @@ def string_val(x):

receive_response = GetOnlineFeaturesResponse()
entity_rows = []
for row_number in range(1, ROW_COUNT + 1):
for row_number in range(0, ROW_COUNT):
fields = get_online_features_fields_statuses[row_number][0]
statuses = get_online_features_fields_statuses[row_number][1]
request.entity_rows.append(
GetOnlineFeaturesRequestV2.EntityRow(
fields={"driver_id": int_val(row_number)}
fields={"driver_id": ValueProto.Value(int64_val=row_number)}
)
)
entity_rows.append({"driver_id": int_val(row_number)})
field_values = GetOnlineFeaturesResponse.FieldValues(
fields={
"driver_id": int_val(row_number),
"driver:age": int_val(1),
"driver:rating": string_val("9"),
"driver:null_value": ValueProto.Value(),
},
statuses={
"driver_id": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
"driver:age": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
"driver:rating": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
"driver:null_value": GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE,
},
entity_rows.append({"driver_id": ValueProto.Value(int64_val=row_number)})
receive_response.field_values.append(
GetOnlineFeaturesResponse.FieldValues(fields=fields, statuses=statuses)
)
receive_response.field_values.append(field_values)

mocker.patch.object(
mocked_client._serving_service_stub,
Expand All @@ -610,16 +619,16 @@ def string_val(x):
request, metadata=auth_metadata
)

got_fields = got_response.field_values[0].fields
got_statuses = got_response.field_values[0].statuses
got_fields = got_response.field_values[1].fields
got_statuses = got_response.field_values[1].statuses
assert (
got_fields["driver_id"] == int_val(1)
got_fields["driver_id"] == ValueProto.Value(int64_val=1)
and got_statuses["driver_id"]
== GetOnlineFeaturesResponse.FieldStatus.PRESENT
and got_fields["driver:age"] == int_val(1)
and got_fields["driver:age"] == ValueProto.Value(int64_val=1)
and got_statuses["driver:age"]
== GetOnlineFeaturesResponse.FieldStatus.PRESENT
and got_fields["driver:rating"] == string_val("9")
and got_fields["driver:rating"] == ValueProto.Value(string_val="9")
and got_statuses["driver:rating"]
== GetOnlineFeaturesResponse.FieldStatus.PRESENT
and got_fields["driver:null_value"] == ValueProto.Value()
Expand All @@ -643,20 +652,14 @@ def string_val(x):
],
)
def test_get_online_features_multi_entities(
self, mocked_client, auth_metadata, mocker
self, mocked_client, auth_metadata, mocker, get_online_features_fields_statuses
):
ROW_COUNT = 100

mocked_client._serving_service_stub = Serving.ServingServiceStub(
grpc.insecure_channel("")
)

def int_val(x):
return ValueProto.Value(int64_val=x)

def string_val(x):
return ValueProto.Value(string_val=x)

request = GetOnlineFeaturesRequestV2(project="driver_project")
request.features.extend(
[
Expand All @@ -668,35 +671,29 @@ def string_val(x):

receive_response = GetOnlineFeaturesResponse()
entity_rows = []
for row_number in range(1, ROW_COUNT + 1):
for row_number in range(0, ROW_COUNT):
fields = get_online_features_fields_statuses[row_number][0]
fields["driver_id2"] = ValueProto.Value(int64_val=1)
statuses = get_online_features_fields_statuses[row_number][1]
statuses["driver_id2"] = GetOnlineFeaturesResponse.FieldStatus.PRESENT

request.entity_rows.append(
GetOnlineFeaturesRequestV2.EntityRow(
fields={
"driver_id": int_val(row_number),
"driver_id2": int_val(row_number),
"driver_id": ValueProto.Value(int64_val=row_number),
"driver_id2": ValueProto.Value(int64_val=row_number),
}
)
)
entity_rows.append(
{"driver_id": int_val(row_number), "driver_id2": int_val(row_number)}
{
"driver_id": ValueProto.Value(int64_val=row_number),
"driver_id2": ValueProto.Value(int64_val=row_number),
}
)
field_values = GetOnlineFeaturesResponse.FieldValues(
fields={
"driver_id": int_val(row_number),
"driver_id2": int_val(row_number),
"driver:age": int_val(1),
"driver:rating": string_val("9"),
"driver:null_value": ValueProto.Value(),
},
statuses={
"driver_id": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
"driver_id2": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
"driver:age": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
"driver:rating": GetOnlineFeaturesResponse.FieldStatus.PRESENT,
"driver:null_value": GetOnlineFeaturesResponse.FieldStatus.NULL_VALUE,
},
receive_response.field_values.append(
GetOnlineFeaturesResponse.FieldValues(fields=fields, statuses=statuses)
)
receive_response.field_values.append(field_values)

mocker.patch.object(
mocked_client._serving_service_stub,
Expand All @@ -712,19 +709,19 @@ def string_val(x):
request, metadata=auth_metadata
)

got_fields = got_response.field_values[0].fields
got_statuses = got_response.field_values[0].statuses
got_fields = got_response.field_values[1].fields
got_statuses = got_response.field_values[1].statuses
assert (
got_fields["driver_id"] == int_val(1)
got_fields["driver_id"] == ValueProto.Value(int64_val=1)
and got_statuses["driver_id"]
== GetOnlineFeaturesResponse.FieldStatus.PRESENT
and got_fields["driver_id2"] == int_val(1)
and got_fields["driver_id2"] == ValueProto.Value(int64_val=1)
and got_statuses["driver_id2"]
== GetOnlineFeaturesResponse.FieldStatus.PRESENT
and got_fields["driver:age"] == int_val(1)
and got_fields["driver:age"] == ValueProto.Value(int64_val=1)
and got_statuses["driver:age"]
== GetOnlineFeaturesResponse.FieldStatus.PRESENT
and got_fields["driver:rating"] == string_val("9")
and got_fields["driver:rating"] == ValueProto.Value(string_val="9")
and got_statuses["driver:rating"]
== GetOnlineFeaturesResponse.FieldStatus.PRESENT
and got_fields["driver:null_value"] == ValueProto.Value()
Expand Down

0 comments on commit 2c266af

Please sign in to comment.