Skip to content

Commit

Permalink
# This is a combination of 2 commits.
Browse files Browse the repository at this point in the history
# This is the 1st commit message:

merge

# This is the commit message PaddlePaddle#2:

修改旧CHECK宏
  • Loading branch information
Lans1ot committed Aug 5, 2024
1 parent c535b95 commit 9840148
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 13 deletions.
30 changes: 24 additions & 6 deletions paddle/phi/kernels/fusion/xpu/weight_only_linear_kernel_xpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
const int32_t arch,
const int32_t group_size,
DenseTensor* out) {
PD_CHECK(weight_dtype == "int8",
"WeightOnlyLinearKernel xpu just support int8 weight only");
PADDLE_ENFORCE_EQ(
weight_dtype,
"int8",
phi::errors::Fatal(
"WeightOnlyLinearKernel xpu just support int8 weight only"));
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto xpu_ctx = static_cast<const phi::XPUContext*>(&dev_ctx);
dev_ctx.template Alloc<T>(out);
Expand Down Expand Up @@ -55,14 +58,21 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
false,
weight_dtype == "int8" ? 127.f : 7.f,
0.f);
PD_CHECK(r == 0, "scale failed");
PADDLE_ENFORCE_EQ(r,
0,
phi::errors::Fatal
"scale failed, scale related variable `r` is %d",
r);
r = baidu::xpu::api::cast_v2<XPUType, float>(
xpu_ctx->x_context(),
reinterpret_cast<const XPUType*>(
max_value_fp16.data<phi::dtype::float16>()),
max_value.data<float>(),
max_value.numel());
PD_CHECK(r == 0, "cast_v2 failed");
PADDLE_ENFORCE_EQ(r,
0,
phi::errors::Fatal(
"cast_v2 failed, related variable `r` is %d", r));
} else if (weight_scale.dtype() == phi::DataType::FLOAT32) {
r = baidu::xpu::api::scale(xpu_ctx->x_context(),
weight_scale.data<float>(),
Expand All @@ -71,7 +81,10 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
false,
weight_dtype == "int8" ? 127.f : 7.f,
0.f);
PD_CHECK(r == 0, "scale failed");
PADDLE_ENFORCE_EQ(
r,
0,
phi::errors::Fatal("scale failed, related variable `r` is %d", r));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Only support that weight scale as type float32 ot float16."));
Expand Down Expand Up @@ -115,7 +128,12 @@ void WeightOnlyLinearKernel(const Context& dev_ctx,
: nullptr,
baidu::xpu::api::Activation_t::LINEAR,
max_value.data<float>());
PD_CHECK(r == 0, "baidu::xpu::api::gpt_fc_fusion failed.");
PADDLE_ENFORCE_EQ(
r,
0,
phi::errors::Fatal("baidu::xpu::api::gpt_fc_fusion failed, related "
"variable `r` is %d",
r));
} else if (weight_dtype == "int4") {
PD_THROW("only support int8 weight only now");
}
Expand Down
39 changes: 32 additions & 7 deletions test/cpp/eager/data_structure_tests/grad_node_info_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ TEST(GradNodeInfo, GradSlotMeta) {
auto grad_slot = egr::GradSlotMeta();
VLOG(6) << "Set SetStopGradient";
grad_slot.SetStopGradient();
CHECK(grad_slot.IsStopGradient() == true);
PADDLE_ENFORCE_EQ(grad_slot.IsStopGradient(),
true,
phi::errors::Fatal("`grad_slot.IsStopGradient()` should be "
"true, please check related function"));
}

void TestGradNodeBase(bool is_remove_gradient_hook) {
Expand Down Expand Up @@ -80,8 +83,16 @@ void TestGradNodeBase(bool is_remove_gradient_hook) {
grad_test_node0->InputMeta()[1][0].GetTensorMeta().dtype,
meta.dtype,
phi::errors::InvalidArgument("Dtype of input tensor mismatch."));
CHECK(grad_test_node0->OutputMeta()[0][0].IsStopGradient());
CHECK(grad_test_node0->OutputMeta()[1][0].IsStopGradient());
PADDLE_ENFORCE_EQ(
grad_test_node0->OutputMeta()[0][0].IsStopGradient(),
true,
phi::errors::Fatal("`grad_test_node0->OutputMeta()[0][0].IsStopGradient()"
"` should be true, please related function"));
PADDLE_ENFORCE_EQ(
grad_test_node0->OutputMeta()[1][0].IsStopGradient(),
true,
phi::errors::Fatal("`grad_test_node0->OutputMeta()[1][0].IsStopGradient()"
"` should be true, please related function"));
PADDLE_ENFORCE_EQ(
grad_test_node0->OutputMeta()[0][0].GetTensorMeta().dtype,
meta.dtype,
Expand All @@ -99,7 +110,11 @@ void TestGradNodeBase(bool is_remove_gradient_hook) {
grad_test_node2->OutputMeta()[0].size(),
0UL,
phi::errors::InvalidArgument("Size of output not greater than 0."));
CHECK(grad_test_node2->OutputMeta()[0][0].IsStopGradient() == false);
PADDLE_ENFORCE_EQ(
grad_test_node2->OutputMeta()[0][0].IsStopGradient(),
false,
phi::errors::Fatal("`grad_test_node2->OutputMeta()[0][0].IsStopGradient()"
"` should be false, please check related function"));
PADDLE_ENFORCE_EQ(
grad_test_node2->OutputMeta()[0].size(),
1UL,
Expand Down Expand Up @@ -160,9 +175,15 @@ TEST(GradNodeInfo, Edge) {
auto auto_grad1 = std::make_shared<egr::AutogradMeta>();
VLOG(6) << "Test Construct Edge";
egr::Edge edge0 = egr::Edge();
CHECK(edge0.IsInitialized() == false);
PADDLE_ENFORCE_EQ(edge0.IsInitialized(),
false,
phi::errors::Fatal("`edge0.IsInitialized()` should be "
"false, please check related function"));
egr::Edge edge1 = egr::Edge(grad_test_node0, size_t(0), size_t(0));
CHECK(edge1.IsInitialized() == true);
PADDLE_ENFORCE_EQ(edge1.IsInitialized(),
true,
phi::errors::Fatal("`edge1.IsInitialized()` should be "
"true, please check related function"));
egr::Edge edge2 =
egr::Edge(grad_test_node0, std::make_pair(size_t(1), size_t(0)));
VLOG(6) << "Test Set Edge's Grad Node";
Expand All @@ -175,7 +196,11 @@ TEST(GradNodeInfo, Edge) {
2UL,
phi::errors::InvalidArgument("Size of input mismatch. Expected 2."));
std::vector<egr::AutogradMeta*> metas = {auto_grad1.get()};
CHECK(grad_node->InputMeta()[0][0].IsStopGradient() == true);
PADDLE_ENFORCE_EQ(
grad_node->InputMeta()[0][0].IsStopGradient(),
true,
phi::errors::Fatal("`grad_node->InputMeta()[0][0].IsStopGradient()` "
"should be true, please check related function"));
VLOG(6) << "Test Get/Set Edge Rank Info";
PADDLE_ENFORCE_EQ(
edge2.GetEdgeRankInfo().first,
Expand Down

0 comments on commit 9840148

Please sign in to comment.