diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index fddd42bd61dfa..a522d3e8c9688 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -60,6 +60,7 @@ def execute_query( with_event: bool = False, wait_for_completion: bool = True, poll_interval: int = 10, + workgroup_name: str | None = None, ) -> str: """ Execute a statement against Amazon Redshift. @@ -74,6 +75,9 @@ def execute_query( :param with_event: indicates whether to send an event to EventBridge :param wait_for_completion: indicates whether to wait for a result, if True wait, if False don't wait :param poll_interval: how often in seconds to check the query status + :param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with + `cluster_identifier`. Specify this parameter to query Redshift Serverless. More info + https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html :returns statement_id: str, the UUID of the statement """ @@ -85,6 +89,7 @@ def execute_query( "WithEvent": with_event, "SecretArn": secret_arn, "StatementName": statement_name, + "WorkgroupName": workgroup_name, } if isinstance(sql, list): kwargs["Sqls"] = sql @@ -95,6 +100,9 @@ def execute_query( statement_id = resp["Id"] + if bool(cluster_identifier) is bool(workgroup_name): + raise ValueError("Either 'cluster_identifier' or 'workgroup_name' must be specified.") + if wait_for_completion: self.wait_for_results(statement_id, poll_interval=poll_interval) @@ -127,6 +135,7 @@ def get_table_primary_key( database: str, schema: str | None = "public", cluster_identifier: str | None = None, + workgroup_name: str | None = None, db_user: str | None = None, secret_arn: str | None = None, statement_name: str | None = None, @@ -168,6 +177,7 @@ def get_table_primary_key( sql=sql, database=database, cluster_identifier=cluster_identifier, + workgroup_name=workgroup_name, db_user=db_user, secret_arn=secret_arn, statement_name=statement_name, diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index bf560c9973cec..126c585e3fcfb 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -51,6 +51,9 @@ class RedshiftDataOperator(BaseOperator): if False (default) will return statement ID :param aws_conn_id: aws connection to use :param region: aws region to use + :param workgroup_name: name of the Redshift Serverless workgroup. Mutually exclusive with + `cluster_identifier`. Specify this parameter to query Redshift Serverless. More info + https://docs.aws.amazon.com/redshift/latest/mgmt/working-with-serverless.html """ template_fields = ( @@ -62,6 +65,7 @@ class RedshiftDataOperator(BaseOperator): "statement_name", "aws_conn_id", "region", + "workgroup_name", ) template_ext = (".sql",) template_fields_renderers = {"sql": "sql"} @@ -82,12 +86,14 @@ def __init__( return_sql_result: bool = False, aws_conn_id: str = "aws_default", region: str | None = None, + workgroup_name: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) self.database = database self.sql = sql self.cluster_identifier = cluster_identifier + self.workgroup_name = workgroup_name self.db_user = db_user self.parameters = parameters self.secret_arn = secret_arn @@ -119,6 +125,7 @@ def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: database=self.database, sql=self.sql, cluster_identifier=self.cluster_identifier, + workgroup_name=self.workgroup_name, db_user=self.db_user, parameters=self.parameters, secret_arn=self.secret_arn, diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py b/tests/providers/amazon/aws/hooks/test_redshift_data.py index 92920f7042697..cc174a872cbdd 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift_data.py +++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py @@ -20,6 +20,8 @@ import logging from unittest import mock +import pytest + from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook SQL = "sql" @@ -39,22 +41,50 @@ def test_conn_attribute(self): @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") def test_execute_without_waiting(self, mock_conn): mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + cluster_identifier = "cluster_identifier" hook = RedshiftDataHook() hook.execute_query( database=DATABASE, + cluster_identifier=cluster_identifier, sql=SQL, wait_for_completion=False, ) mock_conn.execute_statement.assert_called_once_with( Database=DATABASE, + ClusterIdentifier=cluster_identifier, Sql=SQL, WithEvent=False, ) mock_conn.describe_statement.assert_not_called() + @pytest.mark.parametrize( + "cluster_identifier, workgroup_name", + [ + (None, None), + ("some_cluster", "some_workgroup"), + ], + ) @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") - def test_execute_with_all_parameters(self, mock_conn): + def test_execute_requires_either_cluster_identifier_or_workgroup_name( + self, mock_conn, cluster_identifier, workgroup_name + ): + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + cluster_identifier = "cluster_identifier" + workgroup_name = "workgroup_name" + + with pytest.raises(ValueError): + hook = RedshiftDataHook() + hook.execute_query( + database=DATABASE, + cluster_identifier=cluster_identifier, + workgroup_name=workgroup_name, + sql=SQL, + wait_for_completion=False, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_with_all_parameters_cluster_identifier(self, mock_conn): cluster_identifier = "cluster_identifier" db_user = "db_user" secret_arn = "secret_arn" @@ -88,6 +118,41 @@ def test_execute_with_all_parameters(self, mock_conn): Id=STATEMENT_ID, ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") + def test_execute_with_all_parameters_workgroup_name(self, mock_conn): + workgroup_name = "workgroup_name" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} + mock_conn.describe_statement.return_value = {"Status": "FINISHED"} + + hook = RedshiftDataHook() + hook.execute_query( + sql=SQL, + database=DATABASE, + workgroup_name=workgroup_name, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + ) + + mock_conn.execute_statement.assert_called_once_with( + Database=DATABASE, + Sql=SQL, + WorkgroupName=workgroup_name, + DbUser=db_user, + SecretArn=secret_arn, + StatementName=statement_name, + Parameters=parameters, + WithEvent=False, + ) + mock_conn.describe_statement.assert_called_once_with( + Id=STATEMENT_ID, + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn") def test_batch_execute(self, mock_conn): mock_conn.execute_statement.return_value = {"Id": STATEMENT_ID} diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py index be77f96c307d4..e5a851fe737e5 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_data.py +++ b/tests/providers/amazon/aws/operators/test_redshift_data.py @@ -32,6 +32,7 @@ class TestRedshiftDataOperator: @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") def test_execute(self, mock_exec_query): cluster_identifier = "cluster_identifier" + workgroup_name = None db_user = "db_user" secret_arn = "secret_arn" statement_name = "statement_name" @@ -57,6 +58,46 @@ def test_execute(self, mock_exec_query): sql=SQL, database=DATABASE, cluster_identifier=cluster_identifier, + workgroup_name=workgroup_name, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + with_event=False, + wait_for_completion=wait_for_completion, + poll_interval=poll_interval, + ) + + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") + def test_execute_with_workgroup_name(self, mock_exec_query): + cluster_identifier = None + workgroup_name = "workgroup_name" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + wait_for_completion = True + + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + workgroup_name=workgroup_name, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=True, + poll_interval=poll_interval, + ) + operator.execute(None) + mock_exec_query.assert_called_once_with( + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + workgroup_name=workgroup_name, db_user=db_user, secret_arn=secret_arn, statement_name=statement_name, @@ -85,6 +126,7 @@ def test_on_kill_with_query(self, mock_conn): operator = RedshiftDataOperator( aws_conn_id=CONN_ID, task_id=TASK_ID, + cluster_identifier="cluster_identifier", sql=SQL, database=DATABASE, wait_for_completion=False,