Skip to content

Commit

Permalink
feat: component with default label
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotzh committed Oct 19, 2022
1 parent 86a6d37 commit 6f025e6
Show file tree
Hide file tree
Showing 5 changed files with 716 additions and 9 deletions.
2 changes: 2 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/constants/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@
"AzureFile": "https://{}.file.{}",
}

DEFAULT_LABEL_NAME = "default"
DEFAULT_COMPONENT_VERSION = "azureml_default"
ANONYMOUS_COMPONENT_NAME = "azureml_anonymous"
GIT_PATH_PREFIX = "git+"
SCHEMA_VALIDATION_ERROR_TEMPLATE = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from azure.ai.ml._utils._azureml_polling import AzureMLPolling
from azure.ai.ml._utils._endpoint_utils import polling_wait
from azure.ai.ml._utils._logger_utils import OpsLogger
from azure.ai.ml.constants._common import AzureMLResourceType, LROConfigurations
from azure.ai.ml.constants._common import AzureMLResourceType, LROConfigurations, DEFAULT_LABEL_NAME, \
DEFAULT_COMPONENT_VERSION
from azure.ai.ml.entities import Component, ValidationResult
from azure.ai.ml.entities._assets import Code
from azure.ai.ml.exceptions import ComponentException, ErrorCategory, ErrorTarget, ValidationException
Expand Down Expand Up @@ -172,17 +173,16 @@ def get(self, name: str, version: Optional[str] = None, label: Optional[str] = N
error_category=ErrorCategory.USER_ERROR,
)

if not version and not label:
label = DEFAULT_LABEL_NAME

if label == DEFAULT_LABEL_NAME:
label = None
version = DEFAULT_COMPONENT_VERSION

if label:
return _resolve_label_to_asset(self, name, label)

if not version:
msg = "Must provide either version or label."
raise ValidationException(
message=msg,
target=ErrorTarget.COMPONENT,
no_personal_data_message=msg,
error_category=ErrorCategory.USER_ERROR,
)
result = (
self._version_operation.get(
name=name,
Expand Down
25 changes: 25 additions & 0 deletions sdk/ml/azure-ai-ml/tests/component/e2etests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,3 +850,28 @@ def test_create_pipeline_component_from_job(self, client: MLClient, randstr: Cal
component = PipelineComponent(name=name, source_job_id=job.id)
rest_component = client.components.create_or_update(component)
assert rest_component.name == name

def test_component_with_default_label(
self,
client: MLClient,
randstr: Callable[[str], str],
) -> None:
yaml_path: str = "./tests/test_configs/components/helloworld_component.yml"
component_name = randstr("component_name")

create_component(client, component_name, path=yaml_path)

target_component = client.components.get(component_name, label="latest")

for default_component in [
client.components.get(component_name),
client.components.get(component_name, label="default"),
]:
expected_component_dict = target_component._to_dict()
default_component_dict = default_component._to_dict()
assert pydash.omit(default_component_dict, "id") == pydash.omit(expected_component_dict, "id")

assert default_component.id.endswith(f"/components/{component_name}/labels/default")

node = default_component()
assert node._to_rest_object()["componentId"] == default_component.id
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ def test_get(self, mock_component_operation: ComponentOperations) -> None:
assert "version='1'" in create_call_args_str
mock_component_entity._from_rest_object.assert_called_once()

def test_get_default(self, mock_component_operation: ComponentOperations) -> None:
with patch("azure.ai.ml.operations._component_operations.Component") as mock_component_entity:
mock_component_operation.get("mock_component")

mock_component_operation._version_operation.get.assert_called_once()
create_call_args_str = str(mock_component_operation._version_operation.get.call_args)
assert "name='mock_component'" in create_call_args_str
mock_component_entity._from_rest_object.assert_called_once()

def test_archive_version(self, mock_component_operation: ComponentOperations):
name = "random_name"
component = Mock(ComponentVersionData(properties=Mock(ComponentVersionDetails())))
Expand Down
Loading

0 comments on commit 6f025e6

Please sign in to comment.