Skip to content

Commit

Permalink
Add AutoML -> AutoMLV2 mapper (aws#4500)
Browse files Browse the repository at this point in the history
Co-authored-by: liujiaor <128006184+liujiaorr@users.noreply.github.com>
  • Loading branch information
repushko and liujiaorr authored Mar 14, 2024
1 parent fada4bf commit 8d22789
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 1 deletion.
44 changes: 43 additions & 1 deletion src/sagemaker/automl/automlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""A class for SageMaker AutoML V2 Jobs."""
from __future__ import absolute_import

from __future__ import absolute_import, annotations

import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from sagemaker import Model, PipelineModel, s3
from sagemaker.automl.automl import AutoML
from sagemaker.automl.candidate_estimator import CandidateEstimator
from sagemaker.config import (
AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH,
Expand Down Expand Up @@ -727,6 +729,46 @@ def __init__(
self._auto_ml_job_desc = None
self._best_candidate = None

@classmethod
def from_auto_ml(cls, auto_ml: AutoML) -> AutoMLV2:
"""Create an AutoMLV2 object from an AutoML object.
This method maps AutoML properties into an AutoMLV2 object,
so you can create AutoMLV2 jobs from the existing AutoML objects.
Args:
auto_ml (sagemaker.automl.automl.AutoML): An AutoML object from which
an AutoMLV2 object will be created.
"""
auto_ml_v2 = AutoMLV2(
problem_config=AutoMLTabularConfig(
target_attribute_name=auto_ml.target_attribute_name,
feature_specification_s3_uri=auto_ml.feature_specification_s3_uri,
generate_candidate_definitions_only=auto_ml.generate_candidate_definitions_only,
mode=auto_ml.mode,
problem_type=auto_ml.problem_type,
sample_weight_attribute_name=auto_ml.sample_weight_attribute_name,
max_candidates=auto_ml.max_candidate,
max_runtime_per_training_job_in_seconds=auto_ml.max_runtime_per_training_job_in_seconds, # noqa E501 # pylint: disable=c0301
max_total_job_runtime_in_seconds=auto_ml.total_job_runtime_in_seconds,
),
base_job_name=auto_ml.base_job_name,
output_path=auto_ml.output_path,
output_kms_key=auto_ml.output_kms_key,
job_objective=auto_ml.job_objective,
validation_fraction=auto_ml.validation_fraction,
auto_generate_endpoint_name=auto_ml.auto_generate_endpoint_name,
endpoint_name=auto_ml.endpoint_name,
role=auto_ml.role,
volume_kms_key=auto_ml.volume_kms_key,
encrypt_inter_container_traffic=auto_ml.encrypt_inter_container_traffic,
vpc_config=auto_ml.vpc_config,
tags=auto_ml.tags,
sagemaker_session=auto_ml.sagemaker_session,
)
auto_ml_v2._best_candidate = auto_ml._best_candidate
return auto_ml_v2

def fit(
self,
inputs: Optional[
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/sagemaker/automl/test_auto_ml_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
CandidateEstimator,
LocalAutoMLDataChannel,
PipelineModel,
AutoML,
)
from sagemaker.predictor import Predictor
from sagemaker.session_settings import SessionSettings
Expand Down Expand Up @@ -1100,3 +1101,29 @@ def without_user_input(sess):
expected__with_user_input__with_default_bucket_only="s3://test",
)
assert actual == expected


def test_automl_v1_to_automl_v2_mapping():
auto_ml = AutoML(
role=ROLE,
target_attribute_name=TARGET_ATTRIBUTE_NAME,
sample_weight_attribute_name=SAMPLE_WEIGHT_ATTRIBUTE_NAME,
output_kms_key=OUTPUT_KMS_KEY,
output_path=OUTPUT_PATH,
max_candidates=MAX_CANDIDATES,
base_job_name=BASE_JOB_NAME,
)

auto_ml_v2 = AutoMLV2.from_auto_ml(auto_ml=auto_ml)

assert isinstance(auto_ml_v2.problem_config, AutoMLTabularConfig)
assert auto_ml_v2.role == auto_ml.role
assert auto_ml_v2.problem_config.target_attribute_name == auto_ml.target_attribute_name
assert (
auto_ml_v2.problem_config.sample_weight_attribute_name
== auto_ml.sample_weight_attribute_name
)
assert auto_ml_v2.output_kms_key == auto_ml.output_kms_key
assert auto_ml_v2.output_path == auto_ml.output_path
assert auto_ml_v2.problem_config.max_candidates == auto_ml.max_candidate
assert auto_ml_v2.base_job_name == auto_ml.base_job_name

0 comments on commit 8d22789

Please sign in to comment.