Skip to content

Commit

Permalink
fix codestyle
Browse files Browse the repository at this point in the history
  • Loading branch information
cocoshe committed Oct 21, 2023
1 parent 9c651d9 commit 13d04ad
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 19 deletions.
16 changes: 6 additions & 10 deletions python/paddle/nn/functional/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,13 @@ def pairwise_distance(x, y, p=2.0, epsilon=1e-6, keepdim=False, name=None):

return out


def pdist(
x, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary", name=None
):
r'''
Computes the p-norm distance between every pair of row vectors in the input.
Computes the p-norm distance between every pair of row vectors in the input.
Args:
x (Tensor): A tensor with shape :math:`N \times M`.
p (float, optional): The value for the p-norm distance to calculate between each vector pair. Default: :math:`2.0`.
Expand All @@ -132,7 +133,7 @@ def pdist(
Examples:
.. code-block:: python
>>> import paddle
>>> a = paddle.randn([4, 5])
>>> a
Expand All @@ -145,16 +146,11 @@ def pdist(
>>> pdist_out
Tensor(shape=[6], dtype=float32, place=Place(gpu:0), stop_gradient=True,
[1.85331142, 2.58652687, 2.98273396, 1.61549115, 2.28762150, 2.85576940])
'''

x_shape = list(x.shape)
assert len(x_shape) == 2, (
"The x must be 2-dimensional"
)
assert len(x_shape) == 2, "The x must be 2-dimensional"
d = paddle.cdist(x, x, p, compute_mode)
mask = ~paddle.tril(paddle.ones(d.shape, dtype='bool'))
return paddle.masked_select(d, mask)



15 changes: 6 additions & 9 deletions test/legacy_test/test_pdist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,20 @@
# limitations under the License.

import unittest

import numpy as np

import paddle


def ref_pdist(x, p=2.0):
dist = np.linalg.norm(x[..., None, :] - x[None, :, :], ord=p, axis=-1)
res = []
rows, cols = dist.shape
for i in range(rows):
for j in range(cols):
if i >= j: continue
if i >= j:
continue
res.append(dist[i][j])
return np.array(res)

Expand Down Expand Up @@ -107,7 +111,6 @@ class TestpdistAPICase9(TestpdistAPI):
def init_input(self):
self.x = np.random.rand(500, 100).astype('float64')


def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
Expand All @@ -116,9 +119,7 @@ def test_static_api(self):
out1 = paddle.pdist(x, self.p, "donot_use_mm_for_euclid_dist")
out2 = paddle.pdist(x, self.p, "use_mm_for_euclid_dist")
exe = paddle.static.Executor(self.place)
res = exe.run(
feed={'x': self.x}, fetch_list=[out0, out1, out2]
)
res = exe.run(feed={'x': self.x}, fetch_list=[out0, out1, out2])
out_ref = ref_pdist(self.x, self.p)
np.testing.assert_allclose(out_ref, res[0])
np.testing.assert_allclose(out_ref, res[1])
Expand All @@ -140,7 +141,3 @@ def test_dygraph_api(self):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()




0 comments on commit 13d04ad

Please sign in to comment.