diff --git a/go/internal/feast/featurestore.go b/go/internal/feast/featurestore.go index b0fc987fb4..ad1f94a4ba 100644 --- a/go/internal/feast/featurestore.go +++ b/go/internal/feast/featurestore.go @@ -224,6 +224,14 @@ func (fs *FeatureStore) listAllViews() (map[string]*model.FeatureView, map[strin fvs[featureView.Base.Name] = featureView } + streamFeatureViews, err := fs.ListStreamFeatureViews() + if err != nil { + return nil, nil, err + } + for _, streamFeatureView := range streamFeatureViews { + fvs[streamFeatureView.Base.Name] = streamFeatureView + } + onDemandFeatureViews, err := fs.registry.ListOnDemandFeatureViews(fs.config.Project) if err != nil { return nil, nil, err @@ -242,6 +250,14 @@ func (fs *FeatureStore) ListFeatureViews() ([]*model.FeatureView, error) { return featureViews, nil } +func (fs *FeatureStore) ListStreamFeatureViews() ([]*model.FeatureView, error) { + streamFeatureViews, err := fs.registry.ListStreamFeatureViews(fs.config.Project) + if err != nil { + return streamFeatureViews, err + } + return streamFeatureViews, nil +} + func (fs *FeatureStore) ListEntities(hideDummyEntity bool) ([]*model.Entity, error) { allEntities, err := fs.registry.ListEntities(fs.config.Project) diff --git a/go/internal/feast/model/featureview.go b/go/internal/feast/model/featureview.go index ceb3736f99..b6fde78658 100644 --- a/go/internal/feast/model/featureview.go +++ b/go/internal/feast/model/featureview.go @@ -24,7 +24,24 @@ type FeatureView struct { func NewFeatureViewFromProto(proto *core.FeatureView) *FeatureView { featureView := &FeatureView{Base: NewBaseFeatureView(proto.Spec.Name, proto.Spec.Features), - Ttl: &(*proto.Spec.Ttl), + Ttl: proto.Spec.Ttl, + } + if len(proto.Spec.Entities) == 0 { + featureView.EntityNames = []string{DUMMY_ENTITY_NAME} + } else { + featureView.EntityNames = proto.Spec.Entities + } + entityColumns := make([]*Field, len(proto.Spec.EntityColumns)) + for i, entityColumn := range proto.Spec.EntityColumns { + entityColumns[i] = NewFieldFromProto(entityColumn) + } + featureView.EntityColumns = entityColumns + return featureView +} + +func NewFeatureViewFromStreamFeatureViewProto(proto *core.StreamFeatureView) *FeatureView { + featureView := &FeatureView{Base: NewBaseFeatureView(proto.Spec.Name, proto.Spec.Features), + Ttl: proto.Spec.Ttl, } if len(proto.Spec.Entities) == 0 { featureView.EntityNames = []string{DUMMY_ENTITY_NAME} diff --git a/go/internal/feast/registry/registry.go b/go/internal/feast/registry/registry.go index 38cf167a9f..c67a50a5a6 100644 --- a/go/internal/feast/registry/registry.go +++ b/go/internal/feast/registry/registry.go @@ -30,6 +30,7 @@ type Registry struct { cachedFeatureServices map[string]map[string]*core.FeatureService cachedEntities map[string]map[string]*core.Entity cachedFeatureViews map[string]map[string]*core.FeatureView + cachedStreamFeatureViews map[string]map[string]*core.StreamFeatureView cachedOnDemandFeatureViews map[string]map[string]*core.OnDemandFeatureView cachedRegistry *core.Registry cachedRegistryProtoLastUpdated time.Time @@ -106,10 +107,12 @@ func (r *Registry) load(registry *core.Registry) { r.cachedFeatureServices = make(map[string]map[string]*core.FeatureService) r.cachedEntities = make(map[string]map[string]*core.Entity) r.cachedFeatureViews = make(map[string]map[string]*core.FeatureView) + r.cachedStreamFeatureViews = make(map[string]map[string]*core.StreamFeatureView) r.cachedOnDemandFeatureViews = make(map[string]map[string]*core.OnDemandFeatureView) r.loadEntities(registry) r.loadFeatureServices(registry) r.loadFeatureViews(registry) + r.loadStreamFeatureViews(registry) r.loadOnDemandFeatureViews(registry) r.cachedRegistryProtoLastUpdated = time.Now() } @@ -144,6 +147,16 @@ func (r *Registry) loadFeatureViews(registry *core.Registry) { } } +func (r *Registry) loadStreamFeatureViews(registry *core.Registry) { + streamFeatureViews := registry.StreamFeatureViews + for _, streamFeatureView := range streamFeatureViews { + if _, ok := r.cachedStreamFeatureViews[streamFeatureView.Spec.Project]; !ok { + r.cachedStreamFeatureViews[streamFeatureView.Spec.Project] = make(map[string]*core.StreamFeatureView) + } + r.cachedStreamFeatureViews[streamFeatureView.Spec.Project][streamFeatureView.Spec.Name] = streamFeatureView + } +} + func (r *Registry) loadOnDemandFeatureViews(registry *core.Registry) { onDemandFeatureViews := registry.OnDemandFeatureViews for _, onDemandFeatureView := range onDemandFeatureViews { @@ -193,7 +206,26 @@ func (r *Registry) ListFeatureViews(project string) ([]*model.FeatureView, error } /* - Look up Feature Views inside project + Look up Stream Feature Views inside project + Returns empty list if project not found +*/ + +func (r *Registry) ListStreamFeatureViews(project string) ([]*model.FeatureView, error) { + if cachedStreamFeatureViews, ok := r.cachedStreamFeatureViews[project]; !ok { + return []*model.FeatureView{}, nil + } else { + streamFeatureViews := make([]*model.FeatureView, len(cachedStreamFeatureViews)) + index := 0 + for _, streamFeatureViewProto := range cachedStreamFeatureViews { + streamFeatureViews[index] = model.NewFeatureViewFromStreamFeatureViewProto(streamFeatureViewProto) + index += 1 + } + return streamFeatureViews, nil + } +} + +/* + Look up Feature Services inside project Returns empty list if project not found */ @@ -254,6 +286,18 @@ func (r *Registry) GetFeatureView(project, featureViewName string) (*model.Featu } } +func (r *Registry) GetStreamFeatureView(project, streamFeatureViewName string) (*model.FeatureView, error) { + if cachedStreamFeatureViews, ok := r.cachedStreamFeatureViews[project]; !ok { + return nil, fmt.Errorf("no cached stream feature views found for project %s", project) + } else { + if streamFeatureViewProto, ok := cachedStreamFeatureViews[streamFeatureViewName]; !ok { + return nil, fmt.Errorf("no cached stream feature view %s found for project %s", streamFeatureViewName, project) + } else { + return model.NewFeatureViewFromStreamFeatureViewProto(streamFeatureViewProto), nil + } + } +} + func (r *Registry) GetFeatureService(project, featureServiceName string) (*model.FeatureService, error) { if cachedFeatureServices, ok := r.cachedFeatureServices[project]; !ok { return nil, fmt.Errorf("no cached feature services found for project %s", project) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 7a5a8299eb..31fcbf7d42 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -265,6 +265,19 @@ def _list_feature_views( feature_views.append(fv) return feature_views + def _list_stream_feature_views( + self, allow_cache: bool = False, hide_dummy_entity: bool = True, + ) -> List[StreamFeatureView]: + stream_feature_views = [] + for sfv in self._registry.list_stream_feature_views( + self.project, allow_cache=allow_cache + ): + if hide_dummy_entity and sfv.entities[0] == DUMMY_ENTITY_NAME: + sfv.entities = [] + sfv.entity_columns = [] + stream_feature_views.append(sfv) + return stream_feature_views + @log_exceptions_and_usage def list_on_demand_feature_views( self, allow_cache: bool = False @@ -289,9 +302,7 @@ def list_stream_feature_views( Returns: A list of stream feature views. """ - return self._registry.list_stream_feature_views( - self.project, allow_cache=allow_cache - ) + return self._list_stream_feature_views(allow_cache) @log_exceptions_and_usage def list_data_sources(self, allow_cache: bool = False) -> List[DataSource]: @@ -558,6 +569,9 @@ def _make_inferences( update_feature_views_with_inferred_features_and_entities( views_to_update, entities + entities_to_update, self.config ) + update_feature_views_with_inferred_features_and_entities( + sfvs_to_update, entities + entities_to_update, self.config + ) # TODO(kevjumba): Update schema inferrence for sfv in sfvs_to_update: if not sfv.schema: @@ -574,6 +588,53 @@ def _make_inferences( for feature_service in feature_services_to_update: feature_service.infer_features(fvs_to_update=fvs_to_update_map) + def _get_feature_views_to_materialize( + self, feature_views: Optional[List[str]], + ) -> List[FeatureView]: + """ + Returns the list of feature views that should be materialized. + + If no feature views are specified, all feature views will be returned. + + Args: + feature_views: List of names of feature views to materialize. + + Raises: + FeatureViewNotFoundException: One of the specified feature views could not be found. + ValueError: One of the specified feature views is not configured for materialization. + """ + feature_views_to_materialize: List[FeatureView] = [] + + if feature_views is None: + feature_views_to_materialize = self._list_feature_views( + hide_dummy_entity=False + ) + feature_views_to_materialize = [ + fv for fv in feature_views_to_materialize if fv.online + ] + stream_feature_views_to_materialize = self._list_stream_feature_views( + hide_dummy_entity=False + ) + feature_views_to_materialize += [ + sfv for sfv in stream_feature_views_to_materialize if sfv.online + ] + else: + for name in feature_views: + try: + feature_view = self._get_feature_view(name, hide_dummy_entity=False) + except FeatureViewNotFoundException: + feature_view = self._get_stream_feature_view( + name, hide_dummy_entity=False + ) + + if not feature_view.online: + raise ValueError( + f"FeatureView {feature_view.name} is not configured to be served online." + ) + feature_views_to_materialize.append(feature_view) + + return feature_views_to_materialize + @log_exceptions_and_usage def _plan( self, desired_repo_contents: RepoContents @@ -873,8 +934,8 @@ def apply( self._get_provider().update_infra( project=self.project, - tables_to_delete=views_to_delete if not partial else [], - tables_to_keep=views_to_update, + tables_to_delete=views_to_delete + sfvs_to_delete if not partial else [], + tables_to_keep=views_to_update + sfvs_to_update, entities_to_delete=entities_to_delete if not partial else [], entities_to_keep=entities_to_update, partial=partial, @@ -1151,23 +1212,9 @@ def materialize_incremental( ... """ - feature_views_to_materialize: List[FeatureView] = [] - if feature_views is None: - feature_views_to_materialize = self._list_feature_views( - hide_dummy_entity=False - ) - feature_views_to_materialize = [ - fv for fv in feature_views_to_materialize if fv.online - ] - else: - for name in feature_views: - feature_view = self._get_feature_view(name, hide_dummy_entity=False) - if not feature_view.online: - raise ValueError( - f"FeatureView {feature_view.name} is not configured to be served online." - ) - feature_views_to_materialize.append(feature_view) - + feature_views_to_materialize = self._get_feature_views_to_materialize( + feature_views + ) _print_materialization_log( None, end_date, @@ -1258,23 +1305,9 @@ def materialize( f"The given start_date {start_date} is greater than the given end_date {end_date}." ) - feature_views_to_materialize: List[FeatureView] = [] - if feature_views is None: - feature_views_to_materialize = self._list_feature_views( - hide_dummy_entity=False - ) - feature_views_to_materialize = [ - fv for fv in feature_views_to_materialize if fv.online - ] - else: - for name in feature_views: - feature_view = self._get_feature_view(name, hide_dummy_entity=False) - if not feature_view.online: - raise ValueError( - f"FeatureView {feature_view.name} is not configured to be served online." - ) - feature_views_to_materialize.append(feature_view) - + feature_views_to_materialize = self._get_feature_views_to_materialize( + feature_views + ) _print_materialization_log( start_date, end_date, @@ -1327,6 +1360,7 @@ def push( from feast.data_source import PushSource all_fvs = self.list_feature_views(allow_cache=allow_registry_cache) + all_fvs += self.list_stream_feature_views(allow_cache=allow_registry_cache) fvs_with_push_sources = { fv diff --git a/sdk/python/feast/inference.py b/sdk/python/feast/inference.py index bf9af26b82..011a3b99b2 100644 --- a/sdk/python/feast/inference.py +++ b/sdk/python/feast/inference.py @@ -99,6 +99,10 @@ def update_feature_views_with_inferred_features_and_entities( other columns except designated timestamp columns are considered to be feature columns. If the feature view already has features, feature inference is skipped. + Note that this inference logic currently does not take any transformations (either a UDF or + aggregations) into account. For example, even if a stream feature view has a transformation, + this method assumes that the batch source contains transformed data with the correct final schema. + Args: fvs: The feature views to be updated. entities: A list containing entities associated with the feature views. diff --git a/sdk/python/feast/registry.py b/sdk/python/feast/registry.py index c8b00befc6..c721bd648a 100644 --- a/sdk/python/feast/registry.py +++ b/sdk/python/feast/registry.py @@ -1267,6 +1267,30 @@ def apply_materialization( self.commit() return + for idx, existing_stream_feature_view_proto in enumerate( + self.cached_registry_proto.stream_feature_views + ): + if ( + existing_stream_feature_view_proto.spec.name == feature_view.name + and existing_stream_feature_view_proto.spec.project == project + ): + existing_stream_feature_view = StreamFeatureView.from_proto( + existing_stream_feature_view_proto + ) + existing_stream_feature_view.materialization_intervals.append( + (start_date, end_date) + ) + existing_stream_feature_view.last_updated_timestamp = datetime.utcnow() + stream_feature_view_proto = existing_stream_feature_view.to_proto() + stream_feature_view_proto.spec.project = project + del self.cached_registry_proto.stream_feature_views[idx] + self.cached_registry_proto.stream_feature_views.append( + stream_feature_view_proto + ) + if commit: + self.commit() + return + raise FeatureViewNotFoundException(feature_view.name, project) def list_feature_views( diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 3bd525596b..2122ff1d55 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -1,9 +1,9 @@ import copy import functools import warnings -from datetime import timedelta +from datetime import datetime, timedelta from types import MethodType -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import dill from google.protobuf.duration_pb2 import Duration @@ -42,26 +42,42 @@ class StreamFeatureView(FeatureView): schemas with Feast. Attributes: - name: str. The unique name of the stream feature view. - entities: Union[List[Entity], List[str]]. List of entities or entity join keys. - ttl: timedelta. The amount of time this group of features lives. A ttl of 0 indicates that + name: The unique name of the stream feature view. + entities: List of entities or entity join keys. + ttl: The amount of time this group of features lives. A ttl of 0 indicates that this group of features lives forever. Note that large ttl's or a ttl of 0 can result in extremely computationally intensive queries. - tags: Dict[str, str]. A dictionary of key-value pairs to store arbitrary metadata. - online: bool. Defines whether this stream feature view is used in online feature retrieval. - description: str. A human-readable description. + schema: The schema of the feature view, including feature, timestamp, and entity + columns. If not specified, can be inferred from the underlying data source. + source: DataSource. The stream source of data where this group of features is stored. + aggregations: List of aggregations registered with the stream feature view. + mode: The mode of execution. + timestamp_field: Must be specified if aggregations are specified. Defines the timestamp column on which to aggregate windows. + online: Defines whether this stream feature view is used in online feature retrieval. + description: A human-readable description. + tags: A dictionary of key-value pairs to store arbitrary metadata. owner: The owner of the on demand feature view, typically the email of the primary maintainer. - schema: List[Field] The schema of the feature view, including feature, timestamp, and entity - columns. If not specified, can be inferred from the underlying data source. - source: DataSource. The stream source of data where this group of features - is stored. - aggregations (optional): List[Aggregation]. List of aggregations registered with the stream feature view. - mode(optional): str. The mode of execution. - timestamp_field (optional): Must be specified if aggregations are specified. Defines the timestamp column on which to aggregate windows. - udf (optional): MethodType The user defined transformation function. This transformation function should have all of the corresponding imports imported within the function. + udf: The user defined transformation function. This transformation function should have all of the corresponding imports imported within the function. """ + name: str + entities: List[str] + ttl: Optional[timedelta] + source: DataSource + schema: List[Field] + entity_columns: List[Field] + features: List[Field] + online: bool + description: str + tags: Dict[str, str] + owner: str + aggregations: List[Aggregation] + mode: str + timestamp_field: str + materialization_intervals: List[Tuple[datetime, datetime]] + udf: Optional[MethodType] + def __init__( self, *, @@ -222,7 +238,7 @@ def from_proto(cls, sfv_proto): if sfv_proto.spec.HasField("user_defined_function") else None ) - sfv_feature_view = cls( + stream_feature_view = cls( name=sfv_proto.spec.name, description=sfv_proto.spec.description, tags=dict(sfv_proto.spec.tags), @@ -247,23 +263,27 @@ def from_proto(cls, sfv_proto): ) if batch_source: - sfv_feature_view.batch_source = batch_source + stream_feature_view.batch_source = batch_source if stream_source: - sfv_feature_view.stream_source = stream_source + stream_feature_view.stream_source = stream_source - sfv_feature_view.entities = list(sfv_proto.spec.entities) + stream_feature_view.entities = list(sfv_proto.spec.entities) - sfv_feature_view.features = [ + stream_feature_view.features = [ Field.from_proto(field_proto) for field_proto in sfv_proto.spec.features ] + stream_feature_view.entity_columns = [ + Field.from_proto(field_proto) + for field_proto in sfv_proto.spec.entity_columns + ] if sfv_proto.meta.HasField("created_timestamp"): - sfv_feature_view.created_timestamp = ( + stream_feature_view.created_timestamp = ( sfv_proto.meta.created_timestamp.ToDatetime() ) if sfv_proto.meta.HasField("last_updated_timestamp"): - sfv_feature_view.last_updated_timestamp = ( + stream_feature_view.last_updated_timestamp = ( sfv_proto.meta.last_updated_timestamp.ToDatetime() ) @@ -275,7 +295,7 @@ def from_proto(cls, sfv_proto): ) ) - return sfv_feature_view + return stream_feature_view def __copy__(self): fv = StreamFeatureView( @@ -290,7 +310,7 @@ def __copy__(self): aggregations=self.aggregations, mode=self.mode, timestamp_field=self.timestamp_field, - sources=self.sources, + source=self.source, udf=self.udf, ) fv.projection = copy.copy(self.projection) diff --git a/sdk/python/tests/integration/feature_repos/universal/feature_views.py b/sdk/python/tests/integration/feature_repos/universal/feature_views.py index b93ad987fa..3fee0b7001 100644 --- a/sdk/python/tests/integration/feature_repos/universal/feature_views.py +++ b/sdk/python/tests/integration/feature_repos/universal/feature_views.py @@ -11,6 +11,7 @@ Field, OnDemandFeatureView, PushSource, + StreamFeatureView, ValueType, ) from feast.data_source import DataSource, RequestSource @@ -297,7 +298,7 @@ def create_pushable_feature_view(batch_source: DataSource): push_source = PushSource( name="location_stats_push_source", batch_source=batch_source, ) - return FeatureView( + return StreamFeatureView( name="pushable_location_stats", entities=[location()], schema=[ diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index d05045e295..b01448e7cc 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -441,6 +441,82 @@ def test_online_retrieval_with_event_timestamps( ) +@pytest.mark.integration +@pytest.mark.universal_online_stores +@pytest.mark.goserver +@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) +def test_stream_feature_view_online_retrieval( + environment, universal_data_sources, feature_server_endpoint, full_feature_names +): + """ + Tests materialization and online retrieval for stream feature views. + + This test is separate from test_online_retrieval since combining feature views and + stream feature views into a single test resulted in test flakiness. This is tech + debt that should be resolved soon. + """ + # Set up feature store. + fs = environment.feature_store + entities, datasets, data_sources = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + pushable_feature_view = feature_views.pushed_locations + fs.apply([location(), pushable_feature_view]) + + # Materialize. + fs.materialize( + environment.start_date - timedelta(days=1), + environment.end_date + timedelta(days=1), + ) + + # Get online features by randomly sampling 10 entities that exist in the batch source. + sample_locations = datasets.location_df.sample(10)["location_id"] + entity_rows = [ + {"location_id": sample_location} for sample_location in sample_locations + ] + + feature_refs = [ + "pushable_location_stats:temperature", + ] + unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs if ":" in f] + + online_features_dict = get_online_features_dict( + environment=environment, + endpoint=feature_server_endpoint, + features=feature_refs, + entity_rows=entity_rows, + full_feature_names=full_feature_names, + ) + + # Check that the response has the expected set of keys. + keys = set(online_features_dict.keys()) + expected_keys = set( + f.replace(":", "__") if full_feature_names else f.split(":")[-1] + for f in feature_refs + ) | {"location_id"} + assert ( + keys == expected_keys + ), f"Response keys are different from expected: {keys - expected_keys} (extra) and {expected_keys - keys} (missing)" + + # Check that the feature values match. + tc = unittest.TestCase() + for i, entity_row in enumerate(entity_rows): + df_features = get_latest_feature_values_from_location_df( + entity_row, datasets.location_df + ) + + assert df_features["location_id"] == online_features_dict["location_id"][i] + for unprefixed_feature_ref in unprefixed_feature_refs: + tc.assertAlmostEqual( + df_features[unprefixed_feature_ref], + online_features_dict[ + response_feature_name( + unprefixed_feature_ref, feature_refs, full_feature_names + ) + ][i], + delta=0.0001, + ) + + @pytest.mark.integration @pytest.mark.universal_online_stores @pytest.mark.goserver @@ -859,6 +935,10 @@ def get_latest_feature_values_for_location_df(entity_row, origin_df, destination } +def get_latest_feature_values_from_location_df(entity_row, location_df): + return get_latest_row(entity_row, location_df, "location_id", "location_id") + + def assert_feature_service_correctness( environment, endpoint, diff --git a/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py b/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py index 29cd2f1c26..adeb15317e 100644 --- a/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py +++ b/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py @@ -2,139 +2,147 @@ import pytest -from feast import Entity, Field, FileSource from feast.aggregation import Aggregation from feast.data_format import AvroFormat from feast.data_source import KafkaSource +from feast.entity import Entity +from feast.field import Field from feast.stream_feature_view import stream_feature_view from feast.types import Float32 +from tests.utils.cli_utils import CliRunner, get_example_repo +from tests.utils.data_source_utils import prep_file_source @pytest.mark.integration -def test_apply_stream_feature_view(environment) -> None: +def test_apply_stream_feature_view(simple_dataset_1) -> None: """ Test apply of StreamFeatureView. """ - fs = environment.feature_store - - # Create Feature Views - entity = Entity(name="driver_entity", join_keys=["test_key"]) - - stream_source = KafkaSource( - name="kafka", - timestamp_field="event_timestamp", - bootstrap_servers="", - message_format=AvroFormat(""), - topic="topic", - batch_source=FileSource(path="test_path", timestamp_field="event_timestamp"), - watermark=timedelta(days=1), - ) - - @stream_feature_view( - entities=[entity], - ttl=timedelta(days=30), - owner="test@example.com", - online=True, - schema=[Field(name="dummy_field", dtype=Float32)], - description="desc", - aggregations=[ - Aggregation( - column="dummy_field", function="max", time_window=timedelta(days=1), - ), - Aggregation( - column="dummy_field2", function="count", time_window=timedelta(days=24), - ), - ], - timestamp_field="event_timestamp", - mode="spark", - source=stream_source, - tags={}, - ) - def simple_sfv(df): - return df - - fs.apply([entity, simple_sfv]) - stream_feature_views = fs.list_stream_feature_views() - assert len(stream_feature_views) == 1 - assert stream_feature_views[0] == simple_sfv - - entities = fs.list_entities() - assert len(entities) == 1 - assert entities[0] == entity - - features = fs.get_online_features( - features=["simple_sfv:dummy_field"], entity_rows=[{"test_key": 1001}], - ).to_dict(include_event_timestamps=True) - - assert "test_key" in features - assert features["test_key"] == [1001] - assert "dummy_field" in features - assert features["dummy_field"] == [None] + runner = CliRunner() + with runner.local_repo( + get_example_repo("example_feature_repo_1.py"), "bigquery" + ) as fs, prep_file_source( + df=simple_dataset_1, timestamp_field="ts_1" + ) as file_source: + entity = Entity(name="driver_entity", join_keys=["test_key"]) + + stream_source = KafkaSource( + name="kafka", + timestamp_field="event_timestamp", + bootstrap_servers="", + message_format=AvroFormat(""), + topic="topic", + batch_source=file_source, + watermark=timedelta(days=1), + ) + + @stream_feature_view( + entities=[entity], + ttl=timedelta(days=30), + owner="test@example.com", + online=True, + schema=[Field(name="dummy_field", dtype=Float32)], + description="desc", + aggregations=[ + Aggregation( + column="dummy_field", function="max", time_window=timedelta(days=1), + ), + Aggregation( + column="dummy_field2", + function="count", + time_window=timedelta(days=24), + ), + ], + timestamp_field="event_timestamp", + mode="spark", + source=stream_source, + tags={}, + ) + def simple_sfv(df): + return df + + fs.apply([entity, simple_sfv]) + + stream_feature_views = fs.list_stream_feature_views() + assert len(stream_feature_views) == 1 + assert stream_feature_views[0] == simple_sfv + + features = fs.get_online_features( + features=["simple_sfv:dummy_field"], entity_rows=[{"test_key": 1001}], + ).to_dict(include_event_timestamps=True) + + assert "test_key" in features + assert features["test_key"] == [1001] + assert "dummy_field" in features + assert features["dummy_field"] == [None] @pytest.mark.integration -def test_stream_feature_view_udf(environment) -> None: +def test_stream_feature_view_udf(simple_dataset_1) -> None: """ Test apply of StreamFeatureView udfs are serialized correctly and usable. """ - fs = environment.feature_store - - # Create Feature Views - entity = Entity(name="driver_entity", join_keys=["test_key"]) - - stream_source = KafkaSource( - name="kafka", - timestamp_field="event_timestamp", - bootstrap_servers="", - message_format=AvroFormat(""), - topic="topic", - batch_source=FileSource(path="test_path", timestamp_field="event_timestamp"), - watermark=timedelta(days=1), - ) - - @stream_feature_view( - entities=[entity], - ttl=timedelta(days=30), - owner="test@example.com", - online=True, - schema=[Field(name="dummy_field", dtype=Float32)], - description="desc", - aggregations=[ - Aggregation( - column="dummy_field", function="max", time_window=timedelta(days=1), - ), - Aggregation( - column="dummy_field2", function="count", time_window=timedelta(days=24), - ), - ], - timestamp_field="event_timestamp", - mode="spark", - source=stream_source, - tags={}, - ) - def pandas_view(pandas_df): - import pandas as pd - - assert type(pandas_df) == pd.DataFrame - df = pandas_df.transform(lambda x: x + 10, axis=1) - df.insert(2, "C", [20.2, 230.0, 34.0], True) - return df + runner = CliRunner() + with runner.local_repo( + get_example_repo("example_feature_repo_1.py"), "bigquery" + ) as fs, prep_file_source( + df=simple_dataset_1, timestamp_field="ts_1" + ) as file_source: + entity = Entity(name="driver_entity", join_keys=["test_key"]) + + stream_source = KafkaSource( + name="kafka", + timestamp_field="event_timestamp", + bootstrap_servers="", + message_format=AvroFormat(""), + topic="topic", + batch_source=file_source, + watermark=timedelta(days=1), + ) + + @stream_feature_view( + entities=[entity], + ttl=timedelta(days=30), + owner="test@example.com", + online=True, + schema=[Field(name="dummy_field", dtype=Float32)], + description="desc", + aggregations=[ + Aggregation( + column="dummy_field", function="max", time_window=timedelta(days=1), + ), + Aggregation( + column="dummy_field2", + function="count", + time_window=timedelta(days=24), + ), + ], + timestamp_field="event_timestamp", + mode="spark", + source=stream_source, + tags={}, + ) + def pandas_view(pandas_df): + import pandas as pd + + assert type(pandas_df) == pd.DataFrame + df = pandas_df.transform(lambda x: x + 10, axis=1) + df.insert(2, "C", [20.2, 230.0, 34.0], True) + return df - import pandas as pd - - df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]}) + import pandas as pd - fs.apply([entity, pandas_view]) - stream_feature_views = fs.list_stream_feature_views() - assert len(stream_feature_views) == 1 - assert stream_feature_views[0].name == "pandas_view" - assert stream_feature_views[0] == pandas_view + fs.apply([entity, pandas_view]) - sfv = stream_feature_views[0] + stream_feature_views = fs.list_stream_feature_views() + assert len(stream_feature_views) == 1 + assert stream_feature_views[0] == pandas_view - new_df = sfv.udf(df) + sfv = stream_feature_views[0] - expected_df = pd.DataFrame( - {"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]} - ) - assert new_df.equals(expected_df) + df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]}) + new_df = sfv.udf(df) + expected_df = pd.DataFrame( + {"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]} + ) + assert new_df.equals(expected_df) diff --git a/sdk/python/tests/unit/test_feature_view.py b/sdk/python/tests/unit/test_feature_view.py deleted file mode 100644 index 1ef36081ec..0000000000 --- a/sdk/python/tests/unit/test_feature_view.py +++ /dev/null @@ -1,68 +0,0 @@ -# Copyright 2022 The Feast Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from feast.feature_view import FeatureView -from feast.field import Field -from feast.infra.offline_stores.file_source import FileSource -from feast.types import Float32 - - -def test_hash(): - file_source = FileSource(name="my-file-source", path="test.parquet") - feature_view_1 = FeatureView( - name="my-feature-view", - entities=[], - schema=[ - Field(name="feature1", dtype=Float32), - Field(name="feature2", dtype=Float32), - ], - source=file_source, - ) - feature_view_2 = FeatureView( - name="my-feature-view", - entities=[], - schema=[ - Field(name="feature1", dtype=Float32), - Field(name="feature2", dtype=Float32), - ], - source=file_source, - ) - feature_view_3 = FeatureView( - name="my-feature-view", - entities=[], - schema=[Field(name="feature1", dtype=Float32)], - source=file_source, - ) - feature_view_4 = FeatureView( - name="my-feature-view", - entities=[], - schema=[Field(name="feature1", dtype=Float32)], - source=file_source, - description="test", - ) - - s1 = {feature_view_1, feature_view_2} - assert len(s1) == 1 - - s2 = {feature_view_1, feature_view_3} - assert len(s2) == 2 - - s3 = {feature_view_3, feature_view_4} - assert len(s3) == 2 - - s4 = {feature_view_1, feature_view_2, feature_view_3, feature_view_4} - assert len(s4) == 3 - - -# TODO(felixwang9817): Add tests for proto conversion. -# TODO(felixwang9817): Add tests for field mapping logic. diff --git a/sdk/python/tests/unit/test_feature_views.py b/sdk/python/tests/unit/test_feature_views.py index 64b23edd2c..a1d134a2f0 100644 --- a/sdk/python/tests/unit/test_feature_views.py +++ b/sdk/python/tests/unit/test_feature_views.py @@ -7,6 +7,7 @@ from feast.data_format import AvroFormat from feast.data_source import KafkaSource, PushSource from feast.entity import Entity +from feast.feature_view import FeatureView from feast.field import Field from feast.infra.offline_stores.file_source import FileSource from feast.stream_feature_view import StreamFeatureView, stream_feature_view @@ -201,3 +202,54 @@ def test_stream_feature_view_initialization_with_optional_fields_omitted(): new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto) assert new_sfv == sfv + + +def test_hash(): + file_source = FileSource(name="my-file-source", path="test.parquet") + feature_view_1 = FeatureView( + name="my-feature-view", + entities=[], + schema=[ + Field(name="feature1", dtype=Float32), + Field(name="feature2", dtype=Float32), + ], + source=file_source, + ) + feature_view_2 = FeatureView( + name="my-feature-view", + entities=[], + schema=[ + Field(name="feature1", dtype=Float32), + Field(name="feature2", dtype=Float32), + ], + source=file_source, + ) + feature_view_3 = FeatureView( + name="my-feature-view", + entities=[], + schema=[Field(name="feature1", dtype=Float32)], + source=file_source, + ) + feature_view_4 = FeatureView( + name="my-feature-view", + entities=[], + schema=[Field(name="feature1", dtype=Float32)], + source=file_source, + description="test", + ) + + s1 = {feature_view_1, feature_view_2} + assert len(s1) == 1 + + s2 = {feature_view_1, feature_view_3} + assert len(s2) == 2 + + s3 = {feature_view_3, feature_view_4} + assert len(s3) == 2 + + s4 = {feature_view_1, feature_view_2, feature_view_3, feature_view_4} + assert len(s4) == 3 + + +# TODO(felixwang9817): Add tests for proto conversion. +# TODO(felixwang9817): Add tests for field mapping logic.