Skip to content

Commit

Permalink
jit.save/load support method with parameters. (#34070)
Browse files Browse the repository at this point in the history
* jit.save/load support method with parameters.

* add unittest and warning

* polish warning message.
  • Loading branch information
hbwx24 committed Jul 14, 2021
1 parent 52c1a95 commit 1b37763
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 15 deletions.
55 changes: 40 additions & 15 deletions python/paddle/fluid/dygraph/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from paddle.fluid.dygraph.io import TranslatedLayer, INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
from paddle.fluid.dygraph.layers import Layer
from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.framework import Block, ParamBase, Program, Variable
from paddle.fluid.framework import Block, ParamBase, Program, Variable, Parameter
from paddle.fluid.framework import _current_expected_place, _dygraph_guard, _dygraph_tracer
from paddle.fluid.framework import dygraph_only, in_dygraph_mode
from paddle.fluid.wrapped_decorator import wrap_decorator
Expand Down Expand Up @@ -659,6 +659,10 @@ def fun(inputs):
raise TypeError(
"The input of paddle.jit.save should be 'Layer' or 'Function', but received input type is %s."
% type(layer))
elif inspect.isfunction(layer) or isinstance(layer, StaticFunction):
warnings.warn(
'What you save is a function, and `jit.save` will generate the name of the model file according to `path` you specify. When loading these files with `jit.load`, you get a `TranslatedLayer` whose inference result is the same as the inference result of the function you saved.'
)

# NOTE(chenweihang): If the input layer be wrapped by DataParallel,
# the args and kwargs of forward method will can't be parsed by
Expand Down Expand Up @@ -741,12 +745,38 @@ def fun(inputs):
else:
continue

else:
# When layer is a function
if isinstance(attr_func, StaticFunction):
concrete_program = attr_func.concrete_program_specify_input_spec(
inner_input_spec)
else:
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
inner_input_spec)
static_function = declarative(
attr_func, input_spec=inner_input_spec)
concrete_program = static_function.concrete_program

if static_function._class_instance is None:
warnings.warn(
'`jit.save` will only save the `Program`, not the parameters. If you have to save the parameters, please make sure that {} is a member function of `paddle.nn.Layer` and the saved parameters are in `state_dict`'.
format(layer))

dygraph_state_dict = None
if isinstance(inner_layer, Layer):
dygraph_state_dict = inner_layer.state_dict()
elif isinstance(attr_func, StaticFunction):
if attr_func._class_instance:
dygraph_state_dict = attr_func._class_instance.state_dict()

if dygraph_state_dict:
# NOTE(chenweihang): we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
# saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name
state_names_dict = dict()
for structured_name, var in six.iteritems(inner_layer.state_dict()):
for structured_name, var in six.iteritems(dygraph_state_dict):
state_names_dict[var.name] = structured_name

# 3. share parameters from Layer to scope & record var info
Expand All @@ -767,18 +797,6 @@ def fun(inputs):
if isinstance(param_or_buffer, ParamBase):
extra_info_dict['trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict
else:
# When layer is a function
if isinstance(attr_func, StaticFunction):
concrete_program = attr_func.concrete_program_specify_input_spec(
inner_input_spec)
else:
if inner_input_spec:
inner_input_spec = pack_sequence_as(input_spec,
inner_input_spec)
static_function = declarative(
attr_func, input_spec=inner_input_spec)
concrete_program = static_function.concrete_program

# 4. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ]
Expand Down Expand Up @@ -840,7 +858,14 @@ def fun(inputs):
# but we can save these information in `jit.save` without changing the original
# storage to improve user experience. So we save extra information into
# file `***.pdiparams.info`
if isinstance(layer, Layer) and extra_var_info:

# "layer" can only be Layer or function or StaticFunction.

contain_parameter = False
for var in concrete_program.main_program.list_vars():
contain_parameter |= isinstance(var, Parameter)

if (isinstance(layer, Layer) or contain_parameter) and extra_var_info:
with scope_guard(scope):
extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX
with open(extra_var_info_path, 'wb') as f:
Expand Down
93 changes: 93 additions & 0 deletions python/paddle/fluid/tests/unittests/test_jit_save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -1227,6 +1227,99 @@ def fun(inputs):
self.assertTrue((load_result - origin).abs().max() < 1e-10)


class TestJitSaveLoadFunctionWithParamCase1(unittest.TestCase):
def setUp(self):
paddle.disable_static()

def test_jit_save_load_function(self):
class LinearNet(paddle.nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = paddle.nn.Linear(5, 6)

def forward(self, x):
return paddle.tanh(x)

def anothor_forward(self, x):
return self._linear(x)

layer = LinearNet()

inps = paddle.rand([3, 5])
origin = layer.anothor_forward(inps)

func = paddle.jit.to_static(
layer.anothor_forward, [paddle.static.InputSpec(shape=[-1, 5])])
path = 'test_jit_save_load_function_with_params_case1/func'
paddle.jit.save(func, path)
load_func = paddle.jit.load(path)

load_result = load_func(inps)
self.assertTrue(np.array_equal(load_result.numpy(), origin.numpy()))


class TestJitSaveLoadFunctionWithParamCase2(unittest.TestCase):
def setUp(self):
paddle.disable_static()

def test_jit_save_load_function(self):
class LinearNet(paddle.nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = paddle.nn.Linear(5, 6)

def forward(self, x):
return paddle.tanh(x)

@paddle.jit.to_static(input_spec=[InputSpec(shape=[-1, 5])])
def anothor_forward(self, x):
return self._linear(x)

layer = LinearNet()

inps = paddle.rand([3, 5])

path = 'test_jit_save_load_function_with_params_case2/func'
paddle.jit.save(layer.anothor_forward, path)
origin_result = layer.anothor_forward(inps)
load_func = paddle.jit.load(path)

load_result = load_func(inps)

self.assertTrue(
np.array_equal(origin_result.numpy(), load_result.numpy()))


class TestJitSaveLoadFunctionWithParamCase3(unittest.TestCase):
def setUp(self):
paddle.disable_static()

def test_jit_save_load_function(self):
class LinearNet(paddle.nn.Layer):
def __init__(self):
super(LinearNet, self).__init__()
self._linear = paddle.nn.Linear(5, 6)

def forward(self, x):
return paddle.tanh(x)

@paddle.jit.to_static
def anothor_forward(self, x):
return self._linear(x)

layer = LinearNet()

inps = paddle.rand([3, 5])
origin = layer.anothor_forward(inps)

path = 'test_jit_save_load_function_with_params_case3/func'
paddle.jit.save(layer.anothor_forward, path)
load_func = paddle.jit.load(path)

load_result = load_func(inps)
self.assertTrue(np.array_equal(load_result.numpy(), origin.numpy()))


class TestJitSaveLoadDataParallel(unittest.TestCase):
def verify_inference_correctness(self, layer, path):
layer.eval()
Expand Down

0 comments on commit 1b37763

Please sign in to comment.