Skip to content

Commit

Permalink
fix: Move func and args serialization of function step to step level (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
qidewenwhen authored Dec 14, 2023
1 parent bb0f1cc commit 9269053
Show file tree
Hide file tree
Showing 10 changed files with 328 additions and 85 deletions.
41 changes: 37 additions & 4 deletions src/sagemaker/remote_function/core/pipeline_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@ class _ExecutionVariable:
name: str


@dataclass
class _S3BaseUriIdentifier:
"""Identifies that the class refers to function step s3 base uri.
The s3_base_uri = s3_root_uri + pipeline_name.
This identifier is resolved in function step runtime by SDK.
"""

NAME = "S3_BASE_URI"


@dataclass
class _DelayedReturn:
"""Delayed return from a function."""
Expand Down Expand Up @@ -155,6 +166,7 @@ def __init__(
hmac_key: str,
parameter_resolver: _ParameterResolver,
execution_variable_resolver: _ExecutionVariableResolver,
s3_base_uri: str,
**settings,
):
"""Resolve delayed return.
Expand All @@ -164,8 +176,12 @@ def __init__(
hmac_key: key used to encrypt serialized and deserialized function and arguments.
parameter_resolver: resolver used to pipeline parameters.
execution_variable_resolver: resolver used to resolve execution variables.
s3_base_uri (str): the s3 base uri of the function step that
the serialized artifacts will be uploaded to.
The s3_base_uri = s3_root_uri + pipeline_name.
**settings: settings to pass to the deserialization function.
"""
self._s3_base_uri = s3_base_uri
self._parameter_resolver = parameter_resolver
self._execution_variable_resolver = execution_variable_resolver
# different delayed returns can have the same uri, so we need to dedupe
Expand Down Expand Up @@ -205,6 +221,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn):
uri.append(self._parameter_resolver.resolve(component))
elif isinstance(component, _ExecutionVariable):
uri.append(self._execution_variable_resolver.resolve(component))
elif isinstance(component, _S3BaseUriIdentifier):
uri.append(self._s3_base_uri)
else:
uri.append(component)
return s3_path_join(*uri)
Expand All @@ -219,7 +237,12 @@ def _retrieve_child_item(delayed_return: _DelayedReturn, deserialized_obj: Any):


def resolve_pipeline_variables(
context: Context, func_args: Tuple, func_kwargs: Dict, hmac_key: str, **settings
context: Context,
func_args: Tuple,
func_kwargs: Dict,
hmac_key: str,
s3_base_uri: str,
**settings,
):
"""Resolve pipeline variables.
Expand All @@ -228,6 +251,8 @@ def resolve_pipeline_variables(
func_args: function args.
func_kwargs: function kwargs.
hmac_key: key used to encrypt serialized and deserialized function and arguments.
s3_base_uri: the s3 base uri of the function step that the serialized artifacts
will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
**settings: settings to pass to the deserialization function.
"""

Expand All @@ -251,6 +276,7 @@ def resolve_pipeline_variables(
hmac_key=hmac_key,
parameter_resolver=parameter_resolver,
execution_variable_resolver=execution_variable_resolver,
s3_base_uri=s3_base_uri,
**settings,
)

Expand Down Expand Up @@ -289,11 +315,10 @@ def resolve_pipeline_variables(
return resolved_func_args, resolved_func_kwargs


def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple, func_kwargs: Dict):
def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict):
"""Convert pipeline variables to pickleable.
Args:
s3_base_uri: s3 base uri where artifacts are stored.
func_args: function args.
func_kwargs: function kwargs.
"""
Expand All @@ -304,11 +329,19 @@ def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple,

from sagemaker.workflow.function_step import DelayedReturn

# Notes:
# 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown
# when defining function steps. After step-level arg serialization,
# it's hard to update the s3_base_uri in pipeline compile time.
# Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it.
# 2. For saying s3_root_uri is unknown, it's because when defining function steps,
# the pipeline's sagemaker_session is not passed in, but the default s3_root_uri
# should be retrieved from the pipeline's sagemaker_session.
def convert(arg):
if isinstance(arg, DelayedReturn):
return _DelayedReturn(
uri=[
s3_base_uri,
_S3BaseUriIdentifier(),
ExecutionVariables.PIPELINE_EXECUTION_ID._pickleable,
arg._step.name,
"results",
Expand Down
59 changes: 37 additions & 22 deletions src/sagemaker/remote_function/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,13 @@ def serialize_func_to_s3(
Raises:
SerializationError: when fail to serialize function to bytes.
"""
bytes_to_upload = CloudpickleSerializer.serialize(func)

_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
f"{s3_uri}/metadata.json",
s3_kms_key,
sagemaker_session,
_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(func),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
)


Expand Down Expand Up @@ -220,17 +216,12 @@ def serialize_obj_to_s3(
SerializationError: when fail to serialize object to bytes.
"""

bytes_to_upload = CloudpickleSerializer.serialize(obj)

_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)

_upload_bytes_to_s3(
_MetaData(sha256_hash).to_json(),
f"{s3_uri}/metadata.json",
s3_kms_key,
sagemaker_session,
_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(obj),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
)


Expand Down Expand Up @@ -318,8 +309,32 @@ def serialize_exception_to_s3(
"""
pickling_support.install()

bytes_to_upload = CloudpickleSerializer.serialize(exc)
_upload_payload_and_metadata_to_s3(
bytes_to_upload=CloudpickleSerializer.serialize(exc),
hmac_key=hmac_key,
s3_uri=s3_uri,
sagemaker_session=sagemaker_session,
s3_kms_key=s3_kms_key,
)


def _upload_payload_and_metadata_to_s3(
bytes_to_upload: Union[bytes, io.BytesIO],
hmac_key: str,
s3_uri: str,
sagemaker_session: Session,
s3_kms_key,
):
"""Uploads serialized payload and metadata to s3.
Args:
bytes_to_upload (bytes): Serialized bytes to upload.
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
sagemaker_session (sagemaker.session.Session):
The underlying Boto3 session which AWS service calls are delegated to.
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
"""
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)

sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
Expand Down
42 changes: 42 additions & 0 deletions src/sagemaker/remote_function/core/stored_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

import os
from dataclasses import dataclass
from typing import Any


Expand All @@ -36,6 +37,14 @@
JSON_RESULTS_FILE = "results.json"


@dataclass
class _SerializedData:
"""Data class to store serialized function and arguments"""

func: bytes
args: bytes


class StoredFunction:
"""Class representing a remote function stored in S3."""

Expand Down Expand Up @@ -105,6 +114,38 @@ def save(self, func, *args, **kwargs):
s3_kms_key=self.s3_kms_key,
)

def save_pipeline_step_function(self, serialized_data):
"""Upload serialized function and arguments to s3.
Args:
serialized_data (_SerializedData): The serialized function
and function arguments of a function step.
"""

logger.info(
"Uploading serialized function code to %s",
s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
)
serialization._upload_payload_and_metadata_to_s3(
bytes_to_upload=serialized_data.func,
hmac_key=self.hmac_key,
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
sagemaker_session=self.sagemaker_session,
s3_kms_key=self.s3_kms_key,
)

logger.info(
"Uploading serialized function arguments to %s",
s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
)
serialization._upload_payload_and_metadata_to_s3(
bytes_to_upload=serialized_data.args,
hmac_key=self.hmac_key,
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
sagemaker_session=self.sagemaker_session,
s3_kms_key=self.s3_kms_key,
)

def load_and_invoke(self) -> Any:
"""Load and deserialize the function and the arguments and then execute it."""

Expand Down Expand Up @@ -134,6 +175,7 @@ def load_and_invoke(self) -> Any:
args,
kwargs,
hmac_key=self.hmac_key,
s3_base_uri=self.s3_base_uri,
sagemaker_session=self.sagemaker_session,
)

Expand Down
16 changes: 5 additions & 11 deletions src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,8 @@
from sagemaker.utils import name_from_base, _tmpdir, resolve_value_from_config
from sagemaker.s3 import s3_path_join, S3Uploader
from sagemaker import vpc_utils
from sagemaker.remote_function.core.stored_function import StoredFunction
from sagemaker.remote_function.core.pipeline_variables import (
Context,
convert_pipeline_variables_to_pickleable,
)
from sagemaker.remote_function.core.stored_function import StoredFunction, _SerializedData
from sagemaker.remote_function.core.pipeline_variables import Context
from sagemaker.remote_function.runtime_environment.runtime_environment_manager import (
RuntimeEnvironmentManager,
_DependencySettings,
Expand Down Expand Up @@ -695,6 +692,7 @@ def compile(
func_args: tuple,
func_kwargs: dict,
run_info=None,
serialized_data: _SerializedData = None,
) -> dict:
"""Build the artifacts and generate the training job request."""
from sagemaker.workflow.properties import Properties
Expand Down Expand Up @@ -732,12 +730,8 @@ def compile(
func_step_s3_dir=step_compilation_context.pipeline_build_time,
),
)
converted_func_args, converted_func_kwargs = convert_pipeline_variables_to_pickleable(
s3_base_uri=s3_base_uri,
func_args=func_args,
func_kwargs=func_kwargs,
)
stored_function.save(func, *converted_func_args, **converted_func_kwargs)

stored_function.save_pipeline_step_function(serialized_data)

stopping_condition = {
"MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds,
Expand Down
21 changes: 21 additions & 0 deletions src/sagemaker/workflow/function_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ def __init__(
func_kwargs (dict): keyword arguments of the python function.
**kwargs: Additional arguments to be passed to the `step` decorator.
"""
from sagemaker.remote_function.core.pipeline_variables import (
convert_pipeline_variables_to_pickleable,
)
from sagemaker.remote_function.core.serialization import CloudpickleSerializer
from sagemaker.remote_function.core.stored_function import _SerializedData

super(_FunctionStep, self).__init__(
name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies
Expand All @@ -96,6 +101,21 @@ def __init__(

self.__job_settings = None

(
self._converted_func_args,
self._converted_func_kwargs,
) = convert_pipeline_variables_to_pickleable(
func_args=self._func_args,
func_kwargs=self._func_kwargs,
)

self._serialized_data = _SerializedData(
func=CloudpickleSerializer.serialize(self._func),
args=CloudpickleSerializer.serialize(
(self._converted_func_args, self._converted_func_kwargs)
),
)

@property
def func(self):
"""The python function to run as a pipeline step."""
Expand Down Expand Up @@ -185,6 +205,7 @@ def arguments(self) -> RequestType:
func=self.func,
func_args=self.func_args,
func_kwargs=self.func_kwargs,
serialized_data=self._serialized_data,
)
# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(request_dict, "TrainingJobName", step_compilation_context)
Expand Down
Loading

0 comments on commit 9269053

Please sign in to comment.