From 8b3a97a9774fa6a2cefcf64c237a36212a158ab1 Mon Sep 17 00:00:00 2001 From: Achal Shah Date: Tue, 20 Jul 2021 07:30:39 -0700 Subject: [PATCH] Refactor data source classes to fix import issues (#1723) * Refactor data source classes to fix import issues Signed-off-by: Achal Shah * make format and lint Signed-off-by: Achal Shah * remove unneded __init__ files Signed-off-by: Achal Shah --- sdk/python/feast/__init__.py | 10 +- sdk/python/feast/data_source.py | 22 +- .../feast/infra/offline_stores/bigquery.py | 207 +--------------- .../infra/offline_stores/bigquery_source.py | 206 ++++++++++++++++ sdk/python/feast/infra/offline_stores/file.py | 203 +--------------- .../feast/infra/offline_stores/file_source.py | 205 ++++++++++++++++ .../feast/infra/offline_stores/redshift.py | 219 +---------------- .../infra/offline_stores/redshift_source.py | 220 ++++++++++++++++++ sdk/python/usage_tests/test_usage.py | 2 +- 9 files changed, 664 insertions(+), 630 deletions(-) create mode 100644 sdk/python/feast/infra/offline_stores/bigquery_source.py create mode 100644 sdk/python/feast/infra/offline_stores/file_source.py create mode 100644 sdk/python/feast/infra/offline_stores/redshift_source.py diff --git a/sdk/python/feast/__init__.py b/sdk/python/feast/__init__.py index 43d7aa939b..430fd9f715 100644 --- a/sdk/python/feast/__init__.py +++ b/sdk/python/feast/__init__.py @@ -2,9 +2,9 @@ from pkg_resources import DistributionNotFound, get_distribution -from feast.infra.offline_stores.bigquery import BigQuerySource -from feast.infra.offline_stores.file import FileSource -from feast.infra.offline_stores.redshift import RedshiftSource +from feast.infra.offline_stores.bigquery_source import BigQuerySource +from feast.infra.offline_stores.file_source import FileSource +from feast.infra.offline_stores.redshift_source import RedshiftSource from .client import Client from .data_source import KafkaSource, KinesisSource, SourceType @@ -29,12 +29,10 @@ pass __all__ = [ - "BigQuerySource", "Client", "Entity", "KafkaSource", "KinesisSource", - "RedshiftSource", "Feature", "FeatureStore", "FeatureTable", @@ -42,5 +40,7 @@ "RepoConfig", "SourceType", "ValueType", + "BigQuerySource", "FileSource", + "RedshiftSource", ] diff --git a/sdk/python/feast/data_source.py b/sdk/python/feast/data_source.py index 6886fa0c26..dcad135d8d 100644 --- a/sdk/python/feast/data_source.py +++ b/sdk/python/feast/data_source.py @@ -317,17 +317,17 @@ def from_proto(data_source: DataSourceProto): return cls.from_proto(data_source) if data_source.file_options.file_format and data_source.file_options.file_url: - from feast.infra.offline_stores.file import FileSource + from feast.infra.offline_stores.file_source import FileSource data_source_obj = FileSource.from_proto(data_source) elif ( data_source.bigquery_options.table_ref or data_source.bigquery_options.query ): - from feast.infra.offline_stores.bigquery import BigQuerySource + from feast.infra.offline_stores.bigquery_source import BigQuerySource data_source_obj = BigQuerySource.from_proto(data_source) elif data_source.redshift_options.table or data_source.redshift_options.query: - from feast.infra.offline_stores.redshift import RedshiftSource + from feast.infra.offline_stores.redshift_source import RedshiftSource data_source_obj = RedshiftSource.from_proto(data_source) elif ( @@ -378,6 +378,14 @@ def get_table_column_names_and_types( class KafkaSource(DataSource): + def validate(self, config: RepoConfig): + pass + + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: + pass + def __init__( self, event_timestamp_column: str, @@ -463,6 +471,14 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: class KinesisSource(DataSource): + def validate(self, config: RepoConfig): + pass + + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: + pass + @staticmethod def from_proto(data_source: DataSourceProto): return KinesisSource( diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index ab334a347f..16dc8e950c 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -2,7 +2,7 @@ import uuid from dataclasses import asdict, dataclass from datetime import date, datetime, timedelta -from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import List, Optional, Set, Union import pandas import pyarrow @@ -12,12 +12,11 @@ from pydantic.typing import Literal from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed -from feast import errors, type_map +from feast import errors from feast.data_source import DataSource from feast.errors import ( BigQueryJobCancelled, BigQueryJobStillRunning, - DataSourceNotFoundException, FeastProviderLoginError, ) from feast.feature_view import FeatureView @@ -26,10 +25,10 @@ DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, _get_requested_feature_views_to_features_dict, ) -from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.registry import Registry from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.value_type import ValueType + +from .bigquery_source import BigQuerySource try: from google.api_core.exceptions import NotFound @@ -707,201 +706,3 @@ def _get_bigquery_client(project: Optional[str] = None): ) USING ({{featureview.name}}__entity_row_unique_id) {% endfor %} """ - - -class BigQuerySource(DataSource): - def __init__( - self, - event_timestamp_column: Optional[str] = "", - table_ref: Optional[str] = None, - created_timestamp_column: Optional[str] = "", - field_mapping: Optional[Dict[str, str]] = None, - date_partition_column: Optional[str] = "", - query: Optional[str] = None, - ): - self._bigquery_options = BigQueryOptions(table_ref=table_ref, query=query) - - super().__init__( - event_timestamp_column, - created_timestamp_column, - field_mapping, - date_partition_column, - ) - - def __eq__(self, other): - if not isinstance(other, BigQuerySource): - raise TypeError( - "Comparisons should only involve BigQuerySource class objects." - ) - - return ( - self.bigquery_options.table_ref == other.bigquery_options.table_ref - and self.bigquery_options.query == other.bigquery_options.query - and self.event_timestamp_column == other.event_timestamp_column - and self.created_timestamp_column == other.created_timestamp_column - and self.field_mapping == other.field_mapping - ) - - @property - def table_ref(self): - return self._bigquery_options.table_ref - - @property - def query(self): - return self._bigquery_options.query - - @property - def bigquery_options(self): - """ - Returns the bigquery options of this data source - """ - return self._bigquery_options - - @bigquery_options.setter - def bigquery_options(self, bigquery_options): - """ - Sets the bigquery options of this data source - """ - self._bigquery_options = bigquery_options - - @staticmethod - def from_proto(data_source: DataSourceProto): - - assert data_source.HasField("bigquery_options") - - return BigQuerySource( - field_mapping=dict(data_source.field_mapping), - table_ref=data_source.bigquery_options.table_ref, - event_timestamp_column=data_source.event_timestamp_column, - created_timestamp_column=data_source.created_timestamp_column, - date_partition_column=data_source.date_partition_column, - query=data_source.bigquery_options.query, - ) - - def to_proto(self) -> DataSourceProto: - data_source_proto = DataSourceProto( - type=DataSourceProto.BATCH_BIGQUERY, - field_mapping=self.field_mapping, - bigquery_options=self.bigquery_options.to_proto(), - ) - - data_source_proto.event_timestamp_column = self.event_timestamp_column - data_source_proto.created_timestamp_column = self.created_timestamp_column - data_source_proto.date_partition_column = self.date_partition_column - - return data_source_proto - - def validate(self, config: RepoConfig): - if not self.query: - from google.api_core.exceptions import NotFound - from google.cloud import bigquery - - client = bigquery.Client() - try: - client.get_table(self.table_ref) - except NotFound: - raise DataSourceNotFoundException(self.table_ref) - - def get_table_query_string(self) -> str: - """Returns a string that can directly be used to reference this table in SQL""" - if self.table_ref: - return f"`{self.table_ref}`" - else: - return f"({self.query})" - - @staticmethod - def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: - return type_map.bq_to_feast_value_type - - def get_table_column_names_and_types( - self, config: RepoConfig - ) -> Iterable[Tuple[str, str]]: - from google.cloud import bigquery - - client = bigquery.Client() - if self.table_ref is not None: - table_schema = client.get_table(self.table_ref).schema - if not isinstance(table_schema[0], bigquery.schema.SchemaField): - raise TypeError("Could not parse BigQuery table schema.") - - name_type_pairs = [(field.name, field.field_type) for field in table_schema] - else: - bq_columns_query = f"SELECT * FROM ({self.query}) LIMIT 1" - queryRes = client.query(bq_columns_query).result() - name_type_pairs = [ - (schema_field.name, schema_field.field_type) - for schema_field in queryRes.schema - ] - - return name_type_pairs - - -class BigQueryOptions: - """ - DataSource BigQuery options used to source features from BigQuery query - """ - - def __init__(self, table_ref: Optional[str], query: Optional[str]): - self._table_ref = table_ref - self._query = query - - @property - def query(self): - """ - Returns the BigQuery SQL query referenced by this source - """ - return self._query - - @query.setter - def query(self, query): - """ - Sets the BigQuery SQL query referenced by this source - """ - self._query = query - - @property - def table_ref(self): - """ - Returns the table ref of this BQ table - """ - return self._table_ref - - @table_ref.setter - def table_ref(self, table_ref): - """ - Sets the table ref of this BQ table - """ - self._table_ref = table_ref - - @classmethod - def from_proto(cls, bigquery_options_proto: DataSourceProto.BigQueryOptions): - """ - Creates a BigQueryOptions from a protobuf representation of a BigQuery option - - Args: - bigquery_options_proto: A protobuf representation of a DataSource - - Returns: - Returns a BigQueryOptions object based on the bigquery_options protobuf - """ - - bigquery_options = cls( - table_ref=bigquery_options_proto.table_ref, - query=bigquery_options_proto.query, - ) - - return bigquery_options - - def to_proto(self) -> DataSourceProto.BigQueryOptions: - """ - Converts an BigQueryOptionsProto object to its protobuf representation. - - Returns: - BigQueryOptionsProto protobuf - """ - - bigquery_options_proto = DataSourceProto.BigQueryOptions( - table_ref=self.table_ref, query=self.query, - ) - - return bigquery_options_proto diff --git a/sdk/python/feast/infra/offline_stores/bigquery_source.py b/sdk/python/feast/infra/offline_stores/bigquery_source.py new file mode 100644 index 0000000000..a5c1afa3e0 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/bigquery_source.py @@ -0,0 +1,206 @@ +from typing import Callable, Dict, Iterable, Optional, Tuple + +from feast import type_map +from feast.data_source import DataSource +from feast.errors import DataSourceNotFoundException +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.repo_config import RepoConfig +from feast.value_type import ValueType + + +class BigQuerySource(DataSource): + def __init__( + self, + event_timestamp_column: Optional[str] = "", + table_ref: Optional[str] = None, + created_timestamp_column: Optional[str] = "", + field_mapping: Optional[Dict[str, str]] = None, + date_partition_column: Optional[str] = "", + query: Optional[str] = None, + ): + self._bigquery_options = BigQueryOptions(table_ref=table_ref, query=query) + + super().__init__( + event_timestamp_column, + created_timestamp_column, + field_mapping, + date_partition_column, + ) + + def __eq__(self, other): + if not isinstance(other, BigQuerySource): + raise TypeError( + "Comparisons should only involve BigQuerySource class objects." + ) + + return ( + self.bigquery_options.table_ref == other.bigquery_options.table_ref + and self.bigquery_options.query == other.bigquery_options.query + and self.event_timestamp_column == other.event_timestamp_column + and self.created_timestamp_column == other.created_timestamp_column + and self.field_mapping == other.field_mapping + ) + + @property + def table_ref(self): + return self._bigquery_options.table_ref + + @property + def query(self): + return self._bigquery_options.query + + @property + def bigquery_options(self): + """ + Returns the bigquery options of this data source + """ + return self._bigquery_options + + @bigquery_options.setter + def bigquery_options(self, bigquery_options): + """ + Sets the bigquery options of this data source + """ + self._bigquery_options = bigquery_options + + @staticmethod + def from_proto(data_source: DataSourceProto): + + assert data_source.HasField("bigquery_options") + + return BigQuerySource( + field_mapping=dict(data_source.field_mapping), + table_ref=data_source.bigquery_options.table_ref, + event_timestamp_column=data_source.event_timestamp_column, + created_timestamp_column=data_source.created_timestamp_column, + date_partition_column=data_source.date_partition_column, + query=data_source.bigquery_options.query, + ) + + def to_proto(self) -> DataSourceProto: + data_source_proto = DataSourceProto( + type=DataSourceProto.BATCH_BIGQUERY, + field_mapping=self.field_mapping, + bigquery_options=self.bigquery_options.to_proto(), + ) + + data_source_proto.event_timestamp_column = self.event_timestamp_column + data_source_proto.created_timestamp_column = self.created_timestamp_column + data_source_proto.date_partition_column = self.date_partition_column + + return data_source_proto + + def validate(self, config: RepoConfig): + if not self.query: + from google.api_core.exceptions import NotFound + from google.cloud import bigquery + + client = bigquery.Client() + try: + client.get_table(self.table_ref) + except NotFound: + raise DataSourceNotFoundException(self.table_ref) + + def get_table_query_string(self) -> str: + """Returns a string that can directly be used to reference this table in SQL""" + if self.table_ref: + return f"`{self.table_ref}`" + else: + return f"({self.query})" + + @staticmethod + def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: + return type_map.bq_to_feast_value_type + + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: + from google.cloud import bigquery + + client = bigquery.Client() + if self.table_ref is not None: + table_schema = client.get_table(self.table_ref).schema + if not isinstance(table_schema[0], bigquery.schema.SchemaField): + raise TypeError("Could not parse BigQuery table schema.") + + name_type_pairs = [(field.name, field.field_type) for field in table_schema] + else: + bq_columns_query = f"SELECT * FROM ({self.query}) LIMIT 1" + queryRes = client.query(bq_columns_query).result() + name_type_pairs = [ + (schema_field.name, schema_field.field_type) + for schema_field in queryRes.schema + ] + + return name_type_pairs + + +class BigQueryOptions: + """ + DataSource BigQuery options used to source features from BigQuery query + """ + + def __init__(self, table_ref: Optional[str], query: Optional[str]): + self._table_ref = table_ref + self._query = query + + @property + def query(self): + """ + Returns the BigQuery SQL query referenced by this source + """ + return self._query + + @query.setter + def query(self, query): + """ + Sets the BigQuery SQL query referenced by this source + """ + self._query = query + + @property + def table_ref(self): + """ + Returns the table ref of this BQ table + """ + return self._table_ref + + @table_ref.setter + def table_ref(self, table_ref): + """ + Sets the table ref of this BQ table + """ + self._table_ref = table_ref + + @classmethod + def from_proto(cls, bigquery_options_proto: DataSourceProto.BigQueryOptions): + """ + Creates a BigQueryOptions from a protobuf representation of a BigQuery option + + Args: + bigquery_options_proto: A protobuf representation of a DataSource + + Returns: + Returns a BigQueryOptions object based on the bigquery_options protobuf + """ + + bigquery_options = cls( + table_ref=bigquery_options_proto.table_ref, + query=bigquery_options_proto.query, + ) + + return bigquery_options + + def to_proto(self) -> DataSourceProto.BigQueryOptions: + """ + Converts an BigQueryOptionsProto object to its protobuf representation. + + Returns: + BigQueryOptionsProto protobuf + """ + + bigquery_options_proto = DataSourceProto.BigQueryOptions( + table_ref=self.table_ref, query=self.query, + ) + + return bigquery_options_proto diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index e6f95ee162..74153acaee 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -1,14 +1,12 @@ from datetime import datetime -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Union import pandas as pd import pyarrow import pytz -from pyarrow.parquet import ParquetFile from pydantic.typing import Literal -from feast import type_map -from feast.data_format import FileFormat +from feast import FileSource from feast.data_source import DataSource from feast.errors import FeastJoinKeysDuringMaterialization from feast.feature_view import FeatureView @@ -18,10 +16,8 @@ _get_requested_feature_views_to_features_dict, _run_field_mapping, ) -from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.registry import Registry from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.value_type import ValueType class FileOfflineStoreConfig(FeastConfigBaseModel): @@ -270,198 +266,3 @@ def evaluate_offline_job(): return last_values_df[columns_to_extract] return FileRetrievalJob(evaluation_function=evaluate_offline_job) - - -class FileSource(DataSource): - def __init__( - self, - event_timestamp_column: Optional[str] = "", - file_url: Optional[str] = None, - path: Optional[str] = None, - file_format: FileFormat = None, - created_timestamp_column: Optional[str] = "", - field_mapping: Optional[Dict[str, str]] = None, - date_partition_column: Optional[str] = "", - ): - """Create a FileSource from a file containing feature data. Only Parquet format supported. - - Args: - - path: File path to file containing feature data. Must contain an event_timestamp column, entity columns and - feature columns. - event_timestamp_column: Event timestamp column used for point in time joins of feature values. - created_timestamp_column (optional): Timestamp column when row was created, used for deduplicating rows. - file_url: [Deprecated] Please see path - file_format (optional): Explicitly set the file format. Allows Feast to bypass inferring the file format. - field_mapping: A dictionary mapping of column names in this data source to feature names in a feature table - or view. Only used for feature columns, not entities or timestamp columns. - - Examples: - >>> FileSource(path="/data/my_features.parquet", event_timestamp_column="datetime") - """ - if path is None and file_url is None: - raise ValueError( - 'No "path" argument provided. Please set "path" to the location of your file source.' - ) - if file_url: - from warnings import warn - - warn( - 'Argument "file_url" is being deprecated. Please use the "path" argument.' - ) - else: - file_url = path - - self._file_options = FileOptions(file_format=file_format, file_url=file_url) - - super().__init__( - event_timestamp_column, - created_timestamp_column, - field_mapping, - date_partition_column, - ) - - def __eq__(self, other): - if not isinstance(other, FileSource): - raise TypeError("Comparisons should only involve FileSource class objects.") - - return ( - self.file_options.file_url == other.file_options.file_url - and self.file_options.file_format == other.file_options.file_format - and self.event_timestamp_column == other.event_timestamp_column - and self.created_timestamp_column == other.created_timestamp_column - and self.field_mapping == other.field_mapping - ) - - @property - def file_options(self): - """ - Returns the file options of this data source - """ - return self._file_options - - @file_options.setter - def file_options(self, file_options): - """ - Sets the file options of this data source - """ - self._file_options = file_options - - @property - def path(self): - """ - Returns the file path of this feature data source - """ - return self._file_options.file_url - - @staticmethod - def from_proto(data_source: DataSourceProto): - return FileSource( - field_mapping=dict(data_source.field_mapping), - file_format=FileFormat.from_proto(data_source.file_options.file_format), - path=data_source.file_options.file_url, - event_timestamp_column=data_source.event_timestamp_column, - created_timestamp_column=data_source.created_timestamp_column, - date_partition_column=data_source.date_partition_column, - ) - - def to_proto(self) -> DataSourceProto: - data_source_proto = DataSourceProto( - type=DataSourceProto.BATCH_FILE, - field_mapping=self.field_mapping, - file_options=self.file_options.to_proto(), - ) - - data_source_proto.event_timestamp_column = self.event_timestamp_column - data_source_proto.created_timestamp_column = self.created_timestamp_column - data_source_proto.date_partition_column = self.date_partition_column - - return data_source_proto - - def validate(self, config: RepoConfig): - # TODO: validate a FileSource - pass - - @staticmethod - def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: - return type_map.pa_to_feast_value_type - - def get_table_column_names_and_types( - self, config: RepoConfig - ) -> Iterable[Tuple[str, str]]: - schema = ParquetFile(self.path).schema_arrow - return zip(schema.names, map(str, schema.types)) - - -class FileOptions: - """ - DataSource File options used to source features from a file - """ - - def __init__( - self, file_format: Optional[FileFormat], file_url: Optional[str], - ): - self._file_format = file_format - self._file_url = file_url - - @property - def file_format(self): - """ - Returns the file format of this file - """ - return self._file_format - - @file_format.setter - def file_format(self, file_format): - """ - Sets the file format of this file - """ - self._file_format = file_format - - @property - def file_url(self): - """ - Returns the file url of this file - """ - return self._file_url - - @file_url.setter - def file_url(self, file_url): - """ - Sets the file url of this file - """ - self._file_url = file_url - - @classmethod - def from_proto(cls, file_options_proto: DataSourceProto.FileOptions): - """ - Creates a FileOptions from a protobuf representation of a file option - - args: - file_options_proto: a protobuf representation of a datasource - - Returns: - Returns a FileOptions object based on the file_options protobuf - """ - file_options = cls( - file_format=FileFormat.from_proto(file_options_proto.file_format), - file_url=file_options_proto.file_url, - ) - return file_options - - def to_proto(self) -> DataSourceProto.FileOptions: - """ - Converts an FileOptionsProto object to its protobuf representation. - - Returns: - FileOptionsProto protobuf - """ - - file_options_proto = DataSourceProto.FileOptions( - file_format=( - None if self.file_format is None else self.file_format.to_proto() - ), - file_url=self.file_url, - ) - - return file_options_proto diff --git a/sdk/python/feast/infra/offline_stores/file_source.py b/sdk/python/feast/infra/offline_stores/file_source.py new file mode 100644 index 0000000000..cf20c78a8d --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/file_source.py @@ -0,0 +1,205 @@ +from typing import Callable, Dict, Iterable, Optional, Tuple + +from pyarrow.parquet import ParquetFile + +from feast import type_map +from feast.data_format import FileFormat +from feast.data_source import DataSource +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.repo_config import RepoConfig +from feast.value_type import ValueType + + +class FileSource(DataSource): + def __init__( + self, + event_timestamp_column: Optional[str] = "", + file_url: Optional[str] = None, + path: Optional[str] = None, + file_format: FileFormat = None, + created_timestamp_column: Optional[str] = "", + field_mapping: Optional[Dict[str, str]] = None, + date_partition_column: Optional[str] = "", + ): + """Create a FileSource from a file containing feature data. Only Parquet format supported. + + Args: + + path: File path to file containing feature data. Must contain an event_timestamp column, entity columns and + feature columns. + event_timestamp_column: Event timestamp column used for point in time joins of feature values. + created_timestamp_column (optional): Timestamp column when row was created, used for deduplicating rows. + file_url: [Deprecated] Please see path + file_format (optional): Explicitly set the file format. Allows Feast to bypass inferring the file format. + field_mapping: A dictionary mapping of column names in this data source to feature names in a feature table + or view. Only used for feature columns, not entities or timestamp columns. + + Examples: + >>> FileSource(path="/data/my_features.parquet", event_timestamp_column="datetime") + """ + if path is None and file_url is None: + raise ValueError( + 'No "path" argument provided. Please set "path" to the location of your file source.' + ) + if file_url: + from warnings import warn + + warn( + 'Argument "file_url" is being deprecated. Please use the "path" argument.' + ) + else: + file_url = path + + self._file_options = FileOptions(file_format=file_format, file_url=file_url) + + super().__init__( + event_timestamp_column, + created_timestamp_column, + field_mapping, + date_partition_column, + ) + + def __eq__(self, other): + if not isinstance(other, FileSource): + raise TypeError("Comparisons should only involve FileSource class objects.") + + return ( + self.file_options.file_url == other.file_options.file_url + and self.file_options.file_format == other.file_options.file_format + and self.event_timestamp_column == other.event_timestamp_column + and self.created_timestamp_column == other.created_timestamp_column + and self.field_mapping == other.field_mapping + ) + + @property + def file_options(self): + """ + Returns the file options of this data source + """ + return self._file_options + + @file_options.setter + def file_options(self, file_options): + """ + Sets the file options of this data source + """ + self._file_options = file_options + + @property + def path(self): + """ + Returns the file path of this feature data source + """ + return self._file_options.file_url + + @staticmethod + def from_proto(data_source: DataSourceProto): + return FileSource( + field_mapping=dict(data_source.field_mapping), + file_format=FileFormat.from_proto(data_source.file_options.file_format), + path=data_source.file_options.file_url, + event_timestamp_column=data_source.event_timestamp_column, + created_timestamp_column=data_source.created_timestamp_column, + date_partition_column=data_source.date_partition_column, + ) + + def to_proto(self) -> DataSourceProto: + data_source_proto = DataSourceProto( + type=DataSourceProto.BATCH_FILE, + field_mapping=self.field_mapping, + file_options=self.file_options.to_proto(), + ) + + data_source_proto.event_timestamp_column = self.event_timestamp_column + data_source_proto.created_timestamp_column = self.created_timestamp_column + data_source_proto.date_partition_column = self.date_partition_column + + return data_source_proto + + def validate(self, config: RepoConfig): + # TODO: validate a FileSource + pass + + @staticmethod + def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: + return type_map.pa_to_feast_value_type + + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: + schema = ParquetFile(self.path).schema_arrow + return zip(schema.names, map(str, schema.types)) + + +class FileOptions: + """ + DataSource File options used to source features from a file + """ + + def __init__( + self, file_format: Optional[FileFormat], file_url: Optional[str], + ): + self._file_format = file_format + self._file_url = file_url + + @property + def file_format(self): + """ + Returns the file format of this file + """ + return self._file_format + + @file_format.setter + def file_format(self, file_format): + """ + Sets the file format of this file + """ + self._file_format = file_format + + @property + def file_url(self): + """ + Returns the file url of this file + """ + return self._file_url + + @file_url.setter + def file_url(self, file_url): + """ + Sets the file url of this file + """ + self._file_url = file_url + + @classmethod + def from_proto(cls, file_options_proto: DataSourceProto.FileOptions): + """ + Creates a FileOptions from a protobuf representation of a file option + + args: + file_options_proto: a protobuf representation of a datasource + + Returns: + Returns a FileOptions object based on the file_options protobuf + """ + file_options = cls( + file_format=FileFormat.from_proto(file_options_proto.file_format), + file_url=file_options_proto.file_url, + ) + return file_options + + def to_proto(self) -> DataSourceProto.FileOptions: + """ + Converts an FileOptionsProto object to its protobuf representation. + + Returns: + FileOptionsProto protobuf + """ + + file_options_proto = DataSourceProto.FileOptions( + file_format=( + None if self.file_format is None else self.file_format.to_proto() + ), + file_url=self.file_url, + ) + + return file_options_proto diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index ae085c28b9..9204ff00be 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -1,22 +1,19 @@ import uuid from datetime import datetime -from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import List, Optional, Union import pandas as pd import pyarrow as pa from pydantic import StrictStr from pydantic.typing import Literal -from feast import type_map +from feast import RedshiftSource from feast.data_source import DataSource -from feast.errors import DataSourceNotFoundException, RedshiftCredentialsError from feast.feature_view import FeatureView from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob from feast.infra.utils import aws_utils -from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto from feast.registry import Registry from feast.repo_config import FeastConfigBaseModel, RepoConfig -from feast.value_type import ValueType class RedshiftOfflineStoreConfig(FeastConfigBaseModel): @@ -175,215 +172,3 @@ def to_redshift(self, table_name: str) -> None: self._config.offline_store.user, f'CREATE TABLE "{table_name}" AS ({self.query})', ) - - -class RedshiftSource(DataSource): - def __init__( - self, - event_timestamp_column: Optional[str] = "", - table: Optional[str] = None, - created_timestamp_column: Optional[str] = "", - field_mapping: Optional[Dict[str, str]] = None, - date_partition_column: Optional[str] = "", - query: Optional[str] = None, - ): - super().__init__( - event_timestamp_column, - created_timestamp_column, - field_mapping, - date_partition_column, - ) - - self._redshift_options = RedshiftOptions(table=table, query=query) - - @staticmethod - def from_proto(data_source: DataSourceProto): - return RedshiftSource( - field_mapping=dict(data_source.field_mapping), - table=data_source.redshift_options.table, - event_timestamp_column=data_source.event_timestamp_column, - created_timestamp_column=data_source.created_timestamp_column, - date_partition_column=data_source.date_partition_column, - query=data_source.redshift_options.query, - ) - - def __eq__(self, other): - if not isinstance(other, RedshiftSource): - raise TypeError( - "Comparisons should only involve RedshiftSource class objects." - ) - - return ( - self.redshift_options.table == other.redshift_options.table - and self.redshift_options.query == other.redshift_options.query - and self.event_timestamp_column == other.event_timestamp_column - and self.created_timestamp_column == other.created_timestamp_column - and self.field_mapping == other.field_mapping - ) - - @property - def table(self): - return self._redshift_options.table - - @property - def query(self): - return self._redshift_options.query - - @property - def redshift_options(self): - """ - Returns the Redshift options of this data source - """ - return self._redshift_options - - @redshift_options.setter - def redshift_options(self, _redshift_options): - """ - Sets the Redshift options of this data source - """ - self._redshift_options = _redshift_options - - def to_proto(self) -> DataSourceProto: - data_source_proto = DataSourceProto( - type=DataSourceProto.BATCH_REDSHIFT, - field_mapping=self.field_mapping, - redshift_options=self.redshift_options.to_proto(), - ) - - data_source_proto.event_timestamp_column = self.event_timestamp_column - data_source_proto.created_timestamp_column = self.created_timestamp_column - data_source_proto.date_partition_column = self.date_partition_column - - return data_source_proto - - def validate(self, config: RepoConfig): - # As long as the query gets successfully executed, or the table exists, - # the data source is validated. We don't need the results though. - # TODO: uncomment this - # self.get_table_column_names_and_types(config) - print("Validate", self.get_table_column_names_and_types(config)) - - def get_table_query_string(self) -> str: - """Returns a string that can directly be used to reference this table in SQL""" - if self.table: - return f'"{self.table}"' - else: - return f"({self.query})" - - @staticmethod - def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: - return type_map.redshift_to_feast_value_type - - def get_table_column_names_and_types( - self, config: RepoConfig - ) -> Iterable[Tuple[str, str]]: - from botocore.exceptions import ClientError - - from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig - from feast.infra.utils import aws_utils - - assert isinstance(config.offline_store, RedshiftOfflineStoreConfig) - - client = aws_utils.get_redshift_data_client(config.offline_store.region) - - if self.table is not None: - try: - table = client.describe_table( - ClusterIdentifier=config.offline_store.cluster_id, - Database=config.offline_store.database, - DbUser=config.offline_store.user, - Table=self.table, - ) - except ClientError as e: - if e.response["Error"]["Code"] == "ValidationException": - raise RedshiftCredentialsError() from e - raise - - # The API returns valid JSON with empty column list when the table doesn't exist - if len(table["ColumnList"]) == 0: - raise DataSourceNotFoundException(self.table) - - columns = table["ColumnList"] - else: - statement_id = aws_utils.execute_redshift_statement( - client, - config.offline_store.cluster_id, - config.offline_store.database, - config.offline_store.user, - f"SELECT * FROM ({self.query}) LIMIT 1", - ) - columns = aws_utils.get_redshift_statement_result(client, statement_id)[ - "ColumnMetadata" - ] - - return [(column["name"], column["typeName"].upper()) for column in columns] - - -class RedshiftOptions: - """ - DataSource Redshift options used to source features from Redshift query - """ - - def __init__(self, table: Optional[str], query: Optional[str]): - self._table = table - self._query = query - - @property - def query(self): - """ - Returns the Redshift SQL query referenced by this source - """ - return self._query - - @query.setter - def query(self, query): - """ - Sets the Redshift SQL query referenced by this source - """ - self._query = query - - @property - def table(self): - """ - Returns the table name of this Redshift table - """ - return self._table - - @table.setter - def table(self, table_name): - """ - Sets the table ref of this Redshift table - """ - self._table = table_name - - @classmethod - def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions): - """ - Creates a RedshiftOptions from a protobuf representation of a Redshift option - - Args: - redshift_options_proto: A protobuf representation of a DataSource - - Returns: - Returns a RedshiftOptions object based on the redshift_options protobuf - """ - - redshift_options = cls( - table=redshift_options_proto.table, query=redshift_options_proto.query, - ) - - return redshift_options - - def to_proto(self) -> DataSourceProto.RedshiftOptions: - """ - Converts an RedshiftOptionsProto object to its protobuf representation. - - Returns: - RedshiftOptionsProto protobuf - """ - - redshift_options_proto = DataSourceProto.RedshiftOptions( - table=self.table, query=self.query, - ) - - return redshift_options_proto diff --git a/sdk/python/feast/infra/offline_stores/redshift_source.py b/sdk/python/feast/infra/offline_stores/redshift_source.py new file mode 100644 index 0000000000..81fe35fc18 --- /dev/null +++ b/sdk/python/feast/infra/offline_stores/redshift_source.py @@ -0,0 +1,220 @@ +from typing import Callable, Dict, Iterable, Optional, Tuple + +from feast import type_map +from feast.data_source import DataSource +from feast.errors import DataSourceNotFoundException, RedshiftCredentialsError +from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto +from feast.repo_config import RepoConfig +from feast.value_type import ValueType + + +class RedshiftSource(DataSource): + def __init__( + self, + event_timestamp_column: Optional[str] = "", + table: Optional[str] = None, + created_timestamp_column: Optional[str] = "", + field_mapping: Optional[Dict[str, str]] = None, + date_partition_column: Optional[str] = "", + query: Optional[str] = None, + ): + super().__init__( + event_timestamp_column, + created_timestamp_column, + field_mapping, + date_partition_column, + ) + + self._redshift_options = RedshiftOptions(table=table, query=query) + + @staticmethod + def from_proto(data_source: DataSourceProto): + return RedshiftSource( + field_mapping=dict(data_source.field_mapping), + table=data_source.redshift_options.table, + event_timestamp_column=data_source.event_timestamp_column, + created_timestamp_column=data_source.created_timestamp_column, + date_partition_column=data_source.date_partition_column, + query=data_source.redshift_options.query, + ) + + def __eq__(self, other): + if not isinstance(other, RedshiftSource): + raise TypeError( + "Comparisons should only involve RedshiftSource class objects." + ) + + return ( + self.redshift_options.table == other.redshift_options.table + and self.redshift_options.query == other.redshift_options.query + and self.event_timestamp_column == other.event_timestamp_column + and self.created_timestamp_column == other.created_timestamp_column + and self.field_mapping == other.field_mapping + ) + + @property + def table(self): + return self._redshift_options.table + + @property + def query(self): + return self._redshift_options.query + + @property + def redshift_options(self): + """ + Returns the Redshift options of this data source + """ + return self._redshift_options + + @redshift_options.setter + def redshift_options(self, _redshift_options): + """ + Sets the Redshift options of this data source + """ + self._redshift_options = _redshift_options + + def to_proto(self) -> DataSourceProto: + data_source_proto = DataSourceProto( + type=DataSourceProto.BATCH_REDSHIFT, + field_mapping=self.field_mapping, + redshift_options=self.redshift_options.to_proto(), + ) + + data_source_proto.event_timestamp_column = self.event_timestamp_column + data_source_proto.created_timestamp_column = self.created_timestamp_column + data_source_proto.date_partition_column = self.date_partition_column + + return data_source_proto + + def validate(self, config: RepoConfig): + # As long as the query gets successfully executed, or the table exists, + # the data source is validated. We don't need the results though. + # TODO: uncomment this + # self.get_table_column_names_and_types(config) + print("Validate", self.get_table_column_names_and_types(config)) + + def get_table_query_string(self) -> str: + """Returns a string that can directly be used to reference this table in SQL""" + if self.table: + return f'"{self.table}"' + else: + return f"({self.query})" + + @staticmethod + def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]: + return type_map.redshift_to_feast_value_type + + def get_table_column_names_and_types( + self, config: RepoConfig + ) -> Iterable[Tuple[str, str]]: + from botocore.exceptions import ClientError + + from feast.infra.offline_stores.redshift import RedshiftOfflineStoreConfig + from feast.infra.utils import aws_utils + + assert isinstance(config.offline_store, RedshiftOfflineStoreConfig) + + client = aws_utils.get_redshift_data_client(config.offline_store.region) + + if self.table is not None: + try: + table = client.describe_table( + ClusterIdentifier=config.offline_store.cluster_id, + Database=config.offline_store.database, + DbUser=config.offline_store.user, + Table=self.table, + ) + except ClientError as e: + if e.response["Error"]["Code"] == "ValidationException": + raise RedshiftCredentialsError() from e + raise + + # The API returns valid JSON with empty column list when the table doesn't exist + if len(table["ColumnList"]) == 0: + raise DataSourceNotFoundException(self.table) + + columns = table["ColumnList"] + else: + statement_id = aws_utils.execute_redshift_statement( + client, + config.offline_store.cluster_id, + config.offline_store.database, + config.offline_store.user, + f"SELECT * FROM ({self.query}) LIMIT 1", + ) + columns = aws_utils.get_redshift_statement_result(client, statement_id)[ + "ColumnMetadata" + ] + + return [(column["name"], column["typeName"].upper()) for column in columns] + + +class RedshiftOptions: + """ + DataSource Redshift options used to source features from Redshift query + """ + + def __init__(self, table: Optional[str], query: Optional[str]): + self._table = table + self._query = query + + @property + def query(self): + """ + Returns the Redshift SQL query referenced by this source + """ + return self._query + + @query.setter + def query(self, query): + """ + Sets the Redshift SQL query referenced by this source + """ + self._query = query + + @property + def table(self): + """ + Returns the table name of this Redshift table + """ + return self._table + + @table.setter + def table(self, table_name): + """ + Sets the table ref of this Redshift table + """ + self._table = table_name + + @classmethod + def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions): + """ + Creates a RedshiftOptions from a protobuf representation of a Redshift option + + Args: + redshift_options_proto: A protobuf representation of a DataSource + + Returns: + Returns a RedshiftOptions object based on the redshift_options protobuf + """ + + redshift_options = cls( + table=redshift_options_proto.table, query=redshift_options_proto.query, + ) + + return redshift_options + + def to_proto(self) -> DataSourceProto.RedshiftOptions: + """ + Converts an RedshiftOptionsProto object to its protobuf representation. + + Returns: + RedshiftOptionsProto protobuf + """ + + redshift_options_proto = DataSourceProto.RedshiftOptions( + table=self.table, query=self.query, + ) + + return redshift_options_proto diff --git a/sdk/python/usage_tests/test_usage.py b/sdk/python/usage_tests/test_usage.py index e6b7760fab..3e571a2120 100644 --- a/sdk/python/usage_tests/test_usage.py +++ b/sdk/python/usage_tests/test_usage.py @@ -187,7 +187,7 @@ def test_exception_usage_off(): assert rows.total_rows == 0 -@retry(wait=wait_exponential(multiplier=1, min=1, max=10), stop=stop_after_attempt(5)) +@retry(wait=wait_exponential(multiplier=1, min=1, max=10), stop=stop_after_attempt(7)) def ensure_bigquery_usage_id_with_retry(usage_id): rows = read_bigquery_usage_id(usage_id) if rows.total_rows != 1: