Skip to content

Commit

Permalink
feature: support local mode in SageMaker Studio (aws#1300) (aws#4347)
Browse files Browse the repository at this point in the history
* feature: support local mode in SageMaker Studio

* chore: fix typo

* chore: fix formatting

* chore: revert changes for docker compose logs

* chore: black-format

* change: Use predtermined dns-allow-listed-hostname for Studio Local Support

* add support for CodeEditor and JupyterLabs

---------

Co-authored-by: Erick Benitez-Ramos <141277478+benieric@users.noreply.github.com>
Co-authored-by: Mufaddal Rohawala <mufi@amazon.com>
  • Loading branch information
3 people authored Dec 28, 2023
1 parent c797f2d commit 210f334
Show file tree
Hide file tree
Showing 4 changed files with 313 additions and 30 deletions.
81 changes: 54 additions & 27 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import sagemaker.utils

CONTAINER_PREFIX = "algo"
STUDIO_HOST_NAME = "sagemaker-local"
DOCKER_COMPOSE_FILENAME = "docker-compose.yaml"
DOCKER_COMPOSE_HTTP_TIMEOUT_ENV = "COMPOSE_HTTP_TIMEOUT"
DOCKER_COMPOSE_HTTP_TIMEOUT = "120"
Expand All @@ -50,6 +51,7 @@
REGION_ENV_NAME = "AWS_REGION"
TRAINING_JOB_NAME_ENV_NAME = "TRAINING_JOB_NAME"
S3_ENDPOINT_URL_ENV_NAME = "S3_ENDPOINT_URL"
SM_STUDIO_LOCAL_MODE = "SM_STUDIO_LOCAL_MODE"

# SELinux Enabled
SELINUX_ENABLED = os.environ.get("SAGEMAKER_LOCAL_SELINUX_ENABLED", "False").lower() in [
Expand Down Expand Up @@ -107,10 +109,30 @@ def __init__(
# Since we are using a single docker network, Generate a random suffix to attach to the
# container names. This way multiple jobs can run in parallel.
suffix = "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(5))
self.hosts = [
"{}-{}-{}".format(CONTAINER_PREFIX, i, suffix)
for i in range(1, self.instance_count + 1)
]
self.is_studio = sagemaker.local.utils.check_for_studio()
if self.is_studio:
if self.instance_count > 1:
raise NotImplementedError(
"Multi instance Local Mode execution is "
"currently not supported in SageMaker Studio."
)
# For studio use-case, directories need to be created in `~/tmp`, rather than /tmp
home = os.path.expanduser("~")
root_dir = os.path.join(home, "tmp")
if not os.path.isdir(root_dir):
os.mkdir(root_dir)
if self.sagemaker_session.config:
self.sagemaker_session.config["local"]["container_root"] = root_dir
else:
self.sagemaker_session.config = {"local": {"container_root": root_dir}}
# Studio only supports single instance run
self.hosts = [STUDIO_HOST_NAME]
else:
self.hosts = [
"{}-{}-{}".format(CONTAINER_PREFIX, i, suffix)
for i in range(1, self.instance_count + 1)
]

self.container_root = None
self.container = None

Expand Down Expand Up @@ -201,22 +223,17 @@ def process(
self._generate_compose_file(
"process", additional_volumes=volumes, additional_env_vars=environment
)
compose_command = self._compose()

if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image):
_pull_image(self.image)

compose_command = self._compose()
process = subprocess.Popen(
compose_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)

try:
_stream_output(process)
except RuntimeError as e:
# _stream_output() doesn't have the command line. We will handle the exception
# which contains the exit code and append the command line to it.
msg = f"Failed to run: {compose_command}"
raise RuntimeError(msg) from e
finally:
# Uploading processing outputs back to Amazon S3.
self._upload_processing_outputs(data_dir, processing_output_config)
Expand Down Expand Up @@ -283,22 +300,17 @@ def train(self, input_data_config, output_data_config, hyperparameters, environm
compose_data = self._generate_compose_file(
"train", additional_volumes=volumes, additional_env_vars=training_env_vars
)
compose_command = self._compose()

if _ecr_login_if_needed(self.sagemaker_session.boto_session, self.image):
_pull_image(self.image)

compose_command = self._compose()
process = subprocess.Popen(
compose_command, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)

try:
_stream_output(process)
except RuntimeError as e:
# _stream_output() doesn't have the command line. We will handle the exception
# which contains the exit code and append the command line to it.
msg = "Failed to run: %s, %s" % (compose_command, str(e))
raise RuntimeError(msg)
finally:
artifacts = self.retrieve_artifacts(compose_data, output_data_config, job_name)

Expand Down Expand Up @@ -347,6 +359,7 @@ def serve(self, model_dir, environment):
self._generate_compose_file(
"serve", additional_env_vars=environment, additional_volumes=volumes
)

compose_command = self._compose()

self.container = _HostingContainer(compose_command)
Expand Down Expand Up @@ -710,6 +723,9 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
additional_env_var_list = ["{}={}".format(k, v) for k, v in additional_env_vars.items()]
environment.extend(additional_env_var_list)

if self.is_studio:
environment.extend([f"{SM_STUDIO_LOCAL_MODE}=True"])

if os.environ.get(DOCKER_COMPOSE_HTTP_TIMEOUT_ENV) is None:
os.environ[DOCKER_COMPOSE_HTTP_TIMEOUT_ENV] = DOCKER_COMPOSE_HTTP_TIMEOUT

Expand All @@ -723,12 +739,19 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
for h in self.hosts
}

content = {
# Use version 2.3 as a minimum so that we can specify the runtime
"version": "2.3",
"services": services,
"networks": {"sagemaker-local": {"name": "sagemaker-local"}},
}
if self.is_studio:
content = {
# Use version 2.3 as a minimum so that we can specify the runtime
"version": "2.3",
"services": services,
}
else:
content = {
# Use version 2.3 as a minimum so that we can specify the runtime
"version": "2.3",
"services": services,
"networks": {"sagemaker-local": {"name": "sagemaker-local"}},
}

docker_compose_path = os.path.join(self.container_root, DOCKER_COMPOSE_FILENAME)

Expand Down Expand Up @@ -810,7 +833,6 @@ def _create_docker_host(
"tty": True,
"volumes": [v.map for v in optml_volumes],
"environment": environment,
"networks": {"sagemaker-local": {"aliases": [host]}},
}

is_train_with_entrypoint = False
Expand All @@ -827,14 +849,19 @@ def _create_docker_host(
if self.container_arguments:
host_config["entrypoint"] = host_config["entrypoint"] + self.container_arguments

if self.is_studio:
host_config["network_mode"] = "sagemaker"
else:
host_config["networks"] = {"sagemaker-local": {"aliases": [host]}}

# for GPU support pass in nvidia as the runtime, this is equivalent
# to setting --runtime=nvidia in the docker commandline.
if self.instance_type == "local_gpu":
host_config["deploy"] = {
"resources": {"reservations": {"devices": [{"capabilities": ["gpu"]}]}}
}

if command == "serve":
if not self.is_studio and command == "serve":
serving_port = (
sagemaker.utils.get_config_value(
"local.serving_port", self.sagemaker_session.config
Expand Down Expand Up @@ -910,7 +937,7 @@ def __init__(self, command):
"""Creates a new threaded hosting container.
Args:
command:
command (dict): docker compose command
"""
Thread.__init__(self)
self.command = command
Expand Down Expand Up @@ -987,8 +1014,8 @@ def _stream_output(process):
sys.stdout.write(stdout)
exit_code = process.poll()

if exit_code != 0:
raise RuntimeError("Process exited with code: %s" % exit_code)
if exit_code not in [0, 130]:
raise RuntimeError(f"Failed to run: {process.args}. Process exited with code: {exit_code}")

return exit_code

Expand Down
28 changes: 28 additions & 0 deletions src/sagemaker/local/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

logger = logging.getLogger(__name__)

STUDIO_APP_TYPES = ["KernelGateway", "CodeEditor", "JupyterLab"]


def copy_directory_structure(destination_directory, relative_path):
"""Creates intermediate directory structure for relative_path.
Expand Down Expand Up @@ -216,3 +218,29 @@ def get_using_dot_notation(dictionary, keys):
return get_using_dot_notation(inner_dict, rest)
except (KeyError, IndexError, TypeError):
raise ValueError(f"{keys} does not exist in input dictionary.")


def check_for_studio():
"""Helper function to determine if the run environment is studio.
Returns (bool): Returns True if valid Studio request.
Raises:
NotImplementedError:
if run environment = Studio and AppType not in STUDIO_APP_TYPES
"""
is_studio = False
if os.path.exists("/opt/ml/metadata/resource-metadata.json"):
with open("/opt/ml/metadata/resource-metadata.json", "r") as handle:
metadata = json.load(handle)
app_type = metadata.get("AppType")
if app_type:
# check if the execution is triggered from Studio KernelGateway App
if app_type in STUDIO_APP_TYPES:
is_studio = True
else:
raise NotImplementedError(
f"AppType {app_type} in Studio does not support Local Mode."
)
# if no apptype, case of classic notebooks
return is_studio
Loading

0 comments on commit 210f334

Please sign in to comment.