Skip to content

Commit

Permalink
Feature: Sagemaker Model Quality Job Definition (#7916)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshLevyMN authored Aug 2, 2024
1 parent ec76952 commit 210a2a3
Show file tree
Hide file tree
Showing 3 changed files with 830 additions and 0 deletions.
179 changes: 179 additions & 0 deletions moto/sagemaker/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,12 @@
"limit_default": 100,
"unique_attribute": "arn",
},
"list_model_quality_job_definitions": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "arn",
},
}

METRIC_INFO_TYPE = Dict[str, Union[str, int, float, datetime]]
Expand Down Expand Up @@ -2315,6 +2321,76 @@ def summary(self) -> Dict[str, Any]:
}


class ModelQualityJobDefinition(BaseObject):
def __init__(
self,
job_definition_name: str,
model_quality_baseline_config: Optional[Dict[str, Any]],
model_quality_app_specification: Dict[str, Any],
model_quality_job_input: Dict[str, Any],
model_quality_job_output_config: Dict[str, Any],
job_resources: Dict[str, Any],
network_config: Optional[Dict[str, Any]],
role_arn: str,
stopping_condition: Optional[Dict[str, Any]],
tags: Optional[List[Dict[str, str]]],
region_name: str,
account_id: str,
):
self.region_name = region_name
self.account_id = account_id
self.job_definition_name = job_definition_name
if (
job_definition_name
in sagemaker_backends[account_id][region_name].model_quality_job_definitions
):
raise ResourceInUseException(
message=f"Resource Already Exists: Model Quality Job Definition with name {job_definition_name} already exists. Choose a different name."
)
self.model_quality_baseline_config = model_quality_baseline_config
self.model_quality_app_specification = model_quality_app_specification
self.model_quality_job_input = model_quality_job_input
self.model_quality_job_output_config = model_quality_job_output_config
self.job_resources = job_resources
self.network_config = network_config
self.role_arn = role_arn
self.stopping_condition = stopping_condition
self.tags = tags or []
self.arn = arn_formatter(
"model-quality-job-definition",
self.job_definition_name,
account_id,
region_name,
)
self.creation_time = str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
self.endpoint_name = self.model_quality_job_input["EndpointInput"][
"EndpointName"
]

def describe(self) -> Dict[str, Any]:
return {
"JobDefinitionArn": self.arn,
"JobDefinitionName": self.job_definition_name,
"CreationTime": self.creation_time,
"ModelQualityBaselineConfig": self.model_quality_baseline_config,
"ModelQualityAppSpecification": self.model_quality_app_specification,
"ModelQualityJobInput": self.model_quality_job_input,
"ModelQualityJobOutputConfig": self.model_quality_job_output_config,
"JobResources": self.job_resources,
"NetworkConfig": self.network_config,
"RoleArn": self.role_arn,
"StoppingCondition": self.stopping_condition,
}

def summary(self) -> Dict[str, Any]:
return {
"MonitoringJobDefinitionName": self.job_definition_name,
"MonitoringJobDefinitionArn": self.arn,
"CreationTime": self.creation_time,
"EndpointName": self.endpoint_name,
}


class VpcConfig(BaseObject):
def __init__(self, security_group_ids: List[str], subnets: List[str]):
self.security_group_ids = security_group_ids
Expand Down Expand Up @@ -2716,6 +2792,7 @@ def __init__(self, region_name: str, account_id: str):
str, ModelExplainabilityJobDefinition
] = {}
self.hyper_parameter_tuning_jobs: Dict[str, HyperParameterTuningJob] = {}
self.model_quality_job_definitions: Dict[str, ModelQualityJobDefinition] = {}

@staticmethod
def default_vpc_endpoint_service(
Expand Down Expand Up @@ -2851,6 +2928,7 @@ def _get_resource_from_arn(self, arn: str) -> Any:
"domain": self.domains,
"model-explainability-job-definition": self.model_explainability_job_definitions,
"hyper-parameter-tuning-job": self.hyper_parameter_tuning_jobs,
"model-quality-job-definition": self.model_quality_job_definitions,
}
target_resource, target_name = arn.split(":")[-1].split("/")
try:
Expand Down Expand Up @@ -5258,6 +5336,107 @@ def delete_hyper_parameter_tuning_job(
del self.hyper_parameter_tuning_jobs[hyper_parameter_tuning_job_name]
return

def create_model_quality_job_definition(
self,
job_definition_name: str,
model_quality_baseline_config: Optional[Dict[str, Any]],
model_quality_app_specification: Dict[str, Any],
model_quality_job_input: Dict[str, Any],
model_quality_job_output_config: Dict[str, Any],
job_resources: Dict[str, Any],
network_config: Optional[Dict[str, Any]],
role_arn: str,
stopping_condition: Optional[Dict[str, Any]],
tags: Optional[List[Dict[str, str]]],
) -> str:
model_quality_job_definition = ModelQualityJobDefinition(
job_definition_name=job_definition_name,
model_quality_baseline_config=model_quality_baseline_config,
model_quality_app_specification=model_quality_app_specification,
model_quality_job_input=model_quality_job_input,
model_quality_job_output_config=model_quality_job_output_config,
job_resources=job_resources,
network_config=network_config,
role_arn=role_arn,
stopping_condition=stopping_condition,
region_name=self.region_name,
account_id=self.account_id,
tags=tags,
)
self.model_quality_job_definitions[job_definition_name] = (
model_quality_job_definition
)
return model_quality_job_definition.arn

def describe_model_quality_job_definition(
self, job_definition_name: str
) -> Dict[str, Any]:
if job_definition_name not in self.model_quality_job_definitions:
raise ResourceNotFound(
message=f"Could not find model quality job definition '{job_definition_name}'."
)
return self.model_quality_job_definitions[job_definition_name].describe()

@paginate(pagination_model=PAGINATION_MODEL)
def list_model_quality_job_definitions(
self,
endpoint_name: Optional[str],
sort_by: Optional[str],
sort_order: Optional[str],
name_contains: Optional[str],
creation_time_before: Optional[str],
creation_time_after: Optional[str],
) -> List[ModelQualityJobDefinition]:
model_quality_job_definitions = list(
self.model_quality_job_definitions.values()
)
if endpoint_name:
model_quality_job_definitions = [
i
for i in model_quality_job_definitions
if endpoint_name == i.endpoint_name
]
if name_contains:
model_quality_job_definitions = [
i
for i in model_quality_job_definitions
if name_contains in i.job_definition_name
]
if creation_time_before:
model_quality_job_definitions = [
i
for i in model_quality_job_definitions
if i.creation_time < str(creation_time_before)
]
if creation_time_after:
model_quality_job_definitions = [
i
for i in model_quality_job_definitions
if i.creation_time > str(creation_time_after)
]
reverse = sort_order == "Descending"
if sort_by == "Name":
model_quality_job_definitions = sorted(
model_quality_job_definitions,
key=lambda x: x.job_definition_name,
reverse=reverse,
)
if sort_by == "CreationTime" or sort_by is None:
model_quality_job_definitions = sorted(
model_quality_job_definitions,
key=lambda x: x.creation_time,
reverse=reverse,
)
return model_quality_job_definitions

def delete_model_quality_job_definition(self, job_definition_name: str) -> None:
if job_definition_name not in self.model_quality_job_definitions:
raise ResourceNotFound(
message=f"Could not find model quality job definition '{job_definition_name}'."
)
del self.model_quality_job_definitions[job_definition_name]
return


class FakeExperiment(BaseObject):
def __init__(
Expand Down
67 changes: 67 additions & 0 deletions moto/sagemaker/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,3 +1409,70 @@ def delete_hyper_parameter_tuning_job(self) -> str:
hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name,
)
return json.dumps(dict())

def create_model_quality_job_definition(self) -> str:
job_definition_name = self._get_param("JobDefinitionName")
model_quality_baseline_config = self._get_param("ModelQualityBaselineConfig")
model_quality_app_specification = self._get_param(
"ModelQualityAppSpecification"
)
model_quality_job_input = self._get_param("ModelQualityJobInput")
model_quality_job_output_config = self._get_param("ModelQualityJobOutputConfig")
job_resources = self._get_param("JobResources")
network_config = self._get_param("NetworkConfig")
role_arn = self._get_param("RoleArn")
stopping_condition = self._get_param("StoppingCondition")
tags = self._get_param("Tags")
job_definition_arn = self.sagemaker_backend.create_model_quality_job_definition(
job_definition_name=job_definition_name,
model_quality_baseline_config=model_quality_baseline_config,
model_quality_app_specification=model_quality_app_specification,
model_quality_job_input=model_quality_job_input,
model_quality_job_output_config=model_quality_job_output_config,
job_resources=job_resources,
network_config=network_config,
role_arn=role_arn,
stopping_condition=stopping_condition,
tags=tags,
)
return json.dumps(dict(JobDefinitionArn=job_definition_arn))

def describe_model_quality_job_definition(self) -> str:
job_definition_name = self._get_param("JobDefinitionName")
description = self.sagemaker_backend.describe_model_quality_job_definition(
job_definition_name=job_definition_name,
)
return json.dumps(description)

def list_model_quality_job_definitions(self) -> str:
endpoint_name = self._get_param("EndpointName")
sort_by = self._get_param("SortBy")
sort_order = self._get_param("SortOrder")
next_token = self._get_param("NextToken")
max_results = self._get_param("MaxResults")
name_contains = self._get_param("NameContains")
creation_time_before = self._get_param("CreationTimeBefore")
creation_time_after = self._get_param("CreationTimeAfter")
job_definitions, next_token = (
self.sagemaker_backend.list_model_quality_job_definitions(
endpoint_name=endpoint_name,
sort_by=sort_by,
sort_order=sort_order,
next_token=next_token,
max_results=max_results,
name_contains=name_contains,
creation_time_before=creation_time_before,
creation_time_after=creation_time_after,
)
)
job_definition_summaries = [x.summary() for x in job_definitions]
return json.dumps(
dict(JobDefinitionSummaries=job_definition_summaries, NextToken=next_token)
)

def delete_model_quality_job_definition(self) -> str:
job_definition_name = self._get_param("JobDefinitionName")
self.sagemaker_backend.delete_model_quality_job_definition(
job_definition_name=job_definition_name,
)
return json.dumps(dict())
Loading

0 comments on commit 210a2a3

Please sign in to comment.