Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds support for querying Redshift Serverless clusters #32785

Merged
merged 1 commit into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions airflow/providers/amazon/aws/hooks/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def execute_query(
with_event: bool = False,
wait_for_completion: bool = True,
poll_interval: int = 10,
workgroup_name: str | None = None,
) -> str:
"""
Execute a statement against Amazon Redshift.
Expand All @@ -74,6 +75,9 @@ def execute_query(
:param with_event: indicates whether to send an event to EventBridge
:param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait
:param poll_interval: how often in seconds to check the query status
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html

:returns statement_id: str, the UUID of the statement
"""
Expand All @@ -85,6 +89,7 @@ def execute_query(
"WithEvent": with_event,
"SecretArn": secret_arn,
"StatementName": statement_name,
"WorkgroupName": workgroup_name,
}
if isinstance(sql, list):
kwargs["Sqls"] = sql
Expand All @@ -95,6 +100,9 @@ def execute_query(

statement_id = resp["Id"]

if bool(cluster_identifier) is bool(workgroup_name):
raise ValueError("Either 'cluster_identifier' or 'workgroup_name' must be specified.")

if wait_for_completion:
self.wait_for_results(statement_id, poll_interval=poll_interval)

Expand Down Expand Up @@ -127,6 +135,7 @@ def get_table_primary_key(
database: str,
schema: str | None = "public",
cluster_identifier: str | None = None,
workgroup_name: str | None = None,
db_user: str | None = None,
secret_arn: str | None = None,
statement_name: str | None = None,
Expand Down Expand Up @@ -168,6 +177,7 @@ def get_table_primary_key(
sql=sql,
database=database,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
Expand Down
7 changes: 7 additions & 0 deletions airflow/providers/amazon/aws/operators/redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class RedshiftDataOperator(BaseOperator):
if False (default) will return statement ID
:param aws_conn_id: aws connection to use
:param region: aws region to use
:param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with
`cluster_identifier`. Specify this parameter to query Redshift Serverless. More info
https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html
"""

template_fields = (
Expand All @@ -62,6 +65,7 @@ class RedshiftDataOperator(BaseOperator):
"statement_name",
"aws_conn_id",
"region",
"workgroup_name",
)
template_ext = (".sql",)
template_fields_renderers = {"sql": "sql"}
Expand All @@ -82,12 +86,14 @@ def __init__(
return_sql_result: bool = False,
aws_conn_id: str = "aws_default",
region: str | None = None,
workgroup_name: str | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.database = database
self.sql = sql
self.cluster_identifier = cluster_identifier
self.workgroup_name = workgroup_name
self.db_user = db_user
self.parameters = parameters
self.secret_arn = secret_arn
Expand Down Expand Up @@ -119,6 +125,7 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str:
database=self.database,
sql=self.sql,
cluster_identifier=self.cluster_identifier,
workgroup_name=self.workgroup_name,
db_user=self.db_user,
parameters=self.parameters,
secret_arn=self.secret_arn,
Expand Down
67 changes: 66 additions & 1 deletion tests/providers/amazon/aws/hooks/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import logging
from unittest import mock

import pytest

from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook

SQL = "sql"
Expand All @@ -39,22 +41,50 @@ def test_conn_attribute(self):
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_without_waiting(self, mock_conn):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
cluster_identifier = "cluster_identifier"

hook = RedshiftDataHook()
hook.execute_query(
database=DATABASE,
cluster_identifier=cluster_identifier,
sql=SQL,
wait_for_completion=False,
)
mock_conn.execute_statement.assert_called_once_with(
Database=DATABASE,
ClusterIdentifier=cluster_identifier,
Sql=SQL,
WithEvent=False,
)
mock_conn.describe_statement.assert_not_called()

@pytest.mark.parametrize(
"cluster_identifier, workgroup_name",
[
(None, None),
("some_cluster", "some_workgroup"),
],
)
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_with_all_parameters(self, mock_conn):
def test_execute_requires_either_cluster_identifier_or_workgroup_name(
self, mock_conn, cluster_identifier, workgroup_name
):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
cluster_identifier = "cluster_identifier"
workgroup_name = "workgroup_name"

with pytest.raises(ValueError):
hook = RedshiftDataHook()
hook.execute_query(
database=DATABASE,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
sql=SQL,
wait_for_completion=False,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_with_all_parameters_cluster_identifier(self, mock_conn):
cluster_identifier = "cluster_identifier"
db_user = "db_user"
secret_arn = "secret_arn"
Expand Down Expand Up @@ -88,6 +118,41 @@ def test_execute_with_all_parameters(self, mock_conn):
Id=STATEMENT_ID,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_execute_with_all_parameters_workgroup_name(self, mock_conn):
workgroup_name = "workgroup_name"
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
mock_conn.describe_statement.return_value = {"Status": "FINISHED"}

hook = RedshiftDataHook()
hook.execute_query(
sql=SQL,
database=DATABASE,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
)

mock_conn.execute_statement.assert_called_once_with(
Database=DATABASE,
Sql=SQL,
WorkgroupName=workgroup_name,
DbUser=db_user,
SecretArn=secret_arn,
StatementName=statement_name,
Parameters=parameters,
WithEvent=False,
)
mock_conn.describe_statement.assert_called_once_with(
Id=STATEMENT_ID,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
def test_batch_execute(self, mock_conn):
mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID}
Expand Down
42 changes: 42 additions & 0 deletions tests/providers/amazon/aws/operators/test_redshift_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class TestRedshiftDataOperator:
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute(self, mock_exec_query):
cluster_identifier = "cluster_identifier"
workgroup_name = None
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
Expand All @@ -57,6 +58,46 @@ def test_execute(self, mock_exec_query):
sql=SQL,
database=DATABASE,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
with_event=False,
wait_for_completion=wait_for_completion,
poll_interval=poll_interval,
)

@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
def test_execute_with_workgroup_name(self, mock_exec_query):
cluster_identifier = None
workgroup_name = "workgroup_name"
db_user = "db_user"
secret_arn = "secret_arn"
statement_name = "statement_name"
parameters = [{"name": "id", "value": "1"}]
poll_interval = 5
wait_for_completion = True

operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
task_id=TASK_ID,
sql=SQL,
database=DATABASE,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
parameters=parameters,
wait_for_completion=True,
poll_interval=poll_interval,
)
operator.execute(None)
mock_exec_query.assert_called_once_with(
sql=SQL,
database=DATABASE,
cluster_identifier=cluster_identifier,
workgroup_name=workgroup_name,
db_user=db_user,
secret_arn=secret_arn,
statement_name=statement_name,
Expand Down Expand Up @@ -85,6 +126,7 @@ def test_on_kill_with_query(self, mock_conn):
operator = RedshiftDataOperator(
aws_conn_id=CONN_ID,
task_id=TASK_ID,
cluster_identifier="cluster_identifier",
sql=SQL,
database=DATABASE,
wait_for_completion=False,
Expand Down