Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2stat] Change the Global Switch Name of ProgramTranslator for API 2.0 #27203

Merged
46 changes: 26 additions & 20 deletions python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def __init__(self, function, input_spec=None):
self._function_spec = FunctionSpec(function, input_spec)
self._program_cache = ProgramCache()
self._descriptor_cache = weakref.WeakKeyDictionary()
# Note: Hold a reference to ProgramTranslator for switching `enable_declarative`.
# Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
self._program_trans = ProgramTranslator()

def __get__(self, instance, owner):
Expand Down Expand Up @@ -299,16 +299,17 @@ def __call__(self, *args, **kwargs):
"""

# 1. call dygraph function directly if not enable `declarative`
if not self._program_trans.enable_declarative:
if not self._program_trans.enable_to_static:
logging_utils.warn(
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable=False. "
"We will just return dygraph output.")
"The decorator '@paddle.jit.to_static' does NOT work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. If you would like to get static graph output, please call API "
"ProgramTranslator.enable(True)")
return self._call_dygraph_function(*args, **kwargs)

if not in_dygraph_mode() and self._program_trans.enable_declarative:
if not in_dygraph_mode():
raise RuntimeError(
"Failed to run the callable object {} decorated by '@paddle.jit.to_static', "
"because it does NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
"because it is NOT in dynamic mode. Please disable the static mode to enter dynamic mode with the "
"following API: paddle.disable_static().".format(
self.dygraph_function))

Expand Down Expand Up @@ -723,15 +724,15 @@ def __init__(self):
return
self._initialized = True
self._program_cache = ProgramCache()
self.enable_declarative = True
self.enable_to_static = True

def enable(self, enable_declarative):
def enable(self, enable_to_static):
"""
Enable or disable the converting from imperative to declarative by
ProgramTranslator globally.

Args:
enable_declarative (bool): True or False to enable or disable declarative.
enable_to_static (bool): True or False to enable or disable declarative.

Returns:
None.
Expand Down Expand Up @@ -760,9 +761,9 @@ def func(x):
print(func(x).numpy()) # [[2. 2.]]

"""
check_type(enable_declarative, "enable_declarative", bool,
check_type(enable_to_static, "enable_to_static", bool,
"ProgramTranslator.enable")
self.enable_declarative = enable_declarative
self.enable_to_static = enable_to_static

def get_output(self, dygraph_func, *args, **kwargs):
"""
Expand Down Expand Up @@ -803,10 +804,12 @@ def func(x):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
if not self.enable_declarative:
if not self.enable_to_static:
warnings.warn(
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
"We will just return dygraph output.")
"The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable to False. "
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
)
return dygraph_func(*args, **kwargs)
try:
function_spec = FunctionSpec(dygraph_func)
Expand Down Expand Up @@ -876,10 +879,11 @@ def func(x):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
if not self.enable_declarative:
if not self.enable_to_static:
warnings.warn(
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will "
"just return dygraph output.")
"The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable to False. We will "
"just return dygraph output. Please call ProgramTranslator.enable(True) if you would like to get static output."
)
return dygraph_func

static_func = convert_to_static(dygraph_func)
Expand Down Expand Up @@ -929,10 +933,12 @@ def func(x):
assert callable(
dygraph_func
), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
if not self.enable_declarative:
if not self.enable_to_static:
warnings.warn(
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False."
"We will just return dygraph output.")
"The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable to False."
"We will just return dygraph output. "
"Please call ProgramTranslator.enable(True) if you would like to get static output."
)
return dygraph_func(*args, **kwargs)

function_spec = FunctionSpec(dygraph_func)
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def func(x):
# TODO: remove this decorator after we finalize training API
def __impl__(*args, **kwargs):
program_translator = ProgramTranslator()
if in_dygraph_mode() or not program_translator.enable_declarative:
if in_dygraph_mode() or not program_translator.enable_to_static:
warnings.warn(
"The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode or set ProgramTranslator.enable to False. "
Expand Down Expand Up @@ -832,9 +832,9 @@ def train(layer, loader, loss_fn, opt):

# 1. input check
prog_translator = ProgramTranslator()
if not prog_translator.enable:
if not prog_translator.enable_to_static:
raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable=False."
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
)
if not isinstance(layer, Layer):
raise TypeError(
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/hapi/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,7 +1680,7 @@ def get_inout_spec(all_vars, return_name=False):

# TODO:
# 1. Make it Unnecessary to run model before calling `save_inference_model` for users in dygraph.
# 2. Save correct shape of input, now the interface stores the shape that the user sent to
# 2. Save correct shape of input, now the interface stores the shape that the user sent to
# the inputs of the model in running.
# 3. Make it Unnecessary to add `@paddle.jit.to_static` for users in dynamic mode.
if fluid.in_dygraph_mode():
Expand All @@ -1689,9 +1689,9 @@ def get_inout_spec(all_vars, return_name=False):

# 1. input check
prog_translator = ProgramTranslator()
if not prog_translator.enable_declarative:
if not prog_translator.enable_to_static:
raise RuntimeError(
"save_inference_model doesn't work when setting ProgramTranslator.enable=False."
"save_inference_model doesn't work when setting ProgramTranslator.enable to False."
)
if not isinstance(layer, Layer):
raise TypeError(
Expand Down Expand Up @@ -1902,8 +1902,8 @@ def _verify_spec(self, specs, is_input=False):
assert isinstance(spec, Input)
if spec.name is None:
raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}.".
format(i, spec))
"Requires Input[{}].name != None, but receive `None` with {}."
.format(i, spec))

return out_specs

Expand Down