Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic config p2 #213

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
114 changes: 55 additions & 59 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from ruamel.yaml import YAML

from backends.exllamav2.types import DraftModelInstanceConfig, ModelInstanceConfig
from common.health import HealthManager

from backends.exllamav2.grammar import (
Expand All @@ -43,6 +44,7 @@
hardware_supports_flash_attn,
supports_paged_attn,
)
from common.tabby_config import config
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_generation_params,
Expand Down Expand Up @@ -103,7 +105,12 @@ class ExllamaV2Container:
load_condition: asyncio.Condition = asyncio.Condition()

@classmethod
async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
async def create(
cls,
model: ModelInstanceConfig,
draft: DraftModelInstanceConfig,
quiet=False,
):
"""
Primary asynchronous initializer for model container.

Expand All @@ -117,8 +124,15 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):

# Initialize config
self.config = ExLlamaV2Config()
self.model_dir = model_directory
self.config.model_dir = str(model_directory.resolve())

model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / model.model_name
model_path = model_path.resolve()
if not model_path.exists():
raise FileNotFoundError(f"Model path {model_path} does not exist.")

self.model_dir = model_path
self.config.model_dir = str(model_path)

# Make the max seq len 4096 before preparing the config
# This is a better default than 2048
Expand All @@ -130,35 +144,23 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
self.config.arch_compat_overrides()

# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft"), {})
draft_model_name = draft_args.get("draft_model_name")
enable_draft = draft_args and draft_model_name

# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
enable_draft = False

if enable_draft:
if draft.draft_model_name:
self.draft_config = ExLlamaV2Config()
self.draft_config.no_flash_attn = self.config.no_flash_attn
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")

draft_model_path = (
config.draft_model.draft_model_dir / draft.draft_model_name
)
draft_model_path = draft_model_path / draft_model_name

self.draft_model_dir = draft_model_path
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()

# Create the hf_config
self.hf_config = await HuggingFaceConfig.from_file(model_directory)
self.hf_config = await HuggingFaceConfig.from_file(model_path)

# Load generation config overrides
generation_config_path = model_directory / "generation_config.json"
generation_config_path = model_path / "generation_config.json"
if generation_config_path.exists():
try:
self.generation_config = await GenerationConfig.from_file(
Expand All @@ -171,18 +173,20 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
)

# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# FIXME: THIS IS BROKEN!!!
# kwargs do not exist now
# should be investigated after the models have pydantic stuff
# kwargs = await self.set_model_overrides(**kwargs)

# MARK: User configuration

# Get cache mode
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
self.cache_mode = model.cache_mode

# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
use_tp = unwrap(kwargs.get("tensor_parallel"), False)
gpu_split = kwargs.get("gpu_split")
gpu_split_auto = model.gpu_split_auto
gpu_device_list = list(range(0, gpu_count))

# Set GPU split options
Expand All @@ -191,16 +195,16 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
logger.info("Disabling GPU split because one GPU is in use.")
else:
# Set tensor parallel
if use_tp:
if model.tensor_parallel:
self.use_tp = True

# TP has its own autosplit loader
self.gpu_split_auto = False

# Enable manual GPU split if provided
if gpu_split:
if model.gpu_split:
self.gpu_split_auto = False
self.gpu_split = gpu_split
self.gpu_split = model.gpu_split

gpu_device_list = [
device_idx
Expand All @@ -211,9 +215,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Otherwise fallback to autosplit settings
self.gpu_split_auto = gpu_split_auto

autosplit_reserve_megabytes = unwrap(
kwargs.get("autosplit_reserve"), [96]
)
autosplit_reserve_megabytes = model.autosplit_reserve

# Reserve VRAM for each GPU
self.autosplit_reserve = [
Expand All @@ -225,37 +227,34 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
self.config.max_output_len = 16

# Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len:
self.config.max_seq_len = override_base_seq_len
if model.override_base_seq_len:
self.config.max_seq_len = model.override_base_seq_len

# Grab the base model's sequence length before overrides for
# rope calculations
base_seq_len = self.config.max_seq_len

# Set the target seq len if present
target_max_seq_len = kwargs.get("max_seq_len")
target_max_seq_len = model.max_seq_len
if target_max_seq_len:
self.config.max_seq_len = target_max_seq_len

# Set the rope scale
self.config.scale_pos_emb = unwrap(
kwargs.get("rope_scale"), self.config.scale_pos_emb
)
self.config.scale_pos_emb = unwrap(model.rope_scale, self.config.scale_pos_emb)

# Sets rope alpha value.
# Automatically calculate if unset or defined as an "auto" literal.
rope_alpha = unwrap(kwargs.get("rope_alpha"), "auto")
rope_alpha = unwrap(model.rope_alpha, "auto")
if rope_alpha == "auto":
self.config.scale_alpha_value = self.calculate_rope_alpha(base_seq_len)
else:
self.config.scale_alpha_value = rope_alpha

# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
self.config.fasttensors = config.model.fasttensors

# Set max batch size to the config override
self.max_batch_size = unwrap(kwargs.get("max_batch_size"))
self.max_batch_size = model.max_batch_size

# Check whether the user's configuration supports flash/paged attention
# Also check if exl2 has disabled flash attention
Expand All @@ -272,7 +271,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Set k/v cache size
# cache_size is only relevant when paged mode is enabled
if self.paged:
cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
cache_size = unwrap(model.cache_size, self.config.max_seq_len)

if cache_size < self.config.max_seq_len:
logger.warning(
Expand Down Expand Up @@ -314,7 +313,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):

# Try to set prompt template
self.prompt_template = await self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
model.prompt_template, model.model_name
)

# Catch all for template lookup errors
Expand All @@ -329,29 +328,26 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
)

# Set num of experts per token if provided
num_experts_override = kwargs.get("num_experts_per_token")
if num_experts_override:
self.config.num_experts_per_token = kwargs.get("num_experts_per_token")
if model.num_experts_per_token:
self.config.num_experts_per_token = model.num_experts_per_token

# Make sure chunk size is >= 16 and <= max seq length
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
user_chunk_size = unwrap(model.chunk_size, 2048)
chunk_size = sorted((16, user_chunk_size, self.config.max_seq_len))[1]
self.config.max_input_len = chunk_size
self.config.max_attention_size = chunk_size**2

# Set user-configured draft model values
if enable_draft:
# Fetch from the updated kwargs
draft_args = unwrap(kwargs.get("draft"), {})
if draft.draft_model_name:

self.draft_config.max_seq_len = self.config.max_seq_len

self.draft_config.scale_pos_emb = unwrap(
draft_args.get("draft_rope_scale"), 1.0
draft.draft_rope_scale, 1.0
)

# Set draft rope alpha. Follows same behavior as model rope alpha.
draft_rope_alpha = unwrap(draft_args.get("draft_rope_alpha"), "auto")
draft_rope_alpha = unwrap(draft.draft_rope_alpha, "auto")
if draft_rope_alpha == "auto":
self.draft_config.scale_alpha_value = self.calculate_rope_alpha(
self.draft_config.max_seq_len
Expand All @@ -360,7 +356,7 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
self.draft_config.scale_alpha_value = draft_rope_alpha

# Set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
self.draft_cache_mode = draft.draft_cache_mode

if chunk_size:
self.draft_config.max_input_len = chunk_size
Expand Down Expand Up @@ -524,7 +520,7 @@ def progress(loaded_modules: int, total_modules: int)
async for _ in self.load_gen(progress_callback):
pass

async def load_gen(self, progress_callback=None, **kwargs):
async def load_gen(self, progress_callback=None, skip_wait=False):
"""Loads a model and streams progress via a generator."""

# Indicate that model load has started
Expand All @@ -534,7 +530,7 @@ async def load_gen(self, progress_callback=None, **kwargs):
self.model_is_loading = True

# Wait for existing generation jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
await self.wait_for_jobs(skip_wait)

# Streaming gen for model load progress
model_load_generator = self.load_model_sync(progress_callback)
Expand Down Expand Up @@ -1130,19 +1126,19 @@ async def generate_gen(
grammar_handler = ExLlamaV2Grammar()

# Add JSON schema filter if it exists
json_schema = unwrap(kwargs.get("json_schema"))
json_schema = kwargs.get("json_schema")
if json_schema:
grammar_handler.add_json_schema_filter(
json_schema, self.model, self.tokenizer
)

# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
regex_pattern = kwargs.get("regex_pattern")
if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)

# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
grammar_string = kwargs.get("grammar_string")
if grammar_string:
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)

Expand Down
Loading
Loading