diff --git a/llm/docs/inference.md b/llm/docs/inference.md index 2f30764f19c1..a20e3a32d614 100644 --- a/llm/docs/inference.md +++ b/llm/docs/inference.md @@ -83,7 +83,10 @@ PaddleNLP 针对于Transformer 系列编写了高性能自定义算子,提升 ```shell git clone https://github.com/PaddlePaddle/PaddleNLP +#GPU设备安装自定义算子 cd ./paddlenlp/csrc && python setup_cuda.py install +#XPU设备安装自定义算子 +cd ./paddlenlp/csrc/xpu/src && sh cmake_build.sh ``` ### 2.3 关闭BlockAttention的高性能推理 @@ -163,6 +166,9 @@ python predictor.py --model_name_or_path ./inference --inference_model --quant_ # 动态图模型推理命令参考 python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn +# XPU设备动态图模型推理命令参考 +python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn --device xpu + # Weight Only Int8 动态图推理参考 python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --quant_type weight_only_int8 --block_attn @@ -179,6 +185,9 @@ python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_ # 动转静命令参考 python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn +# XPU设备动转静命令参考 +python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn --device xpu + # Weight Only Int8 动转静命令参考 python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --quant_type weight_only_int8 --block_attn @@ -194,6 +203,9 @@ python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --infere # 静态图推理命令参考 python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn +# XPU设备静态图推理命令参考 +python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn --device xpu + # Weight Only Int8 静态图推理命令参考 python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --quant_type weight_only_int8 --block_attn diff --git a/llm/predictor.py b/llm/predictor.py index 5b893d80bb86..b63df4cd8ca1 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -650,6 +650,11 @@ def _create_predictor(self, predictor_args: PredictorArgument): if predictor_args.device in paddle.device.get_all_custom_device_type(): device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0)) config.enable_custom_device(predictor_args.device, device_id) + elif predictor_args.device == "xpu": + raise ValueError( + "you should export xpu static model with --block_attn flag and use predictor with --block_attn too" + "https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/inference.md" + ) else: device_id = int(os.environ.get("FLAGS_selected_gpus", 0)) config.enable_use_gpu(100, device_id) @@ -1076,6 +1081,16 @@ def _create_predictor(self, predictor_args: PredictorArgument): if predictor_args.device in paddle.device.get_all_custom_device_type(): device_id = int(os.environ.get("FLAGS_selected_{}s".format(predictor_args.device), 0)) config.enable_custom_device(predictor_args.device, device_id) + elif predictor_args.device == "xpu": + config.enable_xpu() + device_id = int(os.environ.get("FLAGS_selected_xpus", 0)) + config.set_xpu_device_id(device_id) + xpu_config = paddle.inference.XpuConfig() + xpu_config.device_id = device_id + xpu_config.l3_size = 63*1024*1024 + xpu_config.l3_autotune_size = 63*1024*1024 + config.set_xpu_config(xpu_config) + config.enable_new_executor() else: device_id = int(os.environ.get("FLAGS_selected_gpus", 0)) config.enable_use_gpu(100, device_id) diff --git a/paddlenlp/experimental/transformers/bloom/modeling.py b/paddlenlp/experimental/transformers/bloom/modeling.py index 659826fe6f1b..9cb8f8133fb0 100644 --- a/paddlenlp/experimental/transformers/bloom/modeling.py +++ b/paddlenlp/experimental/transformers/bloom/modeling.py @@ -19,7 +19,6 @@ from paddle import Tensor, nn from paddle.distributed import fleet from paddle.nn.quant import weight_quantize -from paddlenlp_ops import get_padding_offset, get_padding_offset_v2 from paddlenlp.experimental.transformers.fused_transformer_layers import ( FusedBlockMultiTransformer, @@ -219,6 +218,7 @@ def set_input_embeddings(self, new_embeddings: Tensor): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time ) @@ -592,6 +592,7 @@ def set_transformer_block(self, transformer_config): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( input_ids, cum_offsets_now, token_num, seq_lens_this_time ) diff --git a/paddlenlp/experimental/transformers/chatglm/modeling.py b/paddlenlp/experimental/transformers/chatglm/modeling.py index 5309ccf1d042..59eed00f4aa6 100644 --- a/paddlenlp/experimental/transformers/chatglm/modeling.py +++ b/paddlenlp/experimental/transformers/chatglm/modeling.py @@ -18,7 +18,6 @@ from paddle import nn from paddle.distributed import fleet from paddle.nn.quant import weight_quantize -from paddlenlp_ops import get_padding_offset from paddlenlp.experimental.transformers.fused_transformer_layers import ( FusedMultiTransformerConfig, @@ -273,6 +272,7 @@ def __init__(self, config: ChatGLMConfig): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time ) diff --git a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py index 75dd08396398..86f5114e28c0 100644 --- a/paddlenlp/experimental/transformers/chatglm_v2/modeling.py +++ b/paddlenlp/experimental/transformers/chatglm_v2/modeling.py @@ -19,7 +19,6 @@ import paddle.distributed.fleet as fleet import paddle.nn as nn from paddle.nn.quant import weight_quantize -from paddlenlp_ops import get_padding_offset from paddlenlp.experimental.transformers.fused_transformer_layers import ( FusedMultiTransformerBase, @@ -202,6 +201,7 @@ def set_input_embeddings(self, value): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time ) diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 3d7010599377..05d452789135 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -15,7 +15,7 @@ import paddle import paddle.distributed as dist -from paddle.framework import LayerHelper, in_dynamic_mode +from paddle.framework import LayerHelper, in_dynamic_mode, core from paddle.incubate.nn.functional import ( fused_layer_norm, fused_rms_norm, @@ -28,24 +28,25 @@ from paddlenlp.utils.import_utils import is_paddlenlp_ops_available from paddlenlp.utils.log import logger +from paddlenlp_ops import rebuild_padding_v2 -if is_paddlenlp_ops_available(): + +if not is_paddlenlp_ops_available(): + logger.warning( + "The paddlenlp_ops package is not installed. you can read the docs and install it by hand, " + "you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" + ) + +if core.is_compiled_with_cuda(): from paddlenlp_ops import ( dequant_int8, encode_rotary_qk, qkv_transpose_split, quant_int8, rebuild_padding, - rebuild_padding_v2, transpose_remove_padding, write_cache_kv, ) -else: - logger.warning( - "The paddlenlp_ops package is not installed. you can read the docs and install it by hand, " - "you can refer to: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" - ) - __all__ = [ "FusedMultiTransformerConfig", @@ -1348,6 +1349,9 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer class FusedBlockMultiTransformer(FusedMultiTransformerBase): def __init__(self, config: FusedMultiTransformerConfig): super().__init__(config) + if not core.is_compiled_with_cuda(): + self.cache_k_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype='float32') + self.cache_v_per_batch_maxs = paddle.full(shape=[10, 6], fill_value=0, dtype='float32') def compute_attn( self, @@ -1375,43 +1379,80 @@ def compute_attn( v_quant_scales = self.cache_v_scales k_dequant_scales = self.cache_k_out_scales v_dequant_scales = self.cache_v_out_scales - - fmha_out = paddle.incubate.nn.functional.block_multihead_attention( - qkv_out, - caches[2 * i], - caches[2 * i + 1], - kwargs.get("seq_lens_encoder", None), - kwargs.get("seq_lens_decoder", None), - kwargs.get("seq_lens_this_time", None), - kwargs.get("padding_offsets", None), - kwargs.get("cum_offsets", None), - kwargs.get("cu_seqlens_q", None), - kwargs.get("cu_seqlens_k", None), - kwargs.get("block_tables", None), - pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache - pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache - k_quant_scales[i] if k_quant_scales is not None else None, - v_quant_scales[i] if v_quant_scales is not None else None, - k_dequant_scales[i] if k_dequant_scales is not None else None, - v_dequant_scales[i] if v_dequant_scales is not None else None, - None, # qkv_out_scales - None, # qkv_bias - None, # out_shifts - None, # out_smooths - kwargs.get("max_enc_len_this_time", None), - kwargs.get("max_dec_len_this_time", None), - rotary_embs, - attn_mask, - kwargs.get("tgt_mask", None), - kwargs.get("max_input_length", -1), - kwargs.get("block_size", 64), - self.use_neox_rotary_style, - self.config.use_dynamic_cachekv_quant, - quant_round_type=self.config.quant_round_type, - quant_max_bound=self.config.quant_max_bound, - quant_min_bound=self.config.quant_min_bound, - )[0] - + if not core.is_compiled_with_cuda(): + fmha_out = paddle.incubate.nn.functional.block_multihead_attention_xpu( + qkv_out, + caches[2 * i], + caches[2 * i + 1], + kwargs.get("seq_lens_encoder", None), + kwargs.get("seq_lens_decoder", None), + kwargs.get("seq_lens_this_time", None), + kwargs.get("padding_offsets", None), + kwargs.get("cum_offsets", None), + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), + kwargs.get("block_tables", None), + self.cache_k_per_batch_maxs, + self.cache_v_per_batch_maxs, + pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache + pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache + k_quant_scales[i] if k_quant_scales is not None else None, + v_quant_scales[i] if v_quant_scales is not None else None, + k_dequant_scales[i] if k_dequant_scales is not None else None, + v_dequant_scales[i] if v_dequant_scales is not None else None, + None, # qkv_out_scales + None, # qkv_bias + None, # out_shifts + None, # out_smooths + kwargs.get("max_enc_len_this_time", None), + kwargs.get("max_dec_len_this_time", None), + rotary_embs, + attn_mask, + kwargs.get("tgt_mask", None), + kwargs.get("max_input_length", -1), + kwargs.get("block_size", 64), + self.use_neox_rotary_style, + self.config.use_dynamic_cachekv_quant, + quant_round_type=self.config.quant_round_type, + quant_max_bound=self.config.quant_max_bound, + quant_min_bound=self.config.quant_min_bound, + )[0] + else: + fmha_out = paddle.incubate.nn.functional.block_multihead_attention( + qkv_out, + caches[2 * i], + caches[2 * i + 1], + kwargs.get("seq_lens_encoder", None), + kwargs.get("seq_lens_decoder", None), + kwargs.get("seq_lens_this_time", None), + kwargs.get("padding_offsets", None), + kwargs.get("cum_offsets", None), + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), + kwargs.get("block_tables", None), + pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache + pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache + k_quant_scales[i] if k_quant_scales is not None else None, + v_quant_scales[i] if v_quant_scales is not None else None, + k_dequant_scales[i] if k_dequant_scales is not None else None, + v_dequant_scales[i] if v_dequant_scales is not None else None, + None, # qkv_out_scales + None, # qkv_bias + None, # out_shifts + None, # out_smooths + kwargs.get("max_enc_len_this_time", None), + kwargs.get("max_dec_len_this_time", None), + rotary_embs, + attn_mask, + kwargs.get("tgt_mask", None), + kwargs.get("max_input_length", -1), + kwargs.get("block_size", 64), + self.use_neox_rotary_style, + self.config.use_dynamic_cachekv_quant, + quant_round_type=self.config.quant_round_type, + quant_max_bound=self.config.quant_max_bound, + quant_min_bound=self.config.quant_min_bound, + )[0] out_linear_out = self.compute_out_linear(fmha_out, i) return out_linear_out diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 1aa94f524331..91311ab1a7d2 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -17,17 +17,6 @@ import paddle import paddle.nn.functional as F -from paddlenlp_ops import ( - get_token_penalty_multi_scores, - get_token_penalty_multi_scores_v2, - save_output, - save_with_output, - set_stop_value_multi_ends, - set_stop_value_multi_ends_v2, - set_value_by_flags_and_idx, - set_value_by_flags_and_idx_v2, - update_inputs, -) from paddlenlp.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList @@ -208,6 +197,7 @@ def update_model_kwargs_for_generation(self, cache, just_decoder, next_tokens, e model_kwargs["stop_flags"] = paddle.logical_or(model_kwargs["stop_flags"], length_cond) if cache is None: next_tokens = paddle.where(just_decoder, paddle.full_like(next_tokens, -1), next_tokens) + from paddlenlp_ops import set_stop_value_multi_ends next_tokens, model_kwargs["stop_flags"] = set_stop_value_multi_ends( next_tokens, model_kwargs["stop_flags"], eos_token_id, 2 ) # multi ends @@ -305,6 +295,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): ) # not update when continue decode else: step_idx = model_kwargs["step_idx"] + from paddlenlp_ops import set_value_by_flags_and_idx model_kwargs["stop_flags"] = set_value_by_flags_and_idx( model_kwargs["pre_ids"], model_kwargs["tgt_ids"], @@ -316,6 +307,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): logits = paddle.cast(logits, paddle.float32) logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori) + from paddlenlp_ops import get_token_penalty_multi_scores logits = get_token_penalty_multi_scores( model_kwargs["pre_ids"], logits, @@ -347,6 +339,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): else: model_kwargs["all_input_ids"] = paddle.concat([model_kwargs["all_input_ids"], next_tokens], axis=1) + from paddlenlp_ops import save_with_output save_with_output( next_tokens, batch_idx, @@ -635,6 +628,7 @@ def _post_process_( model_kwargs, ): step_idx = model_kwargs["step_idx"] + from paddlenlp_ops import set_value_by_flags_and_idx_v2 set_value_by_flags_and_idx_v2( model_kwargs["pre_ids"], model_kwargs["input_ids"], @@ -648,6 +642,7 @@ def _post_process_( logits = paddle.cast(outputs, paddle.float32) # pre-process distribution + from paddlenlp_ops import get_token_penalty_multi_scores_v2 logits = get_token_penalty_multi_scores_v2( model_kwargs["pre_ids"], logits, @@ -673,11 +668,13 @@ def _post_process_( paddle.assign(step_idx, model_kwargs["step_idx"]) length_cond = paddle.greater_equal(step_idx, model_kwargs["max_dec_len"]) stop_flags = paddle.logical_or(model_kwargs["stop_flags"], length_cond) + from paddlenlp_ops import set_stop_value_multi_ends_v2 set_stop_value_multi_ends_v2( next_tokens, stop_flags, model_kwargs["seq_lens_this_time"], eos_token_id, model_kwargs["next_tokens"] ) # multi ends paddle.assign(stop_flags, model_kwargs["stop_flags"]) # update inputs + from paddlenlp_ops import update_inputs update_inputs( stop_flags, model_kwargs["not_need_stop"], @@ -689,6 +686,7 @@ def _post_process_( next_tokens, model_kwargs["is_block_step"], ) + from paddlenlp_ops import save_output save_output(next_tokens, model_kwargs["not_need_stop"], self.config.tensor_parallel_rank) return next_tokens diff --git a/paddlenlp/experimental/transformers/gpt/modeling.py b/paddlenlp/experimental/transformers/gpt/modeling.py index 6627c9e42abb..70b822180652 100644 --- a/paddlenlp/experimental/transformers/gpt/modeling.py +++ b/paddlenlp/experimental/transformers/gpt/modeling.py @@ -17,7 +17,6 @@ from paddle import nn from paddle.distributed import fleet from paddle.nn.quant import weight_quantize -from paddlenlp_ops import get_padding_offset from paddlenlp.experimental.transformers.fused_transformer_layers import ( FusedMultiTransformerBase, @@ -203,6 +202,7 @@ def set_input_embeddings(self, value): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time ) diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 537f0e70d40a..cbc59f82935b 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -22,11 +22,6 @@ from paddle import nn from paddle.distributed import fleet from paddle.nn.quant import weight_quantize -from paddlenlp_ops import ( - fused_get_rotary_embedding, - get_padding_offset, - get_padding_offset_v2, -) from paddlenlp.experimental.model_utils import ( ActScalesLoader, @@ -47,7 +42,6 @@ GenerationInferenceModel, ) from paddlenlp.transformers import LlamaConfig, LlamaPretrainedModel -from paddlenlp.transformers.conversion_utils import split_param_func from paddlenlp.transformers.llama.modeling import LlamaLMHead from paddlenlp.transformers.model_outputs import ( BaseModelOutputWithPastAndCrossAttentions, @@ -350,6 +344,7 @@ def set_input_embeddings(self, value): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time ) @@ -436,6 +431,7 @@ def forward( theta = 10000.0 if not is_decoder and pre_caches is not None: position_offset = 128 + from paddlenlp_ops import fused_get_rotary_embedding new_rope = fused_get_rotary_embedding( input_ids, position_ids, self.head_dim_shape_tensor, position_offset, theta, True ) @@ -827,6 +823,7 @@ def set_transformer_block(self, transformer_config): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset_v2 ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( input_ids, cum_offsets_now, token_num, seq_lens_this_time ) diff --git a/paddlenlp/experimental/transformers/opt/modeling.py b/paddlenlp/experimental/transformers/opt/modeling.py index afcb1331b52c..6c176c71b51d 100644 --- a/paddlenlp/experimental/transformers/opt/modeling.py +++ b/paddlenlp/experimental/transformers/opt/modeling.py @@ -18,7 +18,6 @@ import numpy as np import paddle import paddle.nn as nn -from paddlenlp_ops import get_padding_offset from paddlenlp.experimental.transformers.fused_transformer_layers import ( FusedMultiTransformerBase, @@ -147,6 +146,7 @@ def set_input_embeddings(self, value): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time ) diff --git a/paddlenlp/experimental/transformers/qwen/modeling.py b/paddlenlp/experimental/transformers/qwen/modeling.py index c032a85e7ce2..5e0e755a2aeb 100644 --- a/paddlenlp/experimental/transformers/qwen/modeling.py +++ b/paddlenlp/experimental/transformers/qwen/modeling.py @@ -18,7 +18,6 @@ import paddle from paddle import nn from paddle.nn.quant import weight_quantize -from paddlenlp_ops import fused_get_rotary_embedding, get_padding_offset from paddlenlp.experimental.transformers.fused_transformer_layers import ( FusedMultiTransformerBase, @@ -239,6 +238,7 @@ def set_state_dict(self, state_dict): def remove_padding(self, input_ids, seq_lens_this_time): cum_offsets_now = paddle.cumsum(paddle.max(seq_lens_this_time) - seq_lens_this_time) token_num = paddle.sum(seq_lens_this_time) + from paddlenlp_ops import get_padding_offset ids_remove_padding, cum_offsets, padding_offset = get_padding_offset( input_ids, cum_offsets_now, token_num, seq_lens_this_time ) @@ -324,6 +324,7 @@ def forward( if not is_decoder and pre_caches is not None: position_offset = 128 + from paddlenlp_ops import fused_get_rotary_embedding new_rope = fused_get_rotary_embedding( input_ids, position_ids, self.head_dim_shape_tensor, position_offset, theta, True )