diff --git a/airflow/providers/microsoft/azure/operators/batch.py b/airflow/providers/microsoft/azure/operators/batch.py index 63b925a98199c..e26f56dd6e62b 100644 --- a/airflow/providers/microsoft/azure/operators/batch.py +++ b/airflow/providers/microsoft/azure/operators/batch.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +from functools import cached_property from typing import TYPE_CHECKING, Any, Sequence from azure.batch import models as batch_models @@ -176,7 +177,10 @@ def __init__( self.timeout = timeout self.should_delete_job = should_delete_job self.should_delete_pool = should_delete_pool - self.hook = self.get_hook() + + @cached_property + def hook(self): + return self.get_hook() def _check_inputs(self) -> Any: if not self.os_family and not self.vm_publisher: diff --git a/tests/providers/microsoft/azure/operators/test_azure_batch.py b/tests/providers/microsoft/azure/operators/test_azure_batch.py index 0e3947732b98b..e920f7c1d9ed4 100644 --- a/tests/providers/microsoft/azure/operators/test_azure_batch.py +++ b/tests/providers/microsoft/azure/operators/test_azure_batch.py @@ -162,6 +162,7 @@ def setup_method(self, method, mock_batch, mock_hook): self.batch_client = mock_batch.return_value self.mock_instance = mock_hook.return_value assert self.batch_client == self.operator.hook.connection + assert self.batch_client == self.operator2_pass.hook.connection @mock.patch.object(AzureBatchHook, "wait_for_all_node_state") def test_execute_without_failures(self, wait_mock):