diff --git a/src/sagemaker/automl/automlv2.py b/src/sagemaker/automl/automlv2.py index c855414f0b..8b34f54a95 100644 --- a/src/sagemaker/automl/automlv2.py +++ b/src/sagemaker/automl/automlv2.py @@ -11,13 +11,15 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """A class for SageMaker AutoML V2 Jobs.""" -from __future__ import absolute_import + +from __future__ import absolute_import, annotations import logging from dataclasses import dataclass from typing import Dict, List, Optional, Union from sagemaker import Model, PipelineModel, s3 +from sagemaker.automl.automl import AutoML from sagemaker.automl.candidate_estimator import CandidateEstimator from sagemaker.config import ( AUTO_ML_INTER_CONTAINER_ENCRYPTION_PATH, @@ -727,6 +729,46 @@ def __init__( self._auto_ml_job_desc = None self._best_candidate = None + @classmethod + def from_auto_ml(cls, auto_ml: AutoML) -> AutoMLV2: + """Create an AutoMLV2 object from an AutoML object. + + This method maps AutoML properties into an AutoMLV2 object, + so you can create AutoMLV2 jobs from the existing AutoML objects. + + Args: + auto_ml (sagemaker.automl.automl.AutoML): An AutoML object from which + an AutoMLV2 object will be created. + """ + auto_ml_v2 = AutoMLV2( + problem_config=AutoMLTabularConfig( + target_attribute_name=auto_ml.target_attribute_name, + feature_specification_s3_uri=auto_ml.feature_specification_s3_uri, + generate_candidate_definitions_only=auto_ml.generate_candidate_definitions_only, + mode=auto_ml.mode, + problem_type=auto_ml.problem_type, + sample_weight_attribute_name=auto_ml.sample_weight_attribute_name, + max_candidates=auto_ml.max_candidate, + max_runtime_per_training_job_in_seconds=auto_ml.max_runtime_per_training_job_in_seconds, # noqa E501 # pylint: disable=c0301 + max_total_job_runtime_in_seconds=auto_ml.total_job_runtime_in_seconds, + ), + base_job_name=auto_ml.base_job_name, + output_path=auto_ml.output_path, + output_kms_key=auto_ml.output_kms_key, + job_objective=auto_ml.job_objective, + validation_fraction=auto_ml.validation_fraction, + auto_generate_endpoint_name=auto_ml.auto_generate_endpoint_name, + endpoint_name=auto_ml.endpoint_name, + role=auto_ml.role, + volume_kms_key=auto_ml.volume_kms_key, + encrypt_inter_container_traffic=auto_ml.encrypt_inter_container_traffic, + vpc_config=auto_ml.vpc_config, + tags=auto_ml.tags, + sagemaker_session=auto_ml.sagemaker_session, + ) + auto_ml_v2._best_candidate = auto_ml._best_candidate + return auto_ml_v2 + def fit( self, inputs: Optional[ diff --git a/tests/unit/sagemaker/automl/test_auto_ml_v2.py b/tests/unit/sagemaker/automl/test_auto_ml_v2.py index 94d87b0a8e..3b1bfa76ed 100644 --- a/tests/unit/sagemaker/automl/test_auto_ml_v2.py +++ b/tests/unit/sagemaker/automl/test_auto_ml_v2.py @@ -24,6 +24,7 @@ CandidateEstimator, LocalAutoMLDataChannel, PipelineModel, + AutoML, ) from sagemaker.predictor import Predictor from sagemaker.session_settings import SessionSettings @@ -1100,3 +1101,29 @@ def without_user_input(sess): expected__with_user_input__with_default_bucket_only="s3://test", ) assert actual == expected + + +def test_automl_v1_to_automl_v2_mapping(): + auto_ml = AutoML( + role=ROLE, + target_attribute_name=TARGET_ATTRIBUTE_NAME, + sample_weight_attribute_name=SAMPLE_WEIGHT_ATTRIBUTE_NAME, + output_kms_key=OUTPUT_KMS_KEY, + output_path=OUTPUT_PATH, + max_candidates=MAX_CANDIDATES, + base_job_name=BASE_JOB_NAME, + ) + + auto_ml_v2 = AutoMLV2.from_auto_ml(auto_ml=auto_ml) + + assert isinstance(auto_ml_v2.problem_config, AutoMLTabularConfig) + assert auto_ml_v2.role == auto_ml.role + assert auto_ml_v2.problem_config.target_attribute_name == auto_ml.target_attribute_name + assert ( + auto_ml_v2.problem_config.sample_weight_attribute_name + == auto_ml.sample_weight_attribute_name + ) + assert auto_ml_v2.output_kms_key == auto_ml.output_kms_key + assert auto_ml_v2.output_path == auto_ml.output_path + assert auto_ml_v2.problem_config.max_candidates == auto_ml.max_candidate + assert auto_ml_v2.base_job_name == auto_ml.base_job_name