Skip to content

Commit

Permalink
fix (#60634)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 committed Jan 10, 2024
1 parent 97e4bdf commit 233d3d7
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 32 deletions.
14 changes: 3 additions & 11 deletions python/paddle/static/nn/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,6 @@ def case(pred_fn_pairs, default=None, name=None):
... print(res_1, res_2)
[[1. 1.]] [3 3 3]
'''
helper = LayerHelper('case', **locals())

def _case_check_args(pred_fn_pairs, default):
'''
Expand Down Expand Up @@ -899,16 +898,9 @@ def _case_check_args(pred_fn_pairs, default):
)
pred, fn = pred_fn

if not isinstance(pred, Variable):
raise TypeError(
_error_message(
"The pred's type",
"pred_fn_pairs",
"case",
"boolean Variable",
type(pred),
)
)
check_variable_and_dtype(
pred, 'pred', ['bool'], 'paddle.static.nn.case'
)

if not callable(fn):
raise TypeError(
Expand Down
57 changes: 36 additions & 21 deletions test/legacy_test/test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
from paddle.base import core
from paddle.base.backward import append_backward
from paddle.base.framework import Program, program_guard
from paddle.pir_utils import test_with_pir_api

paddle.enable_static()


class TestAPICase(unittest.TestCase):
@test_with_pir_api
def test_return_single_var(self):
def fn_1():
return paddle.tensor.fill_constant(
Expand All @@ -43,9 +45,9 @@ def fn_3():
shape=[4, 3], dtype='int32', value=3
)

main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.tensor.fill_constant(
shape=[1], dtype='float32', value=0.3
)
Expand Down Expand Up @@ -100,6 +102,7 @@ def fn_3():
np.testing.assert_allclose(res[3], 2, rtol=1e-05)
np.testing.assert_allclose(res[4], 2, rtol=1e-05)

@test_with_pir_api
def test_0d_tensor(self):
def fn_1():
return paddle.full(shape=[], dtype='int32', fill_value=1)
Expand All @@ -110,9 +113,9 @@ def fn_2():
def fn_3():
return paddle.full(shape=[], dtype='int32', fill_value=3)

main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.full(shape=[], dtype='float32', fill_value=0.3)
y = paddle.full(shape=[], dtype='float32', fill_value=0.1)
z = paddle.full(shape=[], dtype='float32', fill_value=0.2)
Expand Down Expand Up @@ -166,18 +169,20 @@ def fn_3():
np.testing.assert_allclose(res[4], 2, rtol=1e-05)
self.assertEqual(res[4].shape, ())

# Todo(zhangbo): grad_list can not find dx in oir mode
# @test_with_pir_api
def test_0d_tensor_backward(self):
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.full(shape=[], dtype='float32', fill_value=-2.0)
x.stop_gradient = False
pred = paddle.full(shape=[], dtype='bool', fill_value=0)
# pred is False, so out = -x
out = paddle.static.nn.case(
pred_fn_pairs=[(pred, lambda: x)], default=lambda: -x
)
append_backward(out)
grad_list = append_backward(out)

place = (
base.CUDAPlace(0)
Expand All @@ -186,7 +191,14 @@ def test_0d_tensor_backward(self):
)
exe = base.Executor(place)

res = exe.run(main_program, fetch_list=[out.name, x.grad_name])
if paddle.framework.in_pir_mode():
for p, g in grad_list:
if p.is_same(x):
dx = g
res = exe.run(main_program, fetch_list=[out, dx])
else:
res = exe.run(main_program, fetch_list=[out.name, x.grad_name])

np.testing.assert_allclose(
np.asarray(res[0]), np.array(2.0), rtol=1e-05
)
Expand Down Expand Up @@ -252,6 +264,7 @@ def fn_3():

paddle.enable_static()

@test_with_pir_api
def test_return_var_tuple(self):
def fn_1():
return paddle.tensor.fill_constant(
Expand All @@ -269,14 +282,14 @@ def fn_2():

def fn_3():
return paddle.tensor.fill_constant(
shape=[5], dtype='int32', value=5
shape=[5, 6], dtype='int32', value=5
), paddle.tensor.fill_constant(
shape=[5, 6], dtype='float32', value=6
)

main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.tensor.fill_constant(shape=[1], dtype='float32', value=1)
y = paddle.tensor.fill_constant(shape=[1], dtype='float32', value=1)
z = paddle.tensor.fill_constant(shape=[1], dtype='float32', value=3)
Expand Down Expand Up @@ -305,6 +318,7 @@ def fn_3():


class TestAPICase_Nested(unittest.TestCase):
@test_with_pir_api
def test_nested_case(self):
def fn_1(x=1):
var_5 = paddle.tensor.fill_constant(
Expand Down Expand Up @@ -383,9 +397,9 @@ def fn_3():
)
return out

main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.tensor.fill_constant(
shape=[1], dtype='float32', value=0.3
)
Expand Down Expand Up @@ -423,6 +437,7 @@ def fn_3():
np.testing.assert_allclose(res[1], 2, rtol=1e-05)
np.testing.assert_allclose(res[2], 3, rtol=1e-05)

@test_with_pir_api
def test_nested_0d_tensor(self):
def fn_1(x=1):
var_5 = paddle.full(shape=[], dtype='int32', fill_value=5)
Expand Down Expand Up @@ -489,9 +504,9 @@ def fn_3():
)
return out

main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.full(shape=[], dtype='float32', fill_value=0.3)
y = paddle.full(shape=[], dtype='float32', fill_value=0.1)
z = paddle.full(shape=[], dtype='float32', fill_value=0.2)
Expand Down

0 comments on commit 233d3d7

Please sign in to comment.