Skip to content

Commit

Permalink
xpu devices support llama-7b basic mode inference (turn on BlockAtten…
Browse files Browse the repository at this point in the history
…tion)
  • Loading branch information
zhink committed Jun 12, 2024
1 parent 4609d07 commit ed186fc
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 69 deletions.
12 changes: 12 additions & 0 deletions llm/docs/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ PaddleNLP 针对于Transformer 系列编写了高性能自定义算子,提升

```shell
git clone https://github.com/PaddlePaddle/PaddleNLP
#GPU设备安装自定义算子
cd ./paddlenlp/csrc && python setup_cuda.py install
#XPU设备安装自定义算子
cd ./paddlenlp/csrc/xpu/src && sh cmake_build.sh
```

### 2.3 关闭BlockAttention的高性能推理
Expand Down Expand Up @@ -163,6 +166,9 @@ python predictor.py --model_name_or_path ./inference --inference_model --quant_
# 动态图模型推理命令参考
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn

# XPU设备动态图模型推理命令参考
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn --device xpu

# Weight Only Int8 动态图推理参考
python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --quant_type weight_only_int8 --block_attn

Expand All @@ -179,6 +185,9 @@ python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_
# 动转静命令参考
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn

# XPU设备动转静命令参考
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn --device xpu

# Weight Only Int8 动转静命令参考
python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --quant_type weight_only_int8 --block_attn

Expand All @@ -194,6 +203,9 @@ python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --infere
# 静态图推理命令参考
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn

# XPU设备静态图推理命令参考
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn --device xpu

# Weight Only Int8 静态图推理命令参考
python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --quant_type weight_only_int8 --block_attn

Expand Down
15 changes: 15 additions & 0 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,11 @@ def _create_predictor(self, predictor_args: PredictorArgument):
if predictor_args.device in paddle.device.get_all_custom_device_type():
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
config.enable_custom_device(predictor_args.device, device_id)
elif predictor_args.device == "xpu":
raise ValueError(
"you should export xpu static model with --block_attn flag and use predictor with --block_attn too"
"https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/inference.md"
)
else:
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
Expand Down Expand Up @@ -1076,6 +1081,16 @@ def _create_predictor(self, predictor_args: PredictorArgument):
if predictor_args.device in paddle.device.get_all_custom_device_type():
device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0))
config.enable_custom_device(predictor_args.device, device_id)
elif predictor_args.device == "xpu":
config.enable_xpu()
device_id = int(os.environ.get("FLAGS_selected_xpus", 0))
config.set_xpu_device_id(device_id)
xpu_config = paddle.inference.XpuConfig()
xpu_config.device_id = device_id
xpu_config.l3_size = 63*1024*1024
xpu_config.l3_autotune_size = 63*1024*1024
config.set_xpu_config(xpu_config)
config.enable_new_executor()
else:
device_id = int(os.environ.get("FLAGS_selected_gpus", 0))
config.enable_use_gpu(100, device_id)
Expand Down
3 changes: 2 additions & 1 deletion paddlenlp/experimental/transformers/bloom/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from paddle import Tensor, nn
from paddle.distributed import fleet
from paddle.nn.quant import weight_quantize
from paddlenlp_ops import get_padding_offset, get_padding_offset_v2

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedBlockMultiTransformer,
Expand Down Expand Up @@ -219,6 +218,7 @@ def set_input_embeddings(self, new_embeddings: Tensor):
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset

Check warning on line 221 in paddlenlp/experimental/transformers/bloom/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/bloom/modeling.py#L221

Added line #L221 was not covered by tests
ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down Expand Up @@ -592,6 +592,7 @@ def set_transformer_block(self, transformer_config):
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset_v2

Check warning on line 595 in paddlenlp/experimental/transformers/bloom/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/bloom/modeling.py#L595

Added line #L595 was not covered by tests
ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/experimental/transformers/chatglm/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from paddle import nn
from paddle.distributed import fleet
from paddle.nn.quant import weight_quantize
from paddlenlp_ops import get_padding_offset

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedMultiTransformerConfig,
Expand Down Expand Up @@ -273,6 +272,7 @@ def __init__(self, config: ChatGLMConfig):
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset

Check warning on line 275 in paddlenlp/experimental/transformers/chatglm/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm/modeling.py#L275

Added line #L275 was not covered by tests
ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/experimental/transformers/chatglm_v2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import paddle.distributed.fleet as fleet
import paddle.nn as nn
from paddle.nn.quant import weight_quantize
from paddlenlp_ops import get_padding_offset

from paddlenlp.experimental.transformers.fused_transformer_layers import (
FusedMultiTransformerBase,
Expand Down Expand Up @@ -202,6 +201,7 @@ def set_input_embeddings(self, value):
def remove_padding(self, input_ids, seq_lens_this_time):
cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time)
token_num = paddle.sum(seq_lens_this_time)
from paddlenlp_ops import get_padding_offset

Check warning on line 204 in paddlenlp/experimental/transformers/chatglm_v2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/chatglm_v2/modeling.py#L204

Added line #L204 was not covered by tests
ids_remove_padding, cum_offsets, padding_offset = get_padding_offset(
input_ids, cum_offsets_now, token_num, seq_lens_this_time
)
Expand Down
133 changes: 87 additions & 46 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import paddle
import paddle.distributed as dist
from paddle.framework import LayerHelper, in_dynamic_mode
from paddle.framework import LayerHelper, in_dynamic_mode, core

Check warning on line 18 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L18

Added line #L18 was not covered by tests
from paddle.incubate.nn.functional import (
fused_layer_norm,
fused_rms_norm,
Expand All @@ -28,24 +28,25 @@

from paddlenlp.utils.import_utils import is_paddlenlp_ops_available
from paddlenlp.utils.log import logger
from paddlenlp_ops import rebuild_padding_v2

Check warning on line 31 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L31

Added line #L31 was not covered by tests

if is_paddlenlp_ops_available():

if not is_paddlenlp_ops_available():
logger.warning(

Check warning on line 35 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L34-L35

Added lines #L34 - L35 were not covered by tests
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
)

if core.is_compiled_with_cuda():

Check warning on line 40 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L40

Added line #L40 was not covered by tests
from paddlenlp_ops import (
dequant_int8,
encode_rotary_qk,
qkv_transpose_split,
quant_int8,
rebuild_padding,
rebuild_padding_v2,
transpose_remove_padding,
write_cache_kv,
)
else:
logger.warning(
"The paddlenlp_ops package is not installed. you can read the docs and install it by hand, "
"you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md"
)


__all__ = [
"FusedMultiTransformerConfig",
Expand Down Expand Up @@ -1348,6 +1349,9 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer
class FusedBlockMultiTransformer(FusedMultiTransformerBase):
def __init__(self, config: FusedMultiTransformerConfig):
super().__init__(config)
if not core.is_compiled_with_cuda():
self.cache_k_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype='float32')
self.cache_v_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype='float32')

Check warning on line 1354 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1352-L1354

Added lines #L1352 - L1354 were not covered by tests

def compute_attn(
self,
Expand Down Expand Up @@ -1375,43 +1379,80 @@ def compute_attn(
v_quant_scales = self.cache_v_scales
k_dequant_scales = self.cache_k_out_scales
v_dequant_scales = self.cache_v_out_scales

fmha_out = paddle.incubate.nn.functional.block_multihead_attention(
qkv_out,
caches[2 * i],
caches[2 * i + 1],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("block_tables", None),
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
k_quant_scales[i] if k_quant_scales is not None else None,
v_quant_scales[i] if v_quant_scales is not None else None,
k_dequant_scales[i] if k_dequant_scales is not None else None,
v_dequant_scales[i] if v_dequant_scales is not None else None,
None, # qkv_out_scales
None, # qkv_bias
None, # out_shifts
None, # out_smooths
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
self.config.use_dynamic_cachekv_quant,
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
)[0]

if not core.is_compiled_with_cuda():
fmha_out = paddle.incubate.nn.functional.block_multihead_attention_xpu(

Check warning on line 1383 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1382-L1383

Added lines #L1382 - L1383 were not covered by tests
qkv_out,
caches[2 * i],
caches[2 * i + 1],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("block_tables", None),
self.cache_k_per_batch_maxs,
self.cache_v_per_batch_maxs,
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
k_quant_scales[i] if k_quant_scales is not None else None,
v_quant_scales[i] if v_quant_scales is not None else None,
k_dequant_scales[i] if k_dequant_scales is not None else None,
v_dequant_scales[i] if v_dequant_scales is not None else None,
None, # qkv_out_scales
None, # qkv_bias
None, # out_shifts
None, # out_smooths
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
self.config.use_dynamic_cachekv_quant,
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
)[0]
else:
fmha_out = paddle.incubate.nn.functional.block_multihead_attention(

Check warning on line 1421 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L1421

Added line #L1421 was not covered by tests
qkv_out,
caches[2 * i],
caches[2 * i + 1],
kwargs.get("seq_lens_encoder", None),
kwargs.get("seq_lens_decoder", None),
kwargs.get("seq_lens_this_time", None),
kwargs.get("padding_offsets", None),
kwargs.get("cum_offsets", None),
kwargs.get("cu_seqlens_q", None),
kwargs.get("cu_seqlens_k", None),
kwargs.get("block_tables", None),
pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache
pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache
k_quant_scales[i] if k_quant_scales is not None else None,
v_quant_scales[i] if v_quant_scales is not None else None,
k_dequant_scales[i] if k_dequant_scales is not None else None,
v_dequant_scales[i] if v_dequant_scales is not None else None,
None, # qkv_out_scales
None, # qkv_bias
None, # out_shifts
None, # out_smooths
kwargs.get("max_enc_len_this_time", None),
kwargs.get("max_dec_len_this_time", None),
rotary_embs,
attn_mask,
kwargs.get("tgt_mask", None),
kwargs.get("max_input_length", -1),
kwargs.get("block_size", 64),
self.use_neox_rotary_style,
self.config.use_dynamic_cachekv_quant,
quant_round_type=self.config.quant_round_type,
quant_max_bound=self.config.quant_max_bound,
quant_min_bound=self.config.quant_min_bound,
)[0]
out_linear_out = self.compute_out_linear(fmha_out, i)

return out_linear_out
Expand Down
20 changes: 9 additions & 11 deletions paddlenlp/experimental/transformers/generation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,6 @@

import paddle
import paddle.nn.functional as F
from paddlenlp_ops import (
get_token_penalty_multi_scores,
get_token_penalty_multi_scores_v2,
save_output,
save_with_output,
set_stop_value_multi_ends,
set_stop_value_multi_ends_v2,
set_value_by_flags_and_idx,
set_value_by_flags_and_idx_v2,
update_inputs,
)

from paddlenlp.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList

Expand Down Expand Up @@ -208,6 +197,7 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e
model_kwargs["stop_flags"] = paddle.logical_or(model_kwargs["stop_flags"], length_cond)
if cache is None:
next_tokens = paddle.where(just_decoder, paddle.full_like(next_tokens, -1), next_tokens)
from paddlenlp_ops import set_stop_value_multi_ends

Check warning on line 200 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L200

Added line #L200 was not covered by tests
next_tokens, model_kwargs["stop_flags"] = set_stop_value_multi_ends(
next_tokens, model_kwargs["stop_flags"], eos_token_id, 2
) # multi ends
Expand Down Expand Up @@ -305,6 +295,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
) # not update when continue decode
else:
step_idx = model_kwargs["step_idx"]
from paddlenlp_ops import set_value_by_flags_and_idx

Check warning on line 298 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L298

Added line #L298 was not covered by tests
model_kwargs["stop_flags"] = set_value_by_flags_and_idx(
model_kwargs["pre_ids"],
model_kwargs["tgt_ids"],
Expand All @@ -316,6 +307,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
logits = paddle.cast(logits, paddle.float32)
logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori)

from paddlenlp_ops import get_token_penalty_multi_scores

Check warning on line 310 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L310

Added line #L310 was not covered by tests
logits = get_token_penalty_multi_scores(
model_kwargs["pre_ids"],
logits,
Expand Down Expand Up @@ -347,6 +339,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs):
else:
model_kwargs["all_input_ids"] = paddle.concat([model_kwargs["all_input_ids"], next_tokens], axis=1)

from paddlenlp_ops import save_with_output

Check warning on line 342 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L342

Added line #L342 was not covered by tests
save_with_output(
next_tokens,
batch_idx,
Expand Down Expand Up @@ -635,6 +628,7 @@ def _post_process_(
model_kwargs,
):
step_idx = model_kwargs["step_idx"]
from paddlenlp_ops import set_value_by_flags_and_idx_v2

Check warning on line 631 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L631

Added line #L631 was not covered by tests
set_value_by_flags_and_idx_v2(
model_kwargs["pre_ids"],
model_kwargs["input_ids"],
Expand All @@ -648,6 +642,7 @@ def _post_process_(
logits = paddle.cast(outputs, paddle.float32)

# pre-process distribution
from paddlenlp_ops import get_token_penalty_multi_scores_v2

Check warning on line 645 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L645

Added line #L645 was not covered by tests
logits = get_token_penalty_multi_scores_v2(
model_kwargs["pre_ids"],
logits,
Expand All @@ -673,11 +668,13 @@ def _post_process_(
paddle.assign(step_idx, model_kwargs["step_idx"])
length_cond = paddle.greater_equal(step_idx, model_kwargs["max_dec_len"])
stop_flags = paddle.logical_or(model_kwargs["stop_flags"], length_cond)
from paddlenlp_ops import set_stop_value_multi_ends_v2

Check warning on line 671 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L671

Added line #L671 was not covered by tests
set_stop_value_multi_ends_v2(
next_tokens, stop_flags, model_kwargs["seq_lens_this_time"], eos_token_id, model_kwargs["next_tokens"]
) # multi ends
paddle.assign(stop_flags, model_kwargs["stop_flags"])
# update inputs
from paddlenlp_ops import update_inputs

Check warning on line 677 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L677

Added line #L677 was not covered by tests
update_inputs(
stop_flags,
model_kwargs["not_need_stop"],
Expand All @@ -689,6 +686,7 @@ def _post_process_(
next_tokens,
model_kwargs["is_block_step"],
)
from paddlenlp_ops import save_output

Check warning on line 689 in paddlenlp/experimental/transformers/generation_utils.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/generation_utils.py#L689

Added line #L689 was not covered by tests
save_output(next_tokens, model_kwargs["not_need_stop"], self.config.tensor_parallel_rank)
return next_tokens

Expand Down
Loading

0 comments on commit ed186fc

Please sign in to comment.