From 12936191ed5b2b9caf0ce5090e9bc79afcd9ccff Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Wed, 17 Apr 2024 15:27:02 +0800 Subject: [PATCH 1/8] [XPU] llama add xpu support --- llm/run_pretrain.py | 7 ++ paddlenlp/transformers/llama/modeling.py | 107 +++++++++++++++++++---- 2 files changed, 97 insertions(+), 17 deletions(-) diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index d0df32321e18..b79677ef1a45 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -46,6 +46,7 @@ ) from paddlenlp.utils.batch_sampler import DistributedBatchSampler from paddlenlp.utils.log import logger +from paddlenlp.utils.tools import get_env_device def add_start_docstrings(*docstr): @@ -483,6 +484,12 @@ def main(): config.num_attention_heads % config.sep_parallel_degree == 0 ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" + if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + print("Final pre-training config:", config) # Set the dtype for loading model diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index d70e63ffa484..00f444096a27 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -413,6 +413,10 @@ def forward(self, hidden_states): if self.config.use_fused_rms_norm: if get_env_device() == "npu": return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0] + elif get_env_device() == "xpu": + import paddle_xpu_nn + + return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) if paddle.in_dynamic_mode(): @@ -582,12 +586,33 @@ def __init__(self, config): ColumnParallelLinear = MC2ColumnSeqParallelLinear RowParallelLinear = MC2RowSeqParallelLinear + elif get_env_device() == "xpu": + from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 + XPUColumnSequenceParallelLinear, + XPURowSequenceParallelLinear, + ) + + ColumnParallelLinear = XPUColumnSequenceParallelLinear + RowParallelLinear = XPURowSequenceParallelLinear else: ColumnParallelLinear = ColumnSequenceParallelLinear RowParallelLinear = RowSequenceParallelLinear else: - ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear - RowParallelLinear = fleet.meta_parallel.RowParallelLinear + if get_env_device() == "xpu": + import paddle_xpu # noqa: F821 + + ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear + RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear + else: + ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear + RowParallelLinear = fleet.meta_parallel.RowParallelLinear + + if get_env_device() == "xpu": + import paddle_xpu # noqa: F821 + + Linear = paddle_xpu.layers.nn.Linear + else: + Linear = nn.Linear if config.tensor_parallel_degree > 1: if config.fuse_attention_ffn: @@ -619,15 +644,24 @@ def __init__(self, config): ) else: if config.fuse_attention_ffn: - self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) + self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False) else: - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) + self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False) + self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False) def forward(self, x): if self.fuse_attention_ffn: + # FIXME(yangjianbang): use paddle's native swiglu + if get_env_device() == "xpu": + import paddle_xpu_nn # noqa: F821 + + out = self.gate_up_fused_proj(x) + out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True) + out = self.down_proj(out) + return out + x = swiglu(self.gate_up_fused_proj(x)) else: x = swiglu(self.gate_proj(x), self.up_proj(x)) @@ -689,7 +723,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): self.use_fused_rope = config.use_fused_rope if self.use_fused_rope and get_env_device() != "npu": - if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: + if ( + "gpu" not in paddle.device.get_device() + or "xpu" not in paddle.device.get_device() + or fused_rotary_position_embedding is None + ): warnings.warn( "Enable fuse rope in the config, but fuse rope is not available. " "Will disable fuse rope. Try using latest gpu version of Paddle." @@ -705,12 +743,33 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): ColumnParallelLinear = MC2ColumnSeqParallelLinear RowParallelLinear = MC2RowSeqParallelLinear + elif get_env_device() == "xpu": + from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 + XPUColumnSequenceParallelLinear, + XPURowSequenceParallelLinear, + ) + + ColumnParallelLinear = XPUColumnSequenceParallelLinear + RowParallelLinear = XPURowSequenceParallelLinear else: ColumnParallelLinear = ColumnSequenceParallelLinear RowParallelLinear = RowSequenceParallelLinear else: - ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear - RowParallelLinear = fleet.meta_parallel.RowParallelLinear + if get_env_device() == "xpu": + import paddle_xpu # noqa: F821 + + ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear # noqa: F821 + RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear # noqa: F821 + else: + ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear + RowParallelLinear = fleet.meta_parallel.RowParallelLinear + + if get_env_device() == "xpu": + import paddle_xpu # noqa: F821 + + Linear = paddle_xpu.layers.nn.Linear + else: + Linear = nn.Linear if config.tensor_parallel_degree > 1: if self.fuse_attention_qkv: @@ -741,12 +800,12 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): gather_output=False, ) else: - self.k_proj = nn.Linear( + self.k_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) - self.v_proj = nn.Linear( + self.v_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, @@ -754,23 +813,23 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): else: if self.fuse_attention_qkv: - self.qkv_proj = nn.Linear( + self.qkv_proj = Linear( self.hidden_size, self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) else: - self.q_proj = nn.Linear( + self.q_proj = Linear( self.hidden_size, self.hidden_size, bias_attr=False, ) - self.k_proj = nn.Linear( + self.k_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, ) - self.v_proj = nn.Linear( + self.v_proj = Linear( self.hidden_size, self.config.num_key_value_heads * self.head_dim, bias_attr=False, @@ -784,7 +843,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): input_is_parallel=True, ) else: - self.o_proj = nn.Linear( + self.o_proj = Linear( self.hidden_size, self.hidden_size, bias_attr=False, @@ -1428,6 +1487,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16") expanded_attn_mask = expanded_attn_mask.astype("float16") expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) + elif get_env_device() == "xpu": + x = paddle.to_tensor(0.0, dtype=dtype) + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) + expanded_attn_mask = expanded_attn_mask.astype(dtype) + expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) else: expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype) return expanded_attn_mask @@ -1708,6 +1772,10 @@ def __init__(self, config: LlamaConfig): self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False if self.weight.is_distributed: self.weight.split_axis = 1 + if get_env_device() == "xpu": + import paddle_xpu + + self.xpu_parallel_matmul = paddle_xpu.layers.nn.parallel_matmul() def forward(self, hidden_states, tensor_parallel_output=None): if self.config.sequence_parallel: @@ -1721,7 +1789,12 @@ def forward(self, hidden_states, tensor_parallel_output=None): if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output - logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) + if get_env_device() == "xpu": + logits = self.xpu_parallel_matmul( + hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training + ) + else: + logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output) return logits From e388ed6b894a987c0955a6b947430b52efda957b Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Fri, 19 Apr 2024 11:26:47 +0800 Subject: [PATCH 2/8] fix --- paddlenlp/transformers/llama/modeling.py | 46 ++++++++++++++---------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 00f444096a27..74071a001041 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -414,7 +414,7 @@ def forward(self, hidden_states): if get_env_device() == "npu": return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0] elif get_env_device() == "xpu": - import paddle_xpu_nn + import paddle_xpu_nn # noqa: F821 return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) @@ -599,18 +599,23 @@ def __init__(self, config): RowParallelLinear = RowSequenceParallelLinear else: if get_env_device() == "xpu": - import paddle_xpu # noqa: F821 + from paddle_xpu.layers.nn import ( # noqa: F401 + ColumnParallelLinear as XPUColumnParallelLinear, + ) + from paddle_xpu.layers.nn import ( # noqa: F401 + RowParallelLinear as XPURowParallelLinear, + ) - ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear - RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear + ColumnParallelLinear = XPUColumnParallelLinear + RowParallelLinear = XPURowParallelLinear else: ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear if get_env_device() == "xpu": - import paddle_xpu # noqa: F821 + from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 - Linear = paddle_xpu.layers.nn.Linear + Linear = XPULinear else: Linear = nn.Linear @@ -722,12 +727,8 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): ) self.use_fused_rope = config.use_fused_rope - if self.use_fused_rope and get_env_device() != "npu": - if ( - "gpu" not in paddle.device.get_device() - or "xpu" not in paddle.device.get_device() - or fused_rotary_position_embedding is None - ): + if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]: + if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None: warnings.warn( "Enable fuse rope in the config, but fuse rope is not available. " "Will disable fuse rope. Try using latest gpu version of Paddle." @@ -756,18 +757,23 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): RowParallelLinear = RowSequenceParallelLinear else: if get_env_device() == "xpu": - import paddle_xpu # noqa: F821 + from paddle_xpu.layers.nn import ( # noqa: F401 + ColumnParallelLinear as XPUColumnParallelLinear, + ) + from paddle_xpu.layers.nn import ( # noqa: F401 + RowParallelLinear as XPURowParallelLinear, + ) - ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear # noqa: F821 - RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear # noqa: F821 + ColumnParallelLinear = XPUColumnParallelLinear + RowParallelLinear = XPURowParallelLinear else: ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear if get_env_device() == "xpu": - import paddle_xpu # noqa: F821 + from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 - Linear = paddle_xpu.layers.nn.Linear + Linear = XPULinear else: Linear = nn.Linear @@ -1773,9 +1779,11 @@ def __init__(self, config: LlamaConfig): if self.weight.is_distributed: self.weight.split_axis = 1 if get_env_device() == "xpu": - import paddle_xpu + from paddle_xpu.layers.nn import ( # noqa: F401 + parallel_matmul as xpu_parallel_matmul, + ) - self.xpu_parallel_matmul = paddle_xpu.layers.nn.parallel_matmul() + self.xpu_parallel_matmul = xpu_parallel_matmul() def forward(self, hidden_states, tensor_parallel_output=None): if self.config.sequence_parallel: From 41421f4464b345e4b2ae56f44d6d2d98dbb01d32 Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Mon, 22 Apr 2024 14:47:24 +0800 Subject: [PATCH 3/8] use try import --- llm/run_pretrain.py | 9 +- paddlenlp/transformers/llama/modeling.py | 119 ++++++++++++++--------- 2 files changed, 81 insertions(+), 47 deletions(-) diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index b79677ef1a45..8263d80b2974 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -485,10 +485,13 @@ def main(): ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" if get_env_device() == "xpu" and training_args.gradient_accumulation_steps > 1: - from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 + try: + from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401 - LinearConfig.enable_accumulate_steps_opt() - LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + LinearConfig.enable_accumulate_steps_opt() + LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) + except ImportError: + pass print("Final pre-training config:", config) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 74071a001041..3d0bac45d985 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -414,9 +414,12 @@ def forward(self, hidden_states): if get_env_device() == "npu": return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0] elif get_env_device() == "xpu": - import paddle_xpu_nn # noqa: F821 + try: + import paddle_xpu_nn # noqa: F821 - return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] + except ImportError: + pass return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) if paddle.in_dynamic_mode(): @@ -587,35 +590,46 @@ def __init__(self, config): ColumnParallelLinear = MC2ColumnSeqParallelLinear RowParallelLinear = MC2RowSeqParallelLinear elif get_env_device() == "xpu": - from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 - XPUColumnSequenceParallelLinear, - XPURowSequenceParallelLinear, - ) + try: + from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 + XPUColumnSequenceParallelLinear, + XPURowSequenceParallelLinear, + ) - ColumnParallelLinear = XPUColumnSequenceParallelLinear - RowParallelLinear = XPURowSequenceParallelLinear + ColumnParallelLinear = XPUColumnSequenceParallelLinear + RowParallelLinear = XPURowSequenceParallelLinear + except ImportError: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear else: ColumnParallelLinear = ColumnSequenceParallelLinear RowParallelLinear = RowSequenceParallelLinear else: if get_env_device() == "xpu": - from paddle_xpu.layers.nn import ( # noqa: F401 - ColumnParallelLinear as XPUColumnParallelLinear, - ) - from paddle_xpu.layers.nn import ( # noqa: F401 - RowParallelLinear as XPURowParallelLinear, - ) + try: + from paddle_xpu.layers.nn import ( # noqa: F401 + ColumnParallelLinear as XPUColumnParallelLinear, + ) + from paddle_xpu.layers.nn import ( # noqa: F401 + RowParallelLinear as XPURowParallelLinear, + ) - ColumnParallelLinear = XPUColumnParallelLinear - RowParallelLinear = XPURowParallelLinear + ColumnParallelLinear = XPUColumnParallelLinear + RowParallelLinear = XPURowParallelLinear + except ImportError: + ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear + RowParallelLinear = fleet.meta_parallel.RowParallelLinear else: ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear if get_env_device() == "xpu": - from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 + try: + from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 - Linear = XPULinear + Linear = XPULinear + except ImportError: + Linear = nn.Linear else: Linear = nn.Linear @@ -660,12 +674,15 @@ def forward(self, x): if self.fuse_attention_ffn: # FIXME(yangjianbang): use paddle's native swiglu if get_env_device() == "xpu": - import paddle_xpu_nn # noqa: F821 + try: + import paddle_xpu_nn # noqa: F821 - out = self.gate_up_fused_proj(x) - out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True) - out = self.down_proj(out) - return out + out = self.gate_up_fused_proj(x) + out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True) + out = self.down_proj(out) + return out + except ImportError: + pass x = swiglu(self.gate_up_fused_proj(x)) else: @@ -745,35 +762,46 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): ColumnParallelLinear = MC2ColumnSeqParallelLinear RowParallelLinear = MC2RowSeqParallelLinear elif get_env_device() == "xpu": - from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 - XPUColumnSequenceParallelLinear, - XPURowSequenceParallelLinear, - ) + try: + from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 + XPUColumnSequenceParallelLinear, + XPURowSequenceParallelLinear, + ) - ColumnParallelLinear = XPUColumnSequenceParallelLinear - RowParallelLinear = XPURowSequenceParallelLinear + ColumnParallelLinear = XPUColumnSequenceParallelLinear + RowParallelLinear = XPURowSequenceParallelLinear + except ImportError: + ColumnParallelLinear = ColumnSequenceParallelLinear + RowParallelLinear = RowSequenceParallelLinear else: ColumnParallelLinear = ColumnSequenceParallelLinear RowParallelLinear = RowSequenceParallelLinear else: if get_env_device() == "xpu": - from paddle_xpu.layers.nn import ( # noqa: F401 - ColumnParallelLinear as XPUColumnParallelLinear, - ) - from paddle_xpu.layers.nn import ( # noqa: F401 - RowParallelLinear as XPURowParallelLinear, - ) + try: + from paddle_xpu.layers.nn import ( # noqa: F401 + ColumnParallelLinear as XPUColumnParallelLinear, + ) + from paddle_xpu.layers.nn import ( # noqa: F401 + RowParallelLinear as XPURowParallelLinear, + ) - ColumnParallelLinear = XPUColumnParallelLinear - RowParallelLinear = XPURowParallelLinear + ColumnParallelLinear = XPUColumnParallelLinear + RowParallelLinear = XPURowParallelLinear + except ImportError: + ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear + RowParallelLinear = fleet.meta_parallel.RowParallelLinear else: ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear RowParallelLinear = fleet.meta_parallel.RowParallelLinear if get_env_device() == "xpu": - from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 + try: + from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 - Linear = XPULinear + Linear = XPULinear + except: + Linear = nn.Linear else: Linear = nn.Linear @@ -1779,11 +1807,14 @@ def __init__(self, config: LlamaConfig): if self.weight.is_distributed: self.weight.split_axis = 1 if get_env_device() == "xpu": - from paddle_xpu.layers.nn import ( # noqa: F401 - parallel_matmul as xpu_parallel_matmul, - ) + try: + from paddle_xpu.layers.nn import ( # noqa: F401 + parallel_matmul as xpu_parallel_matmul, + ) - self.xpu_parallel_matmul = xpu_parallel_matmul() + self.xpu_parallel_matmul = xpu_parallel_matmul() + except ImportError: + self.xpu_parallel_matmul = None def forward(self, hidden_states, tensor_parallel_output=None): if self.config.sequence_parallel: @@ -1797,7 +1828,7 @@ def forward(self, hidden_states, tensor_parallel_output=None): if tensor_parallel_output is None: tensor_parallel_output = self.config.tensor_parallel_output - if get_env_device() == "xpu": + if get_env_device() == "xpu" and self.xpu_parallel_matmul is not None: logits = self.xpu_parallel_matmul( hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output, training=self.training ) From e9a4b871d6127ba7bfbfc92b43c96219a0731b20 Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Mon, 22 Apr 2024 16:02:25 +0800 Subject: [PATCH 4/8] fix --- paddlenlp/transformers/llama/modeling.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 3d0bac45d985..71255445057d 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -682,7 +682,9 @@ def forward(self, x): out = self.down_proj(out) return out except ImportError: - pass + gate_out, up_out = paddle.chunk(self.gate_up_fused_proj(x), chunks=2, axis=-1) + out = self.down_proj(F.silu(gate_out) * up_out) + return out x = swiglu(self.gate_up_fused_proj(x)) else: From 2a8c6399dbe1aadb1447f758ff1176ecb301e3ed Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Wed, 24 Apr 2024 18:08:15 +0800 Subject: [PATCH 5/8] refine --- llm/run_pretrain.py | 1 + paddlenlp/transformers/linear_utils.py | 55 +++++++++++ paddlenlp/transformers/llama/modeling.py | 114 +++-------------------- 3 files changed, 69 insertions(+), 101 deletions(-) create mode 100644 paddlenlp/transformers/linear_utils.py diff --git a/llm/run_pretrain.py b/llm/run_pretrain.py index 8263d80b2974..7196f52eea6d 100644 --- a/llm/run_pretrain.py +++ b/llm/run_pretrain.py @@ -491,6 +491,7 @@ def main(): LinearConfig.enable_accumulate_steps_opt() LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps) except ImportError: + # It's OK, not use accumulate_steps optimization pass print("Final pre-training config:", config) diff --git a/paddlenlp/transformers/linear_utils.py b/paddlenlp/transformers/linear_utils.py new file mode 100644 index 000000000000..d318258ca457 --- /dev/null +++ b/paddlenlp/transformers/linear_utils.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from paddle.distributed.fleet.meta_parallel import ( + ColumnParallelLinear, + RowParallelLinear, +) +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + ColumnSequenceParallelLinear, + RowSequenceParallelLinear, +) +from paddle.nn import Linear + +from paddlenlp.utils.tools import get_env_device + +if get_env_device() == "npu" and int(os.getenv("FLAGS_NPU_MC2", 0)): + from paddlenlp.transformers.mc2_seqence_parallel_linear import ( + MC2ColumnSeqParallelLinear, + MC2RowSeqParallelLinear, + ) + + ColumnSequenceParallelLinear = MC2ColumnSeqParallelLinear # noqa: F811 + RowSequenceParallelLinear = MC2RowSeqParallelLinear # noqa: F811 + +if get_env_device() == "xpu": + try: + from paddle_xpu.layers.nn import ColumnParallelLinear as XPUColumnParallelLinear + from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 + from paddle_xpu.layers.nn import RowParallelLinear as XPURowParallelLinear + from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 + XPUColumnSequenceParallelLinear, + XPURowSequenceParallelLinear, + ) + + Linear = XPULinear # noqa: F811 + ColumnParallelLinear = XPUColumnParallelLinear # noqa: F811 + RowParallelLinear = XPURowParallelLinear # noqa: F811 + ColumnSequenceParallelLinear = XPUColumnSequenceParallelLinear # noqa: F811 + RowSequenceParallelLinear = XPURowSequenceParallelLinear # noqa: F811 + except ImportError: + # It's OK, just use paddle's Linear layers + pass diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index 71255445057d..29312c19869d 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -70,6 +70,8 @@ def swiglu(x, y=None): from paddlenlp.utils.log import logger from paddlenlp.utils.tools import get_env_device +from .. import linear_utils +from ..linear_utils import Linear from ..segment_parallel_utils import ReshardLayer from .configuration import ( LLAMA_PRETRAINED_INIT_CONFIGURATION, @@ -419,7 +421,9 @@ def forward(self, hidden_states): return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0] except ImportError: - pass + raise NotImplementedError( + f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature" + ) return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon) if paddle.in_dynamic_mode(): @@ -581,57 +585,11 @@ def __init__(self, config): self.fuse_attention_ffn = config.fuse_attention_ffn if config.sequence_parallel: - if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)): - from paddlenlp.transformers.mc2_seqence_parallel_linear import ( - MC2ColumnSeqParallelLinear, - MC2RowSeqParallelLinear, - ) - - ColumnParallelLinear = MC2ColumnSeqParallelLinear - RowParallelLinear = MC2RowSeqParallelLinear - elif get_env_device() == "xpu": - try: - from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 - XPUColumnSequenceParallelLinear, - XPURowSequenceParallelLinear, - ) - - ColumnParallelLinear = XPUColumnSequenceParallelLinear - RowParallelLinear = XPURowSequenceParallelLinear - except ImportError: - ColumnParallelLinear = ColumnSequenceParallelLinear - RowParallelLinear = RowSequenceParallelLinear - else: - ColumnParallelLinear = ColumnSequenceParallelLinear - RowParallelLinear = RowSequenceParallelLinear - else: - if get_env_device() == "xpu": - try: - from paddle_xpu.layers.nn import ( # noqa: F401 - ColumnParallelLinear as XPUColumnParallelLinear, - ) - from paddle_xpu.layers.nn import ( # noqa: F401 - RowParallelLinear as XPURowParallelLinear, - ) - - ColumnParallelLinear = XPUColumnParallelLinear - RowParallelLinear = XPURowParallelLinear - except ImportError: - ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear - RowParallelLinear = fleet.meta_parallel.RowParallelLinear - else: - ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear - RowParallelLinear = fleet.meta_parallel.RowParallelLinear - - if get_env_device() == "xpu": - try: - from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 - - Linear = XPULinear - except ImportError: - Linear = nn.Linear + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear else: - Linear = nn.Linear + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear if config.tensor_parallel_degree > 1: if config.fuse_attention_ffn: @@ -755,57 +713,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False): self.use_fused_rope = False if config.sequence_parallel: - if is_mc2_valid and int(os.getenv("FLAGS_NPU_MC2", 0)): - from paddlenlp.transformers.mc2_seqence_parallel_linear import ( - MC2ColumnSeqParallelLinear, - MC2RowSeqParallelLinear, - ) - - ColumnParallelLinear = MC2ColumnSeqParallelLinear - RowParallelLinear = MC2RowSeqParallelLinear - elif get_env_device() == "xpu": - try: - from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 - XPUColumnSequenceParallelLinear, - XPURowSequenceParallelLinear, - ) - - ColumnParallelLinear = XPUColumnSequenceParallelLinear - RowParallelLinear = XPURowSequenceParallelLinear - except ImportError: - ColumnParallelLinear = ColumnSequenceParallelLinear - RowParallelLinear = RowSequenceParallelLinear - else: - ColumnParallelLinear = ColumnSequenceParallelLinear - RowParallelLinear = RowSequenceParallelLinear - else: - if get_env_device() == "xpu": - try: - from paddle_xpu.layers.nn import ( # noqa: F401 - ColumnParallelLinear as XPUColumnParallelLinear, - ) - from paddle_xpu.layers.nn import ( # noqa: F401 - RowParallelLinear as XPURowParallelLinear, - ) - - ColumnParallelLinear = XPUColumnParallelLinear - RowParallelLinear = XPURowParallelLinear - except ImportError: - ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear - RowParallelLinear = fleet.meta_parallel.RowParallelLinear - else: - ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear - RowParallelLinear = fleet.meta_parallel.RowParallelLinear - - if get_env_device() == "xpu": - try: - from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 - - Linear = XPULinear - except: - Linear = nn.Linear + ColumnParallelLinear = linear_utils.ColumnSequenceParallelLinear + RowParallelLinear = linear_utils.RowSequenceParallelLinear else: - Linear = nn.Linear + ColumnParallelLinear = linear_utils.ColumnParallelLinear + RowParallelLinear = linear_utils.RowParallelLinear if config.tensor_parallel_degree > 1: if self.fuse_attention_qkv: From 40c23a5b30cff6f26f65bc44cc3472417d8296a2 Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Thu, 25 Apr 2024 20:15:18 +0800 Subject: [PATCH 6/8] refine --- paddlenlp/transformers/linear_utils.py | 48 ++++++++++++++---------- paddlenlp/transformers/llama/modeling.py | 7 +--- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/paddlenlp/transformers/linear_utils.py b/paddlenlp/transformers/linear_utils.py index 8e96effc4e4a..e18c68d3fab0 100644 --- a/paddlenlp/transformers/linear_utils.py +++ b/paddlenlp/transformers/linear_utils.py @@ -12,38 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle.distributed.fleet.meta_parallel import ( - ColumnParallelLinear, - RowParallelLinear, -) -from paddle.distributed.fleet.utils.sequence_parallel_utils import ( - ColumnSequenceParallelLinear, - RowSequenceParallelLinear, -) -from paddle.nn import Linear +""" +This file is used for replacing Paddle's native Linear implementations with vendors' customized implementations +""" + +import paddle.distributed.fleet.meta_parallel as mpu +from paddle import nn +from paddle.distributed.fleet.utils import sequence_parallel_utils from paddlenlp.utils.tools import get_env_device from .mc2_parallel_linear import MC2ColumnSeqParallelLinear, MC2RowSeqParallelLinear -if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: - ColumnSequenceParallelLinear = MC2ColumnSeqParallelLinear # noqa: F811 - RowSequenceParallelLinear = MC2RowSeqParallelLinear # noqa: F811 +Linear = nn.Linear +ColumnParallelLinear = mpu.ColumnParallelLinear +RowParallelLinear = mpu.RowParallelLinear +ColumnSequenceParallelLinear = sequence_parallel_utils.ColumnSequenceParallelLinear +RowSequenceParallelLinear = sequence_parallel_utils.RowSequenceParallelLinear + +if get_env_device() == "npu": + if MC2ColumnSeqParallelLinear is not None and MC2RowSeqParallelLinear is not None: + ColumnSequenceParallelLinear = MC2ColumnSeqParallelLinear + RowSequenceParallelLinear = MC2RowSeqParallelLinear elif get_env_device() == "xpu": try: from paddle_xpu.layers.nn import ColumnParallelLinear as XPUColumnParallelLinear - from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401 + from paddle_xpu.layers.nn import Linear as XPULinear from paddle_xpu.layers.nn import RowParallelLinear as XPURowParallelLinear - from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401 + from paddle_xpu.layers.nn.sequence_parallel import ( XPUColumnSequenceParallelLinear, XPURowSequenceParallelLinear, ) - Linear = XPULinear # noqa: F811 - ColumnParallelLinear = XPUColumnParallelLinear # noqa: F811 - RowParallelLinear = XPURowParallelLinear # noqa: F811 - ColumnSequenceParallelLinear = XPUColumnSequenceParallelLinear # noqa: F811 - RowSequenceParallelLinear = XPURowSequenceParallelLinear # noqa: F811 + Linear = XPULinear + ColumnParallelLinear = XPUColumnParallelLinear + RowParallelLinear = XPURowParallelLinear + ColumnSequenceParallelLinear = XPUColumnSequenceParallelLinear + RowSequenceParallelLinear = XPURowSequenceParallelLinear except ImportError: - # It's OK, just use paddle's Linear layers + # If paddle_xpu is not installed, just use Paddle's native Linear implementations pass +else: + # By default, use Paddle's native Linear implementations + pass diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index aee1313ba8df..e3c68e78f994 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1427,12 +1427,7 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values else: expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) # Convert bool attention_mask to float attention mask, which will be added to attention_scores later - if get_env_device() == "npu": - x = paddle.to_tensor(0.0, dtype="float16") - y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16") - expanded_attn_mask = expanded_attn_mask.astype("float16") - expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) - elif get_env_device() == "xpu": + if get_env_device() in ["npu", "xpu"]: x = paddle.to_tensor(0.0, dtype=dtype) y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) expanded_attn_mask = expanded_attn_mask.astype(dtype) From a3935fdea4ddcb481094182e43be42181964cd36 Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Thu, 25 Apr 2024 20:19:51 +0800 Subject: [PATCH 7/8] refine --- paddlenlp/transformers/linear_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paddlenlp/transformers/linear_utils.py b/paddlenlp/transformers/linear_utils.py index e18c68d3fab0..de1a0f886b79 100644 --- a/paddlenlp/transformers/linear_utils.py +++ b/paddlenlp/transformers/linear_utils.py @@ -20,10 +20,12 @@ from paddle import nn from paddle.distributed.fleet.utils import sequence_parallel_utils +from paddlenlp.transformers.mc2_parallel_linear import ( + MC2ColumnSeqParallelLinear, + MC2RowSeqParallelLinear, +) from paddlenlp.utils.tools import get_env_device -from .mc2_parallel_linear import MC2ColumnSeqParallelLinear, MC2RowSeqParallelLinear - Linear = nn.Linear ColumnParallelLinear = mpu.ColumnParallelLinear RowParallelLinear = mpu.RowParallelLinear From 6e0316a7ac9d8a174d8dd880d14da3bf4110015b Mon Sep 17 00:00:00 2001 From: Jianbang Yang Date: Fri, 26 Apr 2024 15:44:43 +0800 Subject: [PATCH 8/8] refine --- paddlenlp/transformers/llama/modeling.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/paddlenlp/transformers/llama/modeling.py b/paddlenlp/transformers/llama/modeling.py index e3c68e78f994..aee1313ba8df 100755 --- a/paddlenlp/transformers/llama/modeling.py +++ b/paddlenlp/transformers/llama/modeling.py @@ -1427,7 +1427,12 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values else: expanded_attn_mask = _make_causal_mask(input_shape, past_key_values_length=past_key_values_length) # Convert bool attention_mask to float attention mask, which will be added to attention_scores later - if get_env_device() in ["npu", "xpu"]: + if get_env_device() == "npu": + x = paddle.to_tensor(0.0, dtype="float16") + y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16") + expanded_attn_mask = expanded_attn_mask.astype("float16") + expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype) + elif get_env_device() == "xpu": x = paddle.to_tensor(0.0, dtype=dtype) y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype) expanded_attn_mask = expanded_attn_mask.astype(dtype)