Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[intel_hpu] initial commit for intel_hpu support #9273

Merged
merged 15 commits into from
Oct 31, 2024
104 changes: 104 additions & 0 deletions llm/intel_hpu/llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
## 🚣‍♂️ 使用 PaddleNLP 在 Intel HPU 下跑通 llama2-7b 模型 🚣
PaddleNLP 在 Intel® Gaudi®2D([了解 Gaudi](https://docs.habana.ai/en/latest/index.html))上对 llama2-7B 模型进行了深度适配和优化,下面给出详细安装步骤。

## 🚀 快速开始 🚀

### (0)在开始之前,您需要有一台 Intel Gaudi 机器,对此机器的系统要求如下:

| 芯片类型 | 卡型号 | 驱动版本 |
| --- | --- | --- |
| Gaudi | 225D | 1.17.0 |


### (1)环境准备:(这将花费您5~15min 时间)
1. 拉取镜像
```
# 注意此镜像仅为开发环境,镜像中不包含预编译的飞桨安装包
docker pull vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest
```
2. 参考如下命令启动容器
```
docker run -it --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_MCA_btl_vader_single_copy_mechanism=none --cap-add=sys_nice --net=host --ipc=host vault.habana.ai/gaudi-docker/1.17.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest
```
3. 安装 paddle
```
# paddlepaddle『飞桨』深度学习框架,提供运算基础能力
pip install paddlepaddle==0.0.0 -f https://www.paddlepaddle.org.cn/whl/linux/cpu-mkl/develop.html
```
4. 安装 paddleCustomDevice
```
# paddleCustomDevice是paddlepaddle『飞桨』深度学习框架的自定义硬件接入实现,提供Intel HPU的算子实现。
git clone --recursive https://github.com/PaddlePaddle/PaddleCustomDevice
cd PaddleCustomDevice
git submodule sync
git submodule update --remote --init --recursive
cd backends/intel_hpu/
mkdir build && cd build
cmake ..
make -j8
pip install dist/paddle_intel_hpu-0.0.1-cp310-cp310-linux_x86_64.whl
```
5. 克隆 PaddleNLP 仓库代码,并安装依赖
```
# PaddleNLP是基于paddlepaddle『飞桨』的自然语言处理和大语言模型(LLM)开发库,存放了基于『飞桨』框架实现的各种大模型,llama2-7B模型也包含其中。为了便于您更好地使用PaddleNLP,您需要clone整个仓库。
git clone https://github.com/PaddlePaddle/PaddleNLP.git
cd PaddleNLP
python -m pip install -r requirements.txt
python -m pip install -e .
```

### (2)推理:(这将花费您10~15min 时间)
1. 单卡推理

执行如下命令进行推理:
```bash
python inference_hpu.py
```

成功运行后,可以查看到推理结果的生成,样例如下:
```
[2024-10-25 02:42:42,220] [ INFO] - We are using <class 'paddlenlp.transformers.llama.tokenizer.LlamaTokenizer'> to load 'meta-llama/Llama-2-7b-chat'.
[2024-10-25 02:42:42,427] [ INFO] - We are using <class 'paddlenlp.transformers.llama.modeling.LlamaForCausalLM'> to load 'meta-llama/Llama-2-7b-chat'.
[2024-10-25 02:42:42,427] [ INFO] - Loading configuration file /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/config.json
[2024-10-25 02:42:42,428] [ INFO] - Loading weights file from cache at /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/model_state.pdparams
[2024-10-25 02:43:32,871] [ INFO] - Loaded weights file from disk, setting weights to model.
[2024-10-25 02:44:15,226] [ INFO] - All model checkpoint weights were used when initializing LlamaForCausalLM.

[2024-10-25 02:44:15,226] [ INFO] - All the weights of LlamaForCausalLM were initialized from the model checkpoint at meta-llama/Llama-2-7b-chat.
If your task is similar to the task the model of the checkpoint was trained on, you can already use LlamaForCausalLM for predictions without further training.
[2024-10-25 02:44:15,229] [ INFO] - Loading configuration file /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/generation_config.json

['myself. I am a 35 year old woman from the United States. I am a writer and artist, and I have been living in Japan for the past 5 years. I am originally from the Midwest, but I have lived in several different places around the world, including California, New York, and now Japan.\nI am passionate about many things, including art, writing, music, and travel. I love to explore new places and cultures, and I am always looking for new inspiration for my art and writing. I am also a big fan of Japanese culture, and I try to learn as much']
```
2. 多卡推理

执行如下命令进行推理:
```bash
bash test_llama_2x.sh
```
成功运行后,可以查看到推理结果的生成,样例如下:
```bash
[2024-10-29 11:24:39,468] [ INFO] - We are using <class 'paddlenlp.transformers.llama.tokenizer.LlamaTokenizer'> to load 'meta-llama/Llama-2-7b-chat'.
[2024-10-29 11:24:40,705] [ INFO] distributed_strategy.py:214 - distributed strategy initialized
I1029 11:24:40.706755 14711 tcp_utils.cc:181] The server starts to listen on IP_ANY:59129
I1029 11:24:40.706897 14711 tcp_utils.cc:130] Successfully connected to 127.0.0.1:59129
[2024-10-29 11:24:42,740] [ INFO] topology.py:357 - Total 2 pipe comm group(s) create successfully!
[2024-10-29 11:24:52,064] [ INFO] topology.py:357 - Total 2 data comm group(s) create successfully!
[2024-10-29 11:24:52,064] [ INFO] topology.py:357 - Total 1 model comm group(s) create successfully!
[2024-10-29 11:24:52,065] [ INFO] topology.py:357 - Total 2 sharding comm group(s) create successfully!
[2024-10-29 11:24:52,065] [ INFO] topology.py:279 - HybridParallelInfo: rank_id: 0, mp_degree: 2, sharding_degree: 1, pp_degree: 1, dp_degree: 1, sep_degree: 1, mp_group: [0, 1], sharding_group: [0], pp_group: [0], dp_group: [0], sep:group: None, check/clip group: [0, 1]
[2024-10-29 11:24:52,067] [ INFO] - We are using <class 'paddlenlp.transformers.llama.modeling.LlamaForCausalLM'> to load 'meta-llama/Llama-2-7b-chat'.
[2024-10-29 11:24:52,067] [ INFO] - Loading configuration file /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/config.json
[2024-10-29 11:24:52,068] [ INFO] - Loading weights file from cache at /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/model_state.pdparams
[2024-10-29 11:25:43,202] [ INFO] - Starting to convert orignal state_dict to tensor parallel state_dict.
[2024-10-29 11:25:45,125] [ INFO] - Loaded weights file from disk, setting weights to model.
[2024-10-29 11:26:04,008] [ INFO] - All model checkpoint weights were used when initializing LlamaForCausalLM.
[2024-10-29 11:26:04,008] [ INFO] - All the weights of LlamaForCausalLM were initialized from the model checkpoint at meta-llama/Llama-2-7b-chat.
If your task is similar to the task the model of the checkpoint was trained on, you can already use LlamaForCausalLM for predictions without further training.
[2024-10-29 11:26:04,010] [ INFO] - Loading configuration file /root/.paddlenlp/models/meta-llama/Llama-2-7b-chat/generation_config.json

['myself\nHello everyone my name is [Your Name], and I am a new member of this community']
I1029 11:26:16.184163 14767 tcp_store.cc:293] receive shutdown event and so quit from MasterDaemon run loop
LAUNCH INFO 2024-10-29 11:26:17,186 Pod completed
LAUNCH INFO 2024-10-29 11:26:17,186 Exit code 0
```
40 changes: 40 additions & 0 deletions llm/intel_hpu/llama/inference_hpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer

# import os
# os.environ['ENABLE_EXPERIMENTAL_FLAGS'] = '1'
# os.environ['VISUALIZATION_MODE'] = '0'
# os.environ['GRAPH_VISUALIZATION'] = '1'
# os.environ["HABANA_LOGS"] = "logs"
# os.environ["LOG_LEVEL_ALL"] = "0"
# os.environ['GLOG_v'] = '10'


paddle.set_device("intel_hpu")
paddle.set_default_dtype("bfloat16")

model = "meta-llama/Llama-2-7b-chat"
tokenizer = AutoTokenizer.from_pretrained(model)
model = AutoModelForCausalLM.from_pretrained(model, dtype="bfloat16")

input_features = tokenizer("please introduce llm", return_tensors="pd")

with paddle.amp.auto_cast(dtype="bfloat16", custom_white_list={"elementwise_add", "rms_norm"}):
outputs = model.generate(**input_features, max_length=128)

print(tokenizer.batch_decode(outputs[0]))
48 changes: 48 additions & 0 deletions llm/intel_hpu/llama/test_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.distributed import fleet

from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer

paddle.set_device("intel_hpu")
paddle.set_default_dtype("bfloat16")

model = "meta-llama/Llama-2-7b-chat"
tokenizer = AutoTokenizer.from_pretrained(model)
strategy = fleet.DistributedStrategy()
strategy.hybrid_configs = {
"dp_degree": 1,
"mp_degree": 2,
"pp_degree": 1,
"sharding_degree": 1,
}
fleet.init(is_collective=True, strategy=strategy)
hcg = fleet.get_hybrid_communicate_group()
tensor_parallel_rank = hcg.get_model_parallel_rank()

model = AutoModelForCausalLM.from_pretrained(
model,
tensor_parallel_degree=2,
tensor_parallel_rank=tensor_parallel_rank,
dtype="bfloat16",
)
input_features = tokenizer("please introduce llm", return_tensors="pd")


with paddle.amp.auto_cast(dtype="bfloat16", custom_white_list={"elementwise_add", "rms_norm"}):
outputs = model.generate(**input_features, max_length=20)

print(tokenizer.batch_decode(outputs[0]))
40 changes: 40 additions & 0 deletions llm/intel_hpu/llama/test_llama_2x.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -ex

# export LOG_LEVEL_ALL=0
export HABANA_LOGS=./logs

# export HCCL_COMM_ID=127.0.0.1:5555
# export INTEL_HPU_VISIBLE_DEVICES=0,1 # 3,4
export INTEL_HPU_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
export PADDLE_DISTRI_BACKEND=xccl
export PADDLE_XCCL_BACKEND=intel_hpu
# PYTHONPATH=../../:$PYTHONPATH \
export FLAGS_intel_hpu_runtime_debug=0

# export HABANA_PROFILE=1
# export HABANA_PROFILE_WRITE_HLTV_WITH_HOST=1

echo $INTEL_HPU_VISIBLE_DEVICES

# export GRAPH_VISUALIZATION=1
# export ENABLE_EXPERIMENTAL_FLAGS=1
# export VISUALIZATION_MODE=0

#GLOG_v=10
python -m paddle.distributed.launch --devices "3,5" test_llama.py 2>&1 | tee test_llama_2x.log


40 changes: 38 additions & 2 deletions paddlenlp/transformers/llama/fusion_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,32 @@
rotary_emb,
context_parallel_degree=-1,
):
if get_env_device() != "gcu":
if get_env_device() not in ["gcu", "intel_hpu"]:

Check warning on line 67 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L67

Added line #L67 was not covered by tests
assert past_key_value is None, "fuse rotary not support cache kv for now"
batch_size, seq_length, num_heads, head_dim = query_states.shape
_, kv_seq_len, num_key_value_heads, _ = key_states.shape
if context_parallel_degree > 1:
assert get_env_device() == "gpu", "context parallel only support cuda device for now"
kv_seq_len *= context_parallel_degree
if get_env_device() != "gcu":
if get_env_device() not in ["gcu", "intel_hpu"]:

Check warning on line 74 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L74

Added line #L74 was not covered by tests
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
if get_env_device() == "npu":
query_states = core.eager._run_custom_op("fused_rope", query_states, cos, sin)[0]
key_states = core.eager._run_custom_op("fused_rope", key_states, cos, sin)[0]
elif get_env_device() == "intel_hpu":
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-3]
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
cos = cos.squeeze().unsqueeze(0).unsqueeze(0)
sin = sin.squeeze().unsqueeze(0).unsqueeze(0)
query_states, _, _ = paddle.incubate.nn.functional.fused_rotary_position_embedding(

Check warning on line 85 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L79-L85

Added lines #L79 - L85 were not covered by tests
paddle.transpose(query_states, [0, 2, 1, 3]), None, None, sin=sin, cos=cos, position_ids=position_ids
)
key_states, _, _ = paddle.incubate.nn.functional.fused_rotary_position_embedding(

Check warning on line 88 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L88

Added line #L88 was not covered by tests
paddle.transpose(key_states, [0, 2, 1, 3]), None, None, sin=sin, cos=cos, position_ids=position_ids
)
query_states = paddle.transpose(query_states, [0, 2, 1, 3])
key_states = paddle.transpose(key_states, [0, 2, 1, 3])

Check warning on line 92 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L91-L92

Added lines #L91 - L92 were not covered by tests
elif get_env_device() == "gcu":
cos_sin = rotary_emb.get_fused_cos_sin(value_states, seq_len=kv_seq_len)
query_states, key_states = core.eager._run_custom_op(
Expand Down Expand Up @@ -132,6 +146,10 @@
return core.eager._run_custom_op("rms_norm_mlu", hidden_states, weight, variance_epsilon)[0]
elif get_env_device() == "gcu":
return core.eager._run_custom_op("rms_norm_gcu", hidden_states, weight, variance_epsilon)[0]
elif get_env_device() == "intel_hpu":
return paddle.incubate.nn.functional.fused_rms_norm(

Check warning on line 150 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L149-L150

Added lines #L149 - L150 were not covered by tests
hidden_states, weight, None, variance_epsilon, hidden_states.dim() - 1
)[0]
elif get_env_device() == "xpu":
try:
import paddle_xpu_nn # noqa: F821
Expand Down Expand Up @@ -205,6 +223,24 @@
attention_mask is None,
True,
)[0]
elif get_env_device() == "intel_hpu":
if config.context_parallel_degree > 1:
raise ValueError("Context parallel is not implemented for intel_hpu")
scaling_factor = query_states.shape[3] ** -0.5
attention_mask = attention_mask.astype(query_states.dtype)
attn_output = paddle.incubate.nn.functional.fused_dot_product_attention(

Check warning on line 231 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L226-L231

Added lines #L226 - L231 were not covered by tests
query_states,
key_states,
value_states,
attention_mask,
scaling_factor,
0.0,
False,
attention_mask is None,
None,
False,
)
attn_output = paddle.transpose(attn_output, [0, 2, 1, 3])

Check warning on line 243 in paddlenlp/transformers/llama/fusion_ops.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/fusion_ops.py#L243

Added line #L243 was not covered by tests
else:
if config.context_parallel_degree > 1:
attn_output = RingFlashAttention.apply(
Expand Down
Loading
Loading