Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 5th No.17】 为 Paddle 新增 pdist API -part #57869

Merged
merged 26 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3e19c90
add pdist api
cocoshe Oct 3, 2023
51b908a
move pdist to nn.functional, expose paddle.pdist api
cocoshe Oct 7, 2023
ced239c
clean
cocoshe Oct 7, 2023
9c651d9
clean
cocoshe Oct 7, 2023
13d04ad
fix codestyle
cocoshe Oct 21, 2023
8bcbe47
fix conflict
cocoshe Nov 11, 2023
212bdf7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Nov 24, 2023
7bbc22e
remove compute_mode
cocoshe Nov 24, 2023
20f82de
for api name rules
cocoshe Nov 24, 2023
b06dd06
Merge branch 'develop' into pdist_coco_dev
cocoshe Nov 29, 2023
2e259a6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
cocoshe Nov 30, 2023
e9dfa5a
Update test_pdist.py
cocoshe Nov 30, 2023
cddec3b
Merge branch 'pdist_coco_dev' of https://github.com/cocoshe/Paddle in…
cocoshe Nov 30, 2023
23d5f7e
add seed
cocoshe Nov 30, 2023
3a9fbfb
fix code sample
cocoshe Dec 1, 2023
e41f9dc
Merge branch 'pdist_coco_dev' of https://github.com/cocoshe/Paddle in…
cocoshe Dec 5, 2023
008ae5b
fix doc
cocoshe Dec 5, 2023
114453a
Update distance.py
cocoshe Dec 5, 2023
171339d
gpu0 to cpu in api doc
cocoshe Dec 9, 2023
8a281b4
gpu0 to cpu in api doc
cocoshe Dec 9, 2023
a9d053c
remove pdist in nn.functional __all__ list
cocoshe Dec 12, 2023
b3816df
Merge branch 'develop' into pdist_coco_dev
cocoshe Dec 12, 2023
3f7b607
Update __init__.py
cocoshe Dec 12, 2023
e3e5abf
fix en doc
cocoshe Dec 13, 2023
8172f28
Update python/paddle/nn/functional/distance.py
cocoshe Dec 13, 2023
0146a37
Update python/paddle/nn/functional/distance.py
cocoshe Dec 13, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,10 @@
flops,
)

from .nn.functional.distance import ( # noqa: F401
pdist,
)

import paddle.text # noqa: F401
import paddle.vision # noqa: F401

Expand Down Expand Up @@ -711,6 +715,7 @@
'sin_',
'dist',
'cdist',
'pdist',
'unbind',
'meshgrid',
'arange',
Expand Down
4 changes: 3 additions & 1 deletion python/paddle/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@
conv3d,
conv3d_transpose,
)
from .distance import pairwise_distance

from .distance import pairwise_distance, pdist
from .extension import (
diag_embed, # noqa: F401
gather_tree,
Expand Down Expand Up @@ -162,6 +163,7 @@
'conv3d',
'conv3d_transpose',
'pairwise_distance',
'pdist',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which path do we recommend users to use? If paddle.pdist is recommended, it cannot be added to this __all__ list. if paddle.nn.functional.pdist, it cannot be added to the __all__ list in python/paddle/__init__.py

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tks for you review, I think I prefer paddle.pdist path, I removed this line then.

'elu',
'elu_',
'gelu',
Expand Down
37 changes: 37 additions & 0 deletions python/paddle/nn/functional/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,40 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):
)

return out


def pdist(x, p=2.0, name=None):
r'''
Computes the p-norm distance between every pair of row vectors in the input.

Args:
x (Tensor): A tensor with shape :math:`N \times M`.
p (float, optional): The value for the p-norm distance to calculate between each vector pair. Default: :math:`2.0`.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.

Returns:
Tensor with shape: math:`N(N-1)/2` the dtype is same as input tensor.

Examples:
.. code-block:: python

>>> import paddle
>>> a = paddle.randn([4, 5])
Copy link
Contributor

@zxcd zxcd Nov 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc中给出seed,不然PR-CI-Static-Check过不了
参考:

>>> paddle.seed(2023)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc中给出seed,不然PR-CI-Static-Check过不了 参考:

>>> paddle.seed(2023)

现在添加了,但是好像没作用嘛?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加了seed之后,你的print的结果也会有变化,这块的输出你可以参考报错的内容

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

添加了seed之后,你的print的结果也会有变化,这块的输出你可以参考报错的内容

明白了,稍后修改~

>>> a
Tensor(shape=[4, 5], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[[-0.33173719, -0.93648648, -0.01741328, -0.94435263, 2.22178721],
[-0.65466857, 0.10307083, 0.08741203, -0.91078597, 0.72589827],
[ 0.06907391, -0.27584535, 1.35355449, -0.69688839, 0.18408430],
[-0.00939178, -0.32901841, -1.06503606, 0.81856263, 0.16791444]])
>>> pdist_out=paddle.pdist(a)
>>> print(pdist_out)
Tensor(shape=[6], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[1.85331142, 2.58652687, 2.98273396, 1.61549115, 2.28762150, 2.85576940])

'''

x_shape = list(x.shape)
assert len(x_shape) == 2, "The x must be 2-dimensional"
d = paddle.linalg.norm(x[..., None, :] - x[..., None, :, :], p=p, axis=-1)
mask = ~paddle.tril(paddle.ones(d.shape, dtype='bool'))
return paddle.masked_select(d, mask)
151 changes: 151 additions & 0 deletions test/legacy_test/test_pdist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle


def ref_pdist(x, p=2.0):
dist = np.linalg.norm(x[..., None, :] - x[None, :, :], ord=p, axis=-1)
res = []
rows, cols = dist.shape
for i in range(rows):
for j in range(cols):
if i >= j:
continue
res.append(dist[i][j])
return np.array(res)


class TestPdistAPI(unittest.TestCase):
def setUp(self):
self.x = np.random.rand(10, 20).astype('float32')
self.p = 2.0
self.init_input()
self.place = (
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)

def init_input(self):
pass

def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype)
out = paddle.pdist(
x,
self.p,
)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x}, fetch_list=[out])
out_ref = ref_pdist(self.x, self.p)
np.testing.assert_allclose(out_ref, res[0], rtol=1e-5, atol=1e-5)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里应该不需要使用rtol, atol?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> x = np.random.rand(3, 4).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> x == t_x.numpy()
array([[ True,  True,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True]])
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True,  True,  True])


>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True,  True,  True,  True,  True])


>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True,  True,  True,  True, False])
>>> 
>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True,  True,  True,  True, False])
>>> 
>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True,  True,  True,  True, False])
>>> 
>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([False, False, False,  True,  True])
>>> 
>>> x = np.random.rand(5, 6).astype('float64')
>>> t_x = paddle.to_tensor(x)
>>> np_norm = np.linalg.norm(x, axis=-1)
>>> pd_norm = paddle.linalg.norm(t_x, axis=-1)
>>> np_norm == pd_norm.numpy()
array([ True,  True, False, False,  True])
>>> 

如果不用rtol atol的话精度过不了。我进行了一些尝试,发现norm看样子像是没和numpy对齐,在cdist的单测中也是放宽了精度

np.testing.assert_allclose(out_ref, res[0], rtol=1e-5, atol=1e-5)


def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x)
out = paddle.pdist(
x,
self.p,
)
out_ref = ref_pdist(self.x, self.p)
np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5, atol=1e-5)
paddle.enable_static()


class TestPdistAPICase1(TestPdistAPI):
def init_input(self):
self.p = 0


class TestPdistAPICase2(TestPdistAPI):
def init_input(self):
self.p = 1.0


class TestPdistAPICase3(TestPdistAPI):
def init_input(self):
self.p = 3.0


class TestPdistAPICase4(TestPdistAPI):
def init_input(self):
self.p = 1.5


class TestPdistAPICase5(TestPdistAPI):
def init_input(self):
self.p = 2.5


class TestPdistAPICase6(TestPdistAPI):
def init_input(self):
self.p = float('inf')


class TestPdistAPICase7(TestPdistAPI):
def init_input(self):
self.x = np.random.rand(50, 20).astype('float64')


class TestPdistAPICase8(TestPdistAPI):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议给出有意义的测试命名

def init_input(self):
self.x = np.random.rand(500, 100).astype('float64')

def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data('x', self.x.shape, dtype=self.x.dtype)
out0 = paddle.pdist(
x,
self.p,
)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'x': self.x}, fetch_list=[out0])
out_ref = ref_pdist(self.x, self.p)
np.testing.assert_allclose(out_ref, res[0])

def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x)
out0 = paddle.pdist(
x,
self.p,
)
out_ref = ref_pdist(self.x, self.p)
np.testing.assert_allclose(out_ref, out0.numpy())
paddle.enable_static()


class TestPdistShapeError(unittest.TestCase):
def test_error(self):
with self.assertRaises(AssertionError):
self.x = np.random.rand(50, 10, 20).astype('float64')
self.p = 2.0
x = paddle.to_tensor(self.x)
out0 = paddle.pdist(
x,
self.p,
)


if __name__ == '__main__':
paddle.enable_static()
unittest.main()