Skip to content

Commit

Permalink
SSD support in NNVM (apache#1214)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinthesun authored and tqchen committed Jun 14, 2018
1 parent 238a503 commit 42520a2
Show file tree
Hide file tree
Showing 12 changed files with 731 additions and 84 deletions.
49 changes: 49 additions & 0 deletions nnvm/include/nnvm/top/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,55 @@ struct LayoutTransformParam : public dmlc::Parameter<LayoutTransformParam> {
}
};

struct MultiBoxPriorParam : public dmlc::Parameter<MultiBoxPriorParam> {
Tuple<float> sizes;
Tuple<float> ratios;
Tuple<float> steps;
Tuple<float> offsets;
bool clip;

DMLC_DECLARE_PARAMETER(MultiBoxPriorParam) {
DMLC_DECLARE_FIELD(sizes).set_default(Tuple<float>({1.0}))
.describe("List of sizes of generated MultiBoxPriores.");
DMLC_DECLARE_FIELD(ratios).set_default(Tuple<float>({1.0}))
.describe("List of aspect ratios of generated MultiBoxPriores.");
DMLC_DECLARE_FIELD(steps).set_default(Tuple<float>({-1.0, -1.0}))
.describe("Priorbox step across y and x, -1 for auto calculation.");
DMLC_DECLARE_FIELD(offsets).set_default(Tuple<float>({0.5, 0.5}))
.describe("Priorbox center offsets, y and x respectively.");
DMLC_DECLARE_FIELD(clip).set_default(false)
.describe("Whether to clip out-of-boundary boxes.");
}
};

struct MultiBoxTransformLocParam : public dmlc::Parameter<MultiBoxTransformLocParam> {
bool clip;
float threshold;
Tuple<float> variances;
DMLC_DECLARE_PARAMETER(MultiBoxTransformLocParam) {
DMLC_DECLARE_FIELD(clip).set_default(true)
.describe("Clip out-of-boundary boxes.");
DMLC_DECLARE_FIELD(threshold).set_default(0.01)
.describe("Threshold to be a positive prediction.");
DMLC_DECLARE_FIELD(variances).set_default(Tuple<float>{0.1, 0.1, 0.2, 0.2})
.describe("Variances to be decoded from box regression output.");
}
};

struct NMSParam : public dmlc::Parameter<NMSParam> {
float nms_threshold;
bool force_suppress;
int nms_topk;
DMLC_DECLARE_PARAMETER(NMSParam) {
DMLC_DECLARE_FIELD(nms_threshold).set_default(0.5)
.describe("Non-maximum suppression threshold.");
DMLC_DECLARE_FIELD(force_suppress).set_default(false)
.describe("Suppress all detections regardless of class_id.");
DMLC_DECLARE_FIELD(nms_topk).set_default(-1)
.describe("Keep maximum top k detections before nms, -1 for no limit.");
}
};

} // namespace top
} // namespace nnvm

Expand Down
32 changes: 31 additions & 1 deletion nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _batch_norm(inputs, attrs):
new_attrs['axis'] = attrs.get('axis', 1)
new_attrs['epsilon'] = attrs.get('eps', 0.001)
new_attrs['center'] = True
new_attrs['scale'] = True
new_attrs['scale'] = not _parse_bool_str(attrs, 'fix_gamma', default="False")
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _concat(inputs, attrs):
Expand Down Expand Up @@ -195,6 +195,12 @@ def _split(inputs, attrs):
new_attrs['axis'] = attrs.get('axis', 1)
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _softmax_activation(inputs, attrs):
op_name, new_attrs = 'softmax', {}
mode = attrs.get('mode', 'instance')
new_attrs['axis'] = 0 if mode == 'instance' else 1
return _get_nnvm_op(op_name)(inputs[0], **new_attrs)

def _softmax_output(inputs, attrs):
op_name, new_attrs = 'softmax', {}
if _parse_bool_str(attrs, 'multi_output'):
Expand All @@ -212,6 +218,25 @@ def _clip(inputs, attrs):
new_attrs['a_max'] = _required_attr(attrs, 'a_max')
return _get_nnvm_op(op_name)(*inputs, **new_attrs)

def _contrib_multibox_detection(inputs, attrs):
clip = _parse_bool_str(attrs, 'clip', default='True')
threshold = attrs.get('threshold') or 0.01
nms_threshold = attrs.get('nms_threshold') or 0.5
force_suppress = _parse_bool_str(attrs, 'force_suppress', default='False')
variances = tuple([float(x.strip()) for x in attrs.get('variances').strip('()').split(',')]) \
if attrs.get('variances') is not None else (0.1, 0.1, 0.2, 0.2)
nms_topk = attrs.get('nms_topk') or -1
new_attrs0 = {'clip': clip, 'threshold': float(threshold), 'variances': variances}
new_attrs1 = {'nms_threshold': float(nms_threshold), 'force_suppress': force_suppress,
'nms_topk': int(nms_topk)}
data, valid_count = _get_nnvm_op('multibox_transform_loc')(inputs[0], inputs[1],
inputs[2], **new_attrs0)
return _get_nnvm_op('nms')(data, valid_count, **new_attrs1)

def _elemwise_sum(inputs, _):
new_attrs = {'num_args':len(inputs)}
return _get_nnvm_op('elemwise_sum')(*inputs, **new_attrs)


_identity_list = ['__add_scalar__', '__add_symbol__', '__div_scalar__',
'__div_symbol__', '__mul_scalar__', '__mul_symbol__',
Expand All @@ -224,12 +249,15 @@ def _clip(inputs, attrs):
'relu', 'sigmoid', 'softmax', 'sum', 'tanh', 'transpose']

_convert_map = {
'_copy' : _rename('copy'),
'_div_scalar' : _rename('__div_scalar__'),
'_minus_scalar' : _rename('__sub_scalar__'),
'_mul_scalar' : _rename('__mul_scalar__'),
'_plus_scalar' : _rename('__add_scalar__'),
'_rdiv_scalar' : _rename('__rdiv_scalar__'),
'_rminus_scalar': _rename('__rsub_scalar__'),
'_contrib_MultiBoxPrior' : _rename('multibox_prior'),
'_contrib_MultiBoxDetection' : _contrib_multibox_detection,
'Activation' : _activations,
'BatchNorm' : _batch_norm,
'BatchNorm_v1' : _batch_norm,
Expand All @@ -248,7 +276,9 @@ def _clip(inputs, attrs):
'SliceChannel' : _split,
'split' : _split,
'Softmax' : _rename('softmax'),
'SoftmaxActivation' : _softmax_activation,
'SoftmaxOutput' : _softmax_output,
'add_n' : _elemwise_sum,
'concat' : _concat,
'max_axis' : _rename('max'),
'min_axis' : _rename('min'),
Expand Down
69 changes: 69 additions & 0 deletions nnvm/python/nnvm/testing/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# pylint: disable=invalid-name, no-member, import-error, no-name-in-module, global-variable-undefined, bare-except
"""Helper utility for downloading"""
from __future__ import print_function
from __future__ import absolute_import as _abs

import os
import sys
import time
import urllib
import requests

if sys.version_info >= (3,):
import urllib.request as urllib2
else:
import urllib2

def _download_progress(count, block_size, total_size):
"""Show the download progress.
"""
global start_time
if count == 0:
start_time = time.time()
return
duration = time.time() - start_time
progress_size = int(count * block_size)
speed = int(progress_size / (1024 * duration))
percent = int(count * block_size * 100 / total_size)
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" %
(percent, progress_size / (1024 * 1024), speed, duration))
sys.stdout.flush()

def download(url, path, overwrite=False, size_compare=False):
"""Downloads the file from the internet.
Set the input options correctly to overwrite or do the size comparison
Parameters
----------
url : str
Download url.
path : str
Local file path to save downloaded file
overwrite : bool, optional
Whether to overwrite existing file
size_compare : bool, optional
Whether to do size compare to check downloaded file.
"""
if os.path.isfile(path) and not overwrite:
if size_compare:
file_size = os.path.getsize(path)
res_head = requests.head(url)
res_get = requests.get(url, stream=True)
if 'Content-Length' not in res_head.headers:
res_get = urllib2.urlopen(url)
url_file_size = int(res_get.headers['Content-Length'])
if url_file_size != file_size:
print("exist file got corrupted, downloading %s file freshly..." % path)
download(url, path, True, False)
return
print('File {} exists, skip.'.format(path))
return
print('Downloading from url {} to {}'.format(url, path))
try:
urllib.request.urlretrieve(url, path, reporthook=_download_progress)
print('')
except:
urllib.urlretrieve(url, path, reporthook=_download_progress)
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/testing/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def resnet(units, num_stages, filter_list, num_classes, image_shape,
num_unit = len(units)
assert num_unit == num_stages
data = sym.Variable(name='data')
data = sym.batch_norm(data=data, epsilon=2e-5, name='bn_data')
data = sym.batch_norm(data=data, epsilon=2e-5, scale=False, name='bn_data')
(_, height, _) = image_shape
if height <= 32: # such as cifar10
body = sym.conv2d(
Expand Down
15 changes: 15 additions & 0 deletions nnvm/python/nnvm/top/attr_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,21 @@ def get_int(self, key):
"""
return int(self[key])

def get_float_tuple(self, key):
"""Get tuple of float from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
tuple : tuple of float
The result tuple
"""
return tuple(float(x) for x in self[key][1:-1].split(",") if x)

def get_float(self, key):
"""Get float from attr dict
Expand Down
62 changes: 60 additions & 2 deletions nnvm/python/nnvm/top/vision.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@

# pylint: disable=invalid-name, unused-argument
"""Definition of nn ops"""
from __future__ import absolute_import

import topi
import tvm
import topi
from . import registry as reg
from .registry import OpPattern

Expand Down Expand Up @@ -38,3 +37,62 @@ def schedule_region(attrs, outs, target):
return topi.generic.vision.schedule_region(outs)

reg.register_pattern("yolo2_region", OpPattern.OPAQUE)

# multibox_prior
@reg.register_schedule("multibox_prior")
def schedule_multibox_prior(_, outs, target):
"""Schedule definition of multibox_prior"""
with tvm.target.create(target):
return topi.generic.schedule_multibox_prior(outs)

@reg.register_compute("multibox_prior")
def compute_multibox_prior(attrs, inputs, _):
"""Compute definition of multibox_prior"""
sizes = attrs.get_float_tuple('sizes')
ratios = attrs.get_float_tuple('ratios')
steps = attrs.get_float_tuple('steps')
offsets = attrs.get_float_tuple('offsets')
clip = attrs.get_bool('clip')

return topi.vision.ssd.multibox_prior(inputs[0], sizes, ratios,
steps, offsets, clip)

reg.register_pattern("multibox_prior", OpPattern.OPAQUE)

# multibox_transform_loc
@reg.register_schedule("multibox_transform_loc")
def schedule_multibox_transform_loc(_, outs, target):
"""Schedule definition of multibox_detection"""
with tvm.target.create(target):
return topi.generic.schedule_multibox_transform_loc(outs)

@reg.register_compute("multibox_transform_loc")
def compute_multibox_transform_loc(attrs, inputs, _):
"""Compute definition of multibox_detection"""
clip = attrs.get_bool('clip')
threshold = attrs.get_float('threshold')
variance = attrs.get_float_tuple('variances')

return topi.vision.ssd.multibox_transform_loc(inputs[0], inputs[1], inputs[2],
clip, threshold, variance)

reg.register_pattern("multibox_detection", OpPattern.OPAQUE)

# non-maximum suppression
@reg.register_schedule("nms")
def schedule_nms(_, outs, target):
"""Schedule definition of nms"""
with tvm.target.create(target):
return topi.generic.schedule_nms(outs)

@reg.register_compute("nms")
def compute_nms(attrs, inputs, _):
"""Compute definition of nms"""
nms_threshold = attrs.get_float('nms_threshold')
force_suppress = attrs.get_bool('force_suppress')
nms_topk = attrs.get_int('nms_topk')

return topi.vision.nms(inputs[0], inputs[1], nms_threshold,
force_suppress, nms_topk)

reg.register_pattern("nms", OpPattern.OPAQUE)
80 changes: 80 additions & 0 deletions nnvm/src/top/vision/nms.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*!
* Copyright (c) 2017 by Contributors
* \file nms.cc
* \brief Property def of SSD non-maximum suppression operator.
*/

#include <tvm/expr.h>
#include <tvm/packed_func_ext.h>
#include <nnvm/op.h>
#include <nnvm/top/nn.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/compiler/op_attr_types.h>
#include "../op_common.h"
#include "../elemwise_op_common.h"

namespace nnvm {
namespace top {
using compiler::FTVMCompute;
using tvm::Tensor;
using tvm::Array;

DMLC_REGISTER_PARAMETER(NMSParam);

bool NMSShape(const NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]";
TShape dshape = in_attrs->at(0);
TShape vshape = in_attrs->at(1);
CHECK_EQ(dshape.ndim(), 3U) << "Input data should be 3-D.";
CHECK_EQ(vshape.ndim(), 1U) << "Input valid count should be 1-D.";
CHECK_EQ(dshape[2], 6U) << "Data input should have shape "
"(batch_size, num_anchors, 6).";
CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch.";
out_attrs->clear();
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape);
return true;
}

inline bool NMSInferType(const NodeAttrs &attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(0));
return true;
}

inline bool NMSInferLayout(const NodeAttrs& attrs,
std::vector<Layout> *ilayouts,
const std::vector<Layout> *last_ilayouts,
std::vector<Layout> *olayouts) {
static const Layout kNCHW("NCHW");
CHECK_EQ(ilayouts->size(), 2U);
CHECK_EQ(olayouts->size(), 1U);
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kNCHW);
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kNCHW);
return true;
}

NNVM_REGISTER_OP(nms)
.describe(R"doc("Non-maximum suppression."
)doc" NNVM_ADD_FILELINE)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NMSParam>)
.set_attr<FGetAttrDict>("FGetAttrDict",
ParamGetAttrDict<NMSParam>)
.add_arguments(NMSParam::__FIELDS__())
.add_argument("data", "Tensor", "Input data.")
.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.")
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "valid_count"};
})
.set_attr<FInferShape>("FInferShape", NMSShape)
.set_attr<FInferType>("FInferType", NMSInferType)
.set_attr<FCorrectLayout>("FCorrectLayout", NMSInferLayout)
.set_support_level(4);

} // namespace top
} // namespace nnvm

Loading

0 comments on commit 42520a2

Please sign in to comment.