Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set_state_dict not use state_dict hook #43407

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1327,14 +1327,16 @@ def _state_dict_impl(self,
destination=None,
include_sublayers=True,
structured_name_prefix="",
include_non_persistable_buffer=False):
include_non_persistable_buffer=False,
use_hook=True):
"""
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict

Parameters:
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
include_non_persistable_buffer(bool, optional): If true, include non persistable buffers of current layer and its sub-layers, it is used in pure fp16 and jit.save. Default: False
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True
"""

if destination is None:
Expand All @@ -1358,25 +1360,28 @@ def _state_dict_impl(self,
layer_item._state_dict_impl(
destination_temp, include_sublayers,
structured_name_prefix + layer_name + ".",
include_non_persistable_buffer))
include_non_persistable_buffer, use_hook))
destination = destination_temp
for state_dict_hook in self._state_dict_hooks.values():
hook_result = state_dict_hook(destination)
if hook_result is not None:
destination = hook_result
if use_hook:
for state_dict_hook in self._state_dict_hooks.values():
hook_result = state_dict_hook(destination)
if hook_result is not None:
destination = hook_result

return destination

def to_static_state_dict(self,
destination=None,
include_sublayers=True,
structured_name_prefix=""):
structured_name_prefix="",
use_hook=True):
'''
Get all parameters and buffers of current layer and its sub-layers. And set them into a dict

Parameters:
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True

Retruns:
dict: a dict contains all the parameters and persistable buffers.
Expand All @@ -1396,18 +1401,21 @@ def to_static_state_dict(self,
destination=destination,
include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix,
include_non_persistable_buffer=True)
include_non_persistable_buffer=True,
use_hook=use_hook)

def state_dict(self,
destination=None,
include_sublayers=True,
structured_name_prefix=""):
structured_name_prefix="",
use_hook=True):
'''
Get all parameters and persistable buffers of current layer and its sub-layers. And set them into a dict

Parameters:
destination(dict, optional) : If provide, all the parameters and persistable buffers will be set to this dict . Default: None
include_sublayers(bool, optional) : If true, also include the parameters and persistable buffers from sublayers. Default: True
use_hook(bool, optional) : If true, the operations contained in _state_dict_hooks will be appended to the destination. Default: True

Retruns:
dict: a dict contains all the parameters and persistable buffers.
Expand All @@ -1427,7 +1435,8 @@ def state_dict(self,
destination=destination,
include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix,
include_non_persistable_buffer=False)
include_non_persistable_buffer=False,
use_hook=use_hook)

@framework.deprecate_stat_dict
def set_state_dict(self, state_dict, use_structured_name=True):
Expand Down Expand Up @@ -1478,7 +1487,7 @@ def _check_match(key, param):
return param, state

matched_param_state = []
for key, param in self.state_dict().items():
for key, param in self.state_dict(use_hook=False).items():
key_name = key if use_structured_name else param.name
try:
match_res = _check_match(key_name, param)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,37 @@ def test_skip_BatchNorm_Layer_norm(self):
self.assertEqual((param.dtype == paddle.float32), True)


class TestStateDictHookForAMP(unittest.TestCase):

def test_state_dict_hook(self):

def func_isinstance():
paddle.seed(100)
model = paddle.nn.Linear(2, 4)
model = paddle.amp.decorate(models=model,
level='O2',
save_dtype='float32')
param_value_ori = {}
for param in model.parameters():
param_value_ori[param.name] = param.numpy()

state_dict = model.state_dict()
for key, value in state_dict.items():
state_dict[key] = value.cast("float16")
model.set_state_dict(state_dict)

param_value_now = {}
for param in model.parameters():
param_value_now[param.name] = param.numpy()

for key in param_value_ori.keys():
print(np.equal(param_value_ori[key], param_value_now[key]))

with _test_eager_guard():
func_isinstance()
func_isinstance()


class TestPureFp16SaveLoad(unittest.TestCase):

def test_save_dtype_exception(self):
Expand Down