Skip to content

Commit

Permalink
Add support for querying Redshift Serverless clusters (#32785)
Browse files Browse the repository at this point in the history
  • Loading branch information
ivica-k committed Jul 24, 2023
1 parent d05e42e commit 8012c9f
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 1 deletion.
10 changes: 10 additions & 0 deletions airflow/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def execute_query(
with_event: bool = False,
wait_for_completion: bool = True,
poll_interval: int = 10,
workgroup_name: str | None = None,
) -> str:
"""
Execute a statement against Amazon Redshift.
Expand All @@ -74,6 +75,9 @@ def execute_query(
:param with_event: indicates whether to send an event to EventBridge
:param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait
:param poll_interval: how often in seconds to check the query status
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
:returns statement_id: str, the UUID of the statement
"""
Expand All @@ -85,6 +89,7 @@ def execute_query(
"WithEvent": with_event,
"SecretArn": secret_arn,
"StatementName": statement_name,
"WorkgroupName": workgroup_name,
}
if isinstance(sql, list):
kwargs["Sqls"] = sql
Expand All @@ -95,6 +100,9 @@ def execute_query(

statement_id = resp["Id"]

if bool(cluster_identifier) is bool(workgroup_name):
raise ValueError("Either 'cluster_identifier' or 'workgroup_name' must be specified.")

if wait_for_completion:
self.wait_for_results(statement_id, poll_interval=poll_interval)

Expand Down Expand Up @@ -127,6 +135,7 @@ def get_table_primary_key(
database: str,
schema: str | None = "public",
cluster_identifier: str | None = None,
workgroup_name: str | None = None,
db_user: str | None = None,
secret_arn: str | None = None,
statement_name: str | None = None,
Expand Down Expand Up @@ -168,6 +177,7 @@ def get_table_primary_key(
sql=sql,
database=database,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
Expand Down
7 changes: 7 additions & 0 deletions airflow/providers/amazon/aws/operators/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class RedshiftDataOperator(BaseOperator):
if False (default) will return statement ID
:param aws_conn_id: aws connection to use
:param region: aws region to use
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
"""

template_fields = (
Expand All @@ -62,6 +65,7 @@ class RedshiftDataOperator(BaseOperator):
"statement_name",
"aws_conn_id",
"region",
"workgroup_name",
)
template_ext = (".sql",)
template_fields_renderers = {"sql": "sql"}
Expand All @@ -82,12 +86,14 @@ def __init__(
return_sql_result: bool = False,
aws_conn_id: str = "aws_default",
region: str | None = None,
workgroup_name: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.database = database
self.sql = sql
self.cluster_identifier = cluster_identifier
self.workgroup_name = workgroup_name
self.db_user = db_user
self.parameters = parameters
self.secret_arn = secret_arn
Expand Down Expand Up @@ -119,6 +125,7 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
database=self.database,
sql=self.sql,
cluster_identifier=self.cluster_identifier,
workgroup_name=self.workgroup_name,
db_user=self.db_user,
parameters=self.parameters,
secret_arn=self.secret_arn,
Expand Down
67 changes: 66 additions & 1 deletion tests/providers/amazon/aws/hooks/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import logging
from unittest import mock

import pytest

from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook

SQL = "sql"
Expand All @@ -39,22 +41,50 @@ def test_conn_attribute(self):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_without_waiting(self, mock_conn):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
cluster_identifier = "cluster_identifier"

hook = RedshiftDataHook()
hook.execute_query(
database=DATABASE,
cluster_identifier=cluster_identifier,
sql=SQL,
wait_for_completion=False,
)
mock_conn.execute_statement.assert_called_once_with(
Database=DATABASE,
ClusterIdentifier=cluster_identifier,
Sql=SQL,
WithEvent=False,
)
mock_conn.describe_statement.assert_not_called()

@pytest.mark.parametrize(
"cluster_identifier, workgroup_name",
[
(None, None),
("some_cluster", "some_workgroup"),
],
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_with_all_parameters(self, mock_conn):
def test_execute_requires_either_cluster_identifier_or_workgroup_name(
self, mock_conn, cluster_identifier, workgroup_name
):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
cluster_identifier = "cluster_identifier"
workgroup_name = "workgroup_name"

with pytest.raises(ValueError):
hook = RedshiftDataHook()
hook.execute_query(
database=DATABASE,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
sql=SQL,
wait_for_completion=False,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_with_all_parameters_cluster_identifier(self, mock_conn):
cluster_identifier = "cluster_identifier"
db_user = "db_user"
secret_arn = "secret_arn"
Expand Down Expand Up @@ -88,6 +118,41 @@ def test_execute_with_all_parameters(self, mock_conn):
Id=STATEMENT_ID,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_with_all_parameters_workgroup_name(self, mock_conn):
workgroup_name = "workgroup_name"
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}

hook = RedshiftDataHook()
hook.execute_query(
sql=SQL,
database=DATABASE,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
)

mock_conn.execute_statement.assert_called_once_with(
Database=DATABASE,
Sql=SQL,
WorkgroupName=workgroup_name,
DbUser=db_user,
SecretArn=secret_arn,
StatementName=statement_name,
Parameters=parameters,
WithEvent=False,
)
mock_conn.describe_statement.assert_called_once_with(
Id=STATEMENT_ID,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_batch_execute(self, mock_conn):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
Expand Down
42 changes: 42 additions & 0 deletions tests/providers/amazon/aws/operators/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TestRedshiftDataOperator:
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute(self, mock_exec_query):
cluster_identifier = "cluster_identifier"
workgroup_name = None
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
Expand All @@ -57,6 +58,46 @@ def test_execute(self, mock_exec_query):
sql=SQL,
database=DATABASE,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
with_event=False,
wait_for_completion=wait_for_completion,
poll_interval=poll_interval,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute_with_workgroup_name(self, mock_exec_query):
cluster_identifier = None
workgroup_name = "workgroup_name"
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
poll_interval = 5
wait_for_completion = True

operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
task_id=TASK_ID,
sql=SQL,
database=DATABASE,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
wait_for_completion=True,
poll_interval=poll_interval,
)
operator.execute(None)
mock_exec_query.assert_called_once_with(
sql=SQL,
database=DATABASE,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
Expand Down Expand Up @@ -85,6 +126,7 @@ def test_on_kill_with_query(self, mock_conn):
operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
task_id=TASK_ID,
cluster_identifier="cluster_identifier",
sql=SQL,
database=DATABASE,
wait_for_completion=False,
Expand Down

0 comments on commit 8012c9f

Please sign in to comment.