Skip to content

Commit

Permalink
[PIR] Adaptation of TestNoBackwardAPIStatic.test_unique* (#62794)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil committed Mar 17, 2024
1 parent 5128e0d commit a2dd9e4
Showing 1 changed file with 47 additions and 26 deletions.
73 changes: 47 additions & 26 deletions test/legacy_test/test_zero_dim_no_backward_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,37 +487,58 @@ def test_one_hot_label(self):
self.assertEqual(res[0].shape, (4,))
self.assertEqual(res[0][2], 1)

@test_with_pir_api
def test_unique_consecutive(self):
x = paddle.rand([])
y, inverse, counts = paddle.unique_consecutive(
x, return_inverse=True, return_counts=True
)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.rand([])
y, inverse, counts = paddle.unique_consecutive(
x, return_inverse=True, return_counts=True
)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[y, inverse, counts])
self.assertEqual(y, x)
self.assertEqual(inverse, 0)
self.assertEqual(counts, 1)
self.assertEqual(res[0].shape, (1,))
self.assertEqual(res[1].shape, (1,))
self.assertEqual(res[2].shape, (1,))
(
x_res,
y_res,
inverse_res,
counts_res,
) = paddle.static.Executor().run(
main_program, fetch_list=[x, y, inverse, counts]
)
self.assertEqual(x_res, y_res)
self.assertEqual(inverse_res, 0)
self.assertEqual(counts_res, 1)
self.assertEqual(y_res.shape, (1,))
self.assertEqual(inverse_res.shape, (1,))
self.assertEqual(counts_res.shape, (1,))

@test_with_pir_api
def test_unique(self):
x = paddle.rand([])
y, index, inverse, counts = paddle.unique(
x, return_index=True, return_inverse=True, return_counts=True
)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.rand([])
y, index, inverse, counts = paddle.unique(
x, return_index=True, return_inverse=True, return_counts=True
)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[y, index, inverse, counts])
self.assertEqual(y, x)
self.assertEqual(index, 0)
self.assertEqual(inverse, 0)
self.assertEqual(counts, 1)
self.assertEqual(res[0].shape, (1,))
self.assertEqual(res[1].shape, (1,))
self.assertEqual(res[2].shape, (1,))
self.assertEqual(res[3].shape, (1,))
(
x_res,
y_res,
index_res,
inverse_res,
counts_res,
) = paddle.static.Executor().run(
main_program, fetch_list=[x, y, index, inverse, counts]
)
self.assertEqual(x_res, y_res)
self.assertEqual(index_res, 0)
self.assertEqual(inverse_res, 0)
self.assertEqual(counts_res, 1)
self.assertEqual(y_res.shape, (1,))
self.assertEqual(index_res.shape, (1,))
self.assertEqual(inverse_res.shape, (1,))
self.assertEqual(counts_res.shape, (1,))

@test_with_pir_api
def test_static_matrix_rank(self):
Expand Down

0 comments on commit a2dd9e4

Please sign in to comment.