diff --git a/.readthedocs.yml b/.readthedocs.yml index dea27e20b3..75499aa5dd 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -7,6 +7,6 @@ formats: - pdf python: - version: 3.7 + version: "3.8" install: - requirements: sdk/python/docs/requirements.txt \ No newline at end of file diff --git a/java/common/src/main/java/feast/common/logging/entry/MessageAuditLogEntry.java b/java/common/src/main/java/feast/common/logging/entry/MessageAuditLogEntry.java index 6e5072f66e..8ad428a3a3 100644 --- a/java/common/src/main/java/feast/common/logging/entry/MessageAuditLogEntry.java +++ b/java/common/src/main/java/feast/common/logging/entry/MessageAuditLogEntry.java @@ -19,16 +19,13 @@ import com.google.auto.value.AutoValue; import com.google.gson.Gson; import com.google.gson.GsonBuilder; -import com.google.gson.JsonElement; import com.google.gson.JsonParser; -import com.google.gson.JsonSerializationContext; import com.google.gson.JsonSerializer; import com.google.protobuf.Empty; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; import com.google.protobuf.util.JsonFormat; import io.grpc.Status.Code; -import java.lang.reflect.Type; import java.util.UUID; /** MessageAuditLogEntry records the handling of a Protobuf message by a service call. */ @@ -103,20 +100,17 @@ public String toJSON() { new GsonBuilder() .registerTypeAdapter( Message.class, - new JsonSerializer() { - @Override - public JsonElement serialize( - Message message, Type type, JsonSerializationContext context) { - try { - String messageJSON = JsonFormat.printer().print(message); - return new JsonParser().parse(messageJSON); - } catch (InvalidProtocolBufferException e) { - - throw new RuntimeException( - "Unexpected exception converting Protobuf to JSON", e); - } - } - }) + (JsonSerializer) + (message, type, context) -> { + try { + String messageJSON = JsonFormat.printer().print(message); + return new JsonParser().parse(messageJSON); + } catch (InvalidProtocolBufferException e) { + + throw new RuntimeException( + "Unexpected exception converting Protobuf to JSON", e); + } + }) .create(); return gson.toJson(this); } diff --git a/java/common/src/main/java/feast/common/logging/interceptors/GrpcMessageInterceptor.java b/java/common/src/main/java/feast/common/logging/interceptors/GrpcMessageInterceptor.java index 661642a89a..e34fefd115 100644 --- a/java/common/src/main/java/feast/common/logging/interceptors/GrpcMessageInterceptor.java +++ b/java/common/src/main/java/feast/common/logging/interceptors/GrpcMessageInterceptor.java @@ -38,7 +38,7 @@ * GrpcMessageInterceptor assumes that all service calls are unary (ie single request/response). */ public class GrpcMessageInterceptor implements ServerInterceptor { - private LoggingProperties loggingProperties; + private final LoggingProperties loggingProperties; /** * Construct GrpcMessageIntercetor. @@ -78,7 +78,7 @@ public Listener interceptCall( // Register forwarding call to intercept outgoing response and log to audit log call = - new SimpleForwardingServerCall(call) { + new SimpleForwardingServerCall<>(call) { @Override public void sendMessage(RespT message) { // 2. Track the response & Log entry to audit logger @@ -97,7 +97,7 @@ public void close(Status status, Metadata trailers) { }; ServerCall.Listener listener = next.startCall(call, headers); - return new SimpleForwardingServerCallListener(listener) { + return new SimpleForwardingServerCallListener<>(listener) { @Override // Register listener to intercept incoming request messages and log to audit log public void onMessage(ReqT message) { diff --git a/java/sdk/src/main/java/dev/feast/RequestUtil.java b/java/sdk/src/main/java/dev/feast/RequestUtil.java index fc13c45311..da2c0dc42e 100644 --- a/java/sdk/src/main/java/dev/feast/RequestUtil.java +++ b/java/sdk/src/main/java/dev/feast/RequestUtil.java @@ -35,9 +35,7 @@ public static List createFeatureRefs(List featureRef } List featureRefs = - featureRefStrings.stream() - .map(refStr -> parseFeatureRef(refStr)) - .collect(Collectors.toList()); + featureRefStrings.stream().map(RequestUtil::parseFeatureRef).collect(Collectors.toList()); return featureRefs; } diff --git a/java/serving/src/main/java/feast/serving/registry/LocalRegistryFile.java b/java/serving/src/main/java/feast/serving/registry/LocalRegistryFile.java index b0d6b10bfc..1da45813ee 100644 --- a/java/serving/src/main/java/feast/serving/registry/LocalRegistryFile.java +++ b/java/serving/src/main/java/feast/serving/registry/LocalRegistryFile.java @@ -24,7 +24,7 @@ import java.util.Optional; public class LocalRegistryFile implements RegistryFile { - private RegistryProto.Registry cachedRegistry; + private final RegistryProto.Registry cachedRegistry; public LocalRegistryFile(String path) { try { diff --git a/java/serving/src/main/java/feast/serving/registry/Registry.java b/java/serving/src/main/java/feast/serving/registry/Registry.java index 37fae3d8dc..bc953174ea 100644 --- a/java/serving/src/main/java/feast/serving/registry/Registry.java +++ b/java/serving/src/main/java/feast/serving/registry/Registry.java @@ -17,6 +17,9 @@ package feast.serving.registry; import feast.proto.core.*; +import feast.proto.core.FeatureServiceProto.FeatureService; +import feast.proto.core.FeatureViewProto.FeatureView; +import feast.proto.core.OnDemandFeatureViewProto.OnDemandFeatureView; import feast.proto.serving.ServingAPIProto; import feast.serving.exception.SpecRetrievalException; import java.util.List; @@ -26,16 +29,16 @@ public class Registry { private final RegistryProto.Registry registry; - private Map featureViewNameToSpec; + private final Map featureViewNameToSpec; private Map onDemandFeatureViewNameToSpec; - private Map featureServiceNameToSpec; + private final Map featureServiceNameToSpec; Registry(RegistryProto.Registry registry) { this.registry = registry; List featureViewSpecs = registry.getFeatureViewsList().stream() - .map(fv -> fv.getSpec()) + .map(FeatureView::getSpec) .collect(Collectors.toList()); this.featureViewNameToSpec = featureViewSpecs.stream() @@ -43,7 +46,7 @@ public class Registry { Collectors.toMap(FeatureViewProto.FeatureViewSpec::getName, Function.identity())); List onDemandFeatureViewSpecs = registry.getOnDemandFeatureViewsList().stream() - .map(odfv -> odfv.getSpec()) + .map(OnDemandFeatureView::getSpec) .collect(Collectors.toList()); this.onDemandFeatureViewNameToSpec = onDemandFeatureViewSpecs.stream() @@ -53,7 +56,7 @@ public class Registry { Function.identity())); this.featureServiceNameToSpec = registry.getFeatureServicesList().stream() - .map(fs -> fs.getSpec()) + .map(FeatureService::getSpec) .collect( Collectors.toMap( FeatureServiceProto.FeatureServiceSpec::getName, Function.identity())); diff --git a/java/serving/src/main/java/feast/serving/service/OnlineTransformationService.java b/java/serving/src/main/java/feast/serving/service/OnlineTransformationService.java index 365432b84e..ae83635b86 100644 --- a/java/serving/src/main/java/feast/serving/service/OnlineTransformationService.java +++ b/java/serving/src/main/java/feast/serving/service/OnlineTransformationService.java @@ -239,8 +239,7 @@ public void processTransformFeaturesResponse( } catch (IOException e) { log.info(e.toString()); throw Status.INTERNAL - .withDescription( - "Unable to correctly process transform features response: " + e.toString()) + .withDescription("Unable to correctly process transform features response: " + e) .asRuntimeException(); } } @@ -249,11 +248,10 @@ public void processTransformFeaturesResponse( public ValueType serializeValuesIntoArrowIPC(List>> values) { // In order to be serialized correctly, the data must be packaged in a VectorSchemaRoot. // We first construct all the columns. - Map columnNameToColumn = new HashMap(); BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); - List columnFields = new ArrayList(); - List columns = new ArrayList(); + List columnFields = new ArrayList<>(); + List columns = new ArrayList<>(); for (Pair> columnEntry : values) { // The Python FTS does not expect full feature names, so we extract the feature name. @@ -316,8 +314,7 @@ public ValueType serializeValuesIntoArrowIPC(List None: + ingested_stream_df = self._ingest_stream_data() + transformed_df = self._construct_transformation_plan(ingested_stream_df) + online_store_query = self._write_to_online_store(transformed_df) + return online_store_query + + def _ingest_stream_data(self) -> StreamTable: + """Only supports json and avro formats currently.""" + if self.format == "json": + if not isinstance( + self.data_source.kafka_options.message_format, JsonFormat + ): + raise ValueError("kafka source message format is not jsonformat") + stream_df = ( + self.spark.readStream.format("kafka") + .option( + "kafka.bootstrap.servers", + self.data_source.kafka_options.bootstrap_servers, + ) + .option("subscribe", self.data_source.kafka_options.topic) + .option("startingOffsets", "latest") # Query start + .load() + .selectExpr("CAST(value AS STRING)") + .select( + from_json( + col("value"), + self.data_source.kafka_options.message_format.schema_json, + ).alias("table") + ) + .select("table.*") + ) + else: + if not isinstance( + self.data_source.kafka_options.message_format, AvroFormat + ): + raise ValueError("kafka source message format is not avro format") + stream_df = ( + self.spark.readStream.format("kafka") + .option( + "kafka.bootstrap.servers", + self.data_source.kafka_options.bootstrap_servers, + ) + .option("subscribe", self.data_source.kafka_options.topic) + .option("startingOffsets", "latest") # Query start + .load() + .selectExpr("CAST(value AS STRING)") + .select( + from_avro( + col("value"), + self.data_source.kafka_options.message_format.schema_json, + ).alias("table") + ) + .select("table.*") + ) + return stream_df + + def _construct_transformation_plan(self, df: StreamTable) -> StreamTable: + return self.sfv.udf.__call__(df) if self.sfv.udf else df + + def _write_to_online_store(self, df: StreamTable): + # Validation occurs at the fs.write_to_online_store() phase against the stream feature view schema. + def batch_write(row: DataFrame, batch_id: int): + pd_row = row.toPandas() + self.write_function( + pd_row, input_timestamp="event_timestamp", output_timestamp="" + ) + + query = ( + df.writeStream.outputMode("update") + .option("checkpointLocation", "/tmp/checkpoint/") + .trigger(processingTime=self.processing_time) + .foreachBatch(batch_write) + .start() + ) + + query.awaitTermination(timeout=self.query_timeout) + return query diff --git a/sdk/python/feast/infra/contrib/stream_processor.py b/sdk/python/feast/infra/contrib/stream_processor.py new file mode 100644 index 0000000000..cb44b99cd6 --- /dev/null +++ b/sdk/python/feast/infra/contrib/stream_processor.py @@ -0,0 +1,87 @@ +from abc import ABC +from typing import Callable + +import pandas as pd +from pyspark.sql import DataFrame + +from feast.data_source import DataSource +from feast.importer import import_class +from feast.repo_config import FeastConfigBaseModel +from feast.stream_feature_view import StreamFeatureView + +STREAM_PROCESSOR_CLASS_FOR_TYPE = { + ("spark", "kafka"): "feast.infra.contrib.spark_kafka_processor.SparkKafkaProcessor", +} + +# TODO: support more types other than just Spark. +StreamTable = DataFrame + + +class ProcessorConfig(FeastConfigBaseModel): + # Processor mode (spark, etc) + mode: str + # Ingestion source (kafka, kinesis, etc) + source: str + + +class StreamProcessor(ABC): + """ + A StreamProcessor can ingest and transform data for a specific stream feature view, + and persist that data to the online store. + + Attributes: + sfv: The stream feature view on which the stream processor operates. + data_source: The stream data source from which data will be ingested. + """ + + sfv: StreamFeatureView + data_source: DataSource + + def __init__(self, sfv: StreamFeatureView, data_source: DataSource): + self.sfv = sfv + self.data_source = data_source + + def ingest_stream_feature_view(self) -> None: + """ + Ingests data from the stream source attached to the stream feature view; transforms the data + and then persists it to the online store. + """ + pass + + def _ingest_stream_data(self) -> StreamTable: + """ + Ingests data into a StreamTable. + """ + pass + + def _construct_transformation_plan(self, table: StreamTable) -> StreamTable: + """ + Applies transformations on top of StreamTable object. Since stream engines use lazy + evaluation, the StreamTable will not be materialized until it is actually evaluated. + For example: df.collect() in spark or tbl.execute() in Flink. + """ + pass + + def _write_to_online_store(self, table: StreamTable) -> None: + """ + Returns query for persisting data to the online store. + """ + pass + + +def get_stream_processor_object( + config: ProcessorConfig, + sfv: StreamFeatureView, + write_function: Callable[[pd.DataFrame, str, str], None], +): + """ + Returns a stream processor object based on the config mode and stream source type. The write function is a + function that wraps the feature store "write_to_online_store" capability. + """ + if config.mode == "spark" and config.source == "kafka": + stream_processor = STREAM_PROCESSOR_CLASS_FOR_TYPE[("spark", "kafka")] + module_name, class_name = stream_processor.rsplit(".", 1) + cls = import_class(module_name, class_name, "Processor") + return cls(sfv=sfv, config=config, write_function=write_function,) + else: + raise ValueError("other processors besides spark-kafka not supported") diff --git a/sdk/python/feast/infra/registry_stores/sql.py b/sdk/python/feast/infra/registry_stores/sql.py index bf719c7c51..915cbd4e28 100644 --- a/sdk/python/feast/infra/registry_stores/sql.py +++ b/sdk/python/feast/infra/registry_stores/sql.py @@ -167,6 +167,15 @@ ) +feast_metadata = Table( + "feast_metadata", + metadata, + Column("metadata_key", String(50), primary_key=True), + Column("metadata_value", String(50), nullable=False), + Column("last_updated_timestamp", BigInteger, nullable=False), +) + + class SqlRegistry(BaseRegistry): def __init__( self, registry_config: Optional[RegistryConfig], repo_path: Optional[Path] @@ -688,6 +697,7 @@ def _apply_object( } insert_stmt = insert(table).values(values,) conn.execute(insert_stmt) + self._set_last_updated_metadata(update_datetime, project) def _delete_object(self, table, name, project, id_field_name, not_found_exception): @@ -699,6 +709,7 @@ def _delete_object(self, table, name, project, id_field_name, not_found_exceptio if rows.rowcount < 1 and not_found_exception: raise not_found_exception(name, project) self._set_last_updated_metadata(datetime.utcnow(), project) + return rows.rowcount def _get_object( diff --git a/sdk/python/feast/stream_feature_view.py b/sdk/python/feast/stream_feature_view.py index 12d7f9b74b..214ab083ab 100644 --- a/sdk/python/feast/stream_feature_view.py +++ b/sdk/python/feast/stream_feature_view.py @@ -1,3 +1,4 @@ +import copy import functools import warnings from datetime import timedelta @@ -9,7 +10,7 @@ from feast import utils from feast.aggregation import Aggregation -from feast.data_source import DataSource, KafkaSource +from feast.data_source import DataSource, KafkaSource, PushSource from feast.entity import Entity from feast.feature_view import FeatureView from feast.field import Field @@ -39,6 +40,26 @@ class StreamFeatureView(FeatureView): """ NOTE: Stream Feature Views are not yet fully implemented and exist to allow users to register their stream sources and schemas with Feast. + + Attributes: + name: str. The unique name of the stream feature view. + entities: Union[List[Entity], List[str]]. List of entities or entity join keys. + ttl: timedelta. The amount of time this group of features lives. A ttl of 0 indicates that + this group of features lives forever. Note that large ttl's or a ttl of 0 + can result in extremely computationally intensive queries. + tags: Dict[str, str]. A dictionary of key-value pairs to store arbitrary metadata. + online: bool. Defines whether this stream feature view is used in online feature retrieval. + description: str. A human-readable description. + owner: The owner of the on demand feature view, typically the email of the primary + maintainer. + schema: List[Field] The schema of the feature view, including feature, timestamp, and entity + columns. If not specified, can be inferred from the underlying data source. + source: DataSource. The stream source of data where this group of features + is stored. + aggregations (optional): List[Aggregation]. List of aggregations registered with the stream feature view. + mode(optional): str. The mode of execution. + timestamp_field (optional): Must be specified if aggregations are specified. Defines the timestamp column on which to aggregate windows. + udf (optional): MethodType The user defined transformation function. This transformation function should have all of the corresponding imports imported within the function. """ def __init__( @@ -54,8 +75,8 @@ def __init__( schema: Optional[List[Field]] = None, source: Optional[DataSource] = None, aggregations: Optional[List[Aggregation]] = None, - mode: Optional[str] = "spark", # Mode of ingestion/transformation - timestamp_field: Optional[str] = "", # Timestamp for aggregation + mode: Optional[str] = "spark", + timestamp_field: Optional[str] = "", udf: Optional[MethodType] = None, ): warnings.warn( @@ -63,9 +84,10 @@ def __init__( "Some functionality may still be unstable so functionality can change in the future.", RuntimeWarning, ) + if source is None: - raise ValueError("Stream Feature views need a source specified") - # source uses the batch_source of the kafkasource in feature_view + raise ValueError("Stream Feature views need a source to be specified") + if ( type(source).__name__ not in SUPPORTED_STREAM_SOURCES and source.to_proto().type != DataSourceProto.SourceType.CUSTOM_SOURCE @@ -74,18 +96,26 @@ def __init__( f"Stream feature views need a stream source, expected one of {SUPPORTED_STREAM_SOURCES} " f"or CUSTOM_SOURCE, got {type(source).__name__}: {source.name} instead " ) + + if aggregations and not timestamp_field: + raise ValueError( + "aggregations must have a timestamp field associated with them to perform the aggregations" + ) + self.aggregations = aggregations or [] - self.mode = mode - self.timestamp_field = timestamp_field + self.mode = mode or "" + self.timestamp_field = timestamp_field or "" self.udf = udf _batch_source = None - if isinstance(source, KafkaSource): + if isinstance(source, KafkaSource) or isinstance(source, PushSource): _batch_source = source.batch_source if source.batch_source else None - + _ttl = ttl + if not _ttl: + _ttl = timedelta(days=0) super().__init__( name=name, entities=entities, - ttl=ttl, + ttl=_ttl, batch_source=_batch_source, stream_source=source, tags=tags, @@ -102,7 +132,10 @@ def __eq__(self, other): if not super().__eq__(other): return False - + if not self.udf: + return not other.udf + if not other.udf: + return False if ( self.mode != other.mode or self.timestamp_field != other.timestamp_field @@ -113,13 +146,14 @@ def __eq__(self, other): return True - def __hash__(self): + def __hash__(self) -> int: return super().__hash__() def to_proto(self): meta = StreamFeatureViewMetaProto(materialization_intervals=[]) if self.created_timestamp: meta.created_timestamp.FromDatetime(self.created_timestamp) + if self.last_updated_timestamp: meta.last_updated_timestamp.FromDatetime(self.last_updated_timestamp) @@ -134,6 +168,7 @@ def to_proto(self): ttl_duration = Duration() ttl_duration.FromTimedelta(self.ttl) + batch_source_proto = None if self.batch_source: batch_source_proto = self.batch_source.to_proto() batch_source_proto.data_source_class_type = f"{self.batch_source.__class__.__module__}.{self.batch_source.__class__.__name__}" @@ -143,23 +178,24 @@ def to_proto(self): stream_source_proto = self.stream_source.to_proto() stream_source_proto.data_source_class_type = f"{self.stream_source.__class__.__module__}.{self.stream_source.__class__.__name__}" + udf_proto = None + if self.udf: + udf_proto = UserDefinedFunctionProto( + name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True), + ) spec = StreamFeatureViewSpecProto( name=self.name, entities=self.entities, entity_columns=[field.to_proto() for field in self.entity_columns], features=[field.to_proto() for field in self.schema], - user_defined_function=UserDefinedFunctionProto( - name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True), - ) - if self.udf - else None, + user_defined_function=udf_proto, description=self.description, tags=self.tags, owner=self.owner, - ttl=(ttl_duration if ttl_duration is not None else None), + ttl=ttl_duration, online=self.online, batch_source=batch_source_proto or None, - stream_source=stream_source_proto, + stream_source=stream_source_proto or None, timestamp_field=self.timestamp_field, aggregations=[agg.to_proto() for agg in self.aggregations], mode=self.mode, @@ -239,6 +275,25 @@ def from_proto(cls, sfv_proto): return sfv_feature_view + def __copy__(self): + fv = StreamFeatureView( + name=self.name, + schema=self.schema, + entities=self.entities, + ttl=self.ttl, + tags=self.tags, + online=self.online, + description=self.description, + owner=self.owner, + aggregations=self.aggregations, + mode=self.mode, + timestamp_field=self.timestamp_field, + sources=self.sources, + udf=self.udf, + ) + fv.projection = copy.copy(self.projection) + return fv + def stream_feature_view( *, @@ -251,11 +306,13 @@ def stream_feature_view( schema: Optional[List[Field]] = None, source: Optional[DataSource] = None, aggregations: Optional[List[Aggregation]] = None, - mode: Optional[str] = "spark", # Mode of ingestion/transformation - timestamp_field: Optional[str] = "", # Timestamp for aggregation + mode: Optional[str] = "spark", + timestamp_field: Optional[str] = "", ): """ Creates an StreamFeatureView object with the given user function as udf. + Please make sure that the udf contains all non-built in imports within the function to ensure that the execution + of a deserialized function does not miss imports. """ def mainify(obj): diff --git a/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py b/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py index e19641f291..29cd2f1c26 100644 --- a/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py +++ b/sdk/python/tests/integration/registration/test_stream_feature_view_apply.py @@ -70,3 +70,71 @@ def simple_sfv(df): assert features["test_key"] == [1001] assert "dummy_field" in features assert features["dummy_field"] == [None] + + +@pytest.mark.integration +def test_stream_feature_view_udf(environment) -> None: + """ + Test apply of StreamFeatureView udfs are serialized correctly and usable. + """ + fs = environment.feature_store + + # Create Feature Views + entity = Entity(name="driver_entity", join_keys=["test_key"]) + + stream_source = KafkaSource( + name="kafka", + timestamp_field="event_timestamp", + bootstrap_servers="", + message_format=AvroFormat(""), + topic="topic", + batch_source=FileSource(path="test_path", timestamp_field="event_timestamp"), + watermark=timedelta(days=1), + ) + + @stream_feature_view( + entities=[entity], + ttl=timedelta(days=30), + owner="test@example.com", + online=True, + schema=[Field(name="dummy_field", dtype=Float32)], + description="desc", + aggregations=[ + Aggregation( + column="dummy_field", function="max", time_window=timedelta(days=1), + ), + Aggregation( + column="dummy_field2", function="count", time_window=timedelta(days=24), + ), + ], + timestamp_field="event_timestamp", + mode="spark", + source=stream_source, + tags={}, + ) + def pandas_view(pandas_df): + import pandas as pd + + assert type(pandas_df) == pd.DataFrame + df = pandas_df.transform(lambda x: x + 10, axis=1) + df.insert(2, "C", [20.2, 230.0, 34.0], True) + return df + + import pandas as pd + + df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]}) + + fs.apply([entity, pandas_view]) + stream_feature_views = fs.list_stream_feature_views() + assert len(stream_feature_views) == 1 + assert stream_feature_views[0].name == "pandas_view" + assert stream_feature_views[0] == pandas_view + + sfv = stream_feature_views[0] + + new_df = sfv.udf(df) + + expected_df = pd.DataFrame( + {"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]} + ) + assert new_df.equals(expected_df) diff --git a/sdk/python/tests/unit/test_feature_views.py b/sdk/python/tests/unit/test_feature_views.py index 904260dfe6..64b23edd2c 100644 --- a/sdk/python/tests/unit/test_feature_views.py +++ b/sdk/python/tests/unit/test_feature_views.py @@ -9,7 +9,7 @@ from feast.entity import Entity from feast.field import Field from feast.infra.offline_stores.file_source import FileSource -from feast.stream_feature_view import StreamFeatureView +from feast.stream_feature_view import StreamFeatureView, stream_feature_view from feast.types import Float32 @@ -129,3 +129,75 @@ def test_stream_feature_view_serialization(): new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto) assert new_sfv == sfv + + +def test_stream_feature_view_udfs(): + entity = Entity(name="driver_entity", join_keys=["test_key"]) + stream_source = KafkaSource( + name="kafka", + timestamp_field="event_timestamp", + bootstrap_servers="", + message_format=AvroFormat(""), + topic="topic", + batch_source=FileSource(path="some path"), + ) + + @stream_feature_view( + entities=[entity], + ttl=timedelta(days=30), + owner="test@example.com", + online=True, + schema=[Field(name="dummy_field", dtype=Float32)], + description="desc", + aggregations=[ + Aggregation( + column="dummy_field", function="max", time_window=timedelta(days=1), + ) + ], + timestamp_field="event_timestamp", + source=stream_source, + ) + def pandas_udf(pandas_df): + import pandas as pd + + assert type(pandas_df) == pd.DataFrame + df = pandas_df.transform(lambda x: x + 10, axis=1) + return df + + import pandas as pd + + df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]}) + sfv = pandas_udf + sfv_proto = sfv.to_proto() + new_sfv = StreamFeatureView.from_proto(sfv_proto) + new_df = new_sfv.udf(df) + + expected_df = pd.DataFrame({"A": [11, 12, 13], "B": [20, 30, 40]}) + + assert new_df.equals(expected_df) + + +def test_stream_feature_view_initialization_with_optional_fields_omitted(): + entity = Entity(name="driver_entity", join_keys=["test_key"]) + stream_source = KafkaSource( + name="kafka", + timestamp_field="event_timestamp", + bootstrap_servers="", + message_format=AvroFormat(""), + topic="topic", + batch_source=FileSource(path="some path"), + ) + + sfv = StreamFeatureView( + name="test kafka stream feature view", + entities=[entity], + schema=[], + description="desc", + timestamp_field="event_timestamp", + source=stream_source, + tags={}, + ) + sfv_proto = sfv.to_proto() + + new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto) + assert new_sfv == sfv diff --git a/setup.py b/setup.py index f92db4acec..c261507c4a 100644 --- a/setup.py +++ b/setup.py @@ -178,6 +178,13 @@ + HBASE_REQUIRED ) + +# rtd builds fail because of mysql not being installed in their environment. +# We can add mysql there, but it's not strictly needed. This will be faster for builds. +DOCS_REQUIRED = CI_REQUIRED +for _r in MYSQL_REQUIRED: + DOCS_REQUIRED.remove(_r) + DEV_REQUIRED = ["mypy-protobuf==3.1", "grpcio-testing==1.*"] + CI_REQUIRED # Get git repo root directory @@ -480,6 +487,7 @@ def copy_extensions_to_source(self): "ge": GE_REQUIRED, "hbase": HBASE_REQUIRED, "go": GO_REQUIRED, + "docs": DOCS_REQUIRED, }, include_package_data=True, license="Apache",