Skip to content

Commit

Permalink
[XPU] llama add xpu support
Browse files Browse the repository at this point in the history
  • Loading branch information
dynamicheart committed Apr 17, 2024
1 parent 0790824 commit 160c79d
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 20 deletions.
6 changes: 6 additions & 0 deletions llm/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,12 @@ def main():
config.num_attention_heads % config.sep_parallel_degree == 0
), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}"

if paddle.is_compiled_with_xpu() and training_args.gradient_accumulation_steps > 1:
from paddle_xpu.layers.nn.linear import LinearConfig # noqa: F401

LinearConfig.enable_accumulate_steps_opt()
LinearConfig.set_accumulate_steps(training_args.gradient_accumulation_steps)

print("Final pre-training config:", config)

# Set the dtype for loading model
Expand Down
117 changes: 97 additions & 20 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def assign_kv_heads(num_kv_heads: int, num_gpus: int):
return assignment_list


def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
def parallel_matmul(matmul_op, x: Tensor, y: Tensor, tensor_parallel_output=True):
is_fleet_init = True
tensor_parallel_degree = 1
try:
Expand All @@ -192,15 +192,15 @@ def parallel_matmul(x: Tensor, y: Tensor, tensor_parallel_output=True):
if is_fleet_init and tensor_parallel_degree > 1 and y_is_distributed:
# if not running under distributed.launch, it will raise AttributeError: 'Fleet' object has no attribute '_hcg'
input_parallel = paddle.distributed.collective._c_identity(x, group=model_parallel_group)
logits = paddle.matmul(input_parallel, y, transpose_y=False)
logits = matmul_op(input_parallel, y, transpose_y=False)

Check warning on line 195 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L195

Added line #L195 was not covered by tests

if tensor_parallel_output:
return logits

return paddle.distributed.collective._c_concat(logits, group=model_parallel_group)

else:
logits = paddle.matmul(x, y, transpose_y=False)
logits = matmul_op(x, y, transpose_y=False)
return logits


Expand Down Expand Up @@ -413,6 +413,10 @@ def forward(self, hidden_states):
if self.config.use_fused_rms_norm:
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
import paddle_xpu_nn

Check warning on line 417 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L416-L417

Added lines #L416 - L417 were not covered by tests

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]

Check warning on line 419 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L419

Added line #L419 was not covered by tests
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)

if paddle.in_dynamic_mode():
Expand Down Expand Up @@ -582,12 +586,33 @@ def __init__(self, config):

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401

Check warning on line 590 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L589-L590

Added lines #L589 - L590 were not covered by tests
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

ColumnParallelLinear = XPUColumnSequenceParallelLinear
RowParallelLinear = XPURowSequenceParallelLinear

Check warning on line 596 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L595-L596

Added lines #L595 - L596 were not covered by tests
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
if get_env_device() == "xpu":
import paddle_xpu # noqa: F821

Check warning on line 602 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L602

Added line #L602 was not covered by tests

ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear # noqa: F821
RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear # noqa: F821

Check warning on line 605 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L604-L605

Added lines #L604 - L605 were not covered by tests
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
import paddle_xpu # noqa: F821

Check warning on line 611 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L611

Added line #L611 was not covered by tests

Linear = paddle_xpu.layers.nn.Linear # noqa: F821

Check warning on line 613 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L613

Added line #L613 was not covered by tests
else:
Linear = nn.Linear

if config.tensor_parallel_degree > 1:
if config.fuse_attention_ffn:
Expand Down Expand Up @@ -619,15 +644,24 @@ def __init__(self, config):
)
else:
if config.fuse_attention_ffn:
self.gate_up_fused_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)
self.gate_up_fused_proj = Linear(self.hidden_size, self.intermediate_size * 2, bias_attr=False)

Check warning on line 647 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L647

Added line #L647 was not covered by tests
else:
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.gate_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)
self.up_proj = Linear(self.hidden_size, self.intermediate_size, bias_attr=False)

self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias_attr=False)
self.down_proj = Linear(self.intermediate_size, self.hidden_size, bias_attr=False)

def forward(self, x):
if self.fuse_attention_ffn:
# FIXME(yangjianbang): use paddle's native swiglu
if get_env_device() == "xpu":
import paddle_xpu_nn # noqa: F821

Check warning on line 658 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L657-L658

Added lines #L657 - L658 were not covered by tests

out = self.gate_up_fused_proj(x)
out = paddle_xpu_nn.xpu_swiglu(out, axis=-1, turn=True)
out = self.down_proj(out)
return out

Check warning on line 663 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L660-L663

Added lines #L660 - L663 were not covered by tests

x = swiglu(self.gate_up_fused_proj(x))
else:
x = swiglu(self.gate_proj(x), self.up_proj(x))
Expand Down Expand Up @@ -689,7 +723,11 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() != "npu":
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
if (

Check warning on line 726 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L726

Added line #L726 was not covered by tests
"gpu" not in paddle.device.get_device()
or "xpu" not in paddle.device.get_device()
or fused_rotary_position_embedding is None
):
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
"Will disable fuse rope. Try using latest gpu version of Paddle."
Expand All @@ -705,12 +743,33 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):

ColumnParallelLinear = MC2ColumnSeqParallelLinear
RowParallelLinear = MC2RowSeqParallelLinear
elif get_env_device() == "xpu":
from paddle_xpu.layers.nn.sequence_parallel import ( # noqa: F401

Check warning on line 747 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L746-L747

Added lines #L746 - L747 were not covered by tests
XPUColumnSequenceParallelLinear,
XPURowSequenceParallelLinear,
)

ColumnParallelLinear = XPUColumnSequenceParallelLinear
RowParallelLinear = XPURowSequenceParallelLinear

Check warning on line 753 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L752-L753

Added lines #L752 - L753 were not covered by tests
else:
ColumnParallelLinear = ColumnSequenceParallelLinear
RowParallelLinear = RowSequenceParallelLinear
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
if get_env_device() == "xpu":
import paddle_xpu # noqa: F821

Check warning on line 759 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L759

Added line #L759 was not covered by tests

ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear # noqa: F821
RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear # noqa: F821

Check warning on line 762 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L761-L762

Added lines #L761 - L762 were not covered by tests
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
import paddle_xpu # noqa: F821

Check warning on line 768 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L768

Added line #L768 was not covered by tests

Linear = paddle_xpu.layers.nn.Linear # noqa: F821

Check warning on line 770 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L770

Added line #L770 was not covered by tests
else:
Linear = nn.Linear

if config.tensor_parallel_degree > 1:
if self.fuse_attention_qkv:
Expand Down Expand Up @@ -741,36 +800,36 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
gather_output=False,
)
else:
self.k_proj = nn.Linear(
self.k_proj = Linear(

Check warning on line 803 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L803

Added line #L803 was not covered by tests
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(

Check warning on line 808 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L808

Added line #L808 was not covered by tests
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)

else:
if self.fuse_attention_qkv:
self.qkv_proj = nn.Linear(
self.qkv_proj = Linear(

Check warning on line 816 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L816

Added line #L816 was not covered by tests
self.hidden_size,
self.hidden_size + 2 * self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
else:
self.q_proj = nn.Linear(
self.q_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
)
self.k_proj = nn.Linear(
self.k_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
)
self.v_proj = nn.Linear(
self.v_proj = Linear(
self.hidden_size,
self.config.num_key_value_heads * self.head_dim,
bias_attr=False,
Expand All @@ -784,7 +843,7 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
input_is_parallel=True,
)
else:
self.o_proj = nn.Linear(
self.o_proj = Linear(
self.hidden_size,
self.hidden_size,
bias_attr=False,
Expand Down Expand Up @@ -1428,6 +1487,11 @@ def _prepare_decoder_attention_mask(attention_mask, input_shape, past_key_values
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float16")
expanded_attn_mask = expanded_attn_mask.astype("float16")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)

Check warning on line 1494 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1491-L1494

Added lines #L1491 - L1494 were not covered by tests
else:
expanded_attn_mask = paddle.where(expanded_attn_mask, 0.0, paddle.finfo(dtype).min).astype(dtype)
return expanded_attn_mask
Expand Down Expand Up @@ -1708,6 +1772,13 @@ def __init__(self, config: LlamaConfig):
self.weight.is_distributed = True if (vocab_size != config.vocab_size) else False
if self.weight.is_distributed:
self.weight.split_axis = 1
if paddle.is_compiled_with_xpu():
from paddle_xpu.layers.nn import xpu_matmul # noqa: F401

Check warning on line 1776 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1776

Added line #L1776 was not covered by tests

self._xpu_matmul = xpu_matmul()
self.matmul_op = self._xpu_matmul.forward

Check warning on line 1779 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1778-L1779

Added lines #L1778 - L1779 were not covered by tests
else:
self.matmul_op = paddle.matmul

def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sequence_parallel:
Expand All @@ -1721,7 +1792,13 @@ def forward(self, hidden_states, tensor_parallel_output=None):
if tensor_parallel_output is None:
tensor_parallel_output = self.config.tensor_parallel_output

logits = parallel_matmul(hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
matmul_op = self.matmul_op
if paddle.is_compiled_with_xpu():
from functools import partial

Check warning on line 1797 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1797

Added line #L1797 was not covered by tests

matmul_op = partial(matmul_op, training=self.training)

Check warning on line 1799 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1799

Added line #L1799 was not covered by tests

logits = parallel_matmul(matmul_op, hidden_states, self.weight, tensor_parallel_output=tensor_parallel_output)
return logits


Expand Down

0 comments on commit 160c79d

Please sign in to comment.