Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【prim】Maximum grad #51006

Merged
merged 51 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
0effa15
refresh
heyanru01 Feb 21, 2023
aaa3c64
refresh
heyanru01 Feb 21, 2023
eb3fb7f
compat
heyanru01 Feb 26, 2023
858bc20
register
heyanru01 Feb 27, 2023
71b6d09
testop
heyanru01 Feb 28, 2023
a13169b
testcinn
heyanru01 Feb 28, 2023
7386b74
fix
heyanru01 Feb 28, 2023
ff1a478
fix
heyanru01 Feb 28, 2023
01aa2d6
fix
heyanru01 Feb 28, 2023
a0d1593
fox
heyanru01 Mar 1, 2023
9273948
cast
heyanru01 Mar 1, 2023
a425e91
fix
heyanru01 Mar 1, 2023
4ea4faf
cast
heyanru01 Mar 1, 2023
44e92c8
fix
heyanru01 Mar 1, 2023
7567197
type
heyanru01 Mar 1, 2023
c3c2ac1
fix
heyanru01 Mar 2, 2023
aab1434
fix
heyanru01 Mar 2, 2023
853c4c2
fix
heyanru01 Mar 2, 2023
2a05ac5
out
heyanru01 Mar 2, 2023
7fc3a4d
fix
heyanru01 Mar 3, 2023
bbeed2c
cast
heyanru01 Mar 3, 2023
dc8e66f
fix
heyanru01 Mar 4, 2023
55efae6
fix
heyanru01 Mar 4, 2023
301feb2
fix
heyanru01 Mar 5, 2023
c060acc
broad
heyanru01 Mar 5, 2023
493e7bc
broad
heyanru01 Mar 5, 2023
b9e0733
broad
heyanru01 Mar 5, 2023
146c941
fix
heyanru01 Mar 5, 2023
778de9e
fix
heyanru01 Mar 5, 2023
ecab4cc
fix
heyanru01 Mar 5, 2023
ce76053
fix
heyanru01 Mar 5, 2023
3dc7b30
fix
heyanru01 Mar 5, 2023
22bfd33
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
heyanru01 Mar 6, 2023
2c66d99
broad
heyanru01 Mar 6, 2023
cef793f
broad
heyanru01 Mar 6, 2023
4c45353
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
heyanru01 Mar 6, 2023
d5cd3b8
numel
heyanru01 Mar 6, 2023
7d5b6c5
fix
heyanru01 Mar 6, 2023
14f9ff2
fix
heyanru01 Mar 7, 2023
7b1a77b
fix
heyanru01 Mar 8, 2023
3699a5e
fix
heyanru01 Mar 8, 2023
a7aba2e
fix
heyanru01 Mar 8, 2023
aff4d92
cinn
heyanru01 Mar 8, 2023
02fda8d
fix
heyanru01 Mar 9, 2023
323e10a
fix
heyanru01 Mar 9, 2023
4fa996b
fix
heyanru01 Mar 10, 2023
5f405d7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
heyanru01 Mar 10, 2023
517c6e6
fix
heyanru01 Mar 10, 2023
f88c3ed
fix
heyanru01 Mar 10, 2023
29ad702
fix
heyanru01 Mar 10, 2023
a352e34
fix
heyanru01 Mar 11, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion paddle/fluid/operators/elementwise/elementwise_max_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ limitations under the License. */
#include <string>

#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -68,6 +71,35 @@ class ElementwiseFMaxOpMaker : public ElementwiseOpMaker {
}
};

class ElementwiseMaxCompositeGradOpMaker
: public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;

public:
void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor y = this->GetSingleForwardInput("Y");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
paddle::Tensor dy = this->GetSingleInputGrad("Y");
auto* dy_ptr = this->GetOutputPtr(&dy);
std::string dy_name = this->GetOutputName(dy);
int axis = static_cast<int>(this->Attr<int>("axis"));
PADDLE_ENFORCE_EQ(
axis,
-1,
phi::errors::InvalidArgument(
"We only support axis = -1 in composite maximum_grad but we got: ",
axis));
VLOG(6) << "Runing maximum_grad composite func";
prim::maximum_grad<prim::DescTensor>(x, y, out_grad, axis, dx_ptr, dy_ptr);
this->RecoverOutputName(dx, dx_name);
this->RecoverOutputName(dy, dy_name);
}
};

template <typename T>
class ElementwiseMaxGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
Expand Down Expand Up @@ -112,7 +144,8 @@ REGISTER_OPERATOR(elementwise_max,
ops::ElementwiseMaxOpMaker,
ops::ElementwiseOpInferVarType,
ops::ElementwiseMaxGradOpMaker<paddle::framework::OpDesc>,
ops::ElementwiseMaxGradOpMaker<paddle::imperative::OpBase>);
ops::ElementwiseMaxGradOpMaker<paddle::imperative::OpBase>,
ops::ElementwiseMaxCompositeGradOpMaker);

REGISTER_OPERATOR(elementwise_max_grad, ops::ElementwiseOpGrad);

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,5 @@
- pad
- cumsum
- put_along_axis
- greater_than
- less_equal
Original file line number Diff line number Diff line change
Expand Up @@ -898,5 +898,51 @@ void erf_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
}
}

template <typename T>
void maximum_grad(const Tensor& x,
const Tensor& y,
const Tensor& out_grad,
int axis,
Tensor* x_grad,
Tensor* y_grad) {
if (x_grad) {
auto x_tmp = cast<T>(greater_than<T>(x, y), out_grad.dtype());
auto dx_res = out_grad * x_tmp;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure logic the same with kernel when x equal to y

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test x = y in #51568

if (y.dims() != x.dims()) {
// Maybe need reduce here
auto reduce_dim = get_reduce_dims(x.dims(), y.dims());
if (!reduce_dim.size()) {
set_output<T>(dx_res, x_grad);
} else {
auto dx_reduce_res =
dx_res.sum(phi::vectorize(reduce_dim), x.dtype(), false);
auto dx_tmp = reshape<T>(dx_reduce_res, phi::vectorize(x.dims()));
set_output<T>(dx_tmp, x_grad);
}
} else {
set_output<T>(dx_res, x_grad);
}
}

if (y_grad) {
auto y_tmp = cast<T>(less_equal<T>(x, y), out_grad.dtype());
auto dy_res = out_grad * y_tmp;
if (x.dims() != y.dims()) {
// Maybe need reduce here
phi::DDim reduce_dim = get_reduce_dims(y.dims(), x.dims());
if (!reduce_dim.size()) {
set_output<T>(dy_res, y_grad);
} else {
auto dy_reduce_res =
dy_res.sum(phi::vectorize(reduce_dim), y.dtype(), false);
auto dy_tmp = reshape<T>(dy_reduce_res, phi::vectorize(y.dims()));
set_output<T>(dy_tmp, y_grad);
}
} else {
set_output<T>(dy_res, y_grad);
}
}
}

} // namespace prim
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/prim/utils/static/desc_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ class DescTensor : public phi::ExtendedTensor,
return dims_;
}

int64_t numel() const override { return product(dims()); }

DataType dtype() const override {
return paddle::framework::TransToPhiDataType(desc_ptr_->GetDataType());
}
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@
param: [x, y]
kernel :
func : maximum_grad
composite : maximum_grad(x, y, out_grad, -1, x_grad, y_grad)
heyanru01 marked this conversation as resolved.
Show resolved Hide resolved

- backward_op : mean_all_grad
forward : mean_all(Tensor x) -> Tensor(out)
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ set(TEST_CINN_OPS
test_elementwise_mul_op
test_gather_nd_op
test_elementwise_pow_op
test_elementwise_max_op
test_transpose_op
test_reshape_op)

Expand Down
67 changes: 59 additions & 8 deletions python/paddle/fluid/tests/unittests/test_elementwise_max_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class TestElementwiseOp(OpTest):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
Expand All @@ -42,25 +44,58 @@ def test_check_output(self):

def test_check_grad_normal(self):
if hasattr(self, 'attrs'):
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
if self.attrs['axis'] == -1:
self.check_grad(
['X', 'Y'], 'Out', check_eager=False, check_prim=True
)
else:
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
else:
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
self.check_grad(
['X', 'Y'], 'Out', check_eager=True, check_prim=True
)

def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")
)
if hasattr(self, 'attrs') and self.attrs['axis'] != -1:
self.check_grad(
['Y'],
'Out',
max_relative_error=0.005,
no_grad_set=set("X"),
)
else:
self.check_grad(
['Y'],
'Out',
max_relative_error=0.005,
no_grad_set=set("X"),
check_prim=True,
)

def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')
)
if hasattr(self, 'attrs') and self.attrs['axis'] != -1:
self.check_grad(
['X'],
'Out',
max_relative_error=0.005,
no_grad_set=set('Y'),
)
else:
self.check_grad(
['X'],
'Out',
max_relative_error=0.005,
no_grad_set=set('Y'),
check_prim=True,
)


class TestElementwiseMaxOp_ZeroDim1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.uniform(0.1, 1, []).astype("float64")
y = np.random.uniform(0.1, 1, []).astype("float64")
self.inputs = {'X': x, 'Y': y}
Expand All @@ -71,6 +106,8 @@ class TestElementwiseMaxOp_ZeroDim2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
y = np.random.uniform(0.1, 1, []).astype("float64")
self.inputs = {'X': x, 'Y': y}
Expand All @@ -81,6 +118,8 @@ class TestElementwiseMaxOp_ZeroDim3(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.uniform(0.1, 1, []).astype("float64")
y = np.random.uniform(0.1, 1, [13, 17]).astype("float64")
self.inputs = {'X': x, 'Y': y}
Expand All @@ -99,6 +138,8 @@ class TestElementwiseBF16Op(OpTest):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
self.dtype = np.uint16
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
Expand All @@ -120,6 +161,7 @@ def test_check_output(self):

def test_check_grad_normal(self):
if hasattr(self, 'attrs'):
# check_prim=False, bfloat16 is not supported in `less_equal`
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
else:
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
Expand All @@ -138,6 +180,8 @@ class TestElementwiseMaxOp_scalar(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.random_integers(-5, 5, [2, 3, 20]).astype("float64")
y = np.array([0.5]).astype("float64")
self.inputs = {'X': x, 'Y': y}
Expand All @@ -148,6 +192,8 @@ class TestElementwiseMaxOp_Vector(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
self.enable_cinn = False
x = np.random.random((100,)).astype("float64")
sgn = np.random.choice([-1, 1], (100,)).astype("float64")
y = x + sgn * np.random.uniform(0.1, 1, (100,)).astype("float64")
Expand All @@ -159,6 +205,7 @@ class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
x = np.random.uniform(0.5, 1, (100, 5, 2)).astype(np.float64)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float64)
y = x[:, 0, 0] + sgn * np.random.uniform(1, 2, (100,)).astype(
Expand All @@ -178,6 +225,7 @@ class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
x = np.random.uniform(0.5, 1, (2, 100, 3)).astype(np.float64)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float64)
y = x[0, :, 0] + sgn * np.random.uniform(1, 2, (100,)).astype(
Expand All @@ -197,6 +245,7 @@ class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
x = np.random.uniform(0.5, 1, (1, 3, 100)).astype(np.float64)
sgn = np.random.choice([-1, 1], (100,)).astype(np.float64)
y = x[0, 0, :] + sgn * np.random.uniform(1, 2, (100,)).astype(
Expand All @@ -215,6 +264,7 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
x = np.random.uniform(0.5, 1, (2, 50, 2, 1)).astype(np.float64)
sgn = np.random.choice([-1, 1], (50, 2)).astype(np.float64)
y = x[0, :, :, 0] + sgn * np.random.uniform(1, 2, (50, 2)).astype(
Expand All @@ -234,6 +284,7 @@ class TestElementwiseMaxOp_broadcast_4(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_max"
self.python_api = paddle.maximum
self.prim_op_type = "prim"
x = np.random.uniform(0.5, 1, (2, 3, 4, 5)).astype(np.float64)
sgn = np.random.choice([-1, 1], (2, 3, 1, 5)).astype(np.float64)
y = x + sgn * np.random.uniform(1, 2, (2, 3, 1, 5)).astype(np.float64)
Expand Down