Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug that ParamBase lose attributes when paddle.save(Layer) #33500

Merged
merged 2 commits into from
Jun 15, 2021

Conversation

hbwx24
Copy link
Contributor

@hbwx24 hbwx24 commented Jun 10, 2021

PR types

Bug fixes

PR changes

APIs

Describe

修复paddle.save(Layer)时参数属性丢失的问题:通过实现__reduce__的方式保存layer中的ParamBase的所有信息。

pickle原理简介:

def __reduce__(self):
    value = self.numpy()
    state = (self.name, self.persistable, self.stop_gradient)
    
    # (self.__dict__, value,state) 为__setstate__的输入,因此这个参数决定了保存ParamBase的信息量。
    return ParamBase, (self.shape, self.dtype), (self.__dict__, value,state)

def __setstate__(self, state):
    self.__dict__.update(state[0])
    t = self.value().get_tensor()
    t.set(state[1], _current_expected_place())
    self.name, self.persistable, self.stop_gradient = state[2]
  • pickle.dump:

    • 调用__reduce__(),把reduce的返回值(构造函数,构造函数的输入,setstate 的输入)保存下来。其中ParamBase保存的是module路径字符串("paddle.fluid.framework")和对象名("ParamBase")。
  • pickle.load:

    • 先调用ParamBase(self.shape, self.dtype)构建ParamBase对象,再调用ParamBase.__setstate__设置这个对象的属性。

pickle详情参考

兼容性问题:

如果 Layer、ParamBase、__setstate__改变将对加载模型造成影响。

  • 对于Layer和ParamBase。load时,pickle调用paddle内部的ParamBase类、Layer类来恢复出Layer、ParamBase。Linear、Conv等都继承自Layer。因此如果保存了Linear、Conv2D等,在load时都需要Linear类、Conv2D类才能恢复出保存的对象。而pickle是用过记录模块位置来找Linear类、Conv2D类的(例如:paddle.nn.Linear),一旦这个类的位置变了,或者构造函数变了都将导致pickle.load无法恢复保存的模型。由于paddle中有大量的Layer,如果这些类的位置改变或者构造函数改变,将无法恢复加载的模型。
  • 由于paddle2.1 ParamBase 没有__setstate__函数,所以无发加载2.1.1保存的layer。

结论如下:

  • paddle2.1 无法加载paddle2.1.1 保存的Layer对象。paddle2.1.1可以加载paddle2.1保存的Layer对象。
  • 对于非Layer对象,如state_dict等,2.1与2.1.1完全兼容,2.1可以加载2.1.1保存的非Layer对象。可以脱离框架加载pickle格式的模型文件。

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -5535,6 +5535,18 @@ def _copy_to(self, device, blocking):
core.varbase_copy(self, new_param, device, blocking)
return new_param

def __reduce__(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果用户直接save(ParamBase),会和之前格式不同吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用户直接save(ParamBase),和之前格式完全相同。

@chenwhql
Copy link
Contributor

这个主要是为了修复paddle.save(Layer)时参数属性丢失的问题吧,可以补充下

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hbwx24 hbwx24 merged commit 28521e0 into PaddlePaddle:develop Jun 15, 2021
hbwx24 added a commit to hbwx24/Paddle that referenced this pull request Jun 15, 2021
* Save all the information of 'ParamBase' in 'Layer'.

* edit unittest
@hbwx24 hbwx24 changed the title Save all the information of 'ParamBase' in 'Layer'. Fix the bug that ParamBase lose attributes when paddle.save(Layer) Jun 15, 2021
@jzhang533
Copy link
Contributor

考虑到这是一个

  • 临时性的修复:目前框架在重构,VarBase和ParamBase可能在重构之后都会消失。(by @chenwhql

  • 不完全的修复:
    不支持paddle.save(paddle.nn.Linear(3,4).weight, '/tmp/w') 的保存后再加载能够恢复属性。完整示例

因此会先revert此PR,给需要改功能的用户提供替代用法。

hbwx24 added a commit to hbwx24/Paddle that referenced this pull request Jul 9, 2021
* Save all the information of 'ParamBase' in 'Layer'.

* edit unittest
lanxianghit pushed a commit that referenced this pull request Jul 12, 2021
* Save all the information of 'ParamBase' in 'Layer'. (#33500)

* Save all the information of 'ParamBase' in 'Layer'.

* edit unittest

* delete the function of saving layer object. (#33697)

* delete the function of saving layer object.

* edit doc of paddle.save/load and polish error message
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants