-
Notifications
You must be signed in to change notification settings - Fork 55
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 7th No.38】为 Paddle 代码转换工具新增 API 转换规则(第5组) #496
Conversation
Thanks for your contribution! |
参考下这个:#495 (comment) |
@inaomIIsfarell CI未通过,请先自查问题 |
报错问题在本地无法复现,
# This file is generated by PaConvert ToolKit, please Don't edit it!
import paddle
def can_cast(from_, to):
can_cast_dict = {
'bfloat16': {
'bfloat16': True,
'float16': True,
'float32': True,
'float64': True,
'complex64': True,
'complex128': True,
'uint8': False,
'int8': False,
'int16': False,
'int32': False,
'int64': False,
'bool': False
},
'float16': {
'bfloat16': True,
'float16': True,
'float32': True,
'float64': True,
'complex64': True,
'complex128': True,
'uint8': False,
'int8': False,
'int16': False,
'int32': False,
'int64': False,
'bool': False,
},
'float32': {
'bfloat16': True,
'float16': True,
'float32': True,
'float64': True,
'complex64': True,
'complex128': True,
'uint8': False,
'int8': False,
'int16': False,
'int32': False,
'int64': False,
'bool': False,
},
'float64': {
'bfloat16': True,
'float16': True,
'float32': True,
'float64': True,
'complex64': True,
'complex128': True,
'uint8': False,
'int8': False,
'int16': False,
'int32': False,
'int64': False,
'bool': False,
},
'complex64': {
'bfloat16': False,
'float16': False,
'float32': False,
'float64': False,
'complex64': True,
'complex128': True,
'uint8': False,
'int8': False,
'int16': False,
'int32': False,
'int64': False,
'bool': False,
},
'complex128': {
'bfloat16': False,
'float16': False,
'float32': False,
'float64': False,
'complex64': True,
'complex128': True,
'uint8': False,
'int8': False,
'int16': False,
'int32': False,
'int64': False,
'bool': False,
},
'uint8': {
'bfloat16': True,
'float16': True,
'float32': True,
'float64': True,
'complex64': True,
'complex128': True,
'uint8': True,
'int8': True,
'int16': True,
'int32': True,
'int64': True,
'bool': False,
},
'int8': {
'bfloat16': True,
'float16': True,
'float32': True,
'float64': True,
'complex64': True,
'complex128': True,
'uint8': True,
'int8': True,
'int16': True,
'int32': True,
'int64': True,
'bool': False,
},
'int16': {
'bfloat16': True,
'float16': True,
'float32': True,
'float64': True,
'complex64': True,
'complex128': True,
'uint8': True,
'int8': True,
'int16': True,
'int32': True,
'int64': True,
'bool': False,
},
'int32': {
'bfloat16': True,
'float16': True,
'float32': True,
'float64': True,
'complex64': True,
'complex128': True,
'uint8': True,
'int8': True,
'int16': True,
'int32': True,
'int64': True,
'bool': False,
},
'int64': {
'bfloat16': True,
'float16': True,
'float32': True,
'float64': True,
'complex64': True,
'complex128': True,
'uint8': True,
'int8': True,
'int16': True,
'int32': True,
'int64': True,
'bool': False,
},
'bool': {
'bfloat16': True,
'float16': True,
'float32': True,
'float64': True,
'complex64': True,
'complex128': True,
'uint8': True,
'int8': True,
'int16': True,
'int32': True,
'int64': True,
'bool': True,
}
}
return can_cast_dict[from_][to]
setattr(paddle, 'can_cast', can_cast) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个定位到问题了:是由于CI中torch的版本不够新,目前已经更新了版本
下面这个问题更新一下,然后重新提交一遍触发CI重跑
paconvert/api_matcher.py
Outdated
} | ||
} | ||
return can_cast_dict[from_][to] | ||
setattr(paddle, 'can_cast', can_cast) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个既然最后调用的是:
paddle_aux.can_cast
这里就不需要去设置一个paddle.can_cast
了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 已去除setattr
PR-CI-GPU-UnitTest
中出现 comment 中相同的报错问题PR-CI-UnitTest
中出现我未修改内容部分的报错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ptal @zhwesky2010
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
补充:本地删除本地project文件夹后,首次运行会报错,但如果project文件夹已经生成则本地可以正常运行
paconvert/api_matcher.py
Outdated
@@ -3531,6 +3562,13 @@ def generate_code(self, kwargs): | |||
return code | |||
|
|||
|
|||
class CartesianProdMatcher(BaseMatcher): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
参考一下CreateMatcher吧,这个应该有多种输入的用法
paconvert/api_matcher.py
Outdated
|
||
def generate_code(self, kwargs): | ||
self.write_aux_code() | ||
if "input" in kwargs and kwargs["input"] is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个判断不是一个意思吗,只需要判断 if "input" in kwargs 就行
paconvert/api_matcher.py
Outdated
""" | ||
def get_exponent(exponent): | ||
return exponent.cast(paddle.float64) if isinstance(exponent, paddle.Tensor) else exponent | ||
setattr(paddle, "get_exponent", get_exponent) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要这样setattr,所有写法都改一下
paconvert/api_matcher.py
Outdated
def generate_aux_code(self): | ||
CODE_TEMPLATE = textwrap.dedent( | ||
""" | ||
def get_exponent(exponent): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
命名规范一点,这个应该是 cast_exponent ?
return CODE_TEMPLATE | ||
|
||
def generate_code(self, kwargs): | ||
self.write_aux_code() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kwargs["input"] is not None
没必要判断,没有这种用法。input不可能输入None,out有可能按默认的就是None
paconvert/api_matcher.py
Outdated
) | ||
if "out" in kwargs and kwargs["out"] is not None: | ||
code = "paddle.assign({}, {})".format(pow_expression, kwargs["out"]) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个分支也是没必要,你可以把pow_expression命名为code
paconvert/api_matcher.py
Outdated
def generate_code(self, kwargs): | ||
self.write_aux_code() | ||
if "input" in kwargs: | ||
pow_expression = "paddle.pow({}.cast(paddle.float64), paddle_aux.cast_exponent({}))".format( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code = ...
obj.run(pytorch_code, ["result"]) | ||
|
||
|
||
def test_case_3(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测一下 可变参数的用法
c = (a, b)
*c
这种
关键字参数的用法
tensors = (a, b)
要与Matcher的分支一一对应,如果没有这种用法,在Matcher里可以删掉这个分支了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch对这个api似乎不支持指定关键字参数的用法
paconvert/api_matcher.py
Outdated
@@ -3534,6 +3553,25 @@ def generate_code(self, kwargs): | |||
return code | |||
|
|||
|
|||
class CartesianProdMatcher(BaseMatcher): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那你这个能否复用 ScalableVarMatcher
呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
复用 ScalableVarMatcher
时,在tests/test_cartesian_prod.py
中test_case_2()
中报错ValueError: Expect a 1D vector, but got shape []
,但是代码能正常转换,我当时写的时候没定位到这个错误,这才写了一个新的 Matcher 。您看我是复用 ScalableVarMatcher
还是接着用自己写的这个Matcher呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
复用
ScalableVarMatcher
时,在tests/test_cartesian_prod.py
中test_case_2()
中报错ValueError: Expect a 1D vector, but got shape []
,但是代码能正常转换,我当时写的时候没定位到这个错误,这才写了一个新的 Matcher 。您看我是复用ScalableVarMatcher
还是接着用自己写的这个Matcher呢
可以定位一下吧,看是否优化下原Matcher,尽可能复用
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是需要改动一下,p1 中红框圈起来的地方有些问题,我在其他使用了这个Matcher的api里测试了一下,如果torch中只传入形参且形参不是tuple或list时,paddle的api中的参数会从列表变成一个值,同上边的p2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@inaomIIsfarell 那你这个跟之前的ScalableVarMatcher逻辑还是有点区别,之前的逻辑是支持以下用法的
api(3, 4, 5)
api(3)
api([3, 4, 5])
你这个API仅支持前两种用法,不支持为list的用法,因此还是重新写个简化版的ScalableVarMatcher吧,不要辅助函数
|
||
from apibase import APIBase | ||
|
||
obj = APIBase("torch.block_diag") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
有 可变参数、关键字参数 的用法吗?有的话可以加一下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
指定关键字参数时torch报错,应该是不支持指定关键字参数的
@inaomIIsfarell 那你这个跟之前的ScalableVarMatcher逻辑还是有点区别,之前的逻辑是支持以下用法的
你这个API仅支持前1、2、5用法,不支持3、4的用法,因此还是重新写个简化版的ScalableVarMatcher吧,不需要辅助函数。类似于这样:
|
paconvert/api_matcher.py
Outdated
|
||
class CartesianProdMatcher(BaseMatcher): | ||
def get_paddle_nodes(self, args, kwargs): | ||
if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个输入不可能是常数,按我给出的示例,分支的分类有问题吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我在改
@@ -16,7 +16,7 @@ | |||
|
|||
from apibase import APIBase | |||
|
|||
obj = APIBase("torch.Tensor.float_power") | |||
obj = APIBase("torch.Tensor.float_power", is_aux_api=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加了辅助函数,需要新增一些对应的case吧
tests/test_Tensor_permute.py
Outdated
@@ -16,7 +16,7 @@ | |||
|
|||
from apibase import APIBase | |||
|
|||
obj = APIBase("torch.Tensor.permute") | |||
obj = APIBase("torch.Tensor.permute", is_aux_api=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个需要加is_aux_api吗
tests/test_Tensor_tile.py
Outdated
@@ -16,7 +16,7 @@ | |||
|
|||
from apibase import APIBase | |||
|
|||
obj = APIBase("torch.Tensor.tile") | |||
obj = APIBase("torch.Tensor.tile", is_aux_api=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个需要加is_aux_api吗
@zhwesky2010 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR Docs
PaddlePaddle/docs#6885
PR APIs
原pr:#487