Skip to content

Commit

Permalink
Add one-shot NAS API and mnasnet based search space. (PaddlePaddle#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghaoshuang committed Feb 4, 2020
1 parent 1664a75 commit 02144bc
Show file tree
Hide file tree
Showing 10 changed files with 871 additions and 4 deletions.
207 changes: 207 additions & 0 deletions demo/one_shot/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import argparse
import ast
import numpy as np
from PIL import Image
import os
import paddle
import paddle.fluid as fluid
from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.base import to_variable

from paddleslim.nas.one_shot import SuperMnasnet
from paddleslim.nas.one_shot import OneShotSearch


def parse_args():
parser = argparse.ArgumentParser("Training for Mnist.")
parser.add_argument(
"--use_data_parallel",
type=ast.literal_eval,
default=False,
help="The flag indicating whether to use data parallel mode to train the model."
)
parser.add_argument("-e", "--epoch", default=5, type=int, help="set epoch")
parser.add_argument("--ce", action="store_true", help="run ce")
args = parser.parse_args()
return args


class SimpleImgConv(fluid.dygraph.Layer):
def __init__(self,
num_channels,
num_filters,
filter_size,
conv_stride=1,
conv_padding=0,
conv_dilation=1,
conv_groups=1,
act=None,
use_cudnn=False,
param_attr=None,
bias_attr=None):
super(SimpleImgConv, self).__init__()

self._conv2d = Conv2D(
num_channels=num_channels,
num_filters=num_filters,
filter_size=filter_size,
stride=conv_stride,
padding=conv_padding,
dilation=conv_dilation,
groups=conv_groups,
param_attr=None,
bias_attr=None,
act=act,
use_cudnn=use_cudnn)

def forward(self, inputs):
x = self._conv2d(inputs)
return x


class MNIST(fluid.dygraph.Layer):
def __init__(self):
super(MNIST, self).__init__()

self._simple_img_conv_pool_1 = SimpleImgConv(1, 20, 2, act="relu")
self.arch = SuperMnasnet(
name_scope="super_net", input_channels=20, out_channels=20)
self._simple_img_conv_pool_2 = SimpleImgConv(20, 50, 2, act="relu")

self.pool_2_shape = 50 * 13 * 13
SIZE = 10
scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(
self.pool_2_shape,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")

def forward(self, inputs, label=None, tokens=None):
x = self._simple_img_conv_pool_1(inputs)

x = self.arch(x, tokens=tokens) # addddddd
x = self._simple_img_conv_pool_2(x)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
return x, acc
else:
return x


def test_mnist(model, tokens=None):
acc_set = []
avg_loss_set = []
batch_size = 64
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size, drop_last=True)
for batch_id, data in enumerate(test_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(batch_size, 1)

img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True
prediction, acc = model.forward(img, label, tokens=tokens)
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
acc_set.append(float(acc.numpy()))
avg_loss_set.append(float(avg_loss.numpy()))
if batch_id % 100 == 0:
print("Test - batch_id: {}".format(batch_id))
# get test acc and loss
acc_val_mean = np.array(acc_set).mean()
avg_loss_val_mean = np.array(avg_loss_set).mean()

return acc_val_mean


def train_mnist(args, model, tokens=None):
epoch_num = args.epoch
BATCH_SIZE = 64

adam = AdamOptimizer(
learning_rate=0.001, parameter_list=model.parameters())

train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=BATCH_SIZE, drop_last=True)
if args.use_data_parallel:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)

for epoch in range(epoch_num):
for batch_id, data in enumerate(train_reader()):
dy_x_data = np.array([x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)

img = to_variable(dy_x_data)
label = to_variable(y_data)
label.stop_gradient = True

cost, acc = model.forward(img, label, tokens=tokens)

loss = fluid.layers.cross_entropy(cost, label)
avg_loss = fluid.layers.mean(loss)

if args.use_data_parallel:
avg_loss = model.scale_loss(avg_loss)
avg_loss.backward()
model.apply_collective_grads()
else:
avg_loss.backward()

adam.minimize(avg_loss)
# save checkpoint
model.clear_gradients()
if batch_id % 1 == 0:
print("Loss at epoch {} step {}: {:}".format(epoch, batch_id,
avg_loss.numpy()))

model.eval()
test_acc = test_mnist(model, tokens=tokens)
model.train()
print("Loss at epoch {} , acc is: {}".format(epoch, test_acc))

save_parameters = (not args.use_data_parallel) or (
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.save_dygraph(model.state_dict(), "save_temp")
print("checkpoint saved")


if __name__ == '__main__':
args = parse_args()
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = MNIST()
# step 1: training super net
#train_mnist(args, model)
# step 2: search
best_tokens = OneShotSearch(model, test_mnist)
# step 3: final training
# train_mnist(args, model, best_tokens)
154 changes: 154 additions & 0 deletions docs/docs/api/one_shot_api.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@

## OneShotSearch
paddleslim.nas.one_shot.OneShotSearch(model, eval_func, strategy='sa', search_steps=100)[代码]()

: 从超级网络中搜索出一个最佳的子网络。

**参数:**

- **model(fluid.dygraph.layer):** 通过在`OneShotSuperNet`前后添加若该模块构建的动态图模块。因为`OneShotSuperNet`是一个超网络,所以`model`也是一个超网络。换句话说,在`model`模块的子模块中,至少有一个是`OneShotSuperNet`的实例。该方法从`model`超网络中搜索得到一个最佳的子网络。超网络`model`需要先被训练,具体细节请参考[OneShotSuperNet]()

- **eval_func:** 用于评估子网络性能的回调函数。该回调函数需要接受`model`为参数,并调用`model``forward`方法进行性能评估。

- **strategy(str):** 搜索策略的名称。默认为'sa', 当前仅支持'sa'.

- **search_steps(int):** 搜索轮次数。默认为100。

**返回:**

- **best_tokens:** 表示最佳子网络的编码信息(tokens)。

**示例代码:**

请参考[one-shot NAS示例]()


## OneShotSuperNet

用于`OneShot`搜索策略的超级网络的基类,所有超级网络的实现要继承该类。

paddleslim.nas.one_shot.OneShotSuperNet(name_scope)

: 构造方法。

**参数:**

- **name_scope:(str) **超级网络的命名空间。

**返回:**

- **super_net:** 一个`OneShotSuperNet`实例。

init_tokens()

: 获得当前超级网络的初始化子网络的编码,主要用于搜索。

**返回:**

- **tokens(list<int>):** 一个子网络的编码。

range_table()

: 超级网络中各个子网络由一组整型数字编码表示,该方法返回编码每个位置的取值范围。

**返回:**

- **range_table(tuple):** 子网络编码每一位的取值范围。`range_table`格式为`(min_values, max_values)`,其中,`min_values`为一个整型数组,表示每个编码位置可选取的最小值;`max_values`表示每个编码位置可选取的最大值。

_forward_impl(input, tokens)

: 前向计算函数。`OneShotSuperNet`的子类需要实现该函数。

**参数:**

- **input(Variable):** 超级网络的输入。

- **tokens(list<int>):** 执行前向计算所用的子网络的编码。默认为`None`,即随机选取一个子网络执行前向。

**返回:**

- **output(Variable):** 前向计算的输出

forward(self, input, tokens=None)

: 执行前向计算。

**参数:**

- **input(Variable):** 超级网络的输入。

- **tokens(list<int>):** 执行前向计算所用的子网络的编码。默认为`None`,即随机选取一个子网络执行前向。

**返回:**

- **output(Variable):** 前向计算的输出


_random_tokens()

: 随机选取一个子网络,并返回其编码。

**返回:**

- **tokens(list<int>):** 一个子网络的编码。

## SuperMnasnet

[Mnasnet](https://arxiv.org/abs/1807.11626)基础上修改得到的超级网络, 该类继承自`OneShotSuperNet`.

paddleslim.nas.one_shot.SuperMnasnet(name_scope, input_channels=3, out_channels=1280, repeat_times=[6, 6, 6, 6, 6, 6], stride=[1, 1, 1, 1, 2, 1], channels=[16, 24, 40, 80, 96, 192, 320], use_auxhead=False)

: 构造函数。

**参数:**

- **name_scope(str):** 命名空间。

- **input_channels(str):** 当前超级网络的输入的特征图的通道数量。

- **out_channels(str):** 当前超级网络的输出的特征图的通道数量。

- **repeat_times(list):** 每种`block`重复的次数。

- **stride(list):** 一种`block`重复堆叠成`repeat_block``stride`表示每个`repeat_block`的下采样比例。

- **channels(list):** channels[i]和channels[i+1]分别表示第i个`repeat_block`的输入特征图的通道数和输出特征图的通道数。

- **use_auxhead(bool):** 是否使用辅助特征图。如果设置为`True`,则`SuperMnasnet`除了返回输出特征图,还还返回辅助特征图。默认为False.

**返回:**

- **instance(SuperMnasnet):** 一个`SuperMnasnet`实例

**示例:**
```
import paddle
import paddle.fluid as fluid
class MNIST(fluid.dygraph.Layer):
def __init__(self):
super(MNIST, self).__init__()
self.arch = SuperMnasnet(
name_scope="super_net", input_channels=20, out_channels=20)
self.pool_2_shape = 50 * 13 * 13
SIZE = 10
scale = (2.0 / (self.pool_2_shape**2 * SIZE))**0.5
self._fc = Linear(
self.pool_2_shape,
10,
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.NormalInitializer(
loc=0.0, scale=scale)),
act="softmax")
def forward(self, inputs, label=None, tokens=None):
x = self.arch(inputs, tokens=tokens)
x = fluid.layers.reshape(x, shape=[-1, self.pool_2_shape])
x = self._fc(x)
if label is not None:
acc = fluid.layers.accuracy(input=x, label=label)
return x, acc
else:
return x
```
Loading

0 comments on commit 02144bc

Please sign in to comment.