From c84eaaaa656a122cf10a4bc1ae3898a1954890c0 Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Fri, 25 Aug 2023 15:13:37 +0200 Subject: [PATCH] Fix Azure Batch Hook instantation The Hook instantiation for Azure Batch has been done in the constructor, which is wrong. This has been detected when #33716 added example dag and it started to fail provider imports as connection has beeen missing to instantiate it. The hook instantiation is now moved to cached property. --- airflow/providers/microsoft/azure/operators/batch.py | 6 +++++- .../providers/microsoft/azure/operators/test_azure_batch.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/airflow/providers/microsoft/azure/operators/batch.py b/airflow/providers/microsoft/azure/operators/batch.py index 63b925a98199c..e26f56dd6e62b 100644 --- a/airflow/providers/microsoft/azure/operators/batch.py +++ b/airflow/providers/microsoft/azure/operators/batch.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence from azure.batch import models as batch_models @@ -176,7 +177,10 @@ def __init__( self.timeout = timeout self.should_delete_job = should_delete_job self.should_delete_pool = should_delete_pool - self.hook = self.get_hook() + + @cached_property + def hook(self): + return self.get_hook() def _check_inputs(self) -> Any: if not self.os_family and not self.vm_publisher: diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py b/tests/providers/microsoft/azure/operators/test_azure_batch.py index 0e3947732b98b..e920f7c1d9ed4 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_batch.py +++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py @@ -162,6 +162,7 @@ def setup_method(self, method, mock_batch, mock_hook): self.batch_client = mock_batch.return_value self.mock_instance = mock_hook.return_value assert self.batch_client == self.operator.hook.connection + assert self.batch_client == self.operator2_pass.hook.connection @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") def test_execute_without_failures(self, wait_mock):