Skip to content

Commit

Permalink
fix bug for OFA (PaddlePaddle#464)
Browse files Browse the repository at this point in the history
* fix bugs for ernie
  • Loading branch information
ceci3 committed Nov 21, 2020
1 parent c6fdcc3 commit 9f43bbc
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 84 deletions.
38 changes: 22 additions & 16 deletions demo/one_shot/ofa_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.dygraph.nn as nn
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn import ReLU
from paddleslim.nas.ofa import OFA, RunConfig, DistillConfig
from paddleslim.nas.ofa import supernet


class Model(fluid.dygraph.Layer):
class Model(nn.Layer):
def __init__(self):
super(Model, self).__init__()
with supernet(
Expand Down Expand Up @@ -50,18 +50,20 @@ def forward(self, inputs, label, depth=None):

for idx, layer in enumerate(models):
if idx == 6:
inputs = fluid.layers.flatten(inputs, 1)
inputs = paddle.flatten(inputs, 1)
inputs = layer(inputs)

inputs = fluid.layers.softmax(inputs)
inputs = F.softmax(inputs)
return inputs


def test_ofa():

model = Model()
teacher_model = Model()

default_run_config = {
'train_batch_size': 256,
'eval_batch_size': 64,
'n_epochs': [[1], [2, 3], [4, 5]],
'init_learning_rate': [[0.001], [0.003, 0.001], [0.003, 0.001]],
'dynamic_batch_size': [1, 1, 1],
Expand All @@ -72,42 +74,46 @@ def test_ofa():

default_distill_config = {
'lambda_distill': 0.01,
'teacher_model': Model,
'teacher_model': teacher_model,
'mapping_layers': ['models.0.fn']
}
distill_config = DistillConfig(**default_distill_config)

fluid.enable_dygraph()
model = Model()
ofa_model = OFA(model, run_config, distill_config=distill_config)

train_reader = paddle.fluid.io.batch(
paddle.dataset.mnist.train(), batch_size=256, drop_last=True)
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
train_loader = paddle.io.DataLoader(
train_dataset,
places=place,
feed_list=[image, label],
drop_last=True,
batch_size=64)

start_epoch = 0
for idx in range(len(run_config.n_epochs)):
cur_idx = run_config.n_epochs[idx]
for ph_idx in range(len(cur_idx)):
cur_lr = run_config.init_learning_rate[idx][ph_idx]
adam = fluid.optimizer.Adam(
adam = paddle.optimizer.Adam(
learning_rate=cur_lr,
parameter_list=(ofa_model.parameters() + ofa_model.netAs_param))
for epoch_id in range(start_epoch,
run_config.n_epochs[idx][ph_idx]):
for batch_id, data in enumerate(train_reader()):
for batch_id, data in enumerate(train_loader()):
dy_x_data = np.array(
[x[0].reshape(1, 28, 28)
for x in data]).astype('float32')
y_data = np.array(
[x[1] for x in data]).astype('int64').reshape(-1, 1)

img = fluid.dygraph.to_variable(dy_x_data)
label = fluid.dygraph.to_variable(y_data)
img = paddle.dygraph.to_variable(dy_x_data)
label = paddle.dygraph.to_variable(y_data)
label.stop_gradient = True

for model_no in range(run_config.dynamic_batch_size[idx]):
output, _ = ofa_model(img, label)
loss = fluid.layers.reduce_mean(output)
loss = F.mean(output)
dis_loss = ofa_model.calc_distill_loss()
loss += dis_loss
loss.backward()
Expand Down
1 change: 1 addition & 0 deletions paddleslim/nas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .rl_nas import *
from ..nas import darts
from .darts import *
from .ofa import *

__all__ = []
__all__ += sa_nas.__all__
Expand Down
88 changes: 81 additions & 7 deletions paddleslim/nas/ofa/convert_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@
import decorator
import logging
import paddle
import paddle.fluid as fluid
from paddle.fluid import framework
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm
import numbers
from paddle.fluid.dygraph.nn import Conv2D, Conv2DTranspose, Linear, BatchNorm, InstanceNorm, LayerNorm, Embedding
from .layers import *
from ...common import get_logger

_logger = get_logger(__name__, level=logging.INFO)

__all__ = ['supernet']

WEIGHT_LAYER = ['conv', 'linear']
WEIGHT_LAYER = ['conv', 'linear', 'embedding']


### TODO: add decorator
Expand All @@ -45,7 +44,7 @@ def convert(self, model):
cur_channel = None
for idx, layer in enumerate(model):
cls_name = layer.__class__.__name__.lower()
if 'conv' in cls_name or 'linear' in cls_name:
if 'conv' in cls_name or 'linear' in cls_name or 'embedding' in cls_name:
weight_layer_count += 1
last_weight_layer_idx = idx
if first_weight_layer_idx == -1:
Expand All @@ -63,7 +62,7 @@ def convert(self, model):

new_attr_name = [
'_stride', '_dilation', '_groups', '_param_attr',
'_bias_attr', '_use_cudnn', '_act', '_dtype'
'_bias_attr', '_use_cudnn', '_act', '_dtype', '_padding'
]

new_attr_dict = dict()
Expand Down Expand Up @@ -179,6 +178,8 @@ def convert(self, model):
layer._parameters['weight'].shape[0])
elif self.context.channel:
new_attr_dict['num_channels'] = max(cur_channel)
else:
new_attr_dict['num_channels'] = attr_dict['_num_channels']

for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
Expand All @@ -196,7 +197,8 @@ def convert(self, model):

new_attr_name = [
'_stride', '_dilation', '_groups', '_param_attr',
'_bias_attr', '_use_cudnn', '_act', '_dtype', '_output_size'
'_padding', '_bias_attr', '_use_cudnn', '_act', '_dtype',
'_output_size'
]
assert attr_dict[
'_filter_size'] != None, "Conv2DTranspose only support filter size != None now"
Expand Down Expand Up @@ -371,6 +373,8 @@ def convert(self, model):
layer._parameters['scale'].shape[0])
elif self.context.channel:
new_attr_dict['num_channels'] = max(cur_channel)
else:
new_attr_dict['num_channels'] = attr_dict['_num_channels']

for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]
Expand All @@ -380,6 +384,76 @@ def convert(self, model):
layer = SuperInstanceNorm(**new_attr_dict)
model[idx] = layer

elif isinstance(layer, LayerNorm) and (
getattr(self.context, 'expand', None) != None or
getattr(self.context, 'channel', None) != None):
### TODO(ceci3): fix when normalized_shape != last_dim_of_input
if idx > last_weight_layer_idx:
continue

attr_dict = layer.__dict__
new_attr_name = [
'_scale', '_shift', '_param_attr', '_bias_attr', '_act',
'_dtype', '_epsilon'
]
new_attr_dict = dict()
if self.context.expand:
new_attr_dict[
'normalized_shape'] = self.context.expand * int(
attr_dict['_normalized_shape'][0])
elif self.context.channel:
new_attr_dict['normalized_shape'] = max(cur_channel)
else:
new_attr_dict['normalized_shape'] = attr_dict[
'_normalized_shape']

for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]

del layer, attr_dict
layer = SuperLayerNorm(**new_attr_dict)
model[idx] = layer

elif isinstance(layer, Embedding) and (
getattr(self.context, 'expand', None) != None or
getattr(self.context, 'channel', None) != None):
attr_dict = layer.__dict__
key = attr_dict['_full_name']
new_attr_name = [
'_is_sparse', '_is_distributed', '_padding_idx',
'_param_attr', '_dtype'
]

new_attr_dict = dict()
new_attr_dict['candidate_config'] = dict()
bef_size = attr_dict['_size']
if self.context.expand:
new_attr_dict['size'] = [
bef_size[0], self.context.expand * bef_size[1]
]
new_attr_dict['candidate_config'].update({
'expand_ratio': self.context.expand_ratio
})

elif self.context.channel:
cur_channel = self.context.channel[0]
self.context.channel = self.context.channel[1:]
new_attr_dict['size'] = [bef_size[0], max(cur_channel)]
new_attr_dict['candidate_config'].update({
'channel': cur_channel
})
pre_channel = cur_channel
else:
new_attr_dict['size'] = bef_size

for attr in new_attr_name:
new_attr_dict[attr[1:]] = attr_dict[attr]

del layer, attr_dict

layer = Block(SuperEmbedding(**new_attr_dict), key=key)
model[idx] = layer

return model


Expand Down
Loading

0 comments on commit 9f43bbc

Please sign in to comment.