Skip to content

Commit

Permalink
[ML] Add Jobservice subtypes for Ssh, VsCode, TensorBoard, and Jupyte…
Browse files Browse the repository at this point in the history
…rLab (Azure#28397)

* [ML] WIP: Polymorphic Jobservice - TODO: Pending Cleanup/Testing/UTs/E2E
  • Loading branch information
TonyJ1 authored Jan 31, 2023
1 parent 9326f4c commit 0a5a4a2
Show file tree
Hide file tree
Showing 17 changed files with 1,066 additions and 415 deletions.
2 changes: 1 addition & 1 deletion sdk/ml/azure-ai-ml/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
## 1.4.0 (Unreleased)

### Features Added
-
-Add dedicated classes for each type of job service. The classes added are `JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService` with a few properties specific to the type.

### Bugs Fixed
- Fixed an issue where the ordering of `.amlignore` and `.gitignore` files are not respected
Expand Down
24 changes: 22 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/base_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,34 @@
from azure.ai.ml.constants._common import AzureMLResourceType

from .creation_context import CreationContextSchema
from .services import JobServiceSchema
from .services import (
JobServiceSchema,
SshJobServiceSchema,
VsCodeJobServiceSchema,
TensorBoardJobServiceSchema,
JupyterLabJobServiceSchema,
)

module_logger = logging.getLogger(__name__)


class BaseJobSchema(ResourceSchema):
creation_context = NestedField(CreationContextSchema, dump_only=True)
services = fields.Dict(keys=fields.Str(), values=NestedField(JobServiceSchema))
services = fields.Dict(
keys=fields.Str(),
values=UnionField(
[
NestedField(SshJobServiceSchema),
NestedField(TensorBoardJobServiceSchema),
NestedField(VsCodeJobServiceSchema),
NestedField(JupyterLabJobServiceSchema),
# JobServiceSchema should be the last in the list.
# To support types not set by users like Custom, Tracking, Studio.
NestedField(JobServiceSchema),
],
is_strict=True,
),
)
name = fields.Str()
id = ArmStr(azureml_type=AzureMLResourceType.JOB, dump_only=True, required=False)
display_name = fields.Str(required=False)
Expand Down
80 changes: 72 additions & 8 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/job/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@

from marshmallow import fields, post_load

from azure.ai.ml.entities._job.job_service import JobService
from azure.ai.ml.entities._job.job_service import (
JobService,
SshJobService,
JupyterLabJobService,
VsCodeJobService,
TensorBoardJobService,
)
from azure.ai.ml.constants._job.job import JobServiceTypeNames
from azure.ai.ml._schema.core.fields import StringTransformedEnum, UnionField

Expand All @@ -15,7 +21,20 @@
module_logger = logging.getLogger(__name__)


class JobServiceSchema(PathAwareSchema):
class JobServiceBaseSchema(PathAwareSchema):
port = fields.Int()
endpoint = fields.Str(dump_only=True)
status = fields.Str(dump_only=True)
nodes = fields.Str()
error_message = fields.Str(dump_only=True)
properties = fields.Dict()


class JobServiceSchema(JobServiceBaseSchema):
"""This is to support tansformation of job services passed as dict type and
internal job services like Custom, Tracking, Studio set by the system.
"""

job_service_type = UnionField(
[
StringTransformedEnum(
Expand All @@ -25,13 +44,58 @@ class JobServiceSchema(PathAwareSchema):
fields.Str(),
]
)
port = fields.Int()
endpoint = fields.Str(dump_only=True)
status = fields.Str(dump_only=True)
nodes = fields.Str()
error_message = fields.Str(dump_only=True)
properties = fields.Dict()

@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument,no-self-use
data.pop("job_service_type", None)
return JobService(**data)


class TensorBoardJobServiceSchema(JobServiceBaseSchema):
job_service_type = StringTransformedEnum(
allowed_values=JobServiceTypeNames.EntityNames.TENSOR_BOARD,
pass_original=True,
)
log_dir = fields.Str()

@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument,no-self-use
data.pop("job_service_type", None)
return TensorBoardJobService(**data)


class SshJobServiceSchema(JobServiceBaseSchema):
job_service_type = StringTransformedEnum(
allowed_values=JobServiceTypeNames.EntityNames.SSH,
pass_original=True,
)
ssh_public_keys = fields.Str()

@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument,no-self-use
data.pop("job_service_type", None)
return SshJobService(**data)


class VsCodeJobServiceSchema(JobServiceBaseSchema):
job_service_type = StringTransformedEnum(
allowed_values=JobServiceTypeNames.EntityNames.VS_CODE,
pass_original=True,
)

@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument,no-self-use
data.pop("job_service_type", None)
return VsCodeJobService(**data)


class JupyterLabJobServiceSchema(JobServiceBaseSchema):
job_service_type = StringTransformedEnum(
allowed_values=JobServiceTypeNames.EntityNames.JUPYTER_LAB,
pass_original=True,
)

@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument,no-self-use
data.pop("job_service_type", None)
return JupyterLabJobService(**data)
24 changes: 22 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/pipeline/component_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,13 @@
from ..job import ParameterizedCommandSchema, ParameterizedParallelSchema, ParameterizedSparkSchema
from ..job.job_limits import CommandJobLimitsSchema
from ..job.parameterized_spark import SparkEntryClassSchema, SparkEntryFileSchema
from ..job.services import JobServiceSchema
from ..job.services import (
JobServiceSchema,
SshJobServiceSchema,
JupyterLabJobServiceSchema,
VsCodeJobServiceSchema,
TensorBoardJobServiceSchema,
)

module_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -152,7 +158,21 @@ class CommandSchema(BaseNodeSchema, ParameterizedCommandSchema):
ArmVersionedStr(azureml_type=AzureMLResourceType.ENVIRONMENT, allow_default_version=True),
],
)
services = fields.Dict(keys=fields.Str(), values=NestedField(JobServiceSchema))
services = fields.Dict(
keys=fields.Str(),
values=UnionField(
[
NestedField(SshJobServiceSchema),
NestedField(JupyterLabJobServiceSchema),
NestedField(TensorBoardJobServiceSchema),
NestedField(VsCodeJobServiceSchema),
# JobServiceSchema should be the last in the list.
# To support types not set by users like Custom, Tracking, Studio.
NestedField(JobServiceSchema),
],
is_strict=True,
),
)
identity = UnionField(
[
NestedField(ManagedIdentitySchema),
Expand Down
6 changes: 5 additions & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@
from ._job.job import Job
from ._job.job_limits import CommandJobLimits
from ._job.job_resource_configuration import JobResourceConfiguration
from ._job.job_service import JobService
from ._job.job_service import JobService, SshJobService,JupyterLabJobService, TensorBoardJobService, VsCodeJobService
from ._job.parallel.parallel_task import ParallelTask
from ._job.parallel.retry_settings import RetrySettings
from ._job.parameterized_command import ParameterizedCommand
Expand Down Expand Up @@ -147,6 +147,10 @@
"ResourceConfiguration",
"JobResourceConfiguration",
"JobService",
"SshJobService",
"TensorBoardJobService",
"VsCodeJobService",
"JupyterLabJobService",
"SparkResourceConfiguration",
"ParameterizedCommand",
"InputPort",
Expand Down
28 changes: 21 additions & 7 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@
)
from azure.ai.ml.entities._job.job_limits import CommandJobLimits
from azure.ai.ml.entities._job.job_resource_configuration import JobResourceConfiguration
from azure.ai.ml.entities._job.job_service import JobService
from azure.ai.ml.entities._job.job_service import (
JobServiceBase,
JobService,
JupyterLabJobService,
SshJobService,
TensorBoardJobService,
VsCodeJobService,
)
from azure.ai.ml.entities._job.sweep.early_termination_policy import EarlyTerminationPolicy
from azure.ai.ml.entities._job.sweep.objective import Objective
from azure.ai.ml.entities._job.sweep.search_space import (
Expand Down Expand Up @@ -121,7 +128,8 @@ class Command(BaseNode):
:type identity: Union[ManagedIdentity, AmlToken, UserIdentity]
:param services: Interactive services for the node. This is an experimental parameter, and may change at any time.
Please see https://aka.ms/azuremlexperimental for more information.
:type services: Dict[str, JobService]
:type services:
Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]
:raises ~azure.ai.ml.exceptions.ValidationException: Raised if Command cannot be successfully validated.
Details will be provided in the error message.
"""
Expand Down Expand Up @@ -154,7 +162,9 @@ def __init__(
environment: Optional[Union[Environment, str]] = None,
environment_variables: Optional[Dict] = None,
resources: Optional[JobResourceConfiguration] = None,
services: Optional[Dict[str, JobService]] = None,
services: Optional[
Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]
] = None,
**kwargs,
):
# validate init params are valid type
Expand Down Expand Up @@ -548,7 +558,7 @@ def _from_rest_object_to_init_params(cls, obj: dict) -> Dict:
# RestJobService, so we need to convert it back. Here we convert the dict to a
# dummy rest object which may work as a RestJobService instead.
services[service_name] = from_rest_dict_to_dummy_rest_object(service)
obj["services"] = JobService._from_rest_job_services(services)
obj["services"] = JobServiceBase._from_rest_job_services(services)

# handle limits
if "limits" in obj and obj["limits"]:
Expand All @@ -574,7 +584,7 @@ def _load_from_rest_job(cls, obj: JobBase) -> "Command":
properties=rest_command_job.properties,
command=rest_command_job.command,
experiment_name=rest_command_job.experiment_name,
services=JobService._from_rest_job_services(rest_command_job.services),
services=JobServiceBase._from_rest_job_services(rest_command_job.services),
status=rest_command_job.status,
creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
code=rest_command_job.code_id,
Expand Down Expand Up @@ -662,7 +672,9 @@ def __call__(self, *args, **kwargs) -> "Command":
)


def _resolve_job_services(services: dict) -> Dict[str, JobService]:
def _resolve_job_services(
services: dict,
) -> Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]:
"""Resolve normal dict to dict[str, JobService]"""
if services is None:
return None
Expand All @@ -680,7 +692,9 @@ def _resolve_job_services(services: dict) -> Dict[str, JobService]:
for name, service in services.items():
if isinstance(service, dict):
service = load_from_dict(JobServiceSchema, service, context={BASE_PATH_CONTEXT_KEY: "."})
elif not isinstance(service, JobService):
elif not isinstance(
service, (JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService)
):
msg = f"Service value for key {name!r} must be a dict or JobService object, got {type(service)} instead."
raise ValidationException(
message=msg,
Expand Down
13 changes: 10 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_builders/command_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@
)
from azure.ai.ml.entities._inputs_outputs import Input, Output
from azure.ai.ml.entities._job.distribution import MpiDistribution, PyTorchDistribution, TensorFlowDistribution
from azure.ai.ml.entities._job.job_service import JobService
from azure.ai.ml.entities._job.job_service import (
JobService,
JupyterLabJobService,
SshJobService,
TensorBoardJobService,
VsCodeJobService,
)
from azure.ai.ml.entities._job.pipeline._component_translatable import ComponentTranslatableMixin
from azure.ai.ml.entities._job.sweep.search_space import SweepDistribution
from azure.ai.ml.exceptions import ErrorTarget, ValidationErrorType, ValidationException
Expand Down Expand Up @@ -124,7 +130,9 @@ def command(
code: Optional[Union[str, os.PathLike]] = None,
identity: Optional[Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]] = None,
is_deterministic: bool = True,
services: Optional[Dict[str, JobService]] = None,
services: Optional[
Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]
] = None,
**kwargs,
) -> Command:
"""Create a Command object which can be used inside dsl.pipeline as a
Expand Down Expand Up @@ -215,7 +223,6 @@ def command(
is_deterministic=is_deterministic,
**kwargs,
)

command_obj = Command(
component=component,
name=name,
Expand Down
17 changes: 13 additions & 4 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_job/command_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,14 @@
validate_inputs_for_command,
)
from azure.ai.ml.entities._job.distribution import DistributionConfiguration
from azure.ai.ml.entities._job.job_service import JobService
from azure.ai.ml.entities._job.job_service import (
JobServiceBase,
JobService,
JupyterLabJobService,
SshJobService,
TensorBoardJobService,
VsCodeJobService,
)
from azure.ai.ml.entities._system_data import SystemData
from azure.ai.ml.entities._util import load_from_dict
from azure.ai.ml.exceptions import ErrorCategory, ErrorTarget, ValidationErrorType, ValidationException
Expand Down Expand Up @@ -101,7 +108,9 @@ def __init__(
identity: Optional[
Union[ManagedIdentityConfiguration, AmlTokenConfiguration, UserIdentityConfiguration]
] = None,
services: Optional[Dict[str, JobService]] = None,
services: Optional[
Dict[str, Union[JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService]]
] = None,
**kwargs,
):
kwargs[TYPE] = JobType.COMMAND
Expand Down Expand Up @@ -163,7 +172,7 @@ def _to_rest_object(self) -> JobBase:
environment_variables=self.environment_variables,
resources=resources._to_rest_object() if resources else None,
limits=self.limits._to_rest_object() if self.limits else None,
services=JobService._to_rest_job_services(self.services),
services=JobServiceBase._to_rest_job_services(self.services),
)
result = JobBase(properties=properties)
result.name = self.name
Expand All @@ -186,7 +195,7 @@ def _load_from_rest(cls, obj: JobBase) -> "CommandJob":
properties=rest_command_job.properties,
command=rest_command_job.command,
experiment_name=rest_command_job.experiment_name,
services=JobService._from_rest_job_services(rest_command_job.services),
services=JobServiceBase._from_rest_job_services(rest_command_job.services),
status=rest_command_job.status,
creation_context=SystemData._from_rest_object(obj.system_data) if obj.system_data else None,
code=rest_command_job.code_id,
Expand Down
Loading

0 comments on commit 0a5a4a2

Please sign in to comment.