diff --git a/setup.py b/setup.py index cbdc5cdfc6..7db13c56ad 100644 --- a/setup.py +++ b/setup.py @@ -58,6 +58,7 @@ def read_requirements(filename): "packaging>=20.0", "pandas", "pathos", + "pydantic", "schema", "PyYAML~=6.0", "jsonschema", diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 00b01d2156..a722f1d76e 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -98,6 +98,7 @@ to_string, check_and_get_run_experiment_config, resolve_value_from_config, + validate_call_inputs, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -130,6 +131,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz" JOB_CLASS_NAME = "training-job" + @validate_call_inputs def __init__( self, role: str = None, @@ -150,7 +152,7 @@ def __init__( model_uri: Optional[str] = None, model_channel_name: Union[str, PipelineVariable] = "model", metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, - encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None, + encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None, use_spot_instances: Union[bool, PipelineVariable] = False, max_wait: Optional[Union[int, PipelineVariable]] = None, checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, @@ -159,7 +161,7 @@ def __init__( debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, - enable_network_isolation: Union[bool, PipelineVariable] = None, + enable_network_isolation: Optional[Union[bool, PipelineVariable]] = None, profiler_config: Optional[ProfilerConfig] = None, disable_profiler: bool = None, environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, @@ -2656,7 +2658,7 @@ def __init__( model_uri: Optional[str] = None, model_channel_name: Union[str, PipelineVariable] = "model", metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, - encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None, + encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None, use_spot_instances: Union[bool, PipelineVariable] = False, max_wait: Optional[Union[int, PipelineVariable]] = None, checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index d127a2a2d6..feb40669e8 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -14,7 +14,8 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, Dict +from numbers import Number +from typing import Union, Optional, Dict, List from packaging.version import Version @@ -32,6 +33,7 @@ from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import validate_call_inputs logger = logging.getLogger("sagemaker") @@ -44,13 +46,14 @@ class PyTorch(Framework): LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled" INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" + @validate_call_inputs def __init__( self, entry_point: Union[str, PipelineVariable], framework_version: Optional[str] = None, py_version: Optional[str] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict] = None, compiler_config: Optional[TrainingCompilerConfig] = None, @@ -354,14 +357,15 @@ def hyperparameters(self): return hyperparameters + @validate_call_inputs def create_model( self, - model_server_workers=None, - role=None, - vpc_config_override=VPC_CONFIG_DEFAULT, - entry_point=None, - source_dir=None, - dependencies=None, + model_server_workers: Optional[int] = None, + role: Optional[str] = None, + vpc_config_override: Optional[Dict[str, List[str]]] = VPC_CONFIG_DEFAULT, + entry_point: Optional[str] = None, + source_dir: Optional[str] = None, + dependencies: Optional[List[str]] = None, **kwargs, ): """Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``. diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index f7ab96e128..4d11aaa7ba 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -20,7 +20,7 @@ import sagemaker from sagemaker import image_uris, ModelMetrics -from sagemaker.deserializers import NumpyDeserializer +from sagemaker.deserializers import BaseDeserializer, NumpyDeserializer from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, @@ -31,8 +31,10 @@ from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.pytorch import defaults from sagemaker.predictor import Predictor -from sagemaker.serializers import NumpySerializer -from sagemaker.utils import to_string +from sagemaker.serializers import BaseSerializer, NumpySerializer +from sagemaker.session import Session +from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.utils import to_string, validate_call_inputs from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -46,12 +48,13 @@ class PyTorchPredictor(Predictor): multidimensional tensors for PyTorch inference. """ + @validate_call_inputs def __init__( self, - endpoint_name, - sagemaker_session=None, - serializer=NumpySerializer(), - deserializer=NumpyDeserializer(), + endpoint_name: str, + sagemaker_session: Optional[Session] = None, + serializer: BaseSerializer = NumpySerializer(), + deserializer: BaseDeserializer = NumpyDeserializer(), ): """Initialize an ``PyTorchPredictor``. @@ -82,12 +85,13 @@ class PyTorchModel(FrameworkModel): _framework_name = "pytorch" _LOWEST_MMS_VERSION = "1.2" + @validate_call_inputs def __init__( self, model_data: Union[str, PipelineVariable], role: Optional[str] = None, entry_point: Optional[str] = None, - framework_version: str = "1.3", + framework_version: Optional[str] = "1.3", py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, predictor_cls: callable = PyTorchPredictor, @@ -150,6 +154,7 @@ def __init__( self.model_server_workers = model_server_workers + @validate_call_inputs def register( self, content_types: List[Union[str, PipelineVariable]] = None, @@ -264,8 +269,12 @@ def register( skip_model_validation=skip_model_validation, ) + @validate_call_inputs def prepare_container_def( - self, instance_type=None, accelerator_type=None, serverless_inference_config=None + self, + instance_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + serverless_inference_config: Optional[ServerlessInferenceConfig] = None ): """A container definition with framework configuration set in model environment variables. @@ -311,8 +320,13 @@ def prepare_container_def( deploy_image, self.repacked_model_data or self.model_data, deploy_env ) + @validate_call_inputs def serving_image_uri( - self, region_name, instance_type, accelerator_type=None, serverless_inference_config=None + self, + region_name: str, + instance_type: str, + accelerator_type: Optional[str] = None, + serverless_inference_config: Optional[ServerlessInferenceConfig] = None ): """Create a URI for the serving image. diff --git a/src/sagemaker/pytorch/processing.py b/src/sagemaker/pytorch/processing.py index 70fc96497e..bf8bdb837f 100644 --- a/src/sagemaker/pytorch/processing.py +++ b/src/sagemaker/pytorch/processing.py @@ -24,6 +24,7 @@ from sagemaker.processing import FrameworkProcessor from sagemaker.pytorch.estimator import PyTorch from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import validate_call_inputs class PyTorchProcessor(FrameworkProcessor): @@ -31,6 +32,7 @@ class PyTorchProcessor(FrameworkProcessor): estimator_cls = PyTorch + @validate_call_inputs def __init__( self, framework_version: str, # New arg diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 9f4b25f214..57b59f44bf 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -14,7 +14,8 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, Dict +from numbers import Number +from typing import Union, Optional, Dict, List from sagemaker import image_uris from sagemaker.deprecations import renamed_kwargs @@ -29,6 +30,7 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow import is_pipeline_variable +from sagemaker.utils import validate_call_inputs logger = logging.getLogger("sagemaker") @@ -38,13 +40,14 @@ class SKLearn(Framework): _framework_name = defaults.SKLEARN_NAME + @validate_call_inputs def __init__( self, entry_point: Union[str, PipelineVariable], framework_version: Optional[str] = None, - py_version: str = "py3", + py_version: Optional[str] = "py3", source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, image_uri_region: Optional[str] = None, **kwargs @@ -166,14 +169,15 @@ def __init__( instance_type=instance_type, ) + @validate_call_inputs def create_model( self, - model_server_workers=None, - role=None, - vpc_config_override=VPC_CONFIG_DEFAULT, - entry_point=None, - source_dir=None, - dependencies=None, + model_server_workers: Optional[int] = None, + role: Optional[str] = None, + vpc_config_override: Optional[Union[str, dict[str, List[str]]]] = VPC_CONFIG_DEFAULT, + entry_point: Optional[str] = None, + source_dir: Optional[str] = None, + dependencies: Optional[List[str]] = None, **kwargs ): """Create a SageMaker ``SKLearnModel`` object that can be deployed to an ``Endpoint``. @@ -233,7 +237,8 @@ def create_model( ) @classmethod - def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): + @validate_call_inputs + def _prepare_init_params_from_job_description(cls, job_details, model_channel_name: Optional[str] = None): """Convert the job description to init params that can be handled by the class constructor. Args: diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 42f2614fa3..d953c1ffd9 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -18,15 +18,17 @@ import sagemaker from sagemaker import image_uris, ModelMetrics -from sagemaker.deserializers import NumpyDeserializer +from sagemaker.deserializers import BaseDeserializer, NumpyDeserializer from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import Predictor -from sagemaker.serializers import NumpySerializer +from sagemaker.serializers import BaseSerializer, NumpySerializer +from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.session import Session from sagemaker.sklearn import defaults -from sagemaker.utils import to_string +from sagemaker.utils import to_string, validate_call_inputs from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -40,12 +42,13 @@ class SKLearnPredictor(Predictor): multidimensional tensors for Scikit-learn inference. """ + @validate_call_inputs def __init__( self, - endpoint_name, - sagemaker_session=None, - serializer=NumpySerializer(), - deserializer=NumpyDeserializer(), + endpoint_name: str, + sagemaker_session: Optional[Session] = None, + serializer: BaseSerializer = NumpySerializer(), + deserializer: BaseDeserializer = NumpyDeserializer(), ): """Initialize an ``SKLearnPredictor``. @@ -75,13 +78,14 @@ class SKLearnModel(FrameworkModel): _framework_name = defaults.SKLEARN_NAME + @validate_call_inputs def __init__( self, model_data: Union[str, PipelineVariable], role: Optional[str] = None, entry_point: Optional[str] = None, framework_version: Optional[str] = None, - py_version: str = "py3", + py_version: Optional[str] = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, predictor_cls: callable = SKLearnPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, @@ -143,6 +147,7 @@ def __init__( self.model_server_workers = model_server_workers + @validate_call_inputs def register( self, content_types: List[Union[str, PipelineVariable]] = None, @@ -257,8 +262,9 @@ def register( skip_model_validation=skip_model_validation, ) + @validate_call_inputs def prepare_container_def( - self, instance_type=None, accelerator_type=None, serverless_inference_config=None + self, instance_type: Optional[str] = None, accelerator_type: Optional[str] = None, serverless_inference_config: Optional[ServerlessInferenceConfig] = None ): """Container definition with framework configuration set in model environment variables. @@ -300,7 +306,8 @@ def prepare_container_def( ) return sagemaker.container_def(deploy_image, model_data_uri, deploy_env) - def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): + @validate_call_inputs + def serving_image_uri(self, region_name: str, instance_type: Optional[str] = None, serverless_inference_config: Optional[ServerlessInferenceConfig] = None): """Create a URI for the serving image. Args: diff --git a/src/sagemaker/sklearn/processing.py b/src/sagemaker/sklearn/processing.py index 86d0df9113..0f0e179d83 100644 --- a/src/sagemaker/sklearn/processing.py +++ b/src/sagemaker/sklearn/processing.py @@ -24,11 +24,13 @@ from sagemaker.processing import ScriptProcessor from sagemaker.sklearn import defaults from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import validate_call_inputs class SKLearnProcessor(ScriptProcessor): """Handles Amazon SageMaker processing tasks for jobs using scikit-learn.""" + @validate_call_inputs def __init__( self, framework_version: str, # New arg diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index 293b61f835..ef33847c61 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -41,6 +41,7 @@ from sagemaker.session import Session from sagemaker.network import NetworkConfig from sagemaker.spark import defaults +from sagemaker.utils import validate_call_inputs from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -685,6 +686,7 @@ def _handle_script_dependencies(self, inputs, submit_files, file_type): class PySparkProcessor(_SparkProcessorBase): """Handles Amazon SageMaker processing tasks for jobs using PySpark.""" + @validate_call_inputs def __init__( self, role: str = None, @@ -775,18 +777,19 @@ def __init__( network_config=network_config, ) + @validate_call_inputs def get_run_args( self, - submit_app, - submit_py_files=None, - submit_jars=None, - submit_files=None, - inputs=None, - outputs=None, - arguments=None, - job_name=None, - configuration=None, - spark_event_logs_s3_uri=None, + submit_app: str, + submit_py_files: Optional[List[str]] = None, + submit_jars: Optional[List[str]] = None, + submit_files: Optional[List[str]] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingInput]] = None, + arguments: Optional[List[str]] = None, + job_name: Optional[str] = None, + configuration: Optional[Union[List[dict], dict]] = None, + spark_event_logs_s3_uri: Optional[str] = None, ): """Returns a RunArgs object. @@ -822,9 +825,6 @@ def get_run_args( """ self._current_job_name = self._generate_current_job_name(job_name=job_name) - if not submit_app: - raise ValueError("submit_app is required") - extended_inputs, extended_outputs = self._extend_processing_args( inputs=inputs, outputs=outputs, @@ -842,6 +842,7 @@ def get_run_args( arguments=arguments, ) + @validate_call_inputs @runnable_by_pipeline def run( self, @@ -962,6 +963,7 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): class SparkJarProcessor(_SparkProcessorBase): """Handles Amazon SageMaker processing tasks for jobs using Spark with Java or Scala Jars.""" + @validate_call_inputs def __init__( self, role: str = None, @@ -1052,18 +1054,19 @@ def __init__( network_config=network_config, ) + @validate_call_inputs def get_run_args( self, - submit_app, - submit_class=None, - submit_jars=None, - submit_files=None, - inputs=None, - outputs=None, - arguments=None, - job_name=None, - configuration=None, - spark_event_logs_s3_uri=None, + submit_app: str, + submit_class: Optional[str] = None, + submit_jars: Optional[List[str]] = None, + submit_files: Optional[List[str]] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingInput]] = None, + arguments: Optional[List[str]] = None, + job_name: Optional[str] = None, + configuration: Optional[Union[List[dict], dict]] = None, + spark_event_logs_s3_uri: Optional[str] = None, ): """Returns a RunArgs object. @@ -1099,9 +1102,6 @@ def get_run_args( """ self._current_job_name = self._generate_current_job_name(job_name=job_name) - if not submit_app: - raise ValueError("submit_app is required") - extended_inputs, extended_outputs = self._extend_processing_args( inputs=inputs, outputs=outputs, @@ -1119,6 +1119,7 @@ def get_run_args( arguments=arguments, ) + @validate_call_inputs @runnable_by_pipeline def run( self, @@ -1316,15 +1317,16 @@ class SparkConfigUtils: ] @staticmethod - def validate_configuration(configuration: Dict): + @validate_call_inputs + def validate_configuration(configuration: Union[Dict, list]): """Validates the user-provided Hadoop/Spark/Hive configuration. This ensures that the list or dictionary the user provides will serialize to JSON matching the schema of EMR's application configuration Args: - configuration (Dict): A dict that contains the configuration overrides to - the default values. For more information, please visit: + configuration (Dict or List): A dict or a list of dicts that contains the configuration + overrides to the default values. For more information, please visit: https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html """ emr_configure_apps_url = ( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 75a1a9b246..2b7d70771e 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -22,6 +22,7 @@ import random import re import shutil +import sys import tarfile import tempfile import time @@ -31,6 +32,8 @@ import uuid from datetime import datetime +from pydantic import validate_call, ConfigDict + from importlib import import_module import botocore from botocore.utils import merge_dicts @@ -1449,3 +1452,16 @@ def get_instance_type_family(instance_type: str) -> str: if match is not None: instance_type_family = match[1] return instance_type_family + + +def validate_call_inputs(__func: Optional[callable] = None, *args, config: Optional[ConfigDict] = None, validate_return: bool = False): + """Decorator for function input types using pydantic. + + This calls pydantic.validate_call under the hood, with "arbitrary_types_allowed" enabled. See its documentation for more information. + """ + if config is None: + config = ConfigDict() + + config.setdefault("arbitrary_types_allowed", True) + + return validate_call(__func, *args, config=config, validate_return=validate_return) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index d450238854..e76cf5c2d6 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -19,11 +19,14 @@ from mock import ANY, MagicMock, Mock, patch from packaging.version import Version +from pydantic import ValidationError + from sagemaker import image_uris from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel from sagemaker.instance_group import InstanceGroup from sagemaker.session_settings import SessionSettings +from sagemaker.session import Session DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") @@ -66,8 +69,10 @@ def fixture_sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", + spec=Session, boto_session=boto_mock, boto_region_name=REGION, + sagemaker_client=Mock(), config=None, local_mode=False, s3_resource=None, @@ -223,6 +228,13 @@ def test_create_model( name_from_base.assert_called_with(base_job_name) + with pytest.raises(ValidationError) as error: + PyTorch(entry_point=3, py_version="py3", framework_version=pytorch_inference_version) + + + with pytest.raises(ValidationError) as error: + PyTorch(entry_point="", py_version="py3", framework_version=pytorch_inference_version, role=5) + def test_create_model_with_optional_params( sagemaker_session, pytorch_inference_version, pytorch_inference_py_version diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 9745c4ea26..c2923160a2 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -20,7 +20,10 @@ from mock import Mock from mock import patch +from pydantic import ValidationError + from sagemaker.fw_utils import UploadedCode +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from sagemaker.sklearn import SKLearn, SKLearnModel, SKLearnPredictor @@ -60,7 +63,9 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( + spec=Session, name="sagemaker_session", + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, @@ -171,6 +176,17 @@ def test_training_image_uri(sagemaker_session, sklearn_version): assert _get_full_cpu_image_uri(sklearn_version) == sklearn.training_image_uri() +def test_ctor_wrong_parameters(sagemaker_session, sklearn_version): + with pytest.raises(ValidationError) as error: + SKLearn() + + with pytest.raises(ValidationError) as error: + SKLearn(entry_point=3) + + with pytest.raises(ValidationError) as error: + SKLearn(entry_point="", image_uri_region=3) + + def test_create_model(sagemaker_session, sklearn_version): source_dir = "s3://mybucket/source" @@ -186,6 +202,17 @@ def test_create_model(sagemaker_session, sklearn_version): assert model_values["Image"] == image_uri +def test_create_model_wrong_parameters(sagemaker_session, sklearn_version): + with pytest.raises(ValidationError) as error: + SKLearnModel(model_data=1) + + with pytest.raises(ValidationError) as error: + SKLearnModel(model_data="", role=2) + + with pytest.raises(ValidationError) as error: + SKLearnModel(model_data="", entry_point=3) + + @patch("sagemaker.model.FrameworkModel._upload_code") def test_create_model_with_network_isolation(upload, sagemaker_session, sklearn_version): source_dir = "s3://mybucket/source"