-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add paddle.diag API, diag_v2 OP and CUDA kernel #26414
Conversation
Thanks for your contribution! |
5318f3e
to
d1eda82
Compare
b3a5b98
to
967de62
Compare
@@ -914,3 +913,92 @@ def meshgrid(*args, **kwargs): | |||
type='meshgrid', inputs={'X': list(args)}, outputs={'Out': out}) | |||
|
|||
return out | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在旧的API加一下deprecated标注
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -914,3 +913,92 @@ def meshgrid(*args, **kwargs): | |||
type='meshgrid', inputs={'X': list(args)}, outputs={'Out': out}) | |||
|
|||
return out | |||
|
|||
|
|||
def diag(x, offset=0, padding_value=0, name=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在旧的API加一下deprecated标注
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
class TestDiagV2Op(OpTest): | ||
def setUp(self): | ||
self.op_type = "diag_v2" | ||
self.x = np.random.rand(10, 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
random.rand.seed 固定一下
paddle/fluid/operators/diag_v2_op.h
Outdated
set_padding_value(dev_ctx, out, static_cast<T>(padding_value)); | ||
|
||
auto x_length = x_dims[0]; | ||
const int x_stride = ComputeStride(0, x_dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果是是使用const,建议使用const &
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs,APIs
Describe
支持如下功能: