Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update SparkSource to have proper comparable that inspects SparkOptions #3819

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,18 @@ def get_table_query_string(self) -> str:

return f"`{tmp_table_name}`"

# Note: Python requires redefining hash in child classes that override __eq__
def __hash__(self):

return super().__hash__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also include self.spark_options?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't necessarily need to as the two pieces of logic are used for different things. Take a look at the redshift_source.py example for how it is implemented elsewhere.


def __eq__(self, other):
if not isinstance(other, SparkSource):
raise TypeError(
"Comparisons should only involve SparkSource class objects."
)
return super().__eq__(other) and self.spark_options == other.spark_options


class SparkOptions:
allowed_formats = [format.value for format in SparkSourceFormat]
Expand Down Expand Up @@ -282,6 +294,19 @@ def to_proto(self) -> DataSourceProto.SparkOptions:

return spark_options_proto

def __eq__(self, other: object) -> bool:
if not isinstance(other, SparkOptions):
raise TypeError(
"Comparisons should only involve SparkOptions class objects."
)

return (
self.table == other.table
and self.query == other.query
and self.path == other.path
and self.file_format == other.file_format
)


class SavedDatasetSparkStorage(SavedDatasetStorage):
_proto_attr_name = "spark_storage"
Expand Down
38 changes: 38 additions & 0 deletions sdk/python/tests/unit/test_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
)
from feast.field import Field
from feast.infra.offline_stores.bigquery_source import BigQuerySource
from feast.infra.offline_stores.contrib.spark_offline_store.spark_source import (
SparkSource,
)
from feast.infra.offline_stores.file_source import FileSource
from feast.infra.offline_stores.redshift_source import RedshiftSource
from feast.infra.offline_stores.snowflake_source import SnowflakeSource
Expand Down Expand Up @@ -233,3 +236,38 @@ def test_redshift_fully_qualified_table_name(source_kwargs, expected_name):
)

assert redshift_source.redshift_options.fully_qualified_table_name == expected_name


@pytest.mark.parameterize(
"test_data,are_equal",
[
(
SparkSource(
name="name", table="table", query="query", file_format="file_format"
),
True,
),
(SparkSource(table="table", query="query", file_format="file_format"), False),
(
SparkSource(
name="name", table="table", query="query", file_format="file_format1"
),
False,
),
(
SparkSource(
name="name", table="table", query="query1", file_format="file_format"
),
True,
),
],
)
def test_spark_source_equality(test_data, are_equal):
default = SparkSource(
name="name", table="table1", query="query", file_format="file_format"
)
if are_equal:
assert default == test_data
else:
assert default != test_data

Loading