Skip to content

Commit

Permalink
Use PT_COMPILE_ONLY_MODE during warmup (HabanaAI#227)
Browse files Browse the repository at this point in the history
With PT_COMPILE_ONLY_MODE flag, graphs can be compiled without
performing synLaunch. The flag has been added to the warmup phase to
decrease its execution time.
  • Loading branch information
mfylcek authored and zhouyu5 committed Sep 13, 2024
1 parent 679716f commit 8d7a23e
Showing 1 changed file with 66 additions and 59 deletions.
125 changes: 66 additions & 59 deletions vllm/worker/habana_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Optional, Set, Tuple, Type, TypeVar, Union)

import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch

from vllm.attention import AttentionMetadata, get_attn_backend
Expand Down Expand Up @@ -1402,67 +1403,73 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
self.profiler.start('internal', 'warmup')
start_mem = HabanaMemoryProfiler.current_device_memory_usage()
start_time = time.perf_counter()
self.warmup_all_buckets(self.prompt_buckets, True, kv_caches)
self.warmup_all_buckets(self.decode_buckets, False, kv_caches)

if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \
("HabanaWorker.determine_num_available_blocks needs "
"to be called before warming up the model.")
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_margin
graph_free_mem = align_workers(graph_free_mem,
torch.distributed.ReduceOp.MIN)
prompt_graph_mem_ratio = float(
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5'))
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
decode_available_memory = graph_free_mem - prompt_available_memory
msg = (f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})")
logger.info(msg)
prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY',
'min_tokens')
decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY',
'max_bs')
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs(
prompt_strategy, self.prompt_buckets, True, kv_caches,
prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs(
decode_strategy, self.decode_buckets, False, kv_caches,
decode_available_memory)

# Not all prompt buckets were captured, but all decode buckets were
# captured and we have some free graph-allocated space left.
# Let's try to use it for capturing more prompt buckets.
if mem_post_decode + mem_post_prompt < graph_free_mem \
and not prompt_captured_all \
and decode_captured_all:
mem_post_prompt, _, prompt_captured_all = self.warmup_graphs(

with bc.env_setting("PT_COMPILE_ONLY_MODE", True):
self.warmup_all_buckets(self.prompt_buckets, True, kv_caches)
self.warmup_all_buckets(self.decode_buckets, False, kv_caches)

if not self.enforce_eager and htorch.utils.internal.is_lazy():
assert self.mem_margin is not None, \
("HabanaWorker.determine_num_available_blocks needs "
"to be called before warming up the model.")
free_mem = HabanaMemoryProfiler.current_free_device_memory()
graph_free_mem = free_mem - self.mem_margin
graph_free_mem = align_workers(graph_free_mem,
torch.distributed.ReduceOp.MIN)
prompt_graph_mem_ratio = float(
os.environ.get('VLLM_GRAPH_PROMPT_RATIO', '0.5'))
prompt_available_memory = (prompt_graph_mem_ratio *
graph_free_mem)
decode_available_memory = (graph_free_mem -
prompt_available_memory)
msg = (
f"Using {format_bytes(graph_free_mem)}"
f"/{format_bytes(free_mem)} "
"of free device memory for HPUGraphs, "
f"{format_bytes(prompt_available_memory)} for prompt and "
f"{format_bytes(decode_available_memory)} for decode "
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})")
logger.info(msg)
prompt_strategy = os.environ.get('VLLM_GRAPH_PROMPT_STRATEGY',
'min_tokens')
decode_strategy = os.environ.get('VLLM_GRAPH_DECODE_STRATEGY',
'max_bs')
mem_post_prompt, prompt_batch_seq, prompt_captured_all = \
self.warmup_graphs(
prompt_strategy, self.prompt_buckets, True, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq)

# Not all decode buckets were captured, but all prompt buckets were
# captured and we have some free graph-allocated space left.
# Let's try to use it for capturing more decode buckets.
if mem_post_decode + mem_post_prompt < graph_free_mem \
and not decode_captured_all \
and prompt_captured_all:
mem_post_decode, _, _ = self.warmup_graphs(
prompt_available_memory)
mem_post_decode, decode_batch_seq, decode_captured_all = \
self.warmup_graphs(
decode_strategy, self.decode_buckets, False, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq)

self.log_graph_warmup_summary(self.prompt_buckets, True,
mem_post_prompt)
self.log_graph_warmup_summary(self.decode_buckets, False,
mem_post_decode)
decode_available_memory)

# Not all prompt buckets were captured, but all decode buckets
# were captured and we have some free graph-allocated space
# left. Let's try to use it for capturing more prompt buckets.
if (mem_post_decode + mem_post_prompt < graph_free_mem
and not prompt_captured_all and decode_captured_all):
mem_post_prompt, _, prompt_captured_all = (
self.warmup_graphs(
prompt_strategy, self.prompt_buckets, True,
kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_prompt, prompt_batch_seq))

# Not all decode buckets were captured, but all prompt buckets
# were captured and we have some free graph-allocated space
# left. Let's try to use it for capturing more decode buckets.
if mem_post_decode + mem_post_prompt < graph_free_mem \
and not decode_captured_all \
and prompt_captured_all:
mem_post_decode, _, _ = self.warmup_graphs(
decode_strategy, self.decode_buckets, False, kv_caches,
graph_free_mem - mem_post_prompt - mem_post_decode,
mem_post_decode, decode_batch_seq)

self.log_graph_warmup_summary(self.prompt_buckets, True,
mem_post_prompt)
self.log_graph_warmup_summary(self.decode_buckets, False,
mem_post_decode)

end_time = time.perf_counter()
end_mem = HabanaMemoryProfiler.current_device_memory_usage()
Expand Down

0 comments on commit 8d7a23e

Please sign in to comment.