Skip to content

Commit

Permalink
feat: Add AWS Redshift Serverless support (feast-dev#3595)
Browse files Browse the repository at this point in the history
* Rebase master

Signed-off-by: Chris Barcroft <chrisabarcroft@gmail.com>

* Pass optional types to satisfy mypy

Signed-off-by: Chris Barcroft <chrisabarcroft@gmail.com>

* Remove redundant import

Signed-off-by: Chris Barcroft <chrisabarcroft@gmail.com>

* Regenerate python requirements

Signed-off-by: Chris Barcroft <christopher.barcroft@nordstrom.com>

* Fix casing error on DbUser Redshift kwarg

Signed-off-by: Chris Barcroft <christopher.barcroft@nordstrom.com>

---------

Signed-off-by: Chris Barcroft <chrisabarcroft@gmail.com>
Signed-off-by: Chris Barcroft <christopher.barcroft@nordstrom.com>
Co-authored-by: Chris Barcroft <christopher.barcroft@nordstrom.com>
  • Loading branch information
cbarcroft and Chris Barcroft authored Apr 21, 2023
1 parent 7da0580 commit 58ce148
Show file tree
Hide file tree
Showing 11 changed files with 235 additions and 120 deletions.
22 changes: 22 additions & 0 deletions docs/reference/offline-stores/redshift.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,25 @@ While the following trust relationship is necessary to make sure that Redshift,
]
}
```


## Redshift Serverless

In order to use [AWS Redshift Serverless](https://aws.amazon.com/redshift/redshift-serverless/), specify a workgroup instead of a cluster_id and user.

{% code title="feature_store.yaml" %}
```yaml
project: my_feature_repo
registry: data/registry.db
provider: aws
offline_store:
type: redshift
region: us-west-2
workgroup: feast-workgroup
database: feast-database
s3_staging_location: s3://feast-bucket/redshift
iam_role: arn:aws:iam::123456789012:role/redshift_s3_access_role
```
{% endcode %}

Please note that the IAM policies above will need the [redshift-serverless](https://aws.permissions.cloud/iam/redshift-serverless) version, rather than the standard [redshift](https://aws.permissions.cloud/iam/redshift).
49 changes: 42 additions & 7 deletions sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pyarrow
import pyarrow as pa
from dateutil import parser
from pydantic import StrictStr
from pydantic import StrictStr, root_validator
from pydantic.typing import Literal
from pytz import utc

Expand Down Expand Up @@ -51,15 +51,18 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
type: Literal["redshift"] = "redshift"
""" Offline store type selector"""

cluster_id: StrictStr
""" Redshift cluster identifier """
cluster_id: Optional[StrictStr]
""" Redshift cluster identifier, for provisioned clusters """

user: Optional[StrictStr]
""" Redshift user name, only required for provisioned clusters """

workgroup: Optional[StrictStr]
""" Redshift workgroup identifier, for serverless """

region: StrictStr
""" Redshift cluster's AWS region """

user: StrictStr
""" Redshift user name """

database: StrictStr
""" Redshift database name """

Expand All @@ -69,6 +72,26 @@ class RedshiftOfflineStoreConfig(FeastConfigBaseModel):
iam_role: StrictStr
""" IAM Role for Redshift, granting it access to S3 """

@root_validator
def require_cluster_and_user_or_workgroup(cls, values):
"""
Provisioned Redshift clusters: Require cluster_id and user, ignore workgroup
Serverless Redshift: Require workgroup, ignore cluster_id and user
"""
cluster_id, user, workgroup = (
values.get("cluster_id"),
values.get("user"),
values.get("workgroup"),
)
if not (cluster_id and user) and not workgroup:
raise ValueError(
"please specify either cluster_id & user if using provisioned clusters, or workgroup if using serverless"
)
elif cluster_id and workgroup:
raise ValueError("cannot specify both cluster_id and workgroup")

return values


class RedshiftOfflineStore(OfflineStore):
@staticmethod
Expand Down Expand Up @@ -248,6 +271,7 @@ def query_generator() -> Iterator[str]:
aws_utils.execute_redshift_statement(
redshift_client,
config.offline_store.cluster_id,
config.offline_store.workgroup,
config.offline_store.database,
config.offline_store.user,
f"DROP TABLE IF EXISTS {table_name}",
Expand Down Expand Up @@ -294,6 +318,7 @@ def write_logged_features(
table=data,
redshift_data_client=redshift_client,
cluster_id=config.offline_store.cluster_id,
workgroup=config.offline_store.workgroup,
database=config.offline_store.database,
user=config.offline_store.user,
s3_resource=s3_resource,
Expand Down Expand Up @@ -336,8 +361,10 @@ def offline_write_batch(
table=table,
redshift_data_client=redshift_client,
cluster_id=config.offline_store.cluster_id,
workgroup=config.offline_store.workgroup,
database=redshift_options.database
or config.offline_store.database, # Users can define database in the source if needed but it's not required.
# Users can define database in the source if needed but it's not required.
or config.offline_store.database,
user=config.offline_store.user,
s3_resource=s3_resource,
s3_path=f"{config.offline_store.s3_staging_location}/push/{uuid.uuid4()}.parquet",
Expand Down Expand Up @@ -405,6 +432,7 @@ def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
return aws_utils.unload_redshift_query_to_df(
self._redshift_client,
self._config.offline_store.cluster_id,
self._config.offline_store.workgroup,
self._config.offline_store.database,
self._config.offline_store.user,
self._s3_resource,
Expand All @@ -419,6 +447,7 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pa.Table:
return aws_utils.unload_redshift_query_to_pa(
self._redshift_client,
self._config.offline_store.cluster_id,
self._config.offline_store.workgroup,
self._config.offline_store.database,
self._config.offline_store.user,
self._s3_resource,
Expand All @@ -439,6 +468,7 @@ def to_s3(self) -> str:
aws_utils.execute_redshift_query_and_unload_to_s3(
self._redshift_client,
self._config.offline_store.cluster_id,
self._config.offline_store.workgroup,
self._config.offline_store.database,
self._config.offline_store.user,
self._s3_path,
Expand All @@ -455,6 +485,7 @@ def to_redshift(self, table_name: str) -> None:
aws_utils.upload_df_to_redshift(
self._redshift_client,
self._config.offline_store.cluster_id,
self._config.offline_store.workgroup,
self._config.offline_store.database,
self._config.offline_store.user,
self._s3_resource,
Expand All @@ -471,6 +502,7 @@ def to_redshift(self, table_name: str) -> None:
aws_utils.execute_redshift_statement(
self._redshift_client,
self._config.offline_store.cluster_id,
self._config.offline_store.workgroup,
self._config.offline_store.database,
self._config.offline_store.user,
query,
Expand Down Expand Up @@ -509,6 +541,7 @@ def _upload_entity_df(
aws_utils.upload_df_to_redshift(
redshift_client,
config.offline_store.cluster_id,
config.offline_store.workgroup,
config.offline_store.database,
config.offline_store.user,
s3_resource,
Expand All @@ -522,6 +555,7 @@ def _upload_entity_df(
aws_utils.execute_redshift_statement(
redshift_client,
config.offline_store.cluster_id,
config.offline_store.workgroup,
config.offline_store.database,
config.offline_store.user,
f"CREATE TABLE {table_name} AS ({entity_df})",
Expand Down Expand Up @@ -577,6 +611,7 @@ def _get_entity_df_event_timestamp_range(
statement_id = aws_utils.execute_redshift_statement(
redshift_client,
config.offline_store.cluster_id,
config.offline_store.workgroup,
config.offline_store.database,
config.offline_store.user,
f"SELECT MIN({entity_df_event_timestamp_col}) AS min, MAX({entity_df_event_timestamp_col}) AS max "
Expand Down
27 changes: 20 additions & 7 deletions sdk/python/feast/infra/offline_stores/redshift_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,18 +207,30 @@ def get_table_column_names_and_types(
if self.table:
try:
paginator = client.get_paginator("describe_table")
response_iterator = paginator.paginate(
ClusterIdentifier=config.offline_store.cluster_id,
Database=(

paginator_kwargs = {
"Database": (
self.database
if self.database
else config.offline_store.database
),
DbUser=config.offline_store.user,
Table=self.table,
Schema=self.schema,
)
"Table": self.table,
"Schema": self.schema,
}

if config.offline_store.cluster_id:
# Provisioned cluster
paginator_kwargs[
"ClusterIdentifier"
] = config.offline_store.cluster_id
paginator_kwargs["DbUser"] = config.offline_store.user
elif config.offline_store.workgroup:
# Redshift serverless
paginator_kwargs["WorkgroupName"] = config.offline_store.workgroup

response_iterator = paginator.paginate(**paginator_kwargs)
table = response_iterator.build_full_result()

except ClientError as e:
if e.response["Error"]["Code"] == "ValidationException":
raise RedshiftCredentialsError() from e
Expand All @@ -233,6 +245,7 @@ def get_table_column_names_and_types(
statement_id = aws_utils.execute_redshift_statement(
client,
config.offline_store.cluster_id,
config.offline_store.workgroup,
self.database if self.database else config.offline_store.database,
config.offline_store.user,
f"SELECT * FROM ({self.query}) LIMIT 1",
Expand Down
Loading

0 comments on commit 58ce148

Please sign in to comment.