Skip to content

Commit

Permalink
fix PubSubAsyncHook in PubsubPullTrigger to use gcp_conn_id (apache#4…
Browse files Browse the repository at this point in the history
  • Loading branch information
gopidesupavan authored and joaopamaral committed Oct 21, 2024
1 parent a912fef commit 9f4cf3b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 1 deletion.
10 changes: 9 additions & 1 deletion airflow/providers/google/cloud/triggers/pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import annotations

import asyncio
from functools import cached_property
from typing import Any, AsyncIterator, Sequence

from google.cloud.pubsub_v1.types import ReceivedMessage
Expand Down Expand Up @@ -67,7 +68,6 @@ def __init__(
self.poke_interval = poke_interval
self.gcp_conn_id = gcp_conn_id
self.impersonation_chain = impersonation_chain
self.hook = PubSubAsyncHook()

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize PubsubPullTrigger arguments and classpath."""
Expand Down Expand Up @@ -113,3 +113,11 @@ async def message_acknowledgement(self, pulled_messages):
messages=pulled_messages,
)
self.log.info("Acknowledged ack_ids from subscription %s", self.subscription)

@cached_property
def hook(self) -> PubSubAsyncHook:
return PubSubAsyncHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
project_id=self.project_id,
)
20 changes: 20 additions & 0 deletions tests/providers/google/cloud/triggers/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,23 @@ async def test_async_pubsub_pull_trigger_return_event(self, mock_pull):
response = await trigger.run().asend(None)

assert response == expected_event

@mock.patch("airflow.providers.google.cloud.triggers.pubsub.PubSubAsyncHook")
def test_hook(self, mock_async_hook):
trigger = PubsubPullTrigger(
project_id=PROJECT_ID,
subscription="subscription",
max_messages=MAX_MESSAGES,
ack_messages=False,
poke_interval=TEST_POLL_INTERVAL,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
)
async_hook_actual = trigger.hook

mock_async_hook.assert_called_once_with(
gcp_conn_id=trigger.gcp_conn_id,
impersonation_chain=trigger.impersonation_chain,
project_id=trigger.project_id,
)
assert async_hook_actual == mock_async_hook.return_value

0 comments on commit 9f4cf3b

Please sign in to comment.