From 0a6c80ee66365e45f6698f2acff44e5aaf270139 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Thu, 10 Jun 2021 08:00:02 +0000 Subject: [PATCH 1/2] Save all the information of 'ParamBase' in 'Layer'. --- python/paddle/fluid/framework.py | 12 ++++++ .../tests/unittests/test_paddle_save_load.py | 5 ++- python/paddle/framework/io.py | 43 ++++++++++++++++--- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index bffeaf2c6c973..fd4f2bbbf67c3 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -5535,6 +5535,18 @@ def _copy_to(self, device, blocking): core.varbase_copy(self, new_param, device, blocking) return new_param + def __reduce__(self): + value = self.numpy() + state = (self.name, self.persistable, self.stop_gradient) + return ParamBase, (self.shape, self.dtype), (self.__dict__, value, + state) + + def __setstate__(self, state): + self.__dict__.update(state[0]) + t = self.value().get_tensor() + t.set(state[1], _current_expected_place()) + self.name, self.persistable, self.stop_gradient = state[2] + __repr__ = __str__ diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index be2a6a653cc6f..e38a9e10fd16e 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -869,9 +869,10 @@ def test_save_load_layer(self): layer2 = LinearNet() layer1.eval() layer2.eval() + origin_layer = (layer1, layer2) origin = (layer1(inps), layer2(inps)) path = "test_save_load_layer_/layer.pdmodel" - paddle.save((layer1, layer2), path) + paddle.save(origin_layer, path) # static paddle.enable_static() @@ -884,6 +885,8 @@ def test_save_load_layer(self): loaded_result = [l(inps) for l in loaded_layer] for i in range(len(origin)): self.assertTrue((origin[i] - loaded_result[i]).abs().max() < 1e-10) + for k, v in origin_layer[i]._linear.weight.__dict__.items(): + self.assertTrue(v == loaded_layer[i]._linear.weight.__dict__[k]) if __name__ == '__main__': diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 1705db50d391a..8f4e130f1eea1 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -232,9 +232,13 @@ def _pickle_save(obj, f, protocol): raise ValueError("Expected 1<'protocol'<5, but received protocol={}". format(protocol)) - def reudce_varbase(self): + list_params = set() + + def reduce_varbase(self): data = self.numpy() name = self.name + if name in list_params: + return self.__reduce__() return (tuple, ((name, data), )) @@ -243,16 +247,43 @@ def reduce_LoDTensor(self): return (eval, ('data', {'data': data})) + def reduce_Layer(self): + is_param_or_layer = lambda v: isinstance(v, ParamBase) or isinstance(v, core.Layer) + + def collect_params(param_or_layer): + if isinstance(param_or_layer, ParamBase): + list_params.add(param_or_layer.name) + else: + # param_or_layer is layer + _parse_every_object(param_or_layer.__dict__, is_param_or_layer, + collect_params) + return param_or_layer + + _parse_every_object(self.__dict__, is_param_or_layer, collect_params) + return self.__reduce_ex__(protocol) + + dispatch_table_layer = dict() + + def create_layer_dispatch_table(layer): + dispatch_table_layer[layer.__class__] = reduce_Layer + return layer + + _parse_every_object(obj, lambda v: isinstance(v, core.Layer), + create_layer_dispatch_table) + def add_dispatch_table(): # This is not a good method, because the pickle module has been modified. - pickle.dispatch_table[core.VarBase] = reudce_varbase - pickle.dispatch_table[ParamBase] = reudce_varbase + pickle.dispatch_table[core.VarBase] = reduce_varbase + pickle.dispatch_table[ParamBase] = reduce_varbase pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor + pickle.dispatch_table.update(dispatch_table_layer) def pop_dispatch_table(): pickle.dispatch_table.pop(core.VarBase) pickle.dispatch_table.pop(core.LoDTensor) pickle.dispatch_table.pop(ParamBase) + for k in dispatch_table_layer: + pickle.dispatch_table.pop(k) # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' if sys.platform == 'darwin' and sys.version_info.major == 3: @@ -272,10 +303,10 @@ def pop_dispatch_table(): pickler = pickle.Pickler(f, protocol) pickler.dispatch_table = copyreg.dispatch_table.copy() - pickler.dispatch_table[core.VarBase] = reudce_varbase + pickler.dispatch_table[core.VarBase] = reduce_varbase pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor - pickler.dispatch_table[ParamBase] = reudce_varbase - + pickler.dispatch_table[ParamBase] = reduce_varbase + pickler.dispatch_table.update(dispatch_table_layer) pickler.dump(obj) From 99a95a047db403e3908e506bdb740aff8207b1dd Mon Sep 17 00:00:00 2001 From: WeiXin Date: Thu, 10 Jun 2021 12:37:04 +0000 Subject: [PATCH 2/2] edit unittest --- .../paddle/fluid/tests/unittests/test_paddle_save_load.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index e38a9e10fd16e..a92da5e287f10 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -874,13 +874,6 @@ def test_save_load_layer(self): path = "test_save_load_layer_/layer.pdmodel" paddle.save(origin_layer, path) - # static - paddle.enable_static() - with self.assertRaises(ValueError): - paddle.load(path) - # dygraph - paddle.disable_static() - loaded_layer = paddle.load(path) loaded_result = [l(inps) for l in loaded_layer] for i in range(len(origin)):