From 2925239aae085d0760bd05b4413b2b35a314b8e3 Mon Sep 17 00:00:00 2001 From: coco <69197635+cocoshe@users.noreply.github.com> Date: Wed, 15 Nov 2023 11:29:47 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PIR=20API=20adaptor=20No.145=E3=80=91m?= =?UTF-8?q?asked=5Fselect=20(#58929)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/paddle/tensor/search.py | 6 ++---- test/legacy_test/test_masked_select_op.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/python/paddle/tensor/search.py b/python/paddle/tensor/search.py index 8e944d7ca0055..0a9c6f8b74ee6 100755 --- a/python/paddle/tensor/search.py +++ b/python/paddle/tensor/search.py @@ -913,12 +913,9 @@ def masked_select(x, mask, name=None): >>> print(out.numpy()) [1. 5. 6. 9.] """ - - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.masked_select(x, mask) - else: - helper = LayerHelper("masked_select", **locals()) check_variable_and_dtype( x, 'x', @@ -928,6 +925,7 @@ def masked_select(x, mask, name=None): check_variable_and_dtype( mask, 'mask', ['bool'], 'paddle.tensor.search.masked_select' ) + helper = LayerHelper("masked_select", **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op( type='masked_select', diff --git a/test/legacy_test/test_masked_select_op.py b/test/legacy_test/test_masked_select_op.py index 954f3ffd1d9b6..48325d9bc283b 100644 --- a/test/legacy_test/test_masked_select_op.py +++ b/test/legacy_test/test_masked_select_op.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api def np_masked_select(x, mask): @@ -42,10 +43,10 @@ def setUp(self): self.outputs = {'Y': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(['X'], 'Y', check_pir=True) def init(self): self.shape = (50, 3) @@ -77,10 +78,10 @@ def setUp(self): self.outputs = {'Y': out} def test_check_output(self): - self.check_output() + self.check_output(check_pir=True) def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(['X'], 'Y', check_pir=True) def init(self): self.shape = (50, 3) @@ -114,10 +115,12 @@ def setUp(self): self.outputs = {'Y': convert_float_to_uint16(out)} def test_check_output(self): - self.check_output_with_place(core.CUDAPlace(0)) + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) def test_check_grad(self): - self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Y') + self.check_grad_with_place( + core.CUDAPlace(0), ['X'], 'Y', check_pir=True + ) def init(self): self.shape = (50, 3) @@ -146,6 +149,7 @@ def test_imperative_mode(self): np.testing.assert_allclose(out.numpy(), np_out, rtol=1e-05) paddle.enable_static() + @test_with_pir_api def test_static_mode(self): shape = [8, 9, 6] x = paddle.static.data(shape=shape, dtype='float32', name='x')