Skip to content

Commit

Permalink
enable full passthrough of vllm engine args, with backwards compatibi…
Browse files Browse the repository at this point in the history
…lity for existing differences with engine agnostic configs
  • Loading branch information
siddvenk committed Dec 19, 2024
1 parent 3aebeb5 commit 7ac8965
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 149 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,66 +11,43 @@
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import ast
from enum import Enum
import dataclasses
from typing import Optional, Any, Mapping, Tuple, Dict

from pydantic import field_validator, model_validator
from pydantic import field_validator, model_validator, ConfigDict
from vllm import EngineArgs

from djl_python.properties_manager.properties import Properties

DTYPE_MAPPER = {
"fp32": "float32",
"fp16": "float16",
"bf16": "bfloat16",
"auto": "auto"
}


class VllmRbProperties(Properties):
engine: Optional[str] = None
dtype: Optional[str] = "auto"
load_format: Optional[str] = "auto"
# The following configs have different names in DJL compared to vLLM
quantize: Optional[str] = None
tensor_parallel_degree: int = 1
pipeline_parallel_degree: int = 1
max_rolling_batch_prefill_tokens: Optional[int] = None
# Adjustable prefix model length for certain 32k or longer model
max_model_len: Optional[int] = None
enforce_eager: Optional[bool] = False
# TODO: this default may change with different vLLM versions
# TODO: try to get good default from vLLM to prevent revisiting
# TODO: last time check: vllm 0.3.1
gpu_memory_utilization: Optional[float] = 0.9
enable_lora: Optional[bool] = False
cpu_offload_gb_per_gpu: Optional[int] = 0
# The following configs have different defaults, or additional processing in DJL compared to vLLM
dtype: str = "auto"
max_loras: Optional[int] = 4
max_lora_rank: Optional[int] = 16
fully_sharded_loras: bool = False
lora_extra_vocab_size: int = 256
long_lora_scaling_factors: Optional[Tuple[float, ...]] = None
lora_dtype: Optional[str] = 'auto'
max_cpu_loras: Optional[int] = None
limit_mm_per_prompt: Optional[Mapping[str, int]] = None

# Neuron vLLM properties
device: Optional[str] = None
preloaded_model: Optional[Any] = None
generation_config: Optional[Any] = None

max_logprobs: Optional[int] = 20
enable_chunked_prefill: Optional[bool] = None
cpu_offload_gb_per_gpu: Optional[int] = 0
enable_prefix_caching: Optional[bool] = False
disable_sliding_window: Optional[bool] = False
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
use_v2_block_manager: bool = False
tokenizer_mode: str = 'auto'

# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None
# This allows generic vllm engine args to be passed in and set with vllm
model_config = ConfigDict(extra='allow')

@field_validator('engine')
def validate_engine(cls, engine):
Expand Down Expand Up @@ -117,18 +94,91 @@ def validate_limit_mm_per_prompt(cls, val) -> Mapping[str, int]:
out_dict[key] = parsed_value
return out_dict

@model_validator(mode='after')
def validate_speculative_model(self):
if self.speculative_model is not None and not self.use_v2_block_manager:
raise ValueError(
"Speculative decoding requires usage of the V2 block manager. Enable it with option.use_v2_block_manager=true."
)
return self

@model_validator(mode='after')
def validate_pipeline_parallel(self):
if self.pipeline_parallel_degree != 1:
raise ValueError(
"Pipeline parallelism is not supported in vLLM's LLMEngine used in rolling_batch implementation"
)
return self

def handle_lmi_vllm_config_conflicts(self, additional_vllm_engine_args):

def djl_config_conflicts_with_vllm_config(lmi_config_name,
vllm_config_name) -> bool:
# TODO: We may be able to refactor this to throw the ValueError directly from this method.
# The errors are slightly different depending on the specific configs, so for now we keep
# the exception separate in favor of better, more specific client errors
lmi_config_val = self.__getattribute__(lmi_config_name)
vllm_config_val = additional_vllm_engine_args.get(vllm_config_name)
if vllm_config_val is not None and lmi_config_val is not None:
return lmi_config_val != vllm_config_val
return False

if djl_config_conflicts_with_vllm_config("quantize", "quantization"):
raise ValueError(
"Both the DJL quantize config, and vllm quantization configs have been set with conflicting values."
"Only set the DJL quantize config")
if djl_config_conflicts_with_vllm_config("tensor_parallel_degree",
"tensor_parallel_size"):
raise ValueError(
"Both the DJL tensor_parallel_degree and vllm tensor_parallel_size configs have been set with conflicting values."
"Only set the DJL tensor_parallel_degree config")
if djl_config_conflicts_with_vllm_config("pipeline_parallel_degree",
"pipeline_parallel_size"):
raise ValueError(
"Both the DJL pipeline_parallel_degree and vllm pipeline_parallel_size configs have been set with conflicting values."
"Only set the DJL pipeline_parallel_degree config")
if djl_config_conflicts_with_vllm_config(
"max_rolling_batch_prefill_tokens", "max_num_batched_tokens"):
raise ValueError(
"Both the DJL max_rolling_batch_prefill_tokens and vllm max_num_batched_tokens configs have been set with conflicting values."
"Only set one of these configurations")
if djl_config_conflicts_with_vllm_config("cpu_offload_gb_per_gpu",
"cpu_offload_gb"):
raise ValueError(
"Both the DJL cpu_offload_gb_per_gpu and vllm cpu_offload_gb configs have been set with conflicting values."
"Only set one of these configurations")

def get_engine_args(self) -> EngineArgs:
additional_vllm_engine_args = self.get_additional_vllm_engine_args()
self.handle_lmi_vllm_config_conflicts(additional_vllm_engine_args)
max_model_len = additional_vllm_engine_args.pop("max_model_len", None)
if self.device == 'neuron':
return EngineArgs(
model=self.model_id_or_path,
preloaded_model=self.preloaded_model,
tensor_parallel_size=self.tensor_parallel_degree,
pipeline_parallel_size=self.pipeline_parallel_degree,
dtype=DTYPE_MAPPER[self.dtype],
max_num_seqs=self.max_rolling_batch_size,
block_size=max_model_len,
max_model_len=max_model_len,
trust_remote_code=self.trust_remote_code,
revision=self.revision,
device=self.device,
generation_config=self.generation_config,
**additional_vllm_engine_args,
)
return EngineArgs(
model=self.model_id_or_path,
tensor_parallel_size=self.tensor_parallel_degree,
pipeline_parallel_size=self.pipeline_parallel_degree,
dtype=DTYPE_MAPPER[self.dtype],
max_model_len=max_model_len,
quantization=self.quantize,
max_num_batched_tokens=self.max_rolling_batch_prefill_tokens,
max_loras=self.max_loras,
long_lora_scaling_factors=self.long_lora_scaling_factors,
cpu_offload_gb=self.cpu_offload_gb_per_gpu,
limit_mm_per_prompt=self.limit_mm_per_prompt,
**additional_vllm_engine_args,
)

def get_additional_vllm_engine_args(self) -> Dict[str, Any]:
all_engine_args = EngineArgs.__annotations__
return {
arg: val
for arg, val in self.__pydantic_extra__.items()
if arg in all_engine_args
}
Original file line number Diff line number Diff line change
Expand Up @@ -235,75 +235,6 @@ def get_lora_request(lora_name: str, lora_requests: dict) -> dict:
return lora_requests[lora_name]


def get_engine_args_from_config(config: VllmRbProperties) -> EngineArgs:
if config.device == "neuron":
return EngineArgs(model=config.model_id_or_path,
preloaded_model=config.preloaded_model,
tensor_parallel_size=config.tensor_parallel_degree,
dtype=DTYPE_MAPPER[config.dtype],
seed=0,
max_model_len=config.max_model_len,
max_num_seqs=config.max_rolling_batch_size,
block_size=config.max_model_len,
trust_remote_code=config.trust_remote_code,
revision=config.revision,
device=config.device,
generation_config=config.generation_config)
else:
return EngineArgs(
model=config.model_id_or_path,
tensor_parallel_size=config.tensor_parallel_degree,
pipeline_parallel_size=config.pipeline_parallel_degree,
dtype=DTYPE_MAPPER[config.dtype],
seed=0,
max_model_len=config.max_model_len,
enforce_eager=config.enforce_eager,
gpu_memory_utilization=config.gpu_memory_utilization,
max_num_batched_tokens=config.max_rolling_batch_prefill_tokens,
trust_remote_code=config.trust_remote_code,
load_format=config.load_format,
quantization=config.quantize,
enable_lora=config.enable_lora,
max_loras=config.max_loras,
max_lora_rank=config.max_lora_rank,
fully_sharded_loras=config.fully_sharded_loras,
lora_extra_vocab_size=config.lora_extra_vocab_size,
long_lora_scaling_factors=config.long_lora_scaling_factors,
lora_dtype=config.lora_dtype,
max_cpu_loras=config.max_cpu_loras,
revision=config.revision,
max_logprobs=config.max_logprobs,
enable_chunked_prefill=config.enable_chunked_prefill,
cpu_offload_gb=config.cpu_offload_gb_per_gpu,
enable_prefix_caching=config.enable_prefix_caching,
disable_sliding_window=config.disable_sliding_window,
max_num_seqs=config.max_rolling_batch_size,
use_v2_block_manager=config.use_v2_block_manager,
speculative_model=config.speculative_model,
speculative_model_quantization=config.
speculative_model_quantization,
speculative_draft_tensor_parallel_size=config.
speculative_draft_tensor_parallel_size,
num_speculative_tokens=config.num_speculative_tokens,
speculative_max_model_len=config.speculative_max_model_len,
speculative_disable_by_batch_size=config.
speculative_disable_by_batch_size,
ngram_prompt_lookup_max=config.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=config.ngram_prompt_lookup_min,
spec_decoding_acceptance_method=config.
spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=config.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=config.
typical_acceptance_sampler_posterior_alpha,
qlora_adapter_name_or_path=config.qlora_adapter_name_or_path,
disable_logprobs_during_spec_decoding=config.
disable_logprobs_during_spec_decoding,
limit_mm_per_prompt=config.limit_mm_per_prompt,
tokenizer_mode=config.tokenizer_mode,
)


def get_multi_modal_data(request: Request) -> Optional[dict]:
parameters = request.parameters
images = parameters.pop("images", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from djl_python.rolling_batch.rolling_batch import RollingBatch, stop_on_any_exception, filter_unused_generation_params
from djl_python.rolling_batch.rolling_batch_vllm_utils import (
update_request_cache_with_output, create_lora_request, get_lora_request,
get_engine_args_from_config, get_prompt_inputs)
get_prompt_inputs)
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from typing import List, Optional

Expand Down Expand Up @@ -47,7 +47,7 @@ def __init__(self, model_id_or_path: str, properties: dict,
"""
self.vllm_configs = VllmRbProperties(**properties)
super().__init__(self.vllm_configs)
args = get_engine_args_from_config(self.vllm_configs)
args = self.vllm_configs.get_engine_args()
self.engine = LLMEngine.from_engine_args(args)
self.request_cache = OrderedDict()
self.lora_id_counter = AtomicCounter(0)
Expand Down
44 changes: 15 additions & 29 deletions engines/python/setup/djl_python/tests/test_properties_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
TnXMemoryLayout, TnXDtypeName, TnXModelLoaders)
from djl_python.properties_manager.trt_properties import TensorRtLlmProperties
from djl_python.properties_manager.hf_properties import HuggingFaceProperties
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties
from djl_python.properties_manager.vllm_rb_properties import VllmRbProperties, DTYPE_MAPPER
from djl_python.properties_manager.sd_inf2_properties import StableDiffusionNeuronXProperties
from djl_python.properties_manager.lmi_dist_rb_properties import LmiDistRbProperties
from djl_python.properties_manager.scheduler_rb_properties import SchedulerRbProperties
Expand Down Expand Up @@ -425,33 +425,19 @@ def test_vllm_properties(self):

def test_vllm_valid(properties):
vllm_configs = VllmRbProperties(**properties)
self.assertEqual(vllm_configs.model_id_or_path,
properties['model_id'])
self.assertEqual(vllm_configs.engine, properties['engine'])
self.assertEqual(
vllm_configs.max_rolling_batch_prefill_tokens,
int(properties['max_rolling_batch_prefill_tokens']))
self.assertEqual(vllm_configs.dtype, properties['dtype'])
self.assertEqual(vllm_configs.load_format,
properties['load_format'])
self.assertEqual(vllm_configs.quantize, properties['quantize'])
engine_args = vllm_configs.get_engine_args()
self.assertEqual(vllm_configs.model_id_or_path, engine_args.model)
self.assertEqual(vllm_configs.max_rolling_batch_prefill_tokens,
engine_args.max_num_batched_tokens)
self.assertEqual(vllm_configs.tensor_parallel_degree,
int(properties['tensor_parallel_degree']))
self.assertEqual(vllm_configs.max_model_len,
int(properties['max_model_len']))
self.assertEqual(vllm_configs.enforce_eager,
bool(properties['enforce_eager']))
self.assertEqual(vllm_configs.enable_lora,
bool(properties['enable_lora']))
self.assertEqual(vllm_configs.gpu_memory_utilization,
float(properties['gpu_memory_utilization']))

def test_enforce_eager(properties):
properties.pop('enforce_eager')
properties.pop('quantize')
self.assertTrue("enforce_eager" not in properties)
vllm_props = VllmRbProperties(**properties)
self.assertTrue(vllm_props.enforce_eager is False)
engine_args.tensor_parallel_size)
self.assertEqual(vllm_configs.pipeline_parallel_degree,
engine_args.pipeline_parallel_size)
self.assertEqual(vllm_configs.quantize, engine_args.quantization)
self.assertEqual(DTYPE_MAPPER[vllm_configs.dtype],
engine_args.dtype)
self.assertEqual(vllm_configs.cpu_offload_gb_per_gpu,
engine_args.cpu_offload_gb)

def test_long_lora_scaling_factors(properties):
properties['long_lora_scaling_factors'] = "3.0"
Expand Down Expand Up @@ -494,10 +480,10 @@ def test_invalid_long_lora_scaling_factors(properties):
'enforce_eager': "True",
'enable_lora': "true",
"gpu_memory_utilization": "0.85",
'load_format': 'pt'
'load_format': 'pt',
'cpu_offload_gb_per_gpu': '3',
}
test_vllm_valid(properties.copy())
test_enforce_eager(properties.copy())
test_long_lora_scaling_factors(properties.copy())
test_invalid_long_lora_scaling_factors(properties.copy())

Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def run(self):
test_requirements = [
'numpy<2', 'requests', 'Pillow', 'transformers', 'torch', 'einops',
'accelerate', 'sentencepiece', 'protobuf', "peft", 'yapf',
'pydantic>=2.0', "objgraph"
'pydantic>=2.0', "objgraph", 'vllm==0.6.3.post1'
]

setup(name='djl_python',
Expand Down

0 comments on commit 7ac8965

Please sign in to comment.