Skip to content

Commit

Permalink
Add *.py files to model weights if trust_remote_code is provided (#635)
Browse files Browse the repository at this point in the history
* Add *.py files to model weights if trust_remote_code is provided

* Add to azure

* add test

* Add additional tests
  • Loading branch information
dmchoiboi authored Oct 14, 2024
1 parent 5a69175 commit 89b9ddd
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from .chat_completion import * # noqa: F403
from .completion import * # noqa: F403
from .model_endpoints import * # noqa: F403
from .vllm import * # noqa: F403
Original file line number Diff line number Diff line change
Expand Up @@ -665,23 +665,41 @@ async def create_text_generation_inference_bundle(
).model_bundle_id

def load_model_weights_sub_commands(
self, framework, framework_image_tag, checkpoint_path, final_weights_folder
self,
framework,
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code: bool = False,
):
if checkpoint_path.startswith("s3://"):
return self.load_model_weights_sub_commands_s3(
framework, framework_image_tag, checkpoint_path, final_weights_folder
framework,
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code,
)
elif checkpoint_path.startswith("azure://") or "blob.core.windows.net" in checkpoint_path:
return self.load_model_weights_sub_commands_abs(
framework, framework_image_tag, checkpoint_path, final_weights_folder
framework,
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code,
)
else:
raise ObjectHasInvalidValueException(
f"Only S3 and Azure Blob Storage paths are supported. Given checkpoint path: {checkpoint_path}."
)

def load_model_weights_sub_commands_s3(
self, framework, framework_image_tag, checkpoint_path, final_weights_folder
self,
framework,
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code: bool,
):
subcommands = []
s5cmd = "s5cmd"
Expand All @@ -700,14 +718,23 @@ def load_model_weights_sub_commands_s3(
validate_checkpoint_files(checkpoint_files)

# filter to configs ('*.model' and '*.json') and weights ('*.safetensors')
# For models that are not supported by transformers directly, we need to include '*.py' and '*.bin'
# to load the model. Only set this flag if "trust_remote_code" is set to True
file_selection_str = '--include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*"'
if trust_remote_code:
file_selection_str += ' --include "*.py"'
subcommands.append(
f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
)
return subcommands

def load_model_weights_sub_commands_abs(
self, framework, framework_image_tag, checkpoint_path, final_weights_folder
self,
framework,
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code: bool,
):
subcommands = []

Expand All @@ -729,9 +756,8 @@ def load_model_weights_sub_commands_abs(
]
)
else:
file_selection_str = (
'--include-pattern "*.model;*.json;*.safetensors" --exclude-pattern "optimizer*"'
)
additional_pattern = ";*.py" if trust_remote_code else ""
file_selection_str = f'--include-pattern "*.model;*.json;*.safetensors{additional_pattern}" --exclude-pattern "optimizer*"'
subcommands.append(
f"azcopy copy --recursive {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}"
)
Expand Down Expand Up @@ -861,6 +887,8 @@ def _create_vllm_bundle_command(
subcommands = []

checkpoint_path = get_checkpoint_path(model_name, checkpoint_path)
additional_args = infer_addition_engine_args_from_model_name(model_name)

# added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder
if "mistral" in model_name:
final_weights_folder = "mistral_files"
Expand All @@ -871,6 +899,7 @@ def _create_vllm_bundle_command(
framework_image_tag,
checkpoint_path,
final_weights_folder,
trust_remote_code=additional_args.trust_remote_code or False,
)

if multinode and not is_worker:
Expand Down Expand Up @@ -905,8 +934,6 @@ def _create_vllm_bundle_command(
if hmi_config.sensitive_log_mode: # pragma: no cover
vllm_cmd += " --disable-log-requests"

additional_args = infer_addition_engine_args_from_model_name(model_name)

for field in VLLMModelConfig.model_fields.keys():
config_value = getattr(additional_args, field, None)
if config_value is not None:
Expand Down
7 changes: 4 additions & 3 deletions model-engine/model_engine_server/inference/vllm/vllm_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ async def dummy_receive() -> MutableMapping[str, Any]:
)


async def download_model(checkpoint_path: str, target_dir: str) -> None:
s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {target_dir}"
async def download_model(checkpoint_path: str, target_dir: str, trust_remote_code: bool) -> None:
additional_include = "--include '*.py'" if trust_remote_code else ""
s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' {additional_include} --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {target_dir}"
env = os.environ.copy()
env["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default")
# Need to override these env vars so s5cmd uses AWS_PROFILE
Expand Down Expand Up @@ -319,11 +320,11 @@ async def handle_batch_job(request: CreateBatchCompletionsEngineRequest) -> None
metrics_gateway = DatadogInferenceMonitoringMetricsGateway()

model = get_model_name(request.model_cfg)

if request.model_cfg.checkpoint_path:
await download_model(
checkpoint_path=request.model_cfg.checkpoint_path,
target_dir=MODEL_WEIGHTS_FOLDER,
trust_remote_code=request.model_cfg.trust_remote_code or False,
)

content = load_batch_content(request)
Expand Down
22 changes: 22 additions & 0 deletions model-engine/tests/unit/domain/test_llm_use_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,16 @@ def test_load_model_weights_sub_commands(
]
assert expected_result == subcommands

trust_remote_code = True
subcommands = llm_bundle_use_case.load_model_weights_sub_commands(
framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code
)

expected_result = [
'./s5cmd --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*" --include "*.py" s3://fake-checkpoint/* test_folder',
]
assert expected_result == subcommands

framework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE
framework_image_tag = "1.0.0"
checkpoint_path = "s3://fake-checkpoint"
Expand Down Expand Up @@ -555,6 +565,18 @@ def test_load_model_weights_sub_commands(
]
assert expected_result == subcommands

trust_remote_code = True
subcommands = llm_bundle_use_case.load_model_weights_sub_commands(
framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code
)

expected_result = [
"export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD",
"curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy",
'azcopy copy --recursive --include-pattern "*.model;*.json;*.safetensors;*.py" --exclude-pattern "optimizer*" azure://fake-checkpoint/* test_folder',
]
assert expected_result == subcommands


@pytest.mark.asyncio
async def test_create_model_endpoint_trt_llm_use_case_success(
Expand Down

0 comments on commit 89b9ddd

Please sign in to comment.