Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump azure-mgmt-containerinstance>=7.0.0,<9.0.0 #33696

Merged
merged 4 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions airflow/providers/microsoft/azure/hooks/container_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@

import warnings
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from azure.common.client_factory import get_client_from_auth_file, get_client_from_json_dict
from azure.common.credentials import ServicePrincipalCredentials
from azure.identity import DefaultAzureCredential
from azure.mgmt.containerinstance import ContainerInstanceManagementClient

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning
from airflow.providers.microsoft.azure.hooks.base_azure import AzureBaseHook

if TYPE_CHECKING:
Expand Down Expand Up @@ -56,6 +59,59 @@ def __init__(self, azure_conn_id: str = default_conn_name) -> None:
def connection(self):
return self.get_conn()

def get_conn(self) -> Any:
"""
Authenticates the resource using the connection id passed during init.

:return: the authenticated client.
"""
conn = self.get_connection(self.conn_id)
tenant = conn.extra_dejson.get("tenantId")
if not tenant and conn.extra_dejson.get("extra__azure__tenantId"):
warnings.warn(
"`extra__azure__tenantId` is deprecated in azure connection extra, "
"please use `tenantId` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
tenant = conn.extra_dejson.get("extra__azure__tenantId")
subscription_id = conn.extra_dejson.get("subscriptionId")
if not subscription_id and conn.extra_dejson.get("extra__azure__subscriptionId"):
warnings.warn(
"`extra__azure__subscriptionId` is deprecated in azure connection extra, "
"please use `subscriptionId` instead",
AirflowProviderDeprecationWarning,
stacklevel=2,
)
subscription_id = conn.extra_dejson.get("extra__azure__subscriptionId")

key_path = conn.extra_dejson.get("key_path")
if key_path:
if not key_path.endswith(".json"):
raise AirflowException("Unrecognised extension for key file.")
self.log.info("Getting connection using a JSON key file.")
return get_client_from_auth_file(client_class=self.sdk_client, auth_path=key_path)

key_json = conn.extra_dejson.get("key_json")
if key_json:
self.log.info("Getting connection using a JSON config.")
return get_client_from_json_dict(client_class=self.sdk_client, config_dict=key_json)

credential: ServicePrincipalCredentials | DefaultAzureCredential
if all([conn.login, conn.password, tenant]):
self.log.info("Getting connection using specific credentials and subscription_id.")
credential = ServicePrincipalCredentials(
client_id=conn.login, secret=conn.password, tenant=tenant
)
else:
self.log.info("Using DefaultAzureCredential as credential")
credential = DefaultAzureCredential()

return ContainerInstanceManagementClient(
credential=credential,
subscription_id=subscription_id,
)

def create_or_update(self, resource_group: str, name: str, container_group: ContainerGroup) -> None:
"""
Create a new container group.
Expand All @@ -64,7 +120,7 @@ def create_or_update(self, resource_group: str, name: str, container_group: Cont
:param name: the name of the container group
:param container_group: the properties of the container group
"""
self.connection.container_groups.create_or_update(resource_group, name, container_group)
self.connection.container_groups.begin_create_or_update(resource_group, name, container_group)

def get_state_exitcode_details(self, resource_group: str, name: str) -> tuple:
"""
Expand Down Expand Up @@ -109,7 +165,7 @@ def get_state(self, resource_group: str, name: str) -> ContainerGroup:
:param name: the name of the container group
:return: ContainerGroup
"""
return self.connection.container_groups.get(resource_group, name, raw=False)
return self.connection.container_groups.get(resource_group, name)

def get_logs(self, resource_group: str, name: str, tail: int = 1000) -> list:
"""
Expand All @@ -120,7 +176,7 @@ def get_logs(self, resource_group: str, name: str, tail: int = 1000) -> list:
:param tail: the size of the tail
:return: A list of log messages
"""
logs = self.connection.container.list_logs(resource_group, name, name, tail=tail)
logs = self.connection.containers.list_logs(resource_group, name, name, tail=tail)
return logs.content.splitlines(True)

def delete(self, resource_group: str, name: str) -> None:
Expand All @@ -130,7 +186,7 @@ def delete(self, resource_group: str, name: str) -> None:
:param resource_group: the name of the resource group
:param name: the name of the container group
"""
self.connection.container_groups.delete(resource_group, name)
self.connection.container_groups.begin_delete(resource_group, name)

def exists(self, resource_group: str, name: str) -> bool:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ def __init__(
self,
*,
ci_conn_id: str,
registry_conn_id: str | None,
resource_group: str,
name: str,
image: str,
region: str,
registry_conn_id: str | None = None,
potiuk marked this conversation as resolved.
Show resolved Hide resolved
environment_variables: dict | None = None,
secured_variables: str | None = None,
volumes: list | None = None,
Expand Down Expand Up @@ -295,7 +295,6 @@ def _monitor_logging(self, resource_group: str, name: str) -> int:
try:
cg_state = self._ci_hook.get_state(resource_group, name)
instance_view = cg_state.containers[0].instance_view

# If there is no instance view, we show the provisioning state
if instance_view is not None:
c_state = instance_view.current_state
Expand Down
2 changes: 1 addition & 1 deletion airflow/providers/microsoft/azure/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ dependencies:
- azure-kusto-data>=4.1.0
# TODO: upgrade to newer versions of all the below libraries.
# See issue https://github.com/apache/airflow/issues/30199
- azure-mgmt-containerinstance>=1.5.0,<2.0
- azure-mgmt-containerinstance>=7.0.0,<9.0.0
- azure-mgmt-datafactory>=1.0.0,<2.0

integrations:
Expand Down
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@
"azure-identity>=1.3.1",
"azure-keyvault-secrets>=4.1.0",
"azure-kusto-data>=4.1.0",
"azure-mgmt-containerinstance>=1.5.0,<2.0",
"azure-mgmt-containerinstance>=7.0.0,<9.0.0",
"azure-mgmt-cosmosdb",
"azure-mgmt-datafactory>=1.0.0,<2.0",
"azure-mgmt-datalake-store>=0.5.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,17 @@ def setup_test_cases(self, create_mock_connection):
yield

@patch("azure.mgmt.containerinstance.models.ContainerGroup")
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.create_or_update")
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.begin_create_or_update")
def test_create_or_update(self, create_or_update_mock, container_group_mock):
self.hook.create_or_update("resource_group", "aci-test", container_group_mock)
create_or_update_mock.assert_called_once_with("resource_group", "aci-test", container_group_mock)

@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.get")
def test_get_state(self, get_state_mock):
self.hook.get_state("resource_group", "aci-test")
get_state_mock.assert_called_once_with("resource_group", "aci-test", raw=False)
get_state_mock.assert_called_once_with("resource_group", "aci-test")

@patch("azure.mgmt.containerinstance.operations.ContainerOperations.list_logs")
@patch("azure.mgmt.containerinstance.operations.ContainersOperations.list_logs")
def test_get_logs(self, list_logs_mock):
expected_messages = ["log line 1\n", "log line 2\n", "log line 3\n"]
logs = Logs(content="".join(expected_messages))
Expand All @@ -72,7 +72,7 @@ def test_get_logs(self, list_logs_mock):

assert logs == expected_messages

@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.delete")
@patch("azure.mgmt.containerinstance.operations.ContainerGroupsOperations.begin_delete")
def test_delete(self, delete_mock):
self.hook.delete("resource_group", "aci-test")
delete_mock.assert_called_once_with("resource_group", "aci-test")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from unittest.mock import MagicMock

import pytest
from azure.mgmt.containerinstance.models import ContainerState, Event
from azure.mgmt.containerinstance.models import (
Container,
ContainerGroup,
ContainerPropertiesInstanceView,
ContainerState,
Event,
)

from airflow.exceptions import AirflowException
from airflow.providers.microsoft.azure.operators.container_instances import AzureContainerInstancesOperator
Expand All @@ -35,10 +41,12 @@ def make_mock_cg(container_state, events=None):
"""
events = events or []
instance_view_dict = {"current_state": container_state, "events": events}
instance_view = namedtuple("InstanceView", instance_view_dict.keys())(*instance_view_dict.values())
instance_view = namedtuple("ContainerPropertiesInstanceView", instance_view_dict.keys())(
*instance_view_dict.values()
)

container_dict = {"instance_view": instance_view}
container = namedtuple("Container", container_dict.keys())(*container_dict.values())
container = namedtuple("Containers", container_dict.keys())(*container_dict.values())

container_g_dict = {"containers": [container]}
container_g = namedtuple("ContainerGroup", container_g_dict.keys())(*container_g_dict.values())
Expand All @@ -53,23 +61,42 @@ def make_mock_cg_with_missing_events(container_state):
This can happen, when the container group is provisioned, but not started.
"""
instance_view_dict = {"current_state": container_state, "events": None}
instance_view = namedtuple("InstanceView", instance_view_dict.keys())(*instance_view_dict.values())
instance_view = namedtuple("ContainerPropertiesInstanceView", instance_view_dict.keys())(
*instance_view_dict.values()
)

container_dict = {"instance_view": instance_view}
container = namedtuple("Container", container_dict.keys())(*container_dict.values())
container = namedtuple("Containers", container_dict.keys())(*container_dict.values())

container_g_dict = {"containers": [container]}
container_g = namedtuple("ContainerGroup", container_g_dict.keys())(*container_g_dict.values())
return container_g


def make_mock_container(state: str, exit_code: int, detail_status: str, events: Event | None = None):
container = Container(name="hello_world", image="test", resources="test")
container_prop = ContainerPropertiesInstanceView()
container_state = ContainerState()
container_state.state = state
container_state.exit_code = exit_code
container_state.detail_status = detail_status
container_prop.current_state = container_state
if events:
container_prop.events = events
container.instance_view = container_prop

cg = ContainerGroup(containers=[container], os_type="Linux")

return cg


class TestACIOperator:
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute(self, aci_mock):
expected_c_state = ContainerState(state="Terminated", exit_code=0, detail_status="test")
expected_cg = make_mock_cg(expected_c_state)
expected_cg = make_mock_container(state="Terminated", exit_code=0, detail_status="test")

aci_mock.return_value.get_state.return_value = expected_cg

aci_mock.return_value.exists.return_value = False

aci = AzureContainerInstancesOperator(
Expand Down Expand Up @@ -102,10 +129,10 @@ def test_execute(self, aci_mock):

@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_failures(self, aci_mock):
expected_c_state = ContainerState(state="Terminated", exit_code=1, detail_status="test")
expected_cg = make_mock_cg(expected_c_state)

expected_cg = make_mock_container(state="Terminated", exit_code=1, detail_status="test")
aci_mock.return_value.get_state.return_value = expected_cg

aci_mock.return_value.exists.return_value = False

aci = AzureContainerInstancesOperator(
Expand All @@ -124,11 +151,11 @@ def test_execute_with_failures(self, aci_mock):

@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_tags(self, aci_mock):
expected_c_state = ContainerState(state="Terminated", exit_code=0, detail_status="test")
expected_cg = make_mock_cg(expected_c_state)
tags = {"testKey": "testValue"}

expected_cg = make_mock_container(state="Terminated", exit_code=0, detail_status="test")
aci_mock.return_value.get_state.return_value = expected_cg
tags = {"testKey": "testValue"}

aci_mock.return_value.exists.return_value = False

aci = AzureContainerInstancesOperator(
Expand Down Expand Up @@ -163,13 +190,18 @@ def test_execute_with_tags(self, aci_mock):

@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_messages_logs(self, aci_mock):
events = [Event(message="test"), Event(message="messages")]
expected_c_state1 = ContainerState(state="Succeeded", exit_code=0, detail_status="test")
expected_cg1 = make_mock_cg(expected_c_state1, events)
expected_c_state2 = ContainerState(state="Running", exit_code=0, detail_status="test")
expected_cg2 = make_mock_cg(expected_c_state2, events)
expected_c_state3 = ContainerState(state="Terminated", exit_code=0, detail_status="test")
expected_cg3 = make_mock_cg(expected_c_state3, events)
event1 = Event()
event1.message = "test"
event2 = Event()
event2.message = "messages"
events = [event1, event2]
expected_cg1 = make_mock_container(
state="Succeeded", exit_code=0, detail_status="test", events=events
)
expected_cg2 = make_mock_container(state="Running", exit_code=0, detail_status="test", events=events)
expected_cg3 = make_mock_container(
state="Terminated", exit_code=0, detail_status="test", events=events
)

aci_mock.return_value.get_state.side_effect = [expected_cg1, expected_cg2, expected_cg3]
aci_mock.return_value.get_logs.return_value = ["test", "logs"]
Expand Down Expand Up @@ -211,11 +243,11 @@ def test_name_checker(self):

@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_ipaddress(self, aci_mock):
expected_c_state = ContainerState(state="Terminated", exit_code=0, detail_status="test")
expected_cg = make_mock_cg(expected_c_state)
ipaddress = MagicMock()

aci_mock.return_value.get_state.return_value = expected_cg
aci_mock.return_value.get_state.return_value = make_mock_container(
state="Terminated", exit_code=0, detail_status="test"
)
aci_mock.return_value.exists.return_value = False

aci = AzureContainerInstancesOperator(
Expand All @@ -236,10 +268,10 @@ def test_execute_with_ipaddress(self, aci_mock):

@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_with_windows_os_and_diff_restart_policy(self, aci_mock):
expected_c_state = ContainerState(state="Terminated", exit_code=0, detail_status="test")
expected_cg = make_mock_cg(expected_c_state)

aci_mock.return_value.get_state.return_value = expected_cg
aci_mock.return_value.get_state.return_value = make_mock_container(
state="Terminated", exit_code=0, detail_status="test"
)
aci_mock.return_value.exists.return_value = False

aci = AzureContainerInstancesOperator(
Expand All @@ -262,10 +294,10 @@ def test_execute_with_windows_os_and_diff_restart_policy(self, aci_mock):

@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_fails_with_incorrect_os_type(self, aci_mock):
expected_c_state = ContainerState(state="Terminated", exit_code=0, detail_status="test")
expected_cg = make_mock_cg(expected_c_state)

aci_mock.return_value.get_state.return_value = expected_cg
aci_mock.return_value.get_state.return_value = make_mock_container(
state="Terminated", exit_code=0, detail_status="test"
)
aci_mock.return_value.exists.return_value = False

with pytest.raises(AirflowException) as ctx:
Expand All @@ -288,10 +320,10 @@ def test_execute_fails_with_incorrect_os_type(self, aci_mock):

@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
def test_execute_fails_with_incorrect_restart_policy(self, aci_mock):
expected_c_state = ContainerState(state="Terminated", exit_code=0, detail_status="test")
expected_cg = make_mock_cg(expected_c_state)

aci_mock.return_value.get_state.return_value = expected_cg
aci_mock.return_value.get_state.return_value = make_mock_container(
state="Terminated", exit_code=0, detail_status="test"
)
aci_mock.return_value.exists.return_value = False

with pytest.raises(AirflowException) as ctx:
Expand All @@ -315,10 +347,8 @@ def test_execute_fails_with_incorrect_restart_policy(self, aci_mock):
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.sleep")
def test_execute_correct_sleep_cycle(self, sleep_mock, aci_mock):
expected_c_state1 = ContainerState(state="Running", exit_code=0, detail_status="test")
expected_cg1 = make_mock_cg(expected_c_state1)
expected_c_state2 = ContainerState(state="Terminated", exit_code=0, detail_status="test")
expected_cg2 = make_mock_cg(expected_c_state2)
expected_cg1 = make_mock_container(state="Running", exit_code=0, detail_status="test")
expected_cg2 = make_mock_container(state="Terminated", exit_code=0, detail_status="test")

aci_mock.return_value.get_state.side_effect = [expected_cg1, expected_cg1, expected_cg2]
aci_mock.return_value.exists.return_value = False
Expand All @@ -340,10 +370,8 @@ def test_execute_correct_sleep_cycle(self, sleep_mock, aci_mock):
@mock.patch("airflow.providers.microsoft.azure.operators.container_instances.AzureContainerInstanceHook")
@mock.patch("logging.Logger.exception")
def test_execute_with_missing_events(self, log_mock, aci_mock):
expected_c_state1 = ContainerState(state="Running", exit_code=0, detail_status="test")
expected_cg1 = make_mock_cg_with_missing_events(expected_c_state1)
expected_c_state2 = ContainerState(state="Terminated", exit_code=0, detail_status="test")
expected_cg2 = make_mock_cg(expected_c_state2)
expected_cg1 = make_mock_container(state="Running", exit_code=0, detail_status="test")
expected_cg2 = make_mock_container(state="Terminated", exit_code=0, detail_status="test")

aci_mock.return_value.get_state.side_effect = [expected_cg1, expected_cg2]
aci_mock.return_value.exists.return_value = False
Expand Down