Skip to content

Commit

Permalink
feature: Accept user-defined env variables for the entry-point
Browse files Browse the repository at this point in the history
  • Loading branch information
martinRenou authored and akrishna1995 committed Dec 26, 2023
1 parent c797f2d commit 9a82013
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,8 +766,8 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:

def _script_mode_env_vars(self):
"""Returns a mapping of environment variables for script mode execution"""
script_name = None
dir_name = None
script_name = self.env.get(SCRIPT_PARAM_NAME.upper(), "")
dir_name = self.env.get(DIR_PARAM_NAME.upper(), "")
if self.uploaded_code:
script_name = self.uploaded_code.script_name
if self.repacked_model_data or self.enable_network_isolation():
Expand All @@ -783,8 +783,8 @@ def _script_mode_env_vars(self):
else "file://" + self.source_dir
)
return {
SCRIPT_PARAM_NAME.upper(): script_name or str(),
DIR_PARAM_NAME.upper(): dir_name or str(),
SCRIPT_PARAM_NAME.upper(): script_name,
DIR_PARAM_NAME.upper(): dir_name,
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): to_string(self.container_log_level),
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -718,3 +718,31 @@ def test_register_hf_pytorch_model_auto_infer_framework(
sagemaker_session.create_model_package_from_containers.assert_called_with(
**expected_create_model_package_request
)


def test_accept_user_defined_environment_variables(
sagemaker_session,
huggingface_training_compiler_version,
huggingface_training_compiler_pytorch_version,
huggingface_training_compiler_pytorch_py_version,
):
program = "inference.py"
directory = "/opt/ml/model/code"

hf_model = HuggingFaceModel(
model_data="s3://some/data.tar.gz",
role=ROLE,
transformers_version=huggingface_training_compiler_version,
pytorch_version=huggingface_training_compiler_pytorch_version,
py_version=huggingface_training_compiler_pytorch_py_version,
sagemaker_session=sagemaker_session,
env={
"SAGEMAKER_PROGRAM": program,
"SAGEMAKER_SUBMIT_DIRECTORY": directory,
},
)

container_env = hf_model.prepare_container_def("ml.m4.xlarge")["Environment"]

assert container_env["SAGEMAKER_PROGRAM"] == program
assert container_env["SAGEMAKER_SUBMIT_DIRECTORY"] == directory

0 comments on commit 9a82013

Please sign in to comment.