-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 7th No.38】为 Paddle 代码转换工具新增 API 转换规则(第5组) #487
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 写法上注意优化下,目前有些过于hard code
- 转换后的代码注意简洁下,尽可能一行对一行,不要加太多其他的内容及判断等等,这会导致转换后的代码比较丑
- 丰富下测试case,以下四种情况的测试case必须全部包含:
- 传入所有参数且全部指定关键字
- 传入所有参数且全部不指定关键字
- 改变关键字顺序
- 默认参数均不指定
@@ -1450,6 +1450,14 @@ def generate_code(self, kwargs): | |||
return GenericMatcher.generate_code(self, kwargs) | |||
|
|||
|
|||
class ScatterReduceMatcher(BaseMatcher): | |||
def generate_code(self, kwargs): | |||
reduce_mapping = {'"""sum"""': '"add"', '"""prod"""': '"multiply"'} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个可能得用辅助函数形式写,因为传入的reduce可能是个变量
例如:
reduce_type = 'sum'
torch.scatter_reduce(reduce=reduce_type, ...)
paconvert/api_matcher.py
Outdated
class CartesianProdMatcher(BaseMatcher): | ||
def get_paddle_nodes(self, args, kwargs): | ||
new_args = self.parse_args(args) | ||
code = "paddle.cartesian_prod([ {}".format(new_args[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 如果输入的是变量形式,应该无法支持,参考 Chain_MatmulMatcher
- new_args返回的就是一个list,是不是直接format到字符串,目前这样写代码有点丑
paconvert/api_matcher.py
Outdated
return "{}.cast(paddle.float64).pow({})".format( | ||
self.paddleClass, kwargs["exponent"] | ||
self.write_aux_code() | ||
_from_dtype = kwargs["from_"][3:-3] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这么写代码太硬了,一定要这样写死切片吗,如果输入类型是变量,会出错吧,例如:
dtype1=torch.float32
dtype2=torch.float64
torch.can_cast(dtype1, dtype2)
paconvert/api_matcher.py
Outdated
|
||
|
||
class PositiveMatcher(BaseMatcher): | ||
def generate_code(self, kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个这么多行,还是用辅助函数来写吧,不然很容易破坏原代码的结构
paconvert/api_matcher.py
Outdated
) | ||
else: | ||
if "out" not in kwargs: | ||
return "paddle.pow({}.cast(paddle.float64), {}.cast(paddle.float64) if isinstance({}, paddle.Tensor) else {})".format( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exponent也必须cast吗,只用cast input就行吧,这个应该自带类型提升
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该不行,paddle.pow(a, b)
中,a
和 b
都是 tensor 的时候数据类型必须一致,否则会报错
@inaomIIsfarell CI未通过且出现冲突,需要修改 |
PR Docs
PaddlePaddle/docs#6885
PR APIs
ps: 如果有朋友用
vscode
修改paconvert/api_mapping.json
和paconvert/api_alias_mapping.json
,并用pre-commit
commit 时,出现 json 文件diff 全覆盖的情况,可能 是 pre-commit 的问题