Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.42】 为Paddle代码转换工具新增API转换规则 (第1组 编号1-20) #318

Merged
merged 25 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 54 additions & 4 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -1473,7 +1473,9 @@
},
"torch.Tensor.is_inference": {},
"torch.Tensor.is_meta": {},
"torch.Tensor.is_pinned": {},
"torch.Tensor.is_pinned": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.Tensor的API,都需要配置min_input_args

"Matcher": "Is_PinnedMatcher"
},
"torch.Tensor.is_quantized": {},
"torch.Tensor.is_set_to": {},
"torch.Tensor.is_shared": {},
Expand Down Expand Up @@ -2395,7 +2397,25 @@
]
},
"torch.Tensor.qscheme": {},
"torch.Tensor.quantile": {},
"torch.Tensor.quantile": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.quantile",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"min_input_args" 配置了吗

"min_input_args": 1,
"args_list": [
"q",
"dim",
"keepdim",
"*",
"interpolation",
"out"
],
"kwargs_change": {
"dim": "axis"
},
"unsupport_args": [
"interpolation "
]
},
"torch.Tensor.rad2deg": {
"Matcher": "UnchangeMatcher"
},
Expand Down Expand Up @@ -2937,7 +2957,10 @@
"paddle_api": "paddle.Tensor.tanh_"
},
"torch.Tensor.tensor_split": {},
"torch.Tensor.tile": {},
"torch.Tensor.tile": {
"Matcher": "TensorTileMatcher",
"paddle_api": "paddle.Tensor.tile"
},
"torch.Tensor.to": {
"Matcher": "TensorToMatcher"
},
Expand All @@ -2946,7 +2969,14 @@
"paddle_api": "paddle.Tensor.to_dense"
},
"torch.Tensor.to_mkldnn": {},
"torch.Tensor.to_sparse": {},
"torch.Tensor.to_sparse": {
"Matcher": "GenericMatcher",
"min_input_args": 0,
"paddle_api": "paddle.Tensor.to_sparse_coo",
"args_list": [
"sparse_dim"
]
},
"torch.Tensor.tolist": {
"Matcher": "UnchangeMatcher"
},
Expand Down Expand Up @@ -7422,6 +7452,26 @@
"dim"
]
},
"torch.nanquantile": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nanquantile",
"args_list": [
"input",
"q",
"dim",
"keepdim",
"*",
"interpolation",
"out"
],
"kwargs_change": {
"input": "x",
"dim": "axis"
},
"unsupport_args": [
"interpolation "
]
},
"torch.nansum": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.nansum",
Expand Down
31 changes: 31 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,29 @@ def generate_code(self, kwargs):
return code


class TensorTileMatcher(BaseMatcher):
def get_paddle_class_nodes(self, func, args, kwargs):
self.parse_func(func)
kwargs = self.parse_kwargs(kwargs)
if kwargs is None:
return None

if "dims" in kwargs:
kwargs = {"repeat_times": kwargs.pop("dims")}
else:
if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)):
perm = self.parse_args(args)
elif isinstance(args[0], ast.Starred):
perm = astor.to_source(args[0].value).strip("\n")
else:
perm = self.parse_args(args)[0]

kwargs = {"repeat_times": str(perm).replace("'", "")}

code = "{}.tile({})".format(self.paddleClass, self.kwargs_to_str(kwargs))
return ast.parse(code).body


class TensorNew_Matcher(BaseMatcher):
def get_paddle_class_nodes(self, func, args, kwargs):
self.parse_func(func)
Expand Down Expand Up @@ -4151,3 +4174,11 @@ def generate_code(self, kwargs):
return ast.parse(
"paddle.utils.cpp_extension.setup({})".format(self.kwargs_to_str(kwargs))
)


class Is_PinnedMatcher(BaseMatcher):
def generate_code(self, kwargs):

code = f"'pinned' in str({self.paddleClass}.place)"

return code
35 changes: 35 additions & 0 deletions tests/test_Tensor_is_pinned.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.is_pinned")


def test_case_1():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加个GPU的单测,使用if torch.cuda.is_available(): 判断下

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个comment不修改吗

pytorch_code = textwrap.dedent(
"""
import torch
if torch.cuda.is_available():
x = torch.randn(4,4).cuda()
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Nov 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用.pin_memory(),测试更好点

result = x.is_pinned()
else:
x = torch.randn(4,4)
result = x.is_pinned()
"""
)
obj.run(pytorch_code, ["result"])
132 changes: 132 additions & 0 deletions tests/test_Tensor_quantile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.quantile")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64)
result = x.quantile(0.6)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

第一个参数,指定下关键字吧,这四种情况的用例必须实现:

全部指定关键字、全部不指定关键字、改变关键字顺序、默认参数均不指定

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些comment不修改吗

"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64)
result = x.quantile(q=0.6)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64)
k = 0.6
result = x.quantile(k)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64)
k = 0.6
result = x.quantile(q=k)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64)
result = x.quantile(0.6, dim=None)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_6():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64)
result = x.quantile(0.6, dim=None, keepdim=False)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_7():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64)
result = x.quantile(q=0.6, dim=None, keepdim=False)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_8():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64)
result = x.quantile(0.6, None, False)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_9():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64)
result = x.quantile(q=0.6, keepdim=False, dim=None)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_10():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]],dtype=torch.float64)
result = x.quantile(0.6, dim=1, keepdim=True)
"""
)
obj.run(pytorch_code, ["result"])
Loading