Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dynamicheart committed Apr 20, 2024
1 parent 1293619 commit e388ed6
Showing 1 changed file with 27 additions and 19 deletions.
46 changes: 27 additions & 19 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def forward(self, hidden_states):
if get_env_device() == "npu":
return core.eager._run_custom_op("rms_norm_npu", hidden_states, self.weight, self.variance_epsilon)[0]
elif get_env_device() == "xpu":
import paddle_xpu_nn
import paddle_xpu_nn # noqa: F821

Check warning on line 417 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L416-L417

Added lines #L416 - L417 were not covered by tests

return paddle_xpu_nn.xpu_rms_norm(hidden_states, self.weight, self.variance_epsilon)[0]

Check warning on line 419 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L419

Added line #L419 was not covered by tests
return rms_norm_fused(hidden_states, self.weight, self.variance_epsilon)
Expand Down Expand Up @@ -599,18 +599,23 @@ def __init__(self, config):
RowParallelLinear = RowSequenceParallelLinear
else:
if get_env_device() == "xpu":
import paddle_xpu # noqa: F821
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 602 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L602

Added line #L602 was not covered by tests
ColumnParallelLinear as XPUColumnParallelLinear,
)
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 605 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L605

Added line #L605 was not covered by tests
RowParallelLinear as XPURowParallelLinear,
)

ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear
RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear
ColumnParallelLinear = XPUColumnParallelLinear
RowParallelLinear = XPURowParallelLinear

Check warning on line 610 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L609-L610

Added lines #L609 - L610 were not covered by tests
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
import paddle_xpu # noqa: F821
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401

Check warning on line 616 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L616

Added line #L616 was not covered by tests

Linear = paddle_xpu.layers.nn.Linear
Linear = XPULinear

Check warning on line 618 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L618

Added line #L618 was not covered by tests
else:
Linear = nn.Linear

Expand Down Expand Up @@ -722,12 +727,8 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
)

self.use_fused_rope = config.use_fused_rope
if self.use_fused_rope and get_env_device() != "npu":
if (
"gpu" not in paddle.device.get_device()
or "xpu" not in paddle.device.get_device()
or fused_rotary_position_embedding is None
):
if self.use_fused_rope and get_env_device() not in ["npu", "xpu"]:
if "gpu" not in paddle.device.get_device() or fused_rotary_position_embedding is None:
warnings.warn(
"Enable fuse rope in the config, but fuse rope is not available. "
"Will disable fuse rope. Try using latest gpu version of Paddle."
Expand Down Expand Up @@ -756,18 +757,23 @@ def __init__(self, config: LlamaConfig, layerwise_recompute: bool = False):
RowParallelLinear = RowSequenceParallelLinear
else:
if get_env_device() == "xpu":
import paddle_xpu # noqa: F821
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 760 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L760

Added line #L760 was not covered by tests
ColumnParallelLinear as XPUColumnParallelLinear,
)
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 763 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L763

Added line #L763 was not covered by tests
RowParallelLinear as XPURowParallelLinear,
)

ColumnParallelLinear = paddle_xpu.layers.nn.ColumnParallelLinear # noqa: F821
RowParallelLinear = paddle_xpu.layers.nn.RowParallelLinear # noqa: F821
ColumnParallelLinear = XPUColumnParallelLinear
RowParallelLinear = XPURowParallelLinear

Check warning on line 768 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L767-L768

Added lines #L767 - L768 were not covered by tests
else:
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
RowParallelLinear = fleet.meta_parallel.RowParallelLinear

if get_env_device() == "xpu":
import paddle_xpu # noqa: F821
from paddle_xpu.layers.nn import Linear as XPULinear # noqa: F401

Check warning on line 774 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L774

Added line #L774 was not covered by tests

Linear = paddle_xpu.layers.nn.Linear
Linear = XPULinear

Check warning on line 776 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L776

Added line #L776 was not covered by tests
else:
Linear = nn.Linear

Expand Down Expand Up @@ -1773,9 +1779,11 @@ def __init__(self, config: LlamaConfig):
if self.weight.is_distributed:
self.weight.split_axis = 1
if get_env_device() == "xpu":
import paddle_xpu
from paddle_xpu.layers.nn import ( # noqa: F401

Check warning on line 1782 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1782

Added line #L1782 was not covered by tests
parallel_matmul as xpu_parallel_matmul,
)

self.xpu_parallel_matmul = paddle_xpu.layers.nn.parallel_matmul()
self.xpu_parallel_matmul = xpu_parallel_matmul()

Check warning on line 1786 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1786

Added line #L1786 was not covered by tests

def forward(self, hidden_states, tensor_parallel_output=None):
if self.config.sequence_parallel:
Expand Down

0 comments on commit e388ed6

Please sign in to comment.