Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoML -> AutoMLV2 mapper #4500

Merged
merged 2 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion src/sagemaker/automl/automlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""A class for SageMaker AutoML V2 Jobs."""
from __future__ import absolute_import

from __future__ import absolute_import, annotations

import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from sagemaker import Model, PipelineModel, s3
from sagemaker.automl.automl import AutoML
from sagemaker.automl.candidate_estimator import CandidateEstimator
from sagemaker.config import (
AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH,
Expand Down Expand Up @@ -727,6 +729,46 @@ def __init__(
self._auto_ml_job_desc = None
self._best_candidate = None

@classmethod
def from_auto_ml(cls, auto_ml: AutoML) -> AutoMLV2:
"""Create an AutoMLV2 object from an AutoML object.

This method maps AutoML properties into an AutoMLV2 object,
so you can create AutoMLV2 jobs from the existing AutoML objects.

Args:
auto_ml (sagemaker.automl.automl.AutoML): An AutoML object from which
an AutoMLV2 object will be created.
"""
auto_ml_v2 = AutoMLV2(
problem_config=AutoMLTabularConfig(
target_attribute_name=auto_ml.target_attribute_name,
feature_specification_s3_uri=auto_ml.feature_specification_s3_uri,
generate_candidate_definitions_only=auto_ml.generate_candidate_definitions_only,
mode=auto_ml.mode,
problem_type=auto_ml.problem_type,
sample_weight_attribute_name=auto_ml.sample_weight_attribute_name,
max_candidates=auto_ml.max_candidate,
max_runtime_per_training_job_in_seconds=auto_ml.max_runtime_per_training_job_in_seconds, # noqa E501 # pylint: disable=c0301
max_total_job_runtime_in_seconds=auto_ml.total_job_runtime_in_seconds,
),
base_job_name=auto_ml.base_job_name,
output_path=auto_ml.output_path,
output_kms_key=auto_ml.output_kms_key,
job_objective=auto_ml.job_objective,
validation_fraction=auto_ml.validation_fraction,
auto_generate_endpoint_name=auto_ml.auto_generate_endpoint_name,
endpoint_name=auto_ml.endpoint_name,
role=auto_ml.role,
volume_kms_key=auto_ml.volume_kms_key,
encrypt_inter_container_traffic=auto_ml.encrypt_inter_container_traffic,
vpc_config=auto_ml.vpc_config,
tags=auto_ml.tags,
sagemaker_session=auto_ml.sagemaker_session,
)
auto_ml_v2._best_candidate = auto_ml._best_candidate
return auto_ml_v2

def fit(
self,
inputs: Optional[
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/sagemaker/automl/test_auto_ml_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
CandidateEstimator,
LocalAutoMLDataChannel,
PipelineModel,
AutoML,
)
from sagemaker.predictor import Predictor
from sagemaker.session_settings import SessionSettings
Expand Down Expand Up @@ -1100,3 +1101,29 @@ def without_user_input(sess):
expected__with_user_input__with_default_bucket_only="s3://test",
)
assert actual == expected


def test_automl_v1_to_automl_v2_mapping():
auto_ml = AutoML(
role=ROLE,
target_attribute_name=TARGET_ATTRIBUTE_NAME,
sample_weight_attribute_name=SAMPLE_WEIGHT_ATTRIBUTE_NAME,
output_kms_key=OUTPUT_KMS_KEY,
output_path=OUTPUT_PATH,
max_candidates=MAX_CANDIDATES,
base_job_name=BASE_JOB_NAME,
)

auto_ml_v2 = AutoMLV2.from_auto_ml(auto_ml=auto_ml)

assert isinstance(auto_ml_v2.problem_config, AutoMLTabularConfig)
assert auto_ml_v2.role == auto_ml.role
assert auto_ml_v2.problem_config.target_attribute_name == auto_ml.target_attribute_name
assert (
auto_ml_v2.problem_config.sample_weight_attribute_name
== auto_ml.sample_weight_attribute_name
)
assert auto_ml_v2.output_kms_key == auto_ml.output_kms_key
assert auto_ml_v2.output_path == auto_ml.output_path
assert auto_ml_v2.problem_config.max_candidates == auto_ml.max_candidate
assert auto_ml_v2.base_job_name == auto_ml.base_job_name