Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2Stat]Support nonlocal mechanism in IF ast transformer #43666

Merged
merged 5 commits into from
Jun 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 59 additions & 11 deletions python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from paddle.fluid.layers import cast, control_flow, logical_and, logical_not, logical_or, nn
from paddle.fluid.layers.control_flow import cond, while_loop, less_than, increment
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_VAR_NAME
from paddle.fluid.dygraph.dygraph_to_static.utils import UndefinedVar


def convert_while_loop(cond, body, loop_vars):
Expand Down Expand Up @@ -188,25 +189,27 @@ def _run_py_logical_not(x):
return not x


def convert_ifelse(pred, true_fn, false_fn, true_args, false_args):
def convert_ifelse(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
"""
A function representation of a Python ``if/else`` statement.

Args:
pred(bool|Tensor): A boolean Tensor which determines whether to return the result of ``true_fn`` or ``false_fn`` .
true_fn(callable): A callable to be performed if ``pred`` is true.
false_fn(callable): A callable to be performed if ``pred`` is false.
true_args(tuple): Parameters of ``true_fn``.
false_args(tuple): Parameters of ``false_fn``.
get_args(callable): Get all arguments that needed in true_fn and false_fn.
set_args(callable): Update arguments that modified in trure_fn and false_fn.

Returns:
``true_fn(true_args)`` if the predicate ``pred`` is true else ``false_fn(false_args)`` .
``true_fn()`` if the predicate ``pred`` is true else ``false_fn()`` .

"""
if isinstance(pred, Variable):
out = _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args)
out = _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids)
else:
out = _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args)
out = _run_py_ifelse(pred, true_fn, false_fn)

return _remove_no_value_return_var(out)

Expand Down Expand Up @@ -244,14 +247,59 @@ def _remove_no_value_return_var(out):
return out


def _run_paddle_cond(pred, true_fn, false_fn, true_args, false_args):
def _check_no_undefined_var(outs, names, branch_name):
if names is None: return
if not isinstance(outs, (list, tuple)):
outs = [outs]
for var, name in zip(list(outs), names):
if isinstance(var, UndefinedVar):
raise ValueError(
"Required '{}' must be initialized both in if-else branch, but found it not initialized in '{}'."
.format(name, branch_name))


def _run_paddle_cond(pred, true_fn, false_fn, get_args, set_args,
return_name_ids):
"""
Paddle cond API will evaluate both ture_fn and false_fn codes.
"""
pred = cast_bool_if_necessary(pred)
return control_flow.cond(pred, lambda: true_fn(*true_args),
lambda: false_fn(*false_args))
init_args = get_args()

def new_true_fn():
set_args(init_args)
outs = true_fn()
_check_no_undefined_var(outs, return_name_ids, 'if_body')
return outs

def new_false_fn():
set_args(init_args)
outs = false_fn()
_check_no_undefined_var(outs, return_name_ids, 'else_body')
return outs

cond_outs = control_flow.cond(pred, new_true_fn, new_false_fn)
# IfExpr's return_name_ids maybe None
if return_name_ids is None:
return cond_outs

# recover args state
num_outs = len(return_name_ids)
num_args = 1 if not isinstance(init_args, tuple) else len(init_args)
assert num_outs <= num_args

if num_args == 1:
final_outs = cond_outs
else:
cond_outs = (cond_outs, ) if num_outs == 1 else cond_outs
final_outs = cond_outs + init_args[num_outs:]

set_args(final_outs)
return final_outs


def _run_py_ifelse(pred, true_fn, false_fn, true_args, false_args):
return true_fn(*true_args) if pred else false_fn(*false_args)
def _run_py_ifelse(pred, true_fn, false_fn):
return true_fn() if pred else false_fn()


def convert_len(var):
Expand Down
Loading