forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 61
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use PT_COMPILE_ONLY_MODE during warmup #227
Merged
kzawora-intel
merged 2 commits into
habana_main
from
private/mfylcek/warmup_compile_only
Sep 6, 2024
Merged
Use PT_COMPILE_ONLY_MODE during warmup #227
kzawora-intel
merged 2 commits into
habana_main
from
private/mfylcek/warmup_compile_only
Sep 6, 2024
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
please fix formatting with format.sh |
kzawora-intel
added a commit
that referenced
this pull request
Sep 6, 2024
This PR fixes crashes observed on older Synapse builds introduced with #227. Setting PT_COMPILE_ONLY_MODE is not supported in current or older public Synapse builds, but we should not crash because of it, rather we should advise user to use the latest build. Previous behavior: ``` ... INFO 09-06 17:08:37 habana_executor.py:85] # HPU blocks: 10761, # CPU blocks: 910 INFO 09-06 17:08:37 habana_worker.py:201] Initializing cache engine took 47.29 GiB of device memory (54.34 GiB/94.62 GiB used) and -159.6 MiB of host memory (414.9 GiB/1007 GiB used) [rank0]: Traceback (most recent call last): [rank0]: File "/software/users/kzawora/vllm-utils/vllm_hpu_simple_test.py", line 9, in <module> [rank0]: llm = LLM(model="facebook/opt-125m") [rank0]: File "/software/users/kzawora/vllm-fork/vllm/entrypoints/llm.py", line 155, in __init__ [rank0]: self.llm_engine = LLMEngine.from_engine_args( [rank0]: File "/software/users/kzawora/vllm-fork/vllm/engine/llm_engine.py", line 456, in from_engine_args [rank0]: engine = cls( [rank0]: File "/software/users/kzawora/vllm-fork/vllm/engine/llm_engine.py", line 266, in __init__ [rank0]: self._initialize_kv_caches() [rank0]: File "/software/users/kzawora/vllm-fork/vllm/engine/llm_engine.py", line 378, in _initialize_kv_caches [rank0]: self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) [rank0]: File "/software/users/kzawora/vllm-fork/vllm/executor/habana_executor.py", line 89, in initialize_cache [rank0]: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) [rank0]: File "/software/users/kzawora/vllm-fork/vllm/worker/habana_worker.py", line 202, in initialize_cache [rank0]: self._warm_up_model() [rank0]: File "/software/users/kzawora/vllm-fork/vllm/worker/habana_worker.py", line 220, in _warm_up_model [rank0]: self.model_runner.warmup_model(self.hpu_cache[0]) [rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: File "/software/users/kzawora/vllm-fork/vllm/worker/habana_model_runner.py", line 1412, in warmup_model [rank0]: with compile_only_mode_context(): [rank0]: File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__ [rank0]: return next(self.gen) [rank0]: File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/internal/bridge_config.py", line 20, in env_setting [rank0]: get_func = globals()['get_' + var.lower()] [rank0]: KeyError: 'get_pt_compile_only_mode' inc shutdown inc shutdown inc shutdown inc shutdown ``` Current behavior: ``` ... INFO 09-06 17:06:42 habana_executor.py:85] # HPU blocks: 10761, # CPU blocks: 910 INFO 09-06 17:06:43 habana_worker.py:201] Initializing cache engine took 47.29 GiB of device memory (54.34 GiB/94.62 GiB used) and -143.7 MiB of host memory (415 GiB/1007 GiB used) WARNING 09-06 17:06:43 habana_model_runner.py:1419] Cannot use PT_COMPILE_ONLY_MODE. Warmup time will be negatively impacted. Please update Gaudi Software Suite. INFO 09-06 17:06:43 habana_model_runner.py:1336] [Warmup][Prompt][1/23] batch_size:2 seq_len:1024 free_mem:40.28 GiB ... ```
zhouyu5
pushed a commit
to zhouyu5/vllm-fork
that referenced
this pull request
Sep 13, 2024
With PT_COMPILE_ONLY_MODE flag, graphs can be compiled without performing synLaunch. The flag has been added to the warmup phase to decrease its execution time.
zhouyu5
pushed a commit
to zhouyu5/vllm-fork
that referenced
this pull request
Sep 13, 2024
This PR fixes crashes observed on older Synapse builds introduced with HabanaAI#227. Setting PT_COMPILE_ONLY_MODE is not supported in current or older public Synapse builds, but we should not crash because of it, rather we should advise user to use the latest build. Previous behavior: ``` ... INFO 09-06 17:08:37 habana_executor.py:85] # HPU blocks: 10761, # CPU blocks: 910 INFO 09-06 17:08:37 habana_worker.py:201] Initializing cache engine took 47.29 GiB of device memory (54.34 GiB/94.62 GiB used) and -159.6 MiB of host memory (414.9 GiB/1007 GiB used) [rank0]: Traceback (most recent call last): [rank0]: File "/software/users/kzawora/vllm-utils/vllm_hpu_simple_test.py", line 9, in <module> [rank0]: llm = LLM(model="facebook/opt-125m") [rank0]: File "/software/users/kzawora/vllm-fork/vllm/entrypoints/llm.py", line 155, in __init__ [rank0]: self.llm_engine = LLMEngine.from_engine_args( [rank0]: File "/software/users/kzawora/vllm-fork/vllm/engine/llm_engine.py", line 456, in from_engine_args [rank0]: engine = cls( [rank0]: File "/software/users/kzawora/vllm-fork/vllm/engine/llm_engine.py", line 266, in __init__ [rank0]: self._initialize_kv_caches() [rank0]: File "/software/users/kzawora/vllm-fork/vllm/engine/llm_engine.py", line 378, in _initialize_kv_caches [rank0]: self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) [rank0]: File "/software/users/kzawora/vllm-fork/vllm/executor/habana_executor.py", line 89, in initialize_cache [rank0]: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) [rank0]: File "/software/users/kzawora/vllm-fork/vllm/worker/habana_worker.py", line 202, in initialize_cache [rank0]: self._warm_up_model() [rank0]: File "/software/users/kzawora/vllm-fork/vllm/worker/habana_worker.py", line 220, in _warm_up_model [rank0]: self.model_runner.warmup_model(self.hpu_cache[0]) [rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: File "/software/users/kzawora/vllm-fork/vllm/worker/habana_model_runner.py", line 1412, in warmup_model [rank0]: with compile_only_mode_context(): [rank0]: File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__ [rank0]: return next(self.gen) [rank0]: File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/internal/bridge_config.py", line 20, in env_setting [rank0]: get_func = globals()['get_' + var.lower()] [rank0]: KeyError: 'get_pt_compile_only_mode' inc shutdown inc shutdown inc shutdown inc shutdown ``` Current behavior: ``` ... INFO 09-06 17:06:42 habana_executor.py:85] # HPU blocks: 10761, # CPU blocks: 910 INFO 09-06 17:06:43 habana_worker.py:201] Initializing cache engine took 47.29 GiB of device memory (54.34 GiB/94.62 GiB used) and -143.7 MiB of host memory (415 GiB/1007 GiB used) WARNING 09-06 17:06:43 habana_model_runner.py:1419] Cannot use PT_COMPILE_ONLY_MODE. Warmup time will be negatively impacted. Please update Gaudi Software Suite. INFO 09-06 17:06:43 habana_model_runner.py:1336] [Warmup][Prompt][1/23] batch_size:2 seq_len:1024 free_mem:40.28 GiB ... ```
zhouyu5
pushed a commit
to zhouyu5/vllm-fork
that referenced
this pull request
Sep 20, 2024
With PT_COMPILE_ONLY_MODE flag, graphs can be compiled without performing synLaunch. The flag has been added to the warmup phase to decrease its execution time.
zhouyu5
pushed a commit
to zhouyu5/vllm-fork
that referenced
this pull request
Sep 20, 2024
This PR fixes crashes observed on older Synapse builds introduced with HabanaAI#227. Setting PT_COMPILE_ONLY_MODE is not supported in current or older public Synapse builds, but we should not crash because of it, rather we should advise user to use the latest build. Previous behavior: ``` ... INFO 09-06 17:08:37 habana_executor.py:85] # HPU blocks: 10761, # CPU blocks: 910 INFO 09-06 17:08:37 habana_worker.py:201] Initializing cache engine took 47.29 GiB of device memory (54.34 GiB/94.62 GiB used) and -159.6 MiB of host memory (414.9 GiB/1007 GiB used) [rank0]: Traceback (most recent call last): [rank0]: File "/software/users/kzawora/vllm-utils/vllm_hpu_simple_test.py", line 9, in <module> [rank0]: llm = LLM(model="facebook/opt-125m") [rank0]: File "/software/users/kzawora/vllm-fork/vllm/entrypoints/llm.py", line 155, in __init__ [rank0]: self.llm_engine = LLMEngine.from_engine_args( [rank0]: File "/software/users/kzawora/vllm-fork/vllm/engine/llm_engine.py", line 456, in from_engine_args [rank0]: engine = cls( [rank0]: File "/software/users/kzawora/vllm-fork/vllm/engine/llm_engine.py", line 266, in __init__ [rank0]: self._initialize_kv_caches() [rank0]: File "/software/users/kzawora/vllm-fork/vllm/engine/llm_engine.py", line 378, in _initialize_kv_caches [rank0]: self.model_executor.initialize_cache(num_gpu_blocks, num_cpu_blocks) [rank0]: File "/software/users/kzawora/vllm-fork/vllm/executor/habana_executor.py", line 89, in initialize_cache [rank0]: self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) [rank0]: File "/software/users/kzawora/vllm-fork/vllm/worker/habana_worker.py", line 202, in initialize_cache [rank0]: self._warm_up_model() [rank0]: File "/software/users/kzawora/vllm-fork/vllm/worker/habana_worker.py", line 220, in _warm_up_model [rank0]: self.model_runner.warmup_model(self.hpu_cache[0]) [rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context [rank0]: return func(*args, **kwargs) [rank0]: File "/software/users/kzawora/vllm-fork/vllm/worker/habana_model_runner.py", line 1412, in warmup_model [rank0]: with compile_only_mode_context(): [rank0]: File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__ [rank0]: return next(self.gen) [rank0]: File "/usr/local/lib/python3.10/dist-packages/habana_frameworks/torch/internal/bridge_config.py", line 20, in env_setting [rank0]: get_func = globals()['get_' + var.lower()] [rank0]: KeyError: 'get_pt_compile_only_mode' inc shutdown inc shutdown inc shutdown inc shutdown ``` Current behavior: ``` ... INFO 09-06 17:06:42 habana_executor.py:85] # HPU blocks: 10761, # CPU blocks: 910 INFO 09-06 17:06:43 habana_worker.py:201] Initializing cache engine took 47.29 GiB of device memory (54.34 GiB/94.62 GiB used) and -143.7 MiB of host memory (415 GiB/1007 GiB used) WARNING 09-06 17:06:43 habana_model_runner.py:1419] Cannot use PT_COMPILE_ONLY_MODE. Warmup time will be negatively impacted. Please update Gaudi Software Suite. INFO 09-06 17:06:43 habana_model_runner.py:1336] [Warmup][Prompt][1/23] batch_size:2 seq_len:1024 free_mem:40.28 GiB ... ```
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
With PT_COMPILE_ONLY_MODE flag, graphs can be compiled without performing synLaunch. The flag has been added to the warmup phase to decrease its execution time.