Skip to content

Commit

Permalink
【PIR API adaptor No.145】masked_select (PaddlePaddle#58929)
Browse files Browse the repository at this point in the history
  • Loading branch information
cocoshe authored and SecretXV committed Nov 28, 2023
1 parent 4c4a737 commit 2925239
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
6 changes: 2 additions & 4 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,12 +913,9 @@ def masked_select(x, mask, name=None):
>>> print(out.numpy())
[1. 5. 6. 9.]
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.masked_select(x, mask)

else:
helper = LayerHelper("masked_select", **locals())
check_variable_and_dtype(
x,
'x',
Expand All @@ -928,6 +925,7 @@ def masked_select(x, mask, name=None):
check_variable_and_dtype(
mask, 'mask', ['bool'], 'paddle.tensor.search.masked_select'
)
helper = LayerHelper("masked_select", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='masked_select',
Expand Down
16 changes: 10 additions & 6 deletions test/legacy_test/test_masked_select_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import paddle
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


def np_masked_select(x, mask):
Expand All @@ -42,10 +43,10 @@ def setUp(self):
self.outputs = {'Y': out}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Y')
self.check_grad(['X'], 'Y', check_pir=True)

def init(self):
self.shape = (50, 3)
Expand Down Expand Up @@ -77,10 +78,10 @@ def setUp(self):
self.outputs = {'Y': out}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Y')
self.check_grad(['X'], 'Y', check_pir=True)

def init(self):
self.shape = (50, 3)
Expand Down Expand Up @@ -114,10 +115,12 @@ def setUp(self):
self.outputs = {'Y': convert_float_to_uint16(out)}

def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
self.check_output_with_place(core.CUDAPlace(0), check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(core.CUDAPlace(0), ['X'], 'Y')
self.check_grad_with_place(
core.CUDAPlace(0), ['X'], 'Y', check_pir=True
)

def init(self):
self.shape = (50, 3)
Expand Down Expand Up @@ -146,6 +149,7 @@ def test_imperative_mode(self):
np.testing.assert_allclose(out.numpy(), np_out, rtol=1e-05)
paddle.enable_static()

@test_with_pir_api
def test_static_mode(self):
shape = [8, 9, 6]
x = paddle.static.data(shape=shape, dtype='float32', name='x')
Expand Down

0 comments on commit 2925239

Please sign in to comment.