From 9787b8439137ffbcc838fc928898bf73615c0754 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Tue, 14 Jun 2022 03:09:15 +0000 Subject: [PATCH] Add if/elif UT --- .../dygraph_to_static/ifelse_simple_func.py | 16 +++++++++++++++- .../dygraph_to_static/test_program_translator.py | 10 ++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py index 0b600979bea86..60ab0a7f4f5fd 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/ifelse_simple_func.py @@ -106,10 +106,24 @@ def dyfunc_with_if_else_early_return1(): a = paddle.zeros([2, 2]) b = paddle.zeros([3, 3]) return a, b - a = paddle.ones([2, 2]) + a = paddle.zeros([2, 2]) + 1 return a +def dyfunc_with_if_else_early_return2(): + x = paddle.to_tensor([10]) + if x == 0: + a = paddle.zeros([2, 2]) + b = paddle.zeros([3, 3]) + return a, b + elif x == 1: + c = paddle.zeros([2, 2]) + 1 + d = paddle.zeros([2, 2]) + 1 + return c, d + e = paddle.zeros([2, 2]) + 3 + return e + + def dyfunc_with_if_else_with_list_geneator(x): if 10 > 5: y = paddle.add_n( diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py index df252652896b3..cf8be6640300e 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py @@ -29,7 +29,7 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code import paddle.jit.dy2static as _jst -from ifelse_simple_func import dyfunc_with_if_else, dyfunc_with_if_else_early_return1 +from ifelse_simple_func import dyfunc_with_if_else, dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return2 np.random.seed(0) @@ -337,11 +337,17 @@ def test_raise_error(self): class TestIfElseEarlyReturn(unittest.TestCase): def test_ifelse_early_return1(self): - answer = np.ones([2, 2]) + answer = np.zeros([2, 2]) + 1 static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return1) out = static_func() self.assertTrue(np.allclose(answer, out.numpy())) + def test_ifelse_early_return2(self): + answer = np.zeros([2, 2]) + 3 + static_func = paddle.jit.to_static(dyfunc_with_if_else_early_return2) + out = static_func() + self.assertTrue(np.allclose(answer, out.numpy())) + class TestRemoveCommentInDy2St(unittest.TestCase):