diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py index be2a6a653cc6f..af8718a2121b1 100644 --- a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -869,21 +869,11 @@ def test_save_load_layer(self): layer2 = LinearNet() layer1.eval() layer2.eval() + origin_layer = (layer1, layer2) origin = (layer1(inps), layer2(inps)) path = "test_save_load_layer_/layer.pdmodel" - paddle.save((layer1, layer2), path) - - # static - paddle.enable_static() with self.assertRaises(ValueError): - paddle.load(path) - # dygraph - paddle.disable_static() - - loaded_layer = paddle.load(path) - loaded_result = [l(inps) for l in loaded_layer] - for i in range(len(origin)): - self.assertTrue((origin[i] - loaded_result[i]).abs().max() < 1e-10) + paddle.save(origin_layer, path) if __name__ == '__main__': diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py index 1705db50d391a..01145e8563cf3 100644 --- a/python/paddle/framework/io.py +++ b/python/paddle/framework/io.py @@ -232,7 +232,7 @@ def _pickle_save(obj, f, protocol): raise ValueError("Expected 1<'protocol'<5, but received protocol={}". format(protocol)) - def reudce_varbase(self): + def reduce_varbase(self): data = self.numpy() name = self.name @@ -243,16 +243,32 @@ def reduce_LoDTensor(self): return (eval, ('data', {'data': data})) + def reduce_Layer(self): + raise ValueError( + "paddle do not support saving `paddle.nn.Layer` object.") + + dispatch_table_layer = dict() + + def create_layer_dispatch_table(layer): + dispatch_table_layer[layer.__class__] = reduce_Layer + return layer + + _parse_every_object(obj, lambda v: isinstance(v, core.Layer), + create_layer_dispatch_table) + def add_dispatch_table(): # This is not a good method, because the pickle module has been modified. - pickle.dispatch_table[core.VarBase] = reudce_varbase - pickle.dispatch_table[ParamBase] = reudce_varbase + pickle.dispatch_table[core.VarBase] = reduce_varbase + pickle.dispatch_table[ParamBase] = reduce_varbase pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor + pickle.dispatch_table.update(dispatch_table_layer) def pop_dispatch_table(): pickle.dispatch_table.pop(core.VarBase) pickle.dispatch_table.pop(core.LoDTensor) pickle.dispatch_table.pop(ParamBase) + for k in dispatch_table_layer: + pickle.dispatch_table.pop(k) # When value of dict is lager than 4GB ,there is a Bug on 'MAC python3' if sys.platform == 'darwin' and sys.version_info.major == 3: @@ -272,10 +288,10 @@ def pop_dispatch_table(): pickler = pickle.Pickler(f, protocol) pickler.dispatch_table = copyreg.dispatch_table.copy() - pickler.dispatch_table[core.VarBase] = reudce_varbase + pickler.dispatch_table[core.VarBase] = reduce_varbase pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor - pickler.dispatch_table[ParamBase] = reudce_varbase - + pickler.dispatch_table[ParamBase] = reduce_varbase + pickler.dispatch_table.update(dispatch_table_layer) pickler.dump(obj) @@ -496,7 +512,7 @@ def save(obj, path, protocol=4, **configs): Save an object to the specified path. .. note:: - Now supports saving ``state_dict`` of Layer/Optimizer, Layer, Tensor and nested structure containing Tensor, Program. + Now supports saving ``state_dict`` of Layer/Optimizer, Tensor and nested structure containing Tensor, Program. .. note:: Different from ``paddle.jit.save``, since the save result of ``paddle.save`` is a single file, @@ -690,7 +706,7 @@ def load(path, **configs): Load an object can be used in paddle from specified path. .. note:: - Now supports loading ``state_dict`` of Layer/Optimizer, Layer, Tensor and nested structure containing Tensor, Program. + Now supports loading ``state_dict`` of Layer/Optimizer, Tensor and nested structure containing Tensor, Program. .. note:: In order to use the model parameters saved by paddle more efficiently,