Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use PT_COMPILE_ONLY_MODE during warmup #227

Merged
merged 2 commits into from
Sep 6, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 66 additions & 59 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Optional, Set, Tuple, Type, TypeVar, Union)

import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch

from vllm.attention import AttentionMetadata, get_attn_backend
Expand Down Expand Up @@ -1261,67 +1262,73 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
self.profiler.start('internal', 'warmup')
start_mem = HabanaMemoryProfiler.current_device_memory_usage()
start_time = time.perf_counter()
self.warmup_all_buckets(self.prompt_buckets, True, kv_caches)
self.warmup_all_buckets(self.decode_buckets, False, kv_caches)

if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \
("HabanaWorker.determine_num_available_blocks needs "
"to be called before warming up the model.")
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_margin
graph_free_mem = align_workers(graph_free_mem,
torch.distributed.ReduceOp.MIN)
prompt_graph_mem_ratio = float(
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5'))
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})")
logger.info(msg)
prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY',
'min_tokens')
decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY',
'max_bs')
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs(
prompt_strategy, self.prompt_buckets, True, kv_caches,
prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs(
decode_strategy, self.decode_buckets, False, kv_caches,
decode_available_memory)

# Not all prompt buckets were captured, but all decode buckets were
# captured and we have some free graph-allocated space left.
# Let's try to use it for capturing more prompt buckets.
if mem_post_decode + mem_post_prompt < graph_free_mem \
and not prompt_captured_all \
and decode_captured_all:
mem_post_prompt, _, prompt_captured_all = self.warmup_graphs(

with bc.env_setting("PT_COMPILE_ONLY_MODE", True):
self.warmup_all_buckets(self.prompt_buckets, True, kv_caches)
self.warmup_all_buckets(self.decode_buckets, False, kv_caches)

if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \
("HabanaWorker.determine_num_available_blocks needs "
"to be called before warming up the model.")
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_margin
graph_free_mem = align_workers(graph_free_mem,
torch.distributed.ReduceOp.MIN)
prompt_graph_mem_ratio = float(
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5'))
prompt_available_memory = (prompt_graph_mem_ratio *
graph_free_mem)
decode_available_memory = (graph_free_mem -
prompt_available_memory)
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})")
logger.info(msg)
prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY',
'min_tokens')
decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY',
'max_bs')
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs(
prompt_strategy, self.prompt_buckets, True, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq)

# Not all decode buckets were captured, but all prompt buckets were
# captured and we have some free graph-allocated space left.
# Let's try to use it for capturing more decode buckets.
if mem_post_decode + mem_post_prompt < graph_free_mem \
and not decode_captured_all \
and prompt_captured_all:
mem_post_decode, _, _ = self.warmup_graphs(
prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs(
decode_strategy, self.decode_buckets, False, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq)

self.log_graph_warmup_summary(self.prompt_buckets, True,
mem_post_prompt)
self.log_graph_warmup_summary(self.decode_buckets, False,
mem_post_decode)
decode_available_memory)

# Not all prompt buckets were captured, but all decode buckets
# were captured and we have some free graph-allocated space
# left. Let's try to use it for capturing more prompt buckets.
if (mem_post_decode + mem_post_prompt < graph_free_mem
and not prompt_captured_all and decode_captured_all):
mem_post_prompt, _, prompt_captured_all = (
self.warmup_graphs(
prompt_strategy, self.prompt_buckets, True,
kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq))

# Not all decode buckets were captured, but all prompt buckets
# were captured and we have some free graph-allocated space
# left. Let's try to use it for capturing more decode buckets.
if mem_post_decode + mem_post_prompt < graph_free_mem \
and not decode_captured_all \
and prompt_captured_all:
mem_post_decode, _, _ = self.warmup_graphs(
decode_strategy, self.decode_buckets, False, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq)

self.log_graph_warmup_summary(self.prompt_buckets, True,
mem_post_prompt)
self.log_graph_warmup_summary(self.decode_buckets, False,
mem_post_decode)

end_time = time.perf_counter()
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
Expand Down
Loading