Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: divide pipeline related tests #26886

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_utils/_arm_id_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,49 @@ def get_datastore_arm_id(datastore_name: str = None, operation_scope: OperationS
)


class AMLLabelledArmId(object):
"""Parser for versioned arm id: e.g. /subscription/.../code/my-
code/labels/default.

:param arm_id: The labelled ARM id.
:type arm_id: str
:raises ~azure.ai.ml.exceptions.ValidationException: Raised if the ARM id is incorrectly formatted.
"""

REGEX_PATTERN = (
"^/?subscriptions/([^/]+)/resourceGroups/(["
"^/]+)/providers/Microsoft.MachineLearningServices/workspaces/([^/]+)/([^/]+)/([^/]+)/labels/(["
"^/]+)"
)

def __init__(self, arm_id=None):
self.is_registry_id = None
if arm_id:
match = re.match(AMLLabelledArmId.REGEX_PATTERN, arm_id)
if match:
self.subscription_id = match.group(1)
self.resource_group_name = match.group(2)
self.workspace_name = match.group(3)
self.asset_type = match.group(4)
self.asset_name = match.group(5)
self.asset_label = match.group(6)
else:
match = re.match(REGISTRY_VERSION_PATTERN, arm_id)
if match:
self.asset_name = match.group(3)
self.asset_label = match.group(4)
self.is_registry_id = True
else:
msg = "Invalid AzureML ARM versioned Id {}"
raise ValidationException(
message=msg.format(arm_id),
no_personal_data_message=msg.format("[arm_id]"),
error_type=ValidationErrorType.INVALID_VALUE,
error_category=ErrorCategory.USER_ERROR,
target=ErrorTarget.ARM_RESOURCE,
)


class AMLNamedArmId:
"""Parser for named arm id (no version): e.g.
/subscription/.../compute/cpu-cluster.
Expand Down
4 changes: 4 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
NAMED_RESOURCE_ID_FORMAT = "/subscriptions/{}/resourceGroups/{}/providers/{}/workspaces/{}/{}/{}"
LEVEL_ONE_NAMED_RESOURCE_ID_FORMAT = "/subscriptions/{}/resourceGroups/{}/providers/{}/{}/{}"
VERSIONED_RESOURCE_ID_FORMAT = "/subscriptions/{}/resourceGroups/{}/providers/{}/workspaces/{}/{}/{}/versions/{}"
LABELLED_RESOURCE_ID_FORMAT = "/subscriptions/{}/resourceGroups/{}/providers/{}/workspaces/{}/{}/{}/labels/{}"
DATASTORE_RESOURCE_ID = (
"/subscriptions/{}/resourceGroups/{}/providers/Microsoft.MachineLearningServices/workspaces/{}/datastores/{}"
)
Expand All @@ -37,6 +38,7 @@
)
ASSET_ID_FORMAT = "azureml://locations/{}/workspaces/{}/{}/{}/versions/{}"
VERSIONED_RESOURCE_NAME = "{}:{}"
LABELLED_RESOURCE_NAME = "{}@{}"
PYTHON = "python"
AML_TOKEN_YAML = "aml_token"
AAD_TOKEN_YAML = "aad_token"
Expand Down Expand Up @@ -132,6 +134,8 @@
"AzureFile": "https://{}.file.{}",
}

DEFAULT_LABEL_NAME = "default"
DEFAULT_COMPONENT_VERSION = "azureml_default"
ANONYMOUS_COMPONENT_NAME = "azureml_anonymous"
GIT_PATH_PREFIX = "git+"
SCHEMA_VALIDATION_ERROR_TEMPLATE = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from azure.ai.ml._utils._azureml_polling import AzureMLPolling
from azure.ai.ml._utils._endpoint_utils import polling_wait
from azure.ai.ml._utils._logger_utils import OpsLogger
from azure.ai.ml.constants._common import AzureMLResourceType, LROConfigurations
from azure.ai.ml.constants._common import AzureMLResourceType, LROConfigurations, DEFAULT_LABEL_NAME, \
DEFAULT_COMPONENT_VERSION
from azure.ai.ml.entities import Component, ValidationResult
from azure.ai.ml.entities._assets import Code
from azure.ai.ml.exceptions import ComponentException, ErrorCategory, ErrorTarget, ValidationException
Expand Down Expand Up @@ -172,17 +173,16 @@ def get(self, name: str, version: Optional[str] = None, label: Optional[str] = N
error_category=ErrorCategory.USER_ERROR,
)

if not version and not label:
label = DEFAULT_LABEL_NAME

if label == DEFAULT_LABEL_NAME:
label = None
version = DEFAULT_COMPONENT_VERSION

if label:
return _resolve_label_to_asset(self, name, label)

if not version:
msg = "Must provide either version or label."
raise ValidationException(
message=msg,
target=ErrorTarget.COMPONENT,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
)
result = (
self._version_operation.get(
name=name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from azure.ai.ml._utils._arm_id_utils import (
AMLNamedArmId,
AMLVersionedArmId,
AMLLabelledArmId,
get_arm_id_with_version,
is_ARM_id_for_resource,
is_registry_id_for_resource,
Expand All @@ -23,7 +24,6 @@
)
from azure.ai.ml._utils._asset_utils import _resolve_label_to_asset
from azure.ai.ml._utils._storage_utils import AzureMLDatastorePathUri
from azure.ai.ml._utils.utils import is_private_preview_enabled # pylint: disable=unused-import
from azure.ai.ml.constants._common import (
ARM_ID_PREFIX,
AZUREML_RESOURCE_PROVIDER,
Expand All @@ -36,7 +36,10 @@
NAMED_RESOURCE_ID_FORMAT,
VERSIONED_RESOURCE_ID_FORMAT,
VERSIONED_RESOURCE_NAME,
LABELLED_RESOURCE_NAME,
AzureMLResourceType,
LABELLED_RESOURCE_ID_FORMAT,
DEFAULT_LABEL_NAME,
)
from azure.ai.ml.entities import Component
from azure.ai.ml.entities._assets import Code, Data, Environment, Model
Expand Down Expand Up @@ -149,6 +152,19 @@ def get_asset_arm_id(
"CLI and SDK. Learn more at aka.ms/curatedenv"
)
return f"azureml:{asset}"

name, label = parse_name_label(asset)
# TODO: remove this condition after label is fully supported for all versioned resources
if label == DEFAULT_LABEL_NAME and azureml_type == AzureMLResourceType.COMPONENT:
return LABELLED_RESOURCE_ID_FORMAT.format(
self._operation_scope.subscription_id,
self._operation_scope.resource_group_name,
AZUREML_RESOURCE_PROVIDER,
self._operation_scope.workspace_name,
azureml_type,
name,
label,
)
name, version = self._resolve_name_version_from_name_label(asset, azureml_type)
if not version:
name, version = parse_prefixed_name_version(asset)
Expand Down Expand Up @@ -396,6 +412,12 @@ def resolve_azureml_id(self, arm_id: str = None, **kwargs) -> str:
return VERSIONED_RESOURCE_NAME.format(arm_id_obj.asset_name, arm_id_obj.asset_version)
except ValidationException:
pass # fall back to named arm id
try:
arm_id_obj = AMLLabelledArmId(arm_id)
if self._match(arm_id_obj):
return LABELLED_RESOURCE_NAME.format(arm_id_obj.asset_name, arm_id_obj.asset_label)
except ValidationException:
pass # fall back to named arm id
try:
arm_id_obj = AMLNamedArmId(arm_id)
if self._match(arm_id_obj):
Expand Down
25 changes: 25 additions & 0 deletions sdk/ml/azure-ai-ml/tests/component/e2etests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,28 @@ def test_create_pipeline_component_from_job(self, client: MLClient, randstr: Cal
component = PipelineComponent(name=name, source_job_id=job.id)
rest_component = client.components.create_or_update(component)
assert rest_component.name == name

def test_component_with_default_label(
self,
client: MLClient,
randstr: Callable[[str], str],
) -> None:
yaml_path: str = "./tests/test_configs/components/helloworld_component.yml"
component_name = randstr("component_name")

create_component(client, component_name, path=yaml_path)

target_component = client.components.get(component_name, label="latest")

for default_component in [
client.components.get(component_name),
client.components.get(component_name, label="default"),
]:
expected_component_dict = target_component._to_dict()
default_component_dict = default_component._to_dict()
assert pydash.omit(default_component_dict, "id") == pydash.omit(expected_component_dict, "id")

assert default_component.id.endswith(f"/components/{component_name}/labels/default")

node = default_component()
assert node._to_rest_object()["componentId"] == default_component.id
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ def test_get(self, mock_component_operation: ComponentOperations) -> None:
assert "version='1'" in create_call_args_str
mock_component_entity._from_rest_object.assert_called_once()

def test_get_default(self, mock_component_operation: ComponentOperations) -> None:
with patch("azure.ai.ml.operations._component_operations.Component") as mock_component_entity:
mock_component_operation.get("mock_component")

mock_component_operation._version_operation.get.assert_called_once()
create_call_args_str = str(mock_component_operation._version_operation.get.call_args)
assert "name='mock_component'" in create_call_args_str
mock_component_entity._from_rest_object.assert_called_once()

def test_archive_version(self, mock_component_operation: ComponentOperations):
name = "random_name"
component = Mock(ComponentVersionData(properties=Mock(ComponentVersionDetails())))
Expand Down
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def enable_pipeline_private_preview_features(mocker: MockFixture):

@pytest.fixture()
def enable_environment_id_arm_expansion(mocker: MockFixture):
mocker.patch("azure.ai.ml.operations._operation_orchestrator.is_private_preview_enabled", return_value=False)
mocker.patch("azure.ai.ml._utils.utils.is_private_preview_enabled", return_value=False)


@pytest.fixture(autouse=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from azure.ai.ml.entities._datastore._on_prem import HdfsDatastore
from azure.ai.ml.entities._credentials import NoneCredentialConfiguration
from azure.ai.ml.entities._datastore.datastore import Datastore
from azure.core.paging import ItemPaged
from azure.mgmt.storage import StorageManagementClient

from devtools_testutils import AzureRecordedTestCase, is_live

Expand Down
22 changes: 22 additions & 0 deletions sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dsl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2410,3 +2410,25 @@ def pipeline_with_group(group: ParamClass):
}
assert actual_job["inputs"] == expected_job_inputs
assert actual_job["jobs"]["microsoft_samples_command_component_basic_inputs"]["inputs"] == expected_node_inputs

def test_dsl_pipeline_with_default_component(
self,
client: MLClient,
randstr: Callable[[str], str],
) -> None:
yaml_path: str = "./tests/test_configs/components/helloworld_component.yml"
component_name = randstr("component_name")
component: Component = load_component(source=yaml_path, params_override=[{"name": component_name}])
client.components.create_or_update(component)

default_component_func = client.components.get(component_name)

@dsl.pipeline()
def pipeline_with_default_component():
node1 = default_component_func(component_in_path=job_input)
node1.compute = "cpu-cluster"

# component from client.components.get
pipeline_job = client.jobs.create_or_update(pipeline_with_default_component())
created_pipeline_job: PipelineJob = client.jobs.get(pipeline_job.name)
assert created_pipeline_job.jobs["node1"].component == f"{component_name}@default"
Loading