From 9f4cf3bdee743293cd3f8dcde81a2421ea2f526d Mon Sep 17 00:00:00 2001 From: GPK Date: Thu, 3 Oct 2024 21:36:09 +0100 Subject: [PATCH] fix PubSubAsyncHook in PubsubPullTrigger to use gcp_conn_id (#42671) --- .../providers/google/cloud/triggers/pubsub.py | 10 +++++++++- .../google/cloud/triggers/test_pubsub.py | 20 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/airflow/providers/google/cloud/triggers/pubsub.py b/airflow/providers/google/cloud/triggers/pubsub.py index db3fe409e942..e98603006f72 100644 --- a/airflow/providers/google/cloud/triggers/pubsub.py +++ b/airflow/providers/google/cloud/triggers/pubsub.py @@ -19,6 +19,7 @@ from __future__ import annotations import asyncio +from functools import cached_property from typing import Any, AsyncIterator, Sequence from google.cloud.pubsub_v1.types import ReceivedMessage @@ -67,7 +68,6 @@ def __init__( self.poke_interval = poke_interval self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain - self.hook = PubSubAsyncHook() def serialize(self) -> tuple[str, dict[str, Any]]: """Serialize PubsubPullTrigger arguments and classpath.""" @@ -113,3 +113,11 @@ async def message_acknowledgement(self, pulled_messages): messages=pulled_messages, ) self.log.info("Acknowledged ack_ids from subscription %s", self.subscription) + + @cached_property + def hook(self) -> PubSubAsyncHook: + return PubSubAsyncHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + project_id=self.project_id, + ) diff --git a/tests/providers/google/cloud/triggers/test_pubsub.py b/tests/providers/google/cloud/triggers/test_pubsub.py index e1a4e178d291..60acd2b7d4c2 100644 --- a/tests/providers/google/cloud/triggers/test_pubsub.py +++ b/tests/providers/google/cloud/triggers/test_pubsub.py @@ -110,3 +110,23 @@ async def test_async_pubsub_pull_trigger_return_event(self, mock_pull): response = await trigger.run().asend(None) assert response == expected_event + + @mock.patch("airflow.providers.google.cloud.triggers.pubsub.PubSubAsyncHook") + def test_hook(self, mock_async_hook): + trigger = PubsubPullTrigger( + project_id=PROJECT_ID, + subscription="subscription", + max_messages=MAX_MESSAGES, + ack_messages=False, + poke_interval=TEST_POLL_INTERVAL, + gcp_conn_id=TEST_GCP_CONN_ID, + impersonation_chain=None, + ) + async_hook_actual = trigger.hook + + mock_async_hook.assert_called_once_with( + gcp_conn_id=trigger.gcp_conn_id, + impersonation_chain=trigger.impersonation_chain, + project_id=trigger.project_id, + ) + assert async_hook_actual == mock_async_hook.return_value