diff --git a/test/legacy_test/test_zero_dim_no_backward_api.py b/test/legacy_test/test_zero_dim_no_backward_api.py index b3ecbe4849271..998426fe2c71f 100644 --- a/test/legacy_test/test_zero_dim_no_backward_api.py +++ b/test/legacy_test/test_zero_dim_no_backward_api.py @@ -487,37 +487,58 @@ def test_one_hot_label(self): self.assertEqual(res[0].shape, (4,)) self.assertEqual(res[0][2], 1) + @test_with_pir_api def test_unique_consecutive(self): - x = paddle.rand([]) - y, inverse, counts = paddle.unique_consecutive( - x, return_inverse=True, return_counts=True - ) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.rand([]) + y, inverse, counts = paddle.unique_consecutive( + x, return_inverse=True, return_counts=True + ) - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[y, inverse, counts]) - self.assertEqual(y, x) - self.assertEqual(inverse, 0) - self.assertEqual(counts, 1) - self.assertEqual(res[0].shape, (1,)) - self.assertEqual(res[1].shape, (1,)) - self.assertEqual(res[2].shape, (1,)) + ( + x_res, + y_res, + inverse_res, + counts_res, + ) = paddle.static.Executor().run( + main_program, fetch_list=[x, y, inverse, counts] + ) + self.assertEqual(x_res, y_res) + self.assertEqual(inverse_res, 0) + self.assertEqual(counts_res, 1) + self.assertEqual(y_res.shape, (1,)) + self.assertEqual(inverse_res.shape, (1,)) + self.assertEqual(counts_res.shape, (1,)) + @test_with_pir_api def test_unique(self): - x = paddle.rand([]) - y, index, inverse, counts = paddle.unique( - x, return_index=True, return_inverse=True, return_counts=True - ) + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.rand([]) + y, index, inverse, counts = paddle.unique( + x, return_index=True, return_inverse=True, return_counts=True + ) - prog = paddle.static.default_main_program() - res = self.exe.run(prog, fetch_list=[y, index, inverse, counts]) - self.assertEqual(y, x) - self.assertEqual(index, 0) - self.assertEqual(inverse, 0) - self.assertEqual(counts, 1) - self.assertEqual(res[0].shape, (1,)) - self.assertEqual(res[1].shape, (1,)) - self.assertEqual(res[2].shape, (1,)) - self.assertEqual(res[3].shape, (1,)) + ( + x_res, + y_res, + index_res, + inverse_res, + counts_res, + ) = paddle.static.Executor().run( + main_program, fetch_list=[x, y, index, inverse, counts] + ) + self.assertEqual(x_res, y_res) + self.assertEqual(index_res, 0) + self.assertEqual(inverse_res, 0) + self.assertEqual(counts_res, 1) + self.assertEqual(y_res.shape, (1,)) + self.assertEqual(index_res.shape, (1,)) + self.assertEqual(inverse_res.shape, (1,)) + self.assertEqual(counts_res.shape, (1,)) @test_with_pir_api def test_static_matrix_rank(self):