Skip to content

Commit

Permalink
[Dy2Stat]Enhance nonlocal machanism while nonlocal vars is empty (#43848
Browse files Browse the repository at this point in the history
)
  • Loading branch information
Aurelius84 committed Jun 27, 2022
1 parent e6e1c5e commit 40a7731
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ def transform_if_else(node, root):
if ARGS_NAME in nonlocal_names:
nonlocal_names.remove(ARGS_NAME)

nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)]
nonlocal_stmt_node = [create_nonlocal_stmt_node(nonlocal_names)
] if nonlocal_names else []

empty_arg_node = gast.arguments(args=[],
posonlyargs=[],
Expand Down Expand Up @@ -557,8 +558,20 @@ def create_get_args_node(names):
def get_args_0():
nonlocal x, y
return x, y
"""

def empty_node():
func_def = """
def {func_name}():
return
""".format(func_name=unique_name.generate(GET_ARGS_FUNC_PREFIX))
return gast.parse(textwrap.dedent(func_def)).body[0]

assert isinstance(names, (list, tuple))
if not names:
return empty_node()

template = """
def {func_name}():
nonlocal {vars}
Expand All @@ -578,7 +591,19 @@ def set_args_0(__args):
nonlocal x, y
x, y = __args
"""

def empty_node():
func_def = """
def {func_name}({args}):
pass
""".format(func_name=unique_name.generate(SET_ARGS_FUNC_PREFIX),
args=ARGS_NAME)
return gast.parse(textwrap.dedent(func_def)).body[0]

assert isinstance(names, (list, tuple))
if not names:
return empty_node()

template = """
def {func_name}({args}):
nonlocal {vars}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@ def loss_fn(x, lable):
return loss


def dyfunc_empty_nonlocal(x):
flag = True
if flag:
print("It's a test for empty nonlocal stmt")

if paddle.mean(x) < 0:
x + 1

out = x * 2
return out


def dyfunc_with_if_else(x_v, label=None):
if fluid.layers.mean(x_v).numpy()[0] > 5:
x_v = x_v - 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,13 @@ def setUp(self):
self.dyfunc = dyfunc_with_if_else3


class TestDygraphIfElse4(TestDygraphIfElse):

def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_empty_nonlocal


class TestDygraphIfElseWithListGenerator(TestDygraphIfElse):

def setUp(self):
Expand Down

0 comments on commit 40a7731

Please sign in to comment.