Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic SParsity Helper #33132

Merged
merged 10 commits into from
Jun 10, 2021

Conversation

mingxu1067
Copy link
Collaborator

@mingxu1067 mingxu1067 commented May 26, 2021

PR types

New features

PR changes

APIs

Describe

Automatic SParsity functions

Added model pruning functions.
Added optimizer decorator and sparse minimization functions.

The main designing concept behind this ASP module is easy usage, which means that only needs few lines inserted in original training script to enable ASP, ideally is two lines. There are only two things, each map to one line code, to do for enabling ASP training:

  1. Add masking operations to ensure parameter's sparsity.
  2. Prune your well-trained models to fit with 2:4 FP16 sparsity or 1:2 FP32 sparsity.

This ASP module also could coordinate with AMP module without extra-actions needed.
Here is a full example to enable ASP training:

import paddle.fluid as fluid
from paddle.fluid.contrib import sparsity

main_program = fluid.Program()
startup_program = fluid.Program()

place = fluid.CUDAPlace(0)

with fluid.program_guard(main_program, startup_program):
    input_data = fluid.layers.data(name='data', shape=[None, 128])
    label = fluid.layers.data(name='label', shape=[None, 10])
    hidden = fluid.layers.fc(input=input_data, num_flatten_dims=-1, size=32, act=None)
    prob = fluid.layers.fc(input=hidden, num_flatten_dims=-1, size=10, act=None)
    loss = fluid.layers.mean(fluid.layers.square_error_cost(prob, label))
    
    optimizer = fluid.optimizer.SGD(learning_rate=0.1)
    optimizer = fluid.contrib.mixed_precision.decorator.decorate(optimizer )
    # Calling sparsity.decorate() to wrap minimize() in optimizer, which 
    # will insert necessary masking operations for ASP workflow.
    optimizer = sparsity.decorate(optimizer )
    optimizer.minimize(optimizer, loss, main_program, startup_program)

exe = fluid.Executor(place)
exe.run(startup_program)

# Must call `exe.run(startup_program)` first before calling `sparsity.prune_model`
sparsity.prune_model(place, main_program, func_name=sparsity.MaskAlgo.MASK_2D_BEST)

# Start training with ASP workflow.
for epoch in range(total_epochs):
    exe.run(main_program)

If there are some layers, which would like to keep dense, this module also provide set_excluded_layers and reset_excluded_layers functions to exclude layers out from ASP workflow. Here is a full example:

import paddle.fluid as fluid
from paddle.fluid.contrib import sparsity

main_program = fluid.Program()
startup_program = fluid.Program()

place = fluid.CUDAPlace(0)



with fluid.program_guard(main_program, startup_program):
    input_data = fluid.layers.data(name='data', shape=[None, 128])
    label = fluid.layers.data(name='label', shape=[None, 10])
    hidden = fluid.layers.fc(input=input_data, num_flatten_dims=-1, size=32, act=None, name="need_sparse")
    hidden = fluid.layers.fc(input=hidden, num_flatten_dims=-1, size=32, act=None, name="need_dense")
    prob = fluid.layers.fc(input=hidden, num_flatten_dims=-1, size=10, act=None)
    loss = fluid.layers.mean(fluid.layers.square_error_cost(prob, label))

    # Setup exluded layers out from ASP workflow.
    # Please note, excluded_layers must be set before calling `optimizer.minimize()`.
    sparsity.set_excluded_layers(main_program, ["need_dense"])

    optimizer = fluid.optimizer.SGD(learning_rate=0.1)
    optimizer = fluid.contrib.mixed_precision.decorator.decorate(optimizer )
    # Calling sparsity.decorate() to wrap minimize() in optimizer, which 
    # will insert necessary masking operations for ASP workflow.
    optimizer = sparsity.decorate(optimizer )
    optimizer.minimize(optimizer, loss, main_program, startup_program)

exe = fluid.Executor(place)
exe.run(startup_program)

# Must call `exe.run(startup_program)` first before calling `sparsity.prune_model`
sparsity.prune_model(place, main_program, func_name=sparsity.MaskAlgo.MASK_2D_BEST)

# Start training with ASP workflow.
for epoch in range(total_epochs):
    exe.run(main_program)

1. Added sparse mask generating functions, including 1D and 2Ds.
2. Added sparse pattern checking functions, including 1D and 2Ds.
1. Added model pruning functions.
2. Added optimizer decorator and sparse minimization functions.
1. Added model pruning functions.
2. Added optimizer decorator and sparse minimization functions.
1. Used isinstance() to check object's type.
2. Added change some functions as private.
3. replace "import *" with individual importing.
@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.


class ProgramASPInfo(object):
def __init__(self):
self.__mask_vars = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

能否加些comment解释下,这些成员变量存放的都是什么?

2. pruning well-trained models into 2:4 sparse pattern on FP16 or 1:2 sparse pattern on FP32 for fine-tuning.
"""

MASKE_APPENDDED_NAME = '_asp_mask'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下,这里是MASKE_

@@ -0,0 +1,89 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparsity的单测数量我看挺多的啊,要不建一个sparsity的子目录吧。

return param_name + ASPHelper.MASKE_APPENDDED_NAME

@staticmethod
def get_vars(main_program):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函数名指代含义要更具体点,如get_excluded_parameters。另外,为什么这个函数也要定义成staticmethod,整个类没有__init__函数,不会实例化吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

函式修正後的命名為get_not_ASP_relevant_vars.
這個函式設計成Static的方式是因為 不會依賴到ASPHelper內部成員,僅需要提供Program的物件就可以根據物件去找到所有訊息,因此設計時候定調為靜態函式

fc = fluid.layers.fc(input=input_data, num_flatten_dims=-1, size=32, act=None)

for param in main_program.global_block().all_parameters():
ASPHelper.is_supported_layer(main_program, param.name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉用法上有点不太好理解,函数的参数是一个param name,函数名确实判断是否是支持的layer。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

這個部分在函式命名上考慮到未來會支援動態圖,動態圖得時候,每個Layer會有自己的Class type可以用來判別,不會依賴參數名稱,但是在靜態圖上目前透過參數名稱來判別,所以才出現這個現象。

# fc_0.w_0 -> True
# fc_0.b_0 -> False
"""
if ASPHelper.MASKE_APPENDDED_NAME in param_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数的逻辑还请解释一下。

Copy link
Collaborator Author

@mingxu1067 mingxu1067 May 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判斷邏輯有三大部分

  1. 所以包含MASKE_APPENDDED_NAME名稱的Variable就判定為ASP所添加的Mask Variable -> 不支援
  2. 如果參數名稱之前被設定為excluded_layer的話也不支援
  3. 參數名稱符合 fc[0-9].w0, linear[0-9].w0, conv[0-9]*.w0 才支援 (這部分比較不那麼泛用 之前討論過會議上討論過 可以在依據PaddlePaddle未來命名規則等方式做進一步調整)

return asp_info.masks.copy()


class ASPOptimizerWrapper(object):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以直接叫ASPOptimizer


class ASPOptimizerWrapper(object):
r"""
ASPOptimizerWrapper is a wrapper to decorate `minimize` of ASPHelper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这个解释不太充分。

list: operators from :attr:`optimizer`.minimize(:attr:`loss`).
list: pairs of parameters and their gradients.
"""
return ASPHelper.minimize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉是不是采用和AMP相同的模式,将一些函数直接实现为ASPOptimizer的成员函数更好一些?请@wzzju 也来review下吧。

Copy link
Collaborator Author

@mingxu1067 mingxu1067 May 31, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

這部分確實參考AMP的設計,有些函式並沒有直接的設計在ASPOptimizerWrapper裏頭有幾個原因

  1. 為了讓使用者從Pytorch轉換到Paddle時候能更加快速上手,設計時參考APEX對於Pytorch的設計做了一個全域的Helper
  2. ASPOptimizerWrapper基本上在minimize and append_gradient等函式基本上不做修改,而是在靜態圖的最後對支援的layer parameters做masking,所以跟Optimizer的功能上還是有些差異,因此不直接當作Optimizer的功能

所以在命名上才會多加上Wrapper來提醒使用者,這個不全然是一個optimzer,而僅針對minimize的函示進行包裝,另外這個也不對外開放,使用者本質上也只操作ASPHelper而已

@@ -0,0 +1,520 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

大致看了下代码结构,有如下几点建议:
1)整理一下需要对外公布的接口,不需要对外暴露的接口请使用_标识。
2)建议给一个完整的示例,显示整个操作流程。
3)是不是只需暴露decorate、prune_model这两个接口给用户使用?如果是的话,建议不要将它们放在ASPHelper类中,而是作为全局函数暴露出来。感觉没有必要将整个ASPHelper作为Public API。

@paddle-bot-old
Copy link

paddle-bot-old bot commented Jun 3, 2021

Sorry to inform you that 691c662's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

1. Hid ASPHelper and only explore set/reset_support_layers, decorate
and prune_model.
2. Renamed ASPOptimizerWrapper to be OptimizerWithSparsityGuarantee to
be namimg consistent with AMP.
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

]


def set_excluded_layers(main_program, param_names):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉这2个接口有点类似于AutoMixedPrecisionLists了,可以设置custom_white_listcustom_black_listcustom_black_varnames,其中custom_black_varnamescustom_black_list是指定op type,custom_black_varnames是指定var name。后续设计上或许可以参考AMP的接口再优化下。

optimizer = fluid.optimizer.SGD(learning_rate=0.1)
optimizer = sparsity.decorate(optimizer) # Need to be called before `fleet.distributed_optimizer`
optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(loss, startup_program)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单机的用法和分布式的用法这么写在一个示例里面不太好,或者加个if is_ditributed来控制。

exe.run(startup_program)

# Must call `exe.run(startup_program)` first before calling `sparsity.prune_model`
sparsity.prune_model(place, main_program, func_name=sparsity.MaskAlgo.MASK_2D_BEST)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可以把PR描述里面完整的示例放这里。

@Xreki Xreki merged commit 8061442 into PaddlePaddle:develop Jun 10, 2021
@mingxu1067 mingxu1067 deleted the automatic_sparsity_functions branch October 15, 2021 06:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants