Skip to content

Commit

Permalink
update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
siddvenk committed Dec 20, 2024
1 parent a1fcce3 commit 34efa2d
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,13 @@ class VllmRbProperties(Properties):
tensor_parallel_degree: int = 1
pipeline_parallel_degree: int = 1
# The following configs have different names in DJL compared to vLLM, either is accepted
quantize: Optional[str] = Field(alias="quantization", default=None)
quantize: Optional[str] = Field(alias="quantization",
default=EngineArgs.quantization)
max_rolling_batch_prefill_tokens: Optional[int] = Field(
alias="max_num_batched_tokens", default=None)
cpu_offload_gb_per_gpu: Optional[float] = Field(alias="cpu_offload_gb",
default=None)
alias="max_num_batched_tokens",
default=EngineArgs.max_num_batched_tokens)
cpu_offload_gb_per_gpu: float = Field(alias="cpu_offload_gb",
default=EngineArgs.cpu_offload_gb)
# The following configs have different defaults, or additional processing in DJL compared to vLLM
dtype: str = "auto"
max_loras: int = 4
Expand Down Expand Up @@ -117,6 +119,7 @@ def generate_vllm_engine_arg_dict(self,
'revision': self.revision,
'max_loras': self.max_loras,
'enable_lora': self.enable_lora,
'trust_remote_code': self.trust_remote_code,
}
if self.quantize is not None:
vllm_engine_args['quantization'] = self.quantize
Expand All @@ -127,10 +130,11 @@ def generate_vllm_engine_arg_dict(self,
vllm_engine_args['cpu_offload_gb'] = self.cpu_offload_gb_per_gpu
if self.device is not None:
vllm_engine_args['device'] = self.device
if self.preloaded_model is not None:
if self.device == 'neuron':
vllm_engine_args['preloaded_model'] = self.preloaded_model
if self.generation_config is not None:
vllm_engine_args['generation_config'] = self.generation_config
vllm_engine_args['block_size'] = passthrough_vllm_engine_args.get(
"max_model_len")
vllm_engine_args.update(passthrough_vllm_engine_args)
return vllm_engine_args

Expand Down
96 changes: 90 additions & 6 deletions engines/python/setup/djl_python/tests/test_properties_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def test_hf_error_case(self, params):
HuggingFaceProperties(**params)

def test_vllm_properties(self):
# test with valid vllm properties

def validate_vllm_config_and_engine_args_match(
vllm_config_value,
engine_arg_value,
Expand All @@ -435,7 +435,7 @@ def validate_vllm_config_and_engine_args_match(
def test_vllm_default_properties():
required_properties = {
"engine": "Python",
"model_id_or_path": "some_model",
"model_id": "some_model",
}
vllm_configs = VllmRbProperties(**required_properties)
engine_args = vllm_configs.get_engine_args()
Expand All @@ -451,21 +451,101 @@ def test_vllm_default_properties():
vllm_configs.quantize, engine_args.quantization, None)
validate_vllm_config_and_engine_args_match(
vllm_configs.max_rolling_batch_size, engine_args.max_num_seqs,
HuggingFaceProperties.max_rolling_batch_size)
32)
validate_vllm_config_and_engine_args_match(vllm_configs.dtype,
engine_args.dtype,
'auto')
validate_vllm_config_and_engine_args_match(vllm_configs.max_loras,
engine_args.max_loras,
4)
self.assertEqual(vllm_configs.cpu_offload_gb_per_gpu, None)
validate_vllm_config_and_engine_args_match(
vllm_configs.cpu_offload_gb_per_gpu,
engine_args.cpu_offload_gb, EngineArgs.cpu_offload_gb)
self.assertEqual(
len(vllm_configs.get_additional_vllm_engine_args()), 0)

def test_invalid_pipeline_parallel():
properties = {
"engine": "Python",
"model_id": "some_model",
"tensor_parallel_degree": "4",
"pipeline_parallel_degree": "2",
}
with self.assertRaises(ValueError):
_ = VllmRbProperties(**properties)

def test_invalid_engine():
properties = {
"engine": "bad_engine",
"model_id": "some_model",
}
with self.assertRaises(ValueError):
_ = VllmRbProperties(**properties)

def test_aliases():
properties = {
"engine": "Python",
"model_id": "some_model",
"quantization": "awq",
"max_num_batched_tokens": "546",
"cpu_offload_gb": "7"
}
vllm_configs = VllmRbProperties(**properties)
engine_args = vllm_configs.get_engine_args()
validate_vllm_config_and_engine_args_match(
vllm_configs.quantize, engine_args.quantization, "awq")
validate_vllm_config_and_engine_args_match(
vllm_configs.max_rolling_batch_prefill_tokens,
engine_args.max_num_batched_tokens, 546)
validate_vllm_config_and_engine_args_match(
vllm_configs.cpu_offload_gb_per_gpu,
engine_args.cpu_offload_gb, 7)

def test_vllm_passthrough_properties():
properties = {
"engine": "Python",
"model_id": "some_model",
"tensor_parallel_degree": "4",
"pipeline_parallel_degree": "1",
"max_rolling_batch_size": "111",
"quantize": "awq",
"max_rolling_batch_prefill_tokens": "400",
"cpu_offload_gb_per_gpu": "8",
"dtype": "bf16",
"max_loras": "7",
"long_lora_scaling_factors": "1.1, 2.0",
"trust_remote_code": "true",
"max_model_len": "1024",
"enforce_eager": "true",
"enable_chunked_prefill": "False",
"gpu_memory_utilization": "0.4",
}
expected_engine_args = EngineArgs(model="some_model",
tensor_parallel_size=4,
pipeline_parallel_size=1,
max_num_seqs=111,
quantization='awq',
max_num_batched_tokens=400,
cpu_offload_gb=8,
dtype="bfloat16",
max_loras=7,
long_lora_scaling_factors=(1.1,
2.0),
trust_remote_code=True,
max_model_len=1024,
enforce_eager=True,
enable_chunked_prefill=False,
gpu_memory_utilization=0.4)
vllm_configs = VllmRbProperties(**properties)
computed_engine_args = vllm_configs.get_engine_args()
self.assertTrue(
len(vllm_configs.get_additional_vllm_engine_args()) > 0)
self.assertEqual(expected_engine_args, computed_engine_args)

def test_long_lora_scaling_factors():
properties = {
"engine": "Python",
"model_id_or_path": "some_model",
"model_id": "some_model",
'long_lora_scaling_factors': "3.0"
}
vllm_props = VllmRbProperties(**properties)
Expand Down Expand Up @@ -500,14 +580,18 @@ def test_long_lora_scaling_factors():
def test_invalid_long_lora_scaling_factors():
properties = {
"engine": "Python",
"model_id_or_path": "some_model",
"model_id": "some_model",
'long_lora_scaling_factors': "a,b"
}
vllm_props = VllmRbProperties(**properties)
with self.assertRaises(ValueError):
vllm_props.get_engine_args()

test_vllm_default_properties()
test_invalid_pipeline_parallel()
test_invalid_engine()
test_aliases()
test_vllm_passthrough_properties()
test_long_lora_scaling_factors()
test_invalid_long_lora_scaling_factors()

Expand Down

0 comments on commit 34efa2d

Please sign in to comment.