Skip to content

Commit

Permalink
add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbo9674 committed Jun 13, 2022
1 parent 30f77a0 commit fe21e76
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/paddle/fluid/dygraph/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,7 @@ def to_static_state_dict(self,
include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix,
include_non_persistable_buffer=True,
use_hook=True)
use_hook=use_hook)

def state_dict(self,
destination=None,
Expand Down Expand Up @@ -1433,7 +1433,7 @@ def state_dict(self,
include_sublayers=include_sublayers,
structured_name_prefix=structured_name_prefix,
include_non_persistable_buffer=False,
use_hook=True)
use_hook=use_hook)

@framework.deprecate_stat_dict
def set_state_dict(self, state_dict, use_structured_name=True):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,37 @@ def test_skip_BatchNorm_Layer_norm(self):
self.assertEqual((param.dtype == paddle.float32), True)


class TestStateDictHookForAMP(unittest.TestCase):

def test_state_dict_hook(self):

def func_isinstance():
paddle.seed(100)
model = paddle.nn.Linear(2, 4)
model = paddle.amp.decorate(models=model,
level='O2',
save_dtype='float32')
param_value_ori = {}
for param in model.parameters():
param_value_ori[param.name] = param.numpy()

state_dict = model.state_dict()
for key, value in state_dict.items():
state_dict[key] = value.cast("float16")
model.set_state_dict(state_dict)

param_value_now = {}
for param in model.parameters():
param_value_now[param.name] = param.numpy()

for key in param_value_ori.keys():
print(np.equal(param_value_ori[key], param_value_now[key]))

with _test_eager_guard():
func_isinstance()
func_isinstance()


class TestPureFp16SaveLoad(unittest.TestCase):

def test_save_dtype_exception(self):
Expand Down

0 comments on commit fe21e76

Please sign in to comment.