diff --git a/paconvert/api_mapping.json b/paconvert/api_mapping.json index 5ee4b58cf..ab7693276 100644 --- a/paconvert/api_mapping.json +++ b/paconvert/api_mapping.json @@ -1473,7 +1473,10 @@ }, "torch.Tensor.is_inference": {}, "torch.Tensor.is_meta": {}, - "torch.Tensor.is_pinned": {}, + "torch.Tensor.is_pinned": { + "Matcher": "Is_PinnedMatcher", + "min_input_args": 0 + }, "torch.Tensor.is_quantized": {}, "torch.Tensor.is_set_to": {}, "torch.Tensor.is_shared": {}, @@ -2321,7 +2324,8 @@ }, "torch.Tensor.pin_memory": { "Matcher": "GenericMatcher", - "paddle_api": "paddle.Tensor.pin_memory" + "paddle_api": "paddle.Tensor.pin_memory", + "min_input_args": 0 }, "torch.Tensor.pinverse": { "Matcher": "TensorFunc2PaddleFunc", @@ -2395,7 +2399,25 @@ ] }, "torch.Tensor.qscheme": {}, - "torch.Tensor.quantile": {}, + "torch.Tensor.quantile": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.Tensor.quantile", + "min_input_args": 1, + "args_list": [ + "q", + "dim", + "keepdim", + "*", + "interpolation", + "out" + ], + "kwargs_change": { + "dim": "axis" + }, + "unsupport_args": [ + "interpolation " + ] + }, "torch.Tensor.rad2deg": { "Matcher": "UnchangeMatcher" }, @@ -2937,7 +2959,10 @@ "paddle_api": "paddle.Tensor.tanh_" }, "torch.Tensor.tensor_split": {}, - "torch.Tensor.tile": {}, + "torch.Tensor.tile": { + "Matcher": "TensorTileMatcher", + "paddle_api": "paddle.Tensor.tile" + }, "torch.Tensor.to": { "Matcher": "TensorToMatcher" }, @@ -2946,7 +2971,14 @@ "paddle_api": "paddle.Tensor.to_dense" }, "torch.Tensor.to_mkldnn": {}, - "torch.Tensor.to_sparse": {}, + "torch.Tensor.to_sparse": { + "Matcher": "GenericMatcher", + "min_input_args": 0, + "paddle_api": "paddle.Tensor.to_sparse_coo", + "args_list": [ + "sparse_dim" + ] + }, "torch.Tensor.tolist": { "Matcher": "UnchangeMatcher" }, @@ -7422,6 +7454,26 @@ "dim" ] }, + "torch.nanquantile": { + "Matcher": "GenericMatcher", + "paddle_api": "paddle.nanquantile", + "args_list": [ + "input", + "q", + "dim", + "keepdim", + "*", + "interpolation", + "out" + ], + "kwargs_change": { + "input": "x", + "dim": "axis" + }, + "unsupport_args": [ + "interpolation " + ] + }, "torch.nansum": { "Matcher": "GenericMatcher", "paddle_api": "paddle.nansum", diff --git a/paconvert/api_matcher.py b/paconvert/api_matcher.py index 81c8a94ca..5af785028 100644 --- a/paconvert/api_matcher.py +++ b/paconvert/api_matcher.py @@ -1121,6 +1121,29 @@ def generate_code(self, kwargs): return code +class TensorTileMatcher(BaseMatcher): + def get_paddle_class_nodes(self, func, args, kwargs): + self.parse_func(func) + kwargs = self.parse_kwargs(kwargs) + if kwargs is None: + return None + + if "dims" in kwargs: + kwargs = {"repeat_times": kwargs.pop("dims")} + else: + if len(args) > 1 or (len(args) == 1 and isinstance(args[0], ast.Constant)): + perm = self.parse_args(args) + elif isinstance(args[0], ast.Starred): + perm = astor.to_source(args[0].value).strip("\n") + else: + perm = self.parse_args(args)[0] + + kwargs = {"repeat_times": str(perm).replace("'", "")} + + code = "{}.tile({})".format(self.paddleClass, self.kwargs_to_str(kwargs)) + return ast.parse(code).body + + class TensorNew_Matcher(BaseMatcher): def get_paddle_class_nodes(self, func, args, kwargs): self.parse_func(func) @@ -4151,3 +4174,11 @@ def generate_code(self, kwargs): return ast.parse( "paddle.utils.cpp_extension.setup({})".format(self.kwargs_to_str(kwargs)) ) + + +class Is_PinnedMatcher(BaseMatcher): + def generate_code(self, kwargs): + + code = f"'pinned' in str({self.paddleClass}.place)" + + return code diff --git a/tests/test_Tensor_is_pinned.py b/tests/test_Tensor_is_pinned.py new file mode 100644 index 000000000..73f43e4ec --- /dev/null +++ b/tests/test_Tensor_is_pinned.py @@ -0,0 +1,34 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.is_pinned") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + if torch.cuda.is_available(): + x = torch.randn(4,4).pin_memory() + else: + x = torch.randn(4,4) + result = x.is_pinned() + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_quantile.py b/tests/test_Tensor_quantile.py new file mode 100644 index 000000000..b0ab9c358 --- /dev/null +++ b/tests/test_Tensor_quantile.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.quantile") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64) + result = x.quantile(0.6) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64) + result = x.quantile(q=0.6) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64) + k = 0.6 + result = x.quantile(k) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64) + k = 0.6 + result = x.quantile(q=k) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64) + result = x.quantile(0.6, dim=None) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64) + result = x.quantile(0.6, dim=None, keepdim=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64) + result = x.quantile(q=0.6, dim=None, keepdim=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64) + result = x.quantile(0.6, None, False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([0., 1., 2., 3.],dtype=torch.float64) + result = x.quantile(q=0.6, keepdim=False, dim=None) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_10(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]],dtype=torch.float64) + result = x.quantile(0.6, dim=1, keepdim=True) + """ + ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_tile.py b/tests/test_Tensor_tile.py index af524bffe..e8f7de589 100644 --- a/tests/test_Tensor_tile.py +++ b/tests/test_Tensor_tile.py @@ -23,13 +23,101 @@ def test_case_1(): pytorch_code = textwrap.dedent( """ import torch - a = torch.Tensor([[1.,2.], [3.,4.]]) - result = a.tile((1,)) + x = torch.tensor([1., 2., 3., 4.]) + result = x.tile(1) """ ) - obj.run( - pytorch_code, - ["result"], - unsupport=True, - reason="Paddle not support this api convert now", + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[1., 2.], [ 3., 4.]]) + result = x.tile(2, 1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.Tensor([[1., 2.], [3., 4.]]) + result = x.tile((2, 1)) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.Tensor([[1., 2.], [3., 4.]]) + result = x.tile([2, 1]) + """ + ) + obj.run(pytorch_code, ["result"]) + + +# the only corner case, input a variable which is Constant, has no solution +def _test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.Tensor([1., 2., 3., 4.]) + dims = 1 + result = x.tile(dims) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.Tensor([[1., 2.], [3., 4.]]) + dims = (2, 1) + result = x.tile(dims) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.Tensor([[1., 2.], [3., 4.]]) + dims = (2, 1) + result = x.tile(*dims) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.Tensor([[1., 2.], [3., 4.]]) + dims = (2, 1) + result = x.tile(dims=dims) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.Tensor([[1., 2.], [3., 4.]]) + result = x.tile(dims=(2, 1)) + """ ) + obj.run(pytorch_code, ["result"]) diff --git a/tests/test_Tensor_to_sparse.py b/tests/test_Tensor_to_sparse.py new file mode 100644 index 000000000..62b356274 --- /dev/null +++ b/tests/test_Tensor_to_sparse.py @@ -0,0 +1,49 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.Tensor.to_sparse") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.Tensor([[1.,2.], [3.,4.]]) + b = a.to_sparse(1) + result = b.to_dense() + """ + ) + obj.run( + pytorch_code, + ["result"], + ) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + a = torch.Tensor([[1.,2.], [3.,4.]]) + b = a.to_sparse(sparse_dim = 1) + result = b.to_dense() + """ + ) + obj.run( + pytorch_code, + ["result"], + ) diff --git a/tests/test_nanquantile.py b/tests/test_nanquantile.py new file mode 100644 index 000000000..0c1c30c13 --- /dev/null +++ b/tests/test_nanquantile.py @@ -0,0 +1,125 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import textwrap + +from apibase import APIBase + +obj = APIBase("torch.nanquantile") + + +def test_case_1(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([1.02, 2.21, 3.333, 30], dtype=torch.float64) + result = torch.nanquantile(x, 0.5) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_2(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([float('nan'), 2.21, 3.333, 30], dtype=torch.float64) + result = torch.nanquantile(x, 0.6, dim=0) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_3(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([float('nan'), 1.02, 2.21, 3.333,30, float('nan')], dtype=torch.float64) + result = torch.nanquantile(x, q=0.3, dim=-1) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_4(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[float('nan'), 1.02, 2.21, 3.333,30, float('nan')]], dtype=torch.float64) + result = torch.nanquantile(x, q=0.3, dim=1, keepdim=True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_5(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[float('nan'), 1.02, 2.21, 3.333,30, float('nan')]], dtype=torch.float64) + result = torch.nanquantile(x, q=0.3, dim=1, keepdim=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[float('nan'), 1.02, 2.21, 3.333,30, float('nan')]], dtype=torch.float64) + result = torch.nanquantile(input=x, q=0.3, dim=1, keepdim=False) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + result = torch.nanquantile(x, 0.3, 1, True) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_8(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([float('nan'), 1.02, 2.21, 3.333,30, float('nan')], dtype=torch.float64) + result = torch.tensor([1], dtype=torch.float64) + torch.nanquantile(x, 0.3, -1, True, out=result) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_9(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[0]], dtype=torch.float64) + result = torch.nanquantile(x=x, q=0.3, dim=1, keepdim=True, interpolation='higher') + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle not support this parameter", + ) diff --git a/tests/test_quantile.py b/tests/test_quantile.py index 16a8c9a4b..d334f2981 100644 --- a/tests/test_quantile.py +++ b/tests/test_quantile.py @@ -68,3 +68,30 @@ def test_case_5(): """ ) obj.run(pytorch_code, ["result"]) + + +def test_case_6(): + pytorch_code = textwrap.dedent( + """ + import torch + result = torch.tensor([], dtype=torch.float64) + torch.quantile(torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]],dtype=torch.float64), 0.6, dim=1, keepdim=True, out=result) + """ + ) + obj.run(pytorch_code, ["result"]) + + +def test_case_7(): + pytorch_code = textwrap.dedent( + """ + import torch + x = torch.tensor([[ 0.0795, -1.2117, 0.9765], [ 1.1707, 0.6706, 0.4884]], dtype=torch.float64) + result = torch.quantile(x=x, q=0.3, dim=1, keepdim=True, interpolation='higher') + """ + ) + obj.run( + pytorch_code, + ["result"], + unsupport=True, + reason="Paddle not support this parameter", + )