Skip to content

Commit

Permalink
Feature: Sagemaker Hyper Parameter Tuning Job (#7917)
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshLevyMN authored Aug 2, 2024
1 parent 890d974 commit ac09274
Show file tree
Hide file tree
Showing 3 changed files with 1,403 additions and 0 deletions.
259 changes: 259 additions & 0 deletions moto/sagemaker/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,12 @@
"limit_default": 100,
"unique_attribute": "arn",
},
"list_hyper_parameter_tuning_jobs": {
"input_token": "next_token",
"limit_key": "max_results",
"limit_default": 100,
"unique_attribute": "arn",
},
}

METRIC_INFO_TYPE = Dict[str, Union[str, int, float, datetime]]
Expand Down Expand Up @@ -2176,6 +2182,139 @@ def summary(self) -> Dict[str, Any]:
}


class HyperParameterTuningJob(BaseObject):
def __init__(
self,
hyper_parameter_tuning_job_name: str,
hyper_parameter_tuning_job_config: Dict[str, Any],
region_name: str,
account_id: str,
training_job_definition: Optional[Dict[str, Any]],
training_job_definitions: Optional[List[Dict[str, Any]]],
warm_start_config: Optional[Dict[str, Any]],
tags: Optional[List[Dict[str, str]]],
autotune: Optional[Dict[str, Any]],
):
self.hyper_parameter_tuning_job_name = hyper_parameter_tuning_job_name
if (
hyper_parameter_tuning_job_name
in sagemaker_backends[account_id][region_name].hyper_parameter_tuning_jobs
):
raise ResourceInUseException(
message=f"Resource Already Exists: Hyper Parameter Tuning Job with name {hyper_parameter_tuning_job_name} already exists. Choose a different name."
)
self.arn = arn_formatter(
"hyper-parameter-tuning-job",
self.hyper_parameter_tuning_job_name,
account_id,
region_name,
)
self.hyper_parameter_tuning_job_config = hyper_parameter_tuning_job_config
self.region_name = region_name
self.account_id = account_id
self.training_job_definition = training_job_definition
self.training_job_definitions = training_job_definitions
self.hyper_parameter_tuning_job_status = "Completed"
self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.last_modified_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.hyper_parameter_tuning_end_time = datetime.now().strftime(
"%Y-%m-%d %H:%M:%S"
)
self.training_job_status_counters = {
"Completed": 1,
"InProgress": 0,
"RetryableError": 0,
"NonRetryableError": 0,
"Stopped": 0,
}
self.objective_status_counters = {
"Succeeded": 1,
"Pending": 0,
"Failed": 0,
}
self.best_training_job = {
"TrainingJobDefinitionName": "string",
"TrainingJobName": "FakeTrainingJobName",
"TrainingJobArn": "FakeTrainingJobArn",
"TuningJobName": "FakeTuningJobName",
"CreationTime": str(datetime(2024, 1, 1)),
"TrainingStartTime": str(datetime(2024, 1, 1)),
"TrainingEndTime": str(datetime(2024, 1, 1)),
"TrainingJobStatus": "Completed",
"TunedHyperParameters": {"string": "TunedHyperParameters"},
"FailureReason": "string",
"FinalHyperParameterTuningJobObjectiveMetric": {
"Type": "Maximize",
"MetricName": "Accuracy",
"Value": 1,
},
"ObjectiveStatus": "Succeeded",
}
self.OverallBestTrainingJob = {
"TrainingJobDefinitionName": "FakeTrainingJobDefinitionName",
"TrainingJobName": "FakeTrainingJobName",
"TrainingJobArn": "FakeTrainingJobArn",
"TuningJobName": "FakeTuningJobName",
"CreationTime": str(datetime(2024, 1, 1)),
"TrainingStartTime": str(datetime(2024, 1, 1)),
"TrainingEndTime": str(datetime(2024, 1, 1)),
"TrainingJobStatus": "Completed",
"TunedHyperParameters": {"string": "FakeTunedHyperParameters"},
"FailureReason": "FakeFailureReason",
"FinalHyperParameterTuningJobObjectiveMetric": {
"Type": "Maximize",
"MetricName": "Acccuracy",
"Value": 1,
},
"ObjectiveStatus": "Succeeded",
}
self.warm_start_config = warm_start_config
self.failure_reason = ""
self.tuning_job_completion_details = {
"NumberOfTrainingJobsObjectiveNotImproving": 123,
"ConvergenceDetectedTime": str(datetime(2024, 1, 1)),
}
self.consumed_resources = {"RuntimeInSeconds": 123}
self.tags = tags
self.autotune = autotune

def describe(self) -> Dict[str, Any]:
return {
"HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name,
"HyperParameterTuningJobArn": self.arn,
"HyperParameterTuningJobConfig": self.hyper_parameter_tuning_job_config,
"TrainingJobDefinition": self.training_job_definition,
"TrainingJobDefinitions": self.training_job_definitions,
"HyperParameterTuningJobStatus": self.hyper_parameter_tuning_job_status,
"CreationTime": self.creation_time,
"HyperParameterTuningEndTime": self.hyper_parameter_tuning_end_time,
"LastModifiedTime": self.last_modified_time,
"TrainingJobStatusCounters": self.training_job_status_counters,
"ObjectiveStatusCounters": self.objective_status_counters,
"BestTrainingJob": self.best_training_job,
"OverallBestTrainingJob": self.OverallBestTrainingJob,
"WarmStartConfig": self.warm_start_config,
"Autotune": self.autotune,
"FailureReason": self.failure_reason,
"TuningJobCompletionDetails": self.tuning_job_completion_details,
"ConsumedResources": self.consumed_resources,
}

def summary(self) -> Dict[str, Any]:
return {
"HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name,
"HyperParameterTuningJobArn": self.arn,
"HyperParameterTuningJobStatus": self.hyper_parameter_tuning_job_status,
"Strategy": self.hyper_parameter_tuning_job_config["Strategy"],
"CreationTime": self.creation_time,
"HyperParameterTuningEndTime": self.hyper_parameter_tuning_end_time,
"LastModifiedTime": self.last_modified_time,
"TrainingJobStatusCounters": self.training_job_status_counters,
"ObjectiveStatusCounters": self.objective_status_counters,
"ResourceLimits": self.hyper_parameter_tuning_job_config["ResourceLimits"],
}


class VpcConfig(BaseObject):
def __init__(self, security_group_ids: List[str], subnets: List[str]):
self.security_group_ids = security_group_ids
Expand Down Expand Up @@ -2576,6 +2715,7 @@ def __init__(self, region_name: str, account_id: str):
self.model_explainability_job_definitions: Dict[
str, ModelExplainabilityJobDefinition
] = {}
self.hyper_parameter_tuning_jobs: Dict[str, HyperParameterTuningJob] = {}

@staticmethod
def default_vpc_endpoint_service(
Expand Down Expand Up @@ -2710,6 +2850,7 @@ def _get_resource_from_arn(self, arn: str) -> Any:
"compilation-job": self.compilation_jobs,
"domain": self.domains,
"model-explainability-job-definition": self.model_explainability_job_definitions,
"hyper-parameter-tuning-job": self.hyper_parameter_tuning_jobs,
}
target_resource, target_name = arn.split(":")[-1].split("/")
try:
Expand Down Expand Up @@ -4999,6 +5140,124 @@ def delete_model_explainability_job_definition(
del self.model_explainability_job_definitions[job_definition_name]
return

def create_hyper_parameter_tuning_job(
self,
hyper_parameter_tuning_job_name: str,
hyper_parameter_tuning_job_config: Dict[str, Any],
training_job_definition: Optional[Dict[str, Any]],
training_job_definitions: Optional[List[Dict[str, Any]]],
warm_start_config: Optional[Dict[str, Any]],
tags: Optional[List[Dict[str, str]]],
autotune: Optional[Dict[str, Any]],
) -> str:
hyper_parameter_tuning_job = HyperParameterTuningJob(
hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name,
hyper_parameter_tuning_job_config=hyper_parameter_tuning_job_config,
region_name=self.region_name,
account_id=self.account_id,
training_job_definition=training_job_definition,
training_job_definitions=training_job_definitions,
warm_start_config=warm_start_config,
tags=tags,
autotune=autotune,
)

self.hyper_parameter_tuning_jobs[hyper_parameter_tuning_job_name] = (
hyper_parameter_tuning_job
)
return hyper_parameter_tuning_job.arn

def describe_hyper_parameter_tuning_job(
self, hyper_parameter_tuning_job_name: str
) -> Dict[str, Any]:
if hyper_parameter_tuning_job_name not in self.hyper_parameter_tuning_jobs:
raise ResourceNotFound(
message=f"Could not find hyper parameter tuning job '{hyper_parameter_tuning_job_name}'."
)
return self.hyper_parameter_tuning_jobs[
hyper_parameter_tuning_job_name
].describe()

@paginate(pagination_model=PAGINATION_MODEL)
def list_hyper_parameter_tuning_jobs(
self,
sort_by: Optional[str],
sort_order: Optional[str],
name_contains: Optional[str],
creation_time_after: Optional[str],
creation_time_before: Optional[str],
last_modified_time_after: Optional[str],
last_modified_time_before: Optional[str],
status_equals: Optional[str],
) -> List[HyperParameterTuningJob]:
hyper_parameter_tuning_jobs = list(self.hyper_parameter_tuning_jobs.values())
if name_contains:
hyper_parameter_tuning_jobs = [
i
for i in hyper_parameter_tuning_jobs
if name_contains in i.hyper_parameter_tuning_job_name
]
if status_equals:
hyper_parameter_tuning_jobs = [
i
for i in hyper_parameter_tuning_jobs
if status_equals == i.hyper_parameter_tuning_job_status
]
if creation_time_before:
hyper_parameter_tuning_jobs = [
i
for i in hyper_parameter_tuning_jobs
if i.creation_time < str(creation_time_before)
]
if creation_time_after:
hyper_parameter_tuning_jobs = [
i
for i in hyper_parameter_tuning_jobs
if i.creation_time > str(creation_time_after)
]
if last_modified_time_before:
hyper_parameter_tuning_jobs = [
i
for i in hyper_parameter_tuning_jobs
if i.last_modified_time < str(last_modified_time_before)
]
if last_modified_time_after:
hyper_parameter_tuning_jobs = [
i
for i in hyper_parameter_tuning_jobs
if i.last_modified_time > str(last_modified_time_after)
]
reverse = sort_order == "Descending"
if sort_by == "Name":
hyper_parameter_tuning_jobs = sorted(
hyper_parameter_tuning_jobs,
key=lambda x: x.hyper_parameter_tuning_job_name,
reverse=reverse,
)
elif sort_by == "Status":
hyper_parameter_tuning_jobs = sorted(
hyper_parameter_tuning_jobs,
key=lambda x: x.hyper_parameter_tuning_job_status,
reverse=reverse,
)
else:
hyper_parameter_tuning_jobs = sorted(
hyper_parameter_tuning_jobs,
key=lambda x: x.creation_time,
reverse=reverse,
)
return hyper_parameter_tuning_jobs

def delete_hyper_parameter_tuning_job(
self, hyper_parameter_tuning_job_name: str
) -> None:
if hyper_parameter_tuning_job_name not in self.hyper_parameter_tuning_jobs:
raise ResourceNotFound(
message=f"Could not find hyper parameter tuning job '{hyper_parameter_tuning_job_name}'."
)
del self.hyper_parameter_tuning_jobs[hyper_parameter_tuning_job_name]
return


class FakeExperiment(BaseObject):
def __init__(
Expand Down
76 changes: 76 additions & 0 deletions moto/sagemaker/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1333,3 +1333,79 @@ def delete_model_explainability_job_definition(self) -> str:
job_definition_name=job_definition_name,
)
return json.dumps(dict())

def create_hyper_parameter_tuning_job(self) -> str:
hyper_parameter_tuning_job_name = self._get_param("HyperParameterTuningJobName")
hyper_parameter_tuning_job_config = self._get_param(
"HyperParameterTuningJobConfig"
)
training_job_definition = self._get_param("TrainingJobDefinition")
training_job_definitions = self._get_param("TrainingJobDefinitions")
warm_start_config = self._get_param("WarmStartConfig")
tags = self._get_param("Tags")
autotune = self._get_param("Autotune")
hyper_parameter_tuning_job_arn = (
self.sagemaker_backend.create_hyper_parameter_tuning_job(
hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name,
hyper_parameter_tuning_job_config=hyper_parameter_tuning_job_config,
training_job_definition=training_job_definition,
training_job_definitions=training_job_definitions,
warm_start_config=warm_start_config,
tags=tags,
autotune=autotune,
)
)
return json.dumps(
dict(HyperParameterTuningJobArn=hyper_parameter_tuning_job_arn)
)

def describe_hyper_parameter_tuning_job(self) -> str:
hyper_parameter_tuning_job_name = self._get_param("HyperParameterTuningJobName")
hyper_parameter_tuning_job_description = (
self.sagemaker_backend.describe_hyper_parameter_tuning_job(
hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name,
)
)
return json.dumps(hyper_parameter_tuning_job_description)

def list_hyper_parameter_tuning_jobs(self) -> str:
next_token = self._get_param("NextToken")
max_results = self._get_param("MaxResults")
sort_by = self._get_param("SortBy")
sort_order = self._get_param("SortOrder")
name_contains = self._get_param("NameContains")
creation_time_after = self._get_param("CreationTimeAfter")
creation_time_before = self._get_param("CreationTimeBefore")
last_modified_time_after = self._get_param("LastModifiedTimeAfter")
last_modified_time_before = self._get_param("LastModifiedTimeBefore")
status_equals = self._get_param("StatusEquals")
hyper_parameter_tuning_jobs, next_token = (
self.sagemaker_backend.list_hyper_parameter_tuning_jobs(
next_token=next_token,
max_results=max_results,
sort_by=sort_by,
sort_order=sort_order,
name_contains=name_contains,
creation_time_after=creation_time_after,
creation_time_before=creation_time_before,
last_modified_time_after=last_modified_time_after,
last_modified_time_before=last_modified_time_before,
status_equals=status_equals,
)
)
hyper_parameter_tuning_job_summaries = [
job.summary() for job in hyper_parameter_tuning_jobs
]
return json.dumps(
dict(
HyperParameterTuningJobSummaries=hyper_parameter_tuning_job_summaries,
NextToken=next_token,
)
)

def delete_hyper_parameter_tuning_job(self) -> str:
hyper_parameter_tuning_job_name = self._get_param("HyperParameterTuningJobName")
self.sagemaker_backend.delete_hyper_parameter_tuning_job(
hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name,
)
return json.dumps(dict())
Loading

0 comments on commit ac09274

Please sign in to comment.