Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

'jit.save/load' support save/load function without parameters. #32430

Merged
merged 7 commits into from
Apr 27, 2021

Conversation

hbwx24
Copy link
Contributor

@hbwx24 hbwx24 commented Apr 21, 2021

PR types

Function optimization

PR changes

APIs

Describe

paddle.jit.save/load支持保存function,只保存了program,没有保存参数。因此没有.pdparam.info文件。
示例如下:

import paddle
from paddle.static import InputSpec


def example1():
    @paddle.jit.to_static
    def fun(inputs):
        return paddle.tanh(inputs)

    path = 'test_jit_save_load_function_1/func'
    inps = paddle.rand([3, 6])
    origin = fun(inps)

    paddle.jit.save(fun, path)
    load_func = paddle.jit.load(path)

    load_result = load_func(inps)
    print((load_result - origin).abs().max() < 1e-10)


def example2():
    @paddle.jit.to_static(input_spec=[
        InputSpec(shape=[None, 6], dtype='float32', name='x'),
    ])
    def fun(inputs):
        return paddle.nn.functional.relu(inputs)

    path = 'test_jit_save_load_function_2/func'
    inps = paddle.rand([3, 6])
    origin = fun(inps)

    paddle.jit.save(fun, path)
    load_func = paddle.jit.load(path)
    load_result = load_func(inps)
    print((load_result - origin).abs().max() < 1e-10)


def example3():
    def fun(inputs):
        return paddle.tanh(inputs)

    path = 'test_jit_save_load_function_3/func'
    inps = paddle.rand([3, 6])
    origin = fun(inps)

    paddle.jit.save(fun,
                    path,
                    input_spec=[
                        InputSpec(shape=[None, 6], dtype='float32', name='x'),
                    ])
    load_func = paddle.jit.load(path)

    load_result = load_func(inps)
    print((load_result - origin).abs().max() < 1e-10)


example1()
example2()
example3()

以下情况是不支持的:

def fun(inputs):
    l=paddle.nn.Linear(2, 3)
    y=l(inputs)
    return paddle.nn.functional.relu(y)

文档:红框内为本次修改的内容
http://10.136.157.23:8090/documentation/docs/en/api/paddle/fluid/dygraph/jit/save_en.html
image
image

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@@ -617,9 +617,9 @@ def train(layer, loader, loss_fn, opt):
raise RuntimeError(
"The paddle.jit.save doesn't work when setting ProgramTranslator.enable to False."
)
if not isinstance(layer, Layer):
if not (isinstance(layer, Layer) or isinstance(layer, StaticFunction)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里我们要支持的是不是function,而不是StaticFunction,当时输入是普通function的时候,像layer一样根据input_spec对其进行转换,现在这样写,会导致存function的情况下,input_spec这个参数就没什么意义了吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还有两点问题:

  1. 这个改动,文档也需要配合修改一下,参数类型和描述要涵盖function
  2. 之前实现的时候可能条件卡得比较严,对于layer来讲,如果layer没有参数,是不是也不需要存pdparams和pdparams.info

@hbwx24
Copy link
Contributor Author

hbwx24 commented Apr 25, 2021

还有两点问题:

  1. 这个改动,文档也需要配合修改一下,参数类型和描述要涵盖function
  2. 之前实现的时候可能条件卡得比较严,对于layer来讲,如果layer没有参数,是不是也不需要存pdparams和pdparams.info

Done,thx.

@@ -506,7 +507,7 @@ def _build_load_path_and_config(path, config):
@switch_to_static_graph
def save(layer, path, input_spec=None, **configs):
"""
Saves input Layer as ``paddle.jit.TranslatedLayer``
Saves input Layer or function as ``paddle.jit.TranslatedLayer``
Copy link
Member

@zhhsplendid zhhsplendid Apr 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since function cannot save parameters, you should give a Note or Warning and had better explain the reason for users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

# 5. save inference model
from paddle.fluid.io import save_inference_model

# construct new save_inference_model arguments
model_path = dirname
# NOTE(chenweihang): because prefix contains model and params filename,
# so we don't support set model_filename & params_filename
if 'forward' == attr_func:
if 'forward' == attr_func or not isinstance(layer, Layer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When not isinstance(layer, Layer), that means it is a function, should we also save function's name to model_filename and params_filename ?

zhhsplendid
zhhsplendid previously approved these changes Apr 26, 2021
Copy link
Member

@zhhsplendid zhhsplendid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

TCChenlong
TCChenlong previously approved these changes Apr 26, 2021
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hbwx24 hbwx24 dismissed stale reviews from TCChenlong and zhhsplendid via be7c13e April 26, 2021 11:35
Copy link
Contributor

@TCChenlong TCChenlong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@chenwhql chenwhql merged commit 0372f1d into PaddlePaddle:develop Apr 27, 2021
hbwx24 added a commit to hbwx24/Paddle that referenced this pull request Apr 27, 2021
…ePaddle#32430)

* jit.save/load support function.

* delete unnittest test_jit_load_model_incomplete.

* edit code according to CI

* Modify the documentation.

* add note to doc.
lanxianghit pushed a commit that referenced this pull request Apr 29, 2021
… (#32613)

* jit.save/load support function.

* delete unnittest test_jit_load_model_incomplete.

* edit code according to CI

* Modify the documentation.

* add note to doc.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants