Skip to content

Commit

Permalink
Change: Use pydantic type validation
Browse files Browse the repository at this point in the history
  • Loading branch information
martinRenou committed Nov 27, 2023
1 parent 8462f1a commit e68a8c1
Show file tree
Hide file tree
Showing 12 changed files with 164 additions and 70 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def read_requirements(filename):
"packaging>=20.0",
"pandas",
"pathos",
"pydantic",
"schema",
"PyYAML~=6.0",
"jsonschema",
Expand Down
8 changes: 5 additions & 3 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
to_string,
check_and_get_run_experiment_config,
resolve_value_from_config,
validate_call_inputs,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
Expand Down Expand Up @@ -130,6 +131,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"
JOB_CLASS_NAME = "training-job"

@validate_call_inputs
def __init__(
self,
role: str = None,
Expand All @@ -150,7 +152,7 @@ def __init__(
model_uri: Optional[str] = None,
model_channel_name: Union[str, PipelineVariable] = "model",
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None,
encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None,
use_spot_instances: Union[bool, PipelineVariable] = False,
max_wait: Optional[Union[int, PipelineVariable]] = None,
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
Expand All @@ -159,7 +161,7 @@ def __init__(
debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None,
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
enable_network_isolation: Union[bool, PipelineVariable] = None,
enable_network_isolation: Optional[Union[bool, PipelineVariable]] = None,
profiler_config: Optional[ProfilerConfig] = None,
disable_profiler: bool = None,
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
Expand Down Expand Up @@ -2656,7 +2658,7 @@ def __init__(
model_uri: Optional[str] = None,
model_channel_name: Union[str, PipelineVariable] = "model",
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None,
encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None,
use_spot_instances: Union[bool, PipelineVariable] = False,
max_wait: Optional[Union[int, PipelineVariable]] = None,
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
Expand Down
20 changes: 12 additions & 8 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from __future__ import absolute_import

import logging
from typing import Union, Optional, Dict
from numbers import Number
from typing import Union, Optional, Dict, List

from packaging.version import Version

Expand All @@ -32,6 +33,7 @@
from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.utils import validate_call_inputs

logger = logging.getLogger("sagemaker")

Expand All @@ -44,13 +46,14 @@ class PyTorch(Framework):
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"

@validate_call_inputs
def __init__(
self,
entry_point: Union[str, PipelineVariable],
framework_version: Optional[str] = None,
py_version: Optional[str] = None,
source_dir: Optional[Union[str, PipelineVariable]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
distribution: Optional[Dict] = None,
compiler_config: Optional[TrainingCompilerConfig] = None,
Expand Down Expand Up @@ -354,14 +357,15 @@ def hyperparameters(self):

return hyperparameters

@validate_call_inputs
def create_model(
self,
model_server_workers=None,
role=None,
vpc_config_override=VPC_CONFIG_DEFAULT,
entry_point=None,
source_dir=None,
dependencies=None,
model_server_workers: Optional[int] = None,
role: Optional[str] = None,
vpc_config_override: Optional[Dict[str, List[str]]] = VPC_CONFIG_DEFAULT,
entry_point: Optional[str] = None,
source_dir: Optional[str] = None,
dependencies: Optional[List[str]] = None,
**kwargs,
):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
Expand Down
34 changes: 24 additions & 10 deletions src/sagemaker/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import sagemaker
from sagemaker import image_uris, ModelMetrics
from sagemaker.deserializers import NumpyDeserializer
from sagemaker.deserializers import BaseDeserializer, NumpyDeserializer
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.fw_utils import (
model_code_key_prefix,
Expand All @@ -31,8 +31,10 @@
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.pytorch import defaults
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.utils import to_string
from sagemaker.serializers import BaseSerializer, NumpySerializer
from sagemaker.session import Session
from sagemaker.serverless import ServerlessInferenceConfig
from sagemaker.utils import to_string, validate_call_inputs
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable

Expand All @@ -46,12 +48,13 @@ class PyTorchPredictor(Predictor):
multidimensional tensors for PyTorch inference.
"""

@validate_call_inputs
def __init__(
self,
endpoint_name,
sagemaker_session=None,
serializer=NumpySerializer(),
deserializer=NumpyDeserializer(),
endpoint_name: str,
sagemaker_session: Optional[Session] = None,
serializer: BaseSerializer = NumpySerializer(),
deserializer: BaseDeserializer = NumpyDeserializer(),
):
"""Initialize an ``PyTorchPredictor``.
Expand Down Expand Up @@ -82,12 +85,13 @@ class PyTorchModel(FrameworkModel):
_framework_name = "pytorch"
_LOWEST_MMS_VERSION = "1.2"

@validate_call_inputs
def __init__(
self,
model_data: Union[str, PipelineVariable],
role: Optional[str] = None,
entry_point: Optional[str] = None,
framework_version: str = "1.3",
framework_version: Optional[str] = "1.3",
py_version: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
predictor_cls: callable = PyTorchPredictor,
Expand Down Expand Up @@ -150,6 +154,7 @@ def __init__(

self.model_server_workers = model_server_workers

@validate_call_inputs
def register(
self,
content_types: List[Union[str, PipelineVariable]] = None,
Expand Down Expand Up @@ -264,8 +269,12 @@ def register(
skip_model_validation=skip_model_validation,
)

@validate_call_inputs
def prepare_container_def(
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
self,
instance_type: Optional[str] = None,
accelerator_type: Optional[str] = None,
serverless_inference_config: Optional[ServerlessInferenceConfig] = None
):
"""A container definition with framework configuration set in model environment variables.
Expand Down Expand Up @@ -311,8 +320,13 @@ def prepare_container_def(
deploy_image, self.repacked_model_data or self.model_data, deploy_env
)

@validate_call_inputs
def serving_image_uri(
self, region_name, instance_type, accelerator_type=None, serverless_inference_config=None
self,
region_name: str,
instance_type: str,
accelerator_type: Optional[str] = None,
serverless_inference_config: Optional[ServerlessInferenceConfig] = None
):
"""Create a URI for the serving image.
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/pytorch/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
from sagemaker.processing import FrameworkProcessor
from sagemaker.pytorch.estimator import PyTorch
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.utils import validate_call_inputs


class PyTorchProcessor(FrameworkProcessor):
"""Handles Amazon SageMaker processing tasks for jobs using PyTorch containers."""

estimator_cls = PyTorch

@validate_call_inputs
def __init__(
self,
framework_version: str, # New arg
Expand Down
25 changes: 15 additions & 10 deletions src/sagemaker/sklearn/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from __future__ import absolute_import

import logging
from typing import Union, Optional, Dict
from numbers import Number
from typing import Union, Optional, Dict, List

from sagemaker import image_uris
from sagemaker.deprecations import renamed_kwargs
Expand All @@ -29,6 +30,7 @@
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow import is_pipeline_variable
from sagemaker.utils import validate_call_inputs

logger = logging.getLogger("sagemaker")

Expand All @@ -38,13 +40,14 @@ class SKLearn(Framework):

_framework_name = defaults.SKLEARN_NAME

@validate_call_inputs
def __init__(
self,
entry_point: Union[str, PipelineVariable],
framework_version: Optional[str] = None,
py_version: str = "py3",
py_version: Optional[str] = "py3",
source_dir: Optional[Union[str, PipelineVariable]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
image_uri_region: Optional[str] = None,
**kwargs
Expand Down Expand Up @@ -166,14 +169,15 @@ def __init__(
instance_type=instance_type,
)

@validate_call_inputs
def create_model(
self,
model_server_workers=None,
role=None,
vpc_config_override=VPC_CONFIG_DEFAULT,
entry_point=None,
source_dir=None,
dependencies=None,
model_server_workers: Optional[int] = None,
role: Optional[str] = None,
vpc_config_override: Optional[Union[str, dict[str, List[str]]]] = VPC_CONFIG_DEFAULT,
entry_point: Optional[str] = None,
source_dir: Optional[str] = None,
dependencies: Optional[List[str]] = None,
**kwargs
):
"""Create a SageMaker ``SKLearnModel`` object that can be deployed to an ``Endpoint``.
Expand Down Expand Up @@ -233,7 +237,8 @@ def create_model(
)

@classmethod
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
@validate_call_inputs
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name: Optional[str] = None):
"""Convert the job description to init params that can be handled by the class constructor.
Args:
Expand Down
27 changes: 17 additions & 10 deletions src/sagemaker/sklearn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,17 @@

import sagemaker
from sagemaker import image_uris, ModelMetrics
from sagemaker.deserializers import NumpyDeserializer
from sagemaker.deserializers import BaseDeserializer, NumpyDeserializer
from sagemaker.drift_check_baselines import DriftCheckBaselines
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.serializers import BaseSerializer, NumpySerializer
from sagemaker.serverless import ServerlessInferenceConfig
from sagemaker.session import Session
from sagemaker.sklearn import defaults
from sagemaker.utils import to_string
from sagemaker.utils import to_string, validate_call_inputs
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable

Expand All @@ -40,12 +42,13 @@ class SKLearnPredictor(Predictor):
multidimensional tensors for Scikit-learn inference.
"""

@validate_call_inputs
def __init__(
self,
endpoint_name,
sagemaker_session=None,
serializer=NumpySerializer(),
deserializer=NumpyDeserializer(),
endpoint_name: str,
sagemaker_session: Optional[Session] = None,
serializer: BaseSerializer = NumpySerializer(),
deserializer: BaseDeserializer = NumpyDeserializer(),
):
"""Initialize an ``SKLearnPredictor``.
Expand Down Expand Up @@ -75,13 +78,14 @@ class SKLearnModel(FrameworkModel):

_framework_name = defaults.SKLEARN_NAME

@validate_call_inputs
def __init__(
self,
model_data: Union[str, PipelineVariable],
role: Optional[str] = None,
entry_point: Optional[str] = None,
framework_version: Optional[str] = None,
py_version: str = "py3",
py_version: Optional[str] = "py3",
image_uri: Optional[Union[str, PipelineVariable]] = None,
predictor_cls: callable = SKLearnPredictor,
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
Expand Down Expand Up @@ -143,6 +147,7 @@ def __init__(

self.model_server_workers = model_server_workers

@validate_call_inputs
def register(
self,
content_types: List[Union[str, PipelineVariable]] = None,
Expand Down Expand Up @@ -257,8 +262,9 @@ def register(
skip_model_validation=skip_model_validation,
)

@validate_call_inputs
def prepare_container_def(
self, instance_type=None, accelerator_type=None, serverless_inference_config=None
self, instance_type: Optional[str] = None, accelerator_type: Optional[str] = None, serverless_inference_config: Optional[ServerlessInferenceConfig] = None
):
"""Container definition with framework configuration set in model environment variables.
Expand Down Expand Up @@ -300,7 +306,8 @@ def prepare_container_def(
)
return sagemaker.container_def(deploy_image, model_data_uri, deploy_env)

def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None):
@validate_call_inputs
def serving_image_uri(self, region_name: str, instance_type: Optional[str] = None, serverless_inference_config: Optional[ServerlessInferenceConfig] = None):
"""Create a URI for the serving image.
Args:
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/sklearn/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,13 @@
from sagemaker.processing import ScriptProcessor
from sagemaker.sklearn import defaults
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.utils import validate_call_inputs


class SKLearnProcessor(ScriptProcessor):
"""Handles Amazon SageMaker processing tasks for jobs using scikit-learn."""

@validate_call_inputs
def __init__(
self,
framework_version: str, # New arg
Expand Down
Loading

0 comments on commit e68a8c1

Please sign in to comment.