Skip to content

Commit

Permalink
Refactor data source classes to fix import issues (#1723)
Browse files Browse the repository at this point in the history
* Refactor data source classes to fix import issues

Signed-off-by: Achal Shah <achals@gmail.com>

* make format and lint

Signed-off-by: Achal Shah <achals@gmail.com>

* remove unneded __init__ files

Signed-off-by: Achal Shah <achals@gmail.com>
  • Loading branch information
achals authored Jul 20, 2021
1 parent 8cfe914 commit 8b3a97a
Show file tree
Hide file tree
Showing 9 changed files with 664 additions and 630 deletions.
10 changes: 5 additions & 5 deletions sdk/python/feast/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from pkg_resources import DistributionNotFound, get_distribution

from feast.infra.offline_stores.bigquery import BigQuerySource
from feast.infra.offline_stores.file import FileSource
from feast.infra.offline_stores.redshift import RedshiftSource
from feast.infra.offline_stores.bigquery_source import BigQuerySource
from feast.infra.offline_stores.file_source import FileSource
from feast.infra.offline_stores.redshift_source import RedshiftSource

from .client import Client
from .data_source import KafkaSource, KinesisSource, SourceType
Expand All @@ -29,18 +29,18 @@
pass

__all__ = [
"BigQuerySource",
"Client",
"Entity",
"KafkaSource",
"KinesisSource",
"RedshiftSource",
"Feature",
"FeatureStore",
"FeatureTable",
"FeatureView",
"RepoConfig",
"SourceType",
"ValueType",
"BigQuerySource",
"FileSource",
"RedshiftSource",
]
22 changes: 19 additions & 3 deletions sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,17 +317,17 @@ def from_proto(data_source: DataSourceProto):
return cls.from_proto(data_source)

if data_source.file_options.file_format and data_source.file_options.file_url:
from feast.infra.offline_stores.file import FileSource
from feast.infra.offline_stores.file_source import FileSource

data_source_obj = FileSource.from_proto(data_source)
elif (
data_source.bigquery_options.table_ref or data_source.bigquery_options.query
):
from feast.infra.offline_stores.bigquery import BigQuerySource
from feast.infra.offline_stores.bigquery_source import BigQuerySource

data_source_obj = BigQuerySource.from_proto(data_source)
elif data_source.redshift_options.table or data_source.redshift_options.query:
from feast.infra.offline_stores.redshift import RedshiftSource
from feast.infra.offline_stores.redshift_source import RedshiftSource

data_source_obj = RedshiftSource.from_proto(data_source)
elif (
Expand Down Expand Up @@ -378,6 +378,14 @@ def get_table_column_names_and_types(


class KafkaSource(DataSource):
def validate(self, config: RepoConfig):
pass

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
pass

def __init__(
self,
event_timestamp_column: str,
Expand Down Expand Up @@ -463,6 +471,14 @@ def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:


class KinesisSource(DataSource):
def validate(self, config: RepoConfig):
pass

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
pass

@staticmethod
def from_proto(data_source: DataSourceProto):
return KinesisSource(
Expand Down
207 changes: 4 additions & 203 deletions sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import uuid
from dataclasses import asdict, dataclass
from datetime import date, datetime, timedelta
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import List, Optional, Set, Union

import pandas
import pyarrow
Expand All @@ -12,12 +12,11 @@
from pydantic.typing import Literal
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed

from feast import errors, type_map
from feast import errors
from feast.data_source import DataSource
from feast.errors import (
BigQueryJobCancelled,
BigQueryJobStillRunning,
DataSourceNotFoundException,
FeastProviderLoginError,
)
from feast.feature_view import FeatureView
Expand All @@ -26,10 +25,10 @@
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
_get_requested_feature_views_to_features_dict,
)
from feast.protos.feast.core.DataSource_pb2 import DataSource as DataSourceProto
from feast.registry import Registry
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.value_type import ValueType

from .bigquery_source import BigQuerySource

try:
from google.api_core.exceptions import NotFound
Expand Down Expand Up @@ -707,201 +706,3 @@ def _get_bigquery_client(project: Optional[str] = None):
) USING ({{featureview.name}}__entity_row_unique_id)
{% endfor %}
"""


class BigQuerySource(DataSource):
def __init__(
self,
event_timestamp_column: Optional[str] = "",
table_ref: Optional[str] = None,
created_timestamp_column: Optional[str] = "",
field_mapping: Optional[Dict[str, str]] = None,
date_partition_column: Optional[str] = "",
query: Optional[str] = None,
):
self._bigquery_options = BigQueryOptions(table_ref=table_ref, query=query)

super().__init__(
event_timestamp_column,
created_timestamp_column,
field_mapping,
date_partition_column,
)

def __eq__(self, other):
if not isinstance(other, BigQuerySource):
raise TypeError(
"Comparisons should only involve BigQuerySource class objects."
)

return (
self.bigquery_options.table_ref == other.bigquery_options.table_ref
and self.bigquery_options.query == other.bigquery_options.query
and self.event_timestamp_column == other.event_timestamp_column
and self.created_timestamp_column == other.created_timestamp_column
and self.field_mapping == other.field_mapping
)

@property
def table_ref(self):
return self._bigquery_options.table_ref

@property
def query(self):
return self._bigquery_options.query

@property
def bigquery_options(self):
"""
Returns the bigquery options of this data source
"""
return self._bigquery_options

@bigquery_options.setter
def bigquery_options(self, bigquery_options):
"""
Sets the bigquery options of this data source
"""
self._bigquery_options = bigquery_options

@staticmethod
def from_proto(data_source: DataSourceProto):

assert data_source.HasField("bigquery_options")

return BigQuerySource(
field_mapping=dict(data_source.field_mapping),
table_ref=data_source.bigquery_options.table_ref,
event_timestamp_column=data_source.event_timestamp_column,
created_timestamp_column=data_source.created_timestamp_column,
date_partition_column=data_source.date_partition_column,
query=data_source.bigquery_options.query,
)

def to_proto(self) -> DataSourceProto:
data_source_proto = DataSourceProto(
type=DataSourceProto.BATCH_BIGQUERY,
field_mapping=self.field_mapping,
bigquery_options=self.bigquery_options.to_proto(),
)

data_source_proto.event_timestamp_column = self.event_timestamp_column
data_source_proto.created_timestamp_column = self.created_timestamp_column
data_source_proto.date_partition_column = self.date_partition_column

return data_source_proto

def validate(self, config: RepoConfig):
if not self.query:
from google.api_core.exceptions import NotFound
from google.cloud import bigquery

client = bigquery.Client()
try:
client.get_table(self.table_ref)
except NotFound:
raise DataSourceNotFoundException(self.table_ref)

def get_table_query_string(self) -> str:
"""Returns a string that can directly be used to reference this table in SQL"""
if self.table_ref:
return f"`{self.table_ref}`"
else:
return f"({self.query})"

@staticmethod
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
return type_map.bq_to_feast_value_type

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
from google.cloud import bigquery

client = bigquery.Client()
if self.table_ref is not None:
table_schema = client.get_table(self.table_ref).schema
if not isinstance(table_schema[0], bigquery.schema.SchemaField):
raise TypeError("Could not parse BigQuery table schema.")

name_type_pairs = [(field.name, field.field_type) for field in table_schema]
else:
bq_columns_query = f"SELECT * FROM ({self.query}) LIMIT 1"
queryRes = client.query(bq_columns_query).result()
name_type_pairs = [
(schema_field.name, schema_field.field_type)
for schema_field in queryRes.schema
]

return name_type_pairs


class BigQueryOptions:
"""
DataSource BigQuery options used to source features from BigQuery query
"""

def __init__(self, table_ref: Optional[str], query: Optional[str]):
self._table_ref = table_ref
self._query = query

@property
def query(self):
"""
Returns the BigQuery SQL query referenced by this source
"""
return self._query

@query.setter
def query(self, query):
"""
Sets the BigQuery SQL query referenced by this source
"""
self._query = query

@property
def table_ref(self):
"""
Returns the table ref of this BQ table
"""
return self._table_ref

@table_ref.setter
def table_ref(self, table_ref):
"""
Sets the table ref of this BQ table
"""
self._table_ref = table_ref

@classmethod
def from_proto(cls, bigquery_options_proto: DataSourceProto.BigQueryOptions):
"""
Creates a BigQueryOptions from a protobuf representation of a BigQuery option
Args:
bigquery_options_proto: A protobuf representation of a DataSource
Returns:
Returns a BigQueryOptions object based on the bigquery_options protobuf
"""

bigquery_options = cls(
table_ref=bigquery_options_proto.table_ref,
query=bigquery_options_proto.query,
)

return bigquery_options

def to_proto(self) -> DataSourceProto.BigQueryOptions:
"""
Converts an BigQueryOptionsProto object to its protobuf representation.
Returns:
BigQueryOptionsProto protobuf
"""

bigquery_options_proto = DataSourceProto.BigQueryOptions(
table_ref=self.table_ref, query=self.query,
)

return bigquery_options_proto
Loading

0 comments on commit 8b3a97a

Please sign in to comment.