From 02144bcaa61f66aaef3dbacfee4c9b6ba3a53a0d Mon Sep 17 00:00:00 2001 From: whs Date: Tue, 4 Feb 2020 13:28:13 +0800 Subject: [PATCH] Add one-shot NAS API and mnasnet based search space. (#17) --- demo/one_shot/train.py | 207 ++++++++++++++ docs/docs/api/one_shot_api.md | 154 +++++++++++ docs/docs/tutorials/one_shot_nas_demo.md | 102 +++++++ docs/mkdocs.yml | 3 + paddleslim/nas/__init__.py | 10 +- paddleslim/nas/one_shot/__init__.py | 22 ++ paddleslim/nas/one_shot/one_shot_nas.py | 114 ++++++++ paddleslim/nas/one_shot/super_mnasnet.py | 257 ++++++++++++++++++ paddleslim/nas/search_space/__init__.py | 1 - .../nas/search_space/search_space_base.py | 5 + 10 files changed, 871 insertions(+), 4 deletions(-) create mode 100644 demo/one_shot/train.py create mode 100644 docs/docs/api/one_shot_api.md create mode 100644 docs/docs/tutorials/one_shot_nas_demo.md create mode 100644 paddleslim/nas/one_shot/__init__.py create mode 100644 paddleslim/nas/one_shot/one_shot_nas.py create mode 100644 paddleslim/nas/one_shot/super_mnasnet.py diff --git a/demo/one_shot/train.py b/demo/one_shot/train.py new file mode 100644 index 0000000000000..e9904e45143f8 --- /dev/null +++ b/demo/one_shot/train.py @@ -0,0 +1,207 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import argparse +import ast +import numpy as np +from PIL import Image +import os +import paddle +import paddle.fluid as fluid +from paddle.fluid.optimizer import AdamOptimizer +from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear +from paddle.fluid.dygraph.base import to_variable + +from paddleslim.nas.one_shot import SuperMnasnet +from paddleslim.nas.one_shot import OneShotSearch + + +def parse_args(): + parser = argparse.ArgumentParser("Training for Mnist.") + parser.add_argument( + "--use_data_parallel", + type=ast.literal_eval, + default=False, + help="The flag indicating whether to use data parallel mode to train the model." + ) + parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch") + parser.add_argument("--ce", action="store_true", help="run ce") + args = parser.parse_args() + return args + + +class SimpleImgConv(fluid.dygraph.Layer): + def __init__(self, + num_channels, + num_filters, + filter_size, + conv_stride=1, + conv_padding=0, + conv_dilation=1, + conv_groups=1, + act=None, + use_cudnn=False, + param_attr=None, + bias_attr=None): + super(SimpleImgConv, self).__init__() + + self._conv2d = Conv2D( + num_channels=num_channels, + num_filters=num_filters, + filter_size=filter_size, + stride=conv_stride, + padding=conv_padding, + dilation=conv_dilation, + groups=conv_groups, + param_attr=None, + bias_attr=None, + act=act, + use_cudnn=use_cudnn) + + def forward(self, inputs): + x = self._conv2d(inputs) + return x + + +class MNIST(fluid.dygraph.Layer): + def __init__(self): + super(MNIST, self).__init__() + + self._simple_img_conv_pool_1 = SimpleImgConv(1, 20, 2, act="relu") + self.arch = SuperMnasnet( + name_scope="super_net", input_channels=20, out_channels=20) + self._simple_img_conv_pool_2 = SimpleImgConv(20, 50, 2, act="relu") + + self.pool_2_shape = 50 * 13 * 13 + SIZE = 10 + scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5 + self._fc = Linear( + self.pool_2_shape, + 10, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=scale)), + act="softmax") + + def forward(self, inputs, label=None, tokens=None): + x = self._simple_img_conv_pool_1(inputs) + + x = self.arch(x, tokens=tokens) # addddddd + x = self._simple_img_conv_pool_2(x) + x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape]) + x = self._fc(x) + if label is not None: + acc = fluid.layers.accuracy(input=x, label=label) + return x, acc + else: + return x + + +def test_mnist(model, tokens=None): + acc_set = [] + avg_loss_set = [] + batch_size = 64 + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size, drop_last=True) + for batch_id, data in enumerate(test_reader()): + dy_x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(batch_size, 1) + + img = to_variable(dy_x_data) + label = to_variable(y_data) + label.stop_gradient = True + prediction, acc = model.forward(img, label, tokens=tokens) + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + acc_set.append(float(acc.numpy())) + avg_loss_set.append(float(avg_loss.numpy())) + if batch_id % 100 == 0: + print("Test - batch_id: {}".format(batch_id)) + # get test acc and loss + acc_val_mean = np.array(acc_set).mean() + avg_loss_val_mean = np.array(avg_loss_set).mean() + + return acc_val_mean + + +def train_mnist(args, model, tokens=None): + epoch_num = args.epoch + BATCH_SIZE = 64 + + adam = AdamOptimizer( + learning_rate=0.001, parameter_list=model.parameters()) + + train_reader = paddle.batch( + paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True) + if args.use_data_parallel: + train_reader = fluid.contrib.reader.distributed_batch_reader( + train_reader) + + for epoch in range(epoch_num): + for batch_id, data in enumerate(train_reader()): + dy_x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array( + [x[1] for x in data]).astype('int64').reshape(-1, 1) + + img = to_variable(dy_x_data) + label = to_variable(y_data) + label.stop_gradient = True + + cost, acc = model.forward(img, label, tokens=tokens) + + loss = fluid.layers.cross_entropy(cost, label) + avg_loss = fluid.layers.mean(loss) + + if args.use_data_parallel: + avg_loss = model.scale_loss(avg_loss) + avg_loss.backward() + model.apply_collective_grads() + else: + avg_loss.backward() + + adam.minimize(avg_loss) + # save checkpoint + model.clear_gradients() + if batch_id % 1 == 0: + print("Loss at epoch {} step {}: {:}".format(epoch, batch_id, + avg_loss.numpy())) + + model.eval() + test_acc = test_mnist(model, tokens=tokens) + model.train() + print("Loss at epoch {} , acc is: {}".format(epoch, test_acc)) + + save_parameters = (not args.use_data_parallel) or ( + args.use_data_parallel and + fluid.dygraph.parallel.Env().local_rank == 0) + if save_parameters: + fluid.save_dygraph(model.state_dict(), "save_temp") + print("checkpoint saved") + + +if __name__ == '__main__': + args = parse_args() + place = fluid.CPUPlace() + with fluid.dygraph.guard(place): + model = MNIST() + # step 1: training super net + #train_mnist(args, model) + # step 2: search + best_tokens = OneShotSearch(model, test_mnist) + # step 3: final training + # train_mnist(args, model, best_tokens) diff --git a/docs/docs/api/one_shot_api.md b/docs/docs/api/one_shot_api.md new file mode 100644 index 0000000000000..133d6321ab640 --- /dev/null +++ b/docs/docs/api/one_shot_api.md @@ -0,0 +1,154 @@ + +## OneShotSearch +paddleslim.nas.one_shot.OneShotSearch(model, eval_func, strategy='sa', search_steps=100)[代码]() + +: 从超级网络中搜索出一个最佳的子网络。 + +**参数:** + +- **model(fluid.dygraph.layer):** 通过在`OneShotSuperNet`前后添加若该模块构建的动态图模块。因为`OneShotSuperNet`是一个超网络,所以`model`也是一个超网络。换句话说,在`model`模块的子模块中,至少有一个是`OneShotSuperNet`的实例。该方法从`model`超网络中搜索得到一个最佳的子网络。超网络`model`需要先被训练,具体细节请参考[OneShotSuperNet]()。 + +- **eval_func:** 用于评估子网络性能的回调函数。该回调函数需要接受`model`为参数,并调用`model`的`forward`方法进行性能评估。 + +- **strategy(str):** 搜索策略的名称。默认为'sa', 当前仅支持'sa'. + +- **search_steps(int):** 搜索轮次数。默认为100。 + +**返回:** + +- **best_tokens:** 表示最佳子网络的编码信息(tokens)。 + +**示例代码:** + +请参考[one-shot NAS示例]() + + +## OneShotSuperNet + +用于`OneShot`搜索策略的超级网络的基类,所有超级网络的实现要继承该类。 + +paddleslim.nas.one_shot.OneShotSuperNet(name_scope) + +: 构造方法。 + +**参数:** + +- **name_scope:(str) **超级网络的命名空间。 + +**返回:** + +- **super_net:** 一个`OneShotSuperNet`实例。 + +init_tokens() + +: 获得当前超级网络的初始化子网络的编码,主要用于搜索。 + +**返回:** + +- **tokens(list):** 一个子网络的编码。 + +range_table() + +: 超级网络中各个子网络由一组整型数字编码表示,该方法返回编码每个位置的取值范围。 + +**返回:** + +- **range_table(tuple):** 子网络编码每一位的取值范围。`range_table`格式为`(min_values, max_values)`,其中,`min_values`为一个整型数组,表示每个编码位置可选取的最小值;`max_values`表示每个编码位置可选取的最大值。 + +_forward_impl(input, tokens) + +: 前向计算函数。`OneShotSuperNet`的子类需要实现该函数。 + +**参数:** + +- **input(Variable):** 超级网络的输入。 + +- **tokens(list):** 执行前向计算所用的子网络的编码。默认为`None`,即随机选取一个子网络执行前向。 + +**返回:** + +- **output(Variable):** 前向计算的输出 + +forward(self, input, tokens=None) + +: 执行前向计算。 + +**参数:** + +- **input(Variable):** 超级网络的输入。 + +- **tokens(list):** 执行前向计算所用的子网络的编码。默认为`None`,即随机选取一个子网络执行前向。 + +**返回:** + +- **output(Variable):** 前向计算的输出 + + +_random_tokens() + +: 随机选取一个子网络,并返回其编码。 + +**返回:** + +- **tokens(list):** 一个子网络的编码。 + +## SuperMnasnet + +在[Mnasnet](https://arxiv.org/abs/1807.11626)基础上修改得到的超级网络, 该类继承自`OneShotSuperNet`. + +paddleslim.nas.one_shot.SuperMnasnet(name_scope, input_channels=3, out_channels=1280, repeat_times=[6, 6, 6, 6, 6, 6], stride=[1, 1, 1, 1, 2, 1], channels=[16, 24, 40, 80, 96, 192, 320], use_auxhead=False) + +: 构造函数。 + +**参数:** + +- **name_scope(str):** 命名空间。 + +- **input_channels(str):** 当前超级网络的输入的特征图的通道数量。 + +- **out_channels(str):** 当前超级网络的输出的特征图的通道数量。 + +- **repeat_times(list):** 每种`block`重复的次数。 + +- **stride(list):** 一种`block`重复堆叠成`repeat_block`,`stride`表示每个`repeat_block`的下采样比例。 + +- **channels(list):** channels[i]和channels[i+1]分别表示第i个`repeat_block`的输入特征图的通道数和输出特征图的通道数。 + +- **use_auxhead(bool):** 是否使用辅助特征图。如果设置为`True`,则`SuperMnasnet`除了返回输出特征图,还还返回辅助特征图。默认为False. + +**返回:** + +- **instance(SuperMnasnet):** 一个`SuperMnasnet`实例 + +**示例:** +``` +import paddle +import paddle.fluid as fluid +class MNIST(fluid.dygraph.Layer): + def __init__(self): + super(MNIST, self).__init__() + self.arch = SuperMnasnet( + name_scope="super_net", input_channels=20, out_channels=20) + self.pool_2_shape = 50 * 13 * 13 + SIZE = 10 + scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5 + self._fc = Linear( + self.pool_2_shape, + 10, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=scale)), + act="softmax") + + def forward(self, inputs, label=None, tokens=None): + + x = self.arch(inputs, tokens=tokens) + x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape]) + x = self._fc(x) + if label is not None: + acc = fluid.layers.accuracy(input=x, label=label) + return x, acc + else: + return x + +``` diff --git a/docs/docs/tutorials/one_shot_nas_demo.md b/docs/docs/tutorials/one_shot_nas_demo.md new file mode 100644 index 0000000000000..252b924f0ce52 --- /dev/null +++ b/docs/docs/tutorials/one_shot_nas_demo.md @@ -0,0 +1,102 @@ +# One Shot NAS 示例 + +>该示例依赖Paddle1.7.0或Paddle develop版本。 + +该示例使用MNIST数据,介绍了如何使用PaddleSlim的OneShotNAS接口搜索出一个分类网络。OneShotNAS仅支持动态图,所以该示例完全使用Paddle动态图模式。 + +## 关键代码介绍 + +One-shot网络结构搜索策略包含以下步骤: + +1. 定义超网络 +2. 训练超网络 +3. 基于超网络搜索子网络 +4. 训练最佳子网络 + +以下按序介绍各个步骤的关键代码。 + +### 定义超级网络 + +按照动态图教程,定义一个分类网络模块,该模块包含4个子模块:`_simple_img_conv_pool_1`,`_simple_img_conv_pool_2`,`super_net`和`fc`,其中`super_net`为`SuperMnasnet`的一个实例。 + +在前向计算过程中,输入图像先后经过子模块`_simple_img_conv_pool_1`、`super_net`、`_simple_img_conv_pool_2`和`fc`的前向计算。 + +代码如下所示: +``` +class MNIST(fluid.dygraph.Layer): + def __init__(self): + super(MNIST, self).__init__() + + self._simple_img_conv_pool_1 = SimpleImgConv(1, 20, 2, act="relu") + self.arch = SuperMnasnet( + name_scope="super_net", input_channels=20, out_channels=20) + self._simple_img_conv_pool_2 = SimpleImgConv(20, 50, 2, act="relu") + + self.pool_2_shape = 50 * 13 * 13 + SIZE = 10 + scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5 + self._fc = Linear( + self.pool_2_shape, + 10, + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.NormalInitializer( + loc=0.0, scale=scale)), + act="softmax") + + def forward(self, inputs, label=None, tokens=None): + x = self._simple_img_conv_pool_1(inputs) + + x = self.arch(x, tokens=tokens) # addddddd + x = self._simple_img_conv_pool_2(x) + x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape]) + x = self._fc(x) + if label is not None: + acc = fluid.layers.accuracy(input=x, label=label) + return x, acc + else: + return x + +``` + +动态图模块MNIST的forward函数接受一个参数`tokens`,用于指定在前向计算中使用的子网络,如果`tokens`为None,则随机选取一个子网络进行前向计算。 + +### 训练超级网络 + +网络训练的逻辑定义在`train_mnist`函数中,将`tokens`参数设置为None,进行超网络训练,即在每个batch选取一个超网络进行训练。 + +代码如下所示: + +``` +with fluid.dygraph.guard(place): + model = MNIST() + train_mnist(args, model) +``` + +### 搜索最佳子网络 +使用PaddleSlim提供的`OneShotSearch`接口搜索最佳子网络。传入已定义且训练好的超网络实例`model`和一个用于评估子网络的回调函数`test_mnist`. + +代码如下: + +``` +best_tokens = OneShotSearch(model, test_mnist) +``` + +### 训练最佳子网络 + +获得最佳的子网络的编码`best_tokens`后,调用之前定义的`train_mnist`方法进行子网络的训练。代码如下: + +``` +train_mnist(args, model, best_tokens) +``` + +## 启动示例 + +执行以下代码运行示例: + +``` +python train.py +``` + +执行`python train.py --help`查看更多可配置选项。 + +## FAQ diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 5e8e6e6feb614..d5f725a8c3331 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -9,6 +9,7 @@ nav: - 量化训练: tutorials/quant_aware_demo.md - Embedding量化: tutorials/quant_embedding_demo.md - SA搜索: tutorials/nas_demo.md + - One-shot搜索: tutorials/one_shot_nas_demo.md - 搜索空间: search_space.md - 知识蒸馏: tutorials/distillation_demo.md - API: @@ -17,6 +18,8 @@ nav: - 模型分析: api/analysis_api.md - 知识蒸馏: api/single_distiller_api.md - SA搜索: api/nas_api.md + - One-shot搜索: api/one_shot_api.md + - 搜索空间: search_space.md - 硬件延时评估表: table_latency.md - 算法原理: algo/algo.md diff --git a/paddleslim/nas/__init__.py b/paddleslim/nas/__init__.py index c86051a867676..cf8c75ee5cc4c 100644 --- a/paddleslim/nas/__init__.py +++ b/paddleslim/nas/__init__.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from __future__ import absolute_import +from ..nas import search_space from .search_space import * -from .sa_nas import SANAS +from ..nas import sa_nas +from .sa_nas import * -__all__ = ['SANAS'] +__all__ = [] +__all__ += sa_nas.__all__ +__all__ += search_space.__all__ diff --git a/paddleslim/nas/one_shot/__init__.py b/paddleslim/nas/one_shot/__init__.py new file mode 100644 index 0000000000000..e8dfbe57b7455 --- /dev/null +++ b/paddleslim/nas/one_shot/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from ..one_shot import one_shot_nas +from .one_shot_nas import * +from ..one_shot import super_mnasnet +from .super_mnasnet import * +__all__ = [] +__all__ += one_shot_nas.__all__ +__all__ += super_mnasnet.__all__ diff --git a/paddleslim/nas/one_shot/one_shot_nas.py b/paddleslim/nas/one_shot/one_shot_nas.py new file mode 100644 index 0000000000000..20e0d64046a07 --- /dev/null +++ b/paddleslim/nas/one_shot/one_shot_nas.py @@ -0,0 +1,114 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle.fluid as fluid +from ...common import SAController + +__all__ = ['OneShotSuperNet', 'OneShotSearch'] + + +def OneShotSearch(model, eval_func, strategy='sa', search_steps=100): + """ + Search a best tokens which represents a sub-network. + Archs: + model(fluid.dygraph.Layer): A dynamic graph module whose sub-modules should contain + one instance of `OneShotSuperNet` at least. + eval_func(function): A callback function which accept model and tokens as arguments. + strategy(str): The name of strategy used to search. Default: 'sa'. + search_steps(int): The total steps for searching. + Returns: + tokens(list): The best tokens searched. + """ + super_net = None + for layer in model.sublayers(include_sublayers=False): + print("layer: {}".format(layer)) + if isinstance(layer, OneShotSuperNet): + super_net = layer + break + assert super_net is not None + controller = None + if strategy == "sa": + contoller = SAController( + range_table=super_net.range_table(), + init_tokens=super_net.init_tokens()) + assert (controller is not None, "Unsupported searching strategy.") + for i in range(search_steps): + tokens = contoller.next_tokens() + reward = eval_func(model, tokens) + contoller.update(tokens, reward, i) + return contoller.best_tokens() + + +class OneShotSuperNet(fluid.dygraph.Layer): + """ + The base class of super net used in one-shot searching strategy. + A super net is a dygraph layer. + + Args: + name_scope(str): The name scope of super net. + """ + + def __init__(self, name_scope): + super(OneShotSuperNet, self).__init__(name_scope) + + def init_tokens(self): + """Get init tokens in search space. + Return: + tokens(list): The init tokens which is a list of integer. + """ + raise NotImplementedError('Abstract method.') + + def range_table(self): + """Get range table of current search space. + Return: + range_table(tuple): The maximum value and minimum value in each position of tokens + with format `(min_values, max_values)`. The `min_values` is + a list of integers indicating the minimum values while `max_values` + indicating the maximum values. + """ + raise NotImplementedError('Abstract method.') + + def _forward_impl(self, *inputs, **kwargs): + """ + Defines the computation performed at every call. + Should be overridden by all subclasses. + Args: + inputs(tuple): unpacked tuple arguments + kwargs(dict): unpacked dict arguments + """ + raise NotImplementedError('Abstract method.') + + def forward(self, input, tokens=None): + """ + Defines the computation performed at every call. + Args: + input(variable): The input of super net. + tokens(list): The tokens used to generate a sub-network. + None means computing in super net training mode. + Otherwise, it will execute the sub-network generated by tokens. + The `tokens` should be set in searching stage and final training stage. + Default: None. + Returns: + output(varaible): The output of super net. + """ + if tokens == None: + tokens = self._random_tokens() + return self._forward_impl(input, tokens=tokens) + + def _random_tokens(self): + tokens = [] + for min_v, max_v in zip(self.range_table()[0], self.range_table()[1]): + tokens.append(np.random.randint(min_v, max_v)) + return tokens diff --git a/paddleslim/nas/one_shot/super_mnasnet.py b/paddleslim/nas/one_shot/super_mnasnet.py new file mode 100644 index 0000000000000..852b40383af52 --- /dev/null +++ b/paddleslim/nas/one_shot/super_mnasnet.py @@ -0,0 +1,257 @@ +import paddle +from paddle import fluid +from paddle.fluid.layer_helper import LayerHelper +import numpy as np +from one_shot_nas import OneShotSuperNet + +__all__ = ['SuperMnasnet'] + + +class DConvBlock(fluid.dygraph.Layer): + def __init__(self, + name_scope, + in_channels, + channels, + expansion, + stride, + kernel_size=3, + padding=1): + super(DConvBlock, self).__init__(name_scope) + self.expansion = expansion + self.in_channels = in_channels + self.channels = channels + self.stride = stride + self.flops = 0 + self.flops_calculated = False + self.expand = fluid.dygraph.Conv2D( + in_channels, + num_filters=in_channels * expansion, + filter_size=1, + stride=1, + padding=0, + act=None, + bias_attr=False) + self.expand_bn = fluid.dygraph.BatchNorm( + num_channels=in_channels * expansion, act='relu6') + + self.dconv = fluid.dygraph.Conv2D( + in_channels * expansion, + num_filters=in_channels * expansion, + filter_size=kernel_size, + stride=stride, + padding=padding, + act=None, + bias_attr=False, + groups=in_channels * expansion, + use_cudnn=False) + self.dconv_bn = fluid.dygraph.BatchNorm( + num_channels=in_channels * expansion, act='relu6') + + self.project = fluid.dygraph.Conv2D( + in_channels * expansion, + num_filters=channels, + filter_size=1, + stride=1, + padding=0, + act=None, + bias_attr=False) + self.project_bn = fluid.dygraph.BatchNorm( + num_channels=channels, act=None) + + self.shortcut = fluid.dygraph.Conv2D( + in_channels, + num_filters=channels, + filter_size=1, + stride=1, + padding=0, + act=None, + bias_attr=False) + self.shortcut_bn = fluid.dygraph.BatchNorm( + num_channels=channels, act=None) + + def get_flops(self, input, output, op): + if not self.flops_calculated: + flops = input.shape[1] * output.shape[1] * ( + op._filter_size**2) * output.shape[2] * output.shape[3] + if op._groups: + flops /= op._groups + self.flops += flops + + def forward(self, inputs): + expand_x = self.expand_bn(self.expand(inputs)) + self.get_flops(inputs, expand_x, self.expand) + dconv_x = self.dconv_bn(self.dconv(expand_x)) + self.get_flops(expand_x, dconv_x, self.dconv) + proj_x = self.project_bn(self.project(dconv_x)) + self.get_flops(dconv_x, proj_x, self.project) + if self.in_channels != self.channels and self.stride == 1: + shortcut = self.shortcut_bn(self.shortcut(inputs)) + self.get_flops(inputs, shortcut, self.shortcut) + elif self.stride == 1: + shortcut = inputs + self.flops_calculated = True + if self.stride == 1: + out = fluid.layers.elementwise_add(x=proj_x, y=shortcut) + return out + return proj_x + + +class SearchBlock(fluid.dygraph.Layer): + def __init__(self, + name_scope, + in_channels, + channels, + stride, + kernel_size=3, + padding=1): + super(SearchBlock, self).__init__(name_scope) + self._stride = stride + self.block_list = [] + self.flops = [0 for i in range(10)] + self.flops_calculated = [False if i < 6 else True for i in range(10)] + kernels = [3, 5, 7] + expansions = [3, 6] + for k in kernels: + for e in expansions: + self.block_list.append( + DConvBlock(self.full_name(), in_channels, channels, e, + stride, k, (k - 1) // 2)) + self.add_sublayer("expansion_{}_kernel_{}".format(e, k), + self.block_list[-1]) + + def forward(self, inputs, arch): + if arch >= 6: + return inputs + out = self.block_list[arch](inputs) + if not self.flops_calculated[arch]: + self.flops[arch] = self.block_list[arch].flops + self.flops_calculated[arch] = True + return out + + +class AuxiliaryHead(fluid.dygraph.Layer): + def __init__(self, name_scope, num_classes): + super(AuxiliaryHead, self).__init__(name_scope) + + self.pool1 = fluid.dygraph.Pool2D( + 5, 'avg', pool_stride=3, pool_padding=0) + self.conv1 = fluid.dygraph.Conv2D(128, 1, bias_attr=False) + self.bn1 = fluid.dygraph.BatchNorm(128, act='relu6') + self.conv2 = fluid.dygraph.Conv2D(768, 2, bias_attr=False) + self.bn2 = fluid.dygraph.BatchNorm(768, act='relu6') + self.classifier = fluid.dygraph.FC(num_classes, act='softmax') + self.layer_helper = LayerHelper(self.full_name(), act='relu6') + + def forward(self, inputs): #pylint: disable=arguments-differ + inputs = self.layer_helper.append_activation(inputs) + inputs = self.pool1(inputs) + inputs = self.conv1(inputs) + inputs = self.bn1(inputs) + inputs = self.conv2(inputs) + inputs = self.bn2(inputs) + inputs = self.classifier(inputs) + return inputs + + +class SuperMnasnet(OneShotSuperNet): + def __init__(self, + name_scope, + input_channels=3, + out_channels=1280, + repeat_times=[6, 6, 6, 6, 6, 6], + stride=[1, 1, 1, 1, 2, 1], + channels=[16, 24, 40, 80, 96, 192, 320], + use_auxhead=False): + super(SuperMnasnet, self).__init__(name_scope) + self.flops = 0 + self.repeat_times = repeat_times + self.flops_calculated = False + self.last_tokens = None + self._conv = fluid.dygraph.Conv2D( + input_channels, 32, 3, 1, 1, act=None, bias_attr=False) + self._bn = fluid.dygraph.BatchNorm(32, act='relu6') + self._sep_conv = fluid.dygraph.Conv2D( + 32, + 32, + 3, + 1, + 1, + groups=32, + act=None, + use_cudnn=False, + bias_attr=False) + self._sep_conv_bn = fluid.dygraph.BatchNorm(32, act='relu6') + self._sep_project = fluid.dygraph.Conv2D( + 32, 16, 1, 1, 0, act=None, bias_attr=False) + self._sep_project_bn = fluid.dygraph.BatchNorm(16, act='relu6') + + self._final_conv = fluid.dygraph.Conv2D( + 320, out_channels, 1, 1, 0, act=None, bias_attr=False) + self._final_bn = fluid.dygraph.BatchNorm(out_channels, act='relu6') + self.stride = stride + self.block_list = [] + self.use_auxhead = use_auxhead + + for _iter, _stride in enumerate(self.stride): + repeat_block = [] + for _ind in range(self.repeat_times[_iter]): + if _ind == 0: + block = SearchBlock(self.full_name(), channels[_iter], + channels[_iter + 1], _stride) + else: + block = SearchBlock(self.full_name(), channels[_iter + 1], + channels[_iter + 1], 1) + self.add_sublayer("block_{}_{}".format(_iter, _ind), block) + repeat_block.append(block) + self.block_list.append(repeat_block) + if self.use_auxhead: + self.auxhead = AuxiliaryHead(self.full_name(), 10) + + def init_tokens(self): + return [ + 3, 3, 6, 6, 6, 6, 3, 3, 3, 6, 6, 6, 3, 3, 3, 3, 6, 6, 3, 3, 3, 6, + 6, 6, 3, 3, 3, 6, 6, 6, 3, 6, 6, 6, 6, 6 + ] + + def range_table(self): + max_v = [ + 6, 6, 10, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 6, 6, 6, 10, 10, 6, + 6, 6, 10, 10, 10, 6, 6, 6, 10, 10, 10, 6, 10, 10, 10, 10, 10 + ] + return (len(max_v) * [0], max_v) + + def get_flops(self, input, output, op): + if not self.flops_calculated: + flops = input.shape[1] * output.shape[1] * ( + op._filter_size**2) * output.shape[2] * output.shape[3] + if op._groups: + flops /= op._groups + self.flops += flops + + def _forward_impl(self, inputs, tokens=None): + if isinstance(tokens, np.ndarray) and not (tokens == self.last_tokens).all()\ + or not isinstance(tokens, np.ndarray) and not tokens == self.last_tokens: + self.flops_calculated = False + self.flops = 0 + self.last_tokens = tokens + x = self._bn(self._conv(inputs)) + self.get_flops(inputs, x, self._conv) + sep_x = self._sep_conv_bn(self._sep_conv(x)) + self.get_flops(x, sep_x, self._sep_conv) + proj_x = self._sep_project_bn(self._sep_project(sep_x)) + self.get_flops(sep_x, proj_x, self._sep_project) + x = proj_x + for ind in range(len(self.block_list)): + for b_ind, block in enumerate(self.block_list[ind]): + x = fluid.layers.dropout(block(x, tokens[ind * 6 + b_ind]), 0.) + if not self.flops_calculated: + self.flops += block.flops[tokens[ind * 6 + b_ind]] + if ind == len(self.block_list) * 2 // 3 - 1 and self.use_auxhead: + fc_aux = self.auxhead(x) + final_x = self._final_bn(self._final_conv(x)) + self.get_flops(x, final_x, self._final_conv) + # x = self.global_pooling(final_x) + self.flops_calculated = True + if self.use_auxhead: + return final_x, fc_aux + return final_x diff --git a/paddleslim/nas/search_space/__init__.py b/paddleslim/nas/search_space/__init__.py index 9556c61917406..bd8a3d3141dd3 100644 --- a/paddleslim/nas/search_space/__init__.py +++ b/paddleslim/nas/search_space/__init__.py @@ -21,7 +21,6 @@ from .search_space_registry import SEARCHSPACE from .search_space_factory import SearchSpaceFactory from .search_space_base import SearchSpaceBase - __all__ = [ 'MobileNetV1Space', 'MobileNetV2Space', 'ResNetSpace', 'MobileNetV1BlockSpace', 'MobileNetV2BlockSpace', 'ResNetBlockSpace', diff --git a/paddleslim/nas/search_space/search_space_base.py b/paddleslim/nas/search_space/search_space_base.py index 9dee1431d34af..af4d4a1d6a25c 100644 --- a/paddleslim/nas/search_space/search_space_base.py +++ b/paddleslim/nas/search_space/search_space_base.py @@ -19,6 +19,7 @@ _logger = get_logger(__name__, level=logging.INFO) + class SearchSpaceBase(object): """Controller for Neural Architecture Search. """ @@ -56,3 +57,7 @@ def token2arch(self, tokens): model arch """ raise NotImplementedError('Abstract method.') + + def super_net(self): + """This function is just used in one shot NAS strategy. Return a super graph.""" + raise NotImplementedError('Abstract method.')