-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 7th No.41】NO.41 为 Paddle 代码转换工具新增 API 转换规则(第 8 组) #495
【Hackathon 7th No.41】NO.41 为 Paddle 代码转换工具新增 API 转换规则(第 8 组) #495
Conversation
Thanks for your contribution! |
而且GPU那个测试还是通过的 |
你好,我发现写了辅助函数后,ci里面并不能检测到,但是我在本地是可以跑通的,请问这是怎么回事? @luotao1 |
@@ -751,6 +751,147 @@ def get_paddle_nodes(self, args, kwargs): | |||
for i in range(1, len(new_args)): | |||
code = "{}({}, {})".format(self.get_paddle_api(), code, new_args[i]) | |||
return ast.parse(code).body | |||
|
|||
|
|||
class StudentTMatcher(BaseMatcher): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是什么差异,需要重写Matcher。我也没有在映射文档中看到任何描述这两个API的差异,不能直接对上用GenericMatcher吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paconvert/api_matcher.py
Outdated
shape_list = list(range(x.ndim)) | ||
shape_list[-1], shape_list[-2] = shape_list[-2], shape_list[-1] | ||
y = x.transpose(perm=shape_list) | ||
return x @ y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return x @ x.T 就行吧
paconvert/api_matcher.py
Outdated
x = x.tril(-1) + x.diagonal(axis1=-2, axis2=-1).exp().diag_embed() | ||
shape_list = list(range(x.ndim)) | ||
shape_list[-1], shape_list[-2] = shape_list[-2], shape_list[-1] | ||
y = x.transpose(perm=shape_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个用x.T就行
return API_TEMPLATE | ||
|
||
|
||
class LKJCholeskyMatcher(BaseMatcher): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是什么差异,需要重写Matcher。我也没有在映射文档中看到任何描述这两个API的差异,不能直接对上用GenericMatcher吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有,输出的维度不一样,需要补上一个
paconvert/api_matcher.py
Outdated
|
||
|
||
class Is_InferenceMatcher(BaseMatcher): | ||
def generate_aux_code(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个应该不需要辅助函数吧:
torch.is_inference(x) -> x.stop_gradient
x.is_inference() -> x.stop_gradient
也不需要not吧,两者都为True时对应不需要梯度计算
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没错没错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
没错没错
那你文档又为啥还是not的?
return code | ||
|
||
|
||
class DistributionsConstrainMatcher(BaseMatcher): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是什么差异,需要重写Matcher。我也没有在映射文档中看到任何描述这两个API的差异,不能直接对上用GenericMatcher吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
差异是一个是__call__,一个用的是call,是为了封装
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
差异是一个是__call__,一个用的是call,是为了封装
你在文档里需要写清楚差异
""" | ||
import torch | ||
input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) | ||
result = input.cauchy_(median=0, sigma=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测一个case:全部不指定关键字?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test_case2就是全都不指定关键字的
""" | ||
import torch | ||
input = torch.tensor([-0.6341, -1.4208, -1.0900, 0.5826]) | ||
result = input.random_(0, to=5, generator=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测一个case:全部指定关键字?关键字乱序?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@decade-afk 目前问题挺多的,你的映射文档和Matcher是很多diff对不齐的。先写清楚映射文档,再写Matcher,并保证两者是完全一致无diff的。 |
已经改了,请review |
@@ -5067,6 +5194,36 @@ def generate_code(self, kwargs): | |||
return code | |||
|
|||
|
|||
class RpcRemoteMatcher(BaseMatcher): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
因为rpc_async使用to_wait获取值,这个我改改文档
return code | ||
|
||
|
||
class DistributionsConstrainMatcher(BaseMatcher): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我改改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不包装一下,直接调用check会报错
def generate_aux_code(self): | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
import paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你用func包class,包这两层的意义呢,直接定义一个class不更简洁吗
paconvert/api_matcher.py
Outdated
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
import paddle | ||
def StudentT_Aux_Func(df, loc, scale): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你用func包class,包这两层的意义呢,直接定一个class不更简洁吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经改好了,docs那里也同步修改了。
73af39d
to
cc0062e
Compare
paconvert/api_matcher.py
Outdated
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
remote_obj = paddle.distributed.rpc.rpc_async({}) | ||
paddle_aux.rpc_remote(remote_obj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
都已经用上辅助函数了,建议直接写成一行。
remote_obj = paddle.distributed.rpc.rpc_async({})
你直接放上面就行
好的好的,已经修改了,ci也跑通了的,请review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Docs
PaddlePaddle/docs#6887
PR APIs