forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
238a503
commit 42520a2
Showing
12 changed files
with
731 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# pylint: disable=invalid-name, no-member, import-error, no-name-in-module, global-variable-undefined, bare-except | ||
"""Helper utility for downloading""" | ||
from __future__ import print_function | ||
from __future__ import absolute_import as _abs | ||
|
||
import os | ||
import sys | ||
import time | ||
import urllib | ||
import requests | ||
|
||
if sys.version_info >= (3,): | ||
import urllib.request as urllib2 | ||
else: | ||
import urllib2 | ||
|
||
def _download_progress(count, block_size, total_size): | ||
"""Show the download progress. | ||
""" | ||
global start_time | ||
if count == 0: | ||
start_time = time.time() | ||
return | ||
duration = time.time() - start_time | ||
progress_size = int(count * block_size) | ||
speed = int(progress_size / (1024 * duration)) | ||
percent = int(count * block_size * 100 / total_size) | ||
sys.stdout.write("\r...%d%%, %d MB, %d KB/s, %d seconds passed" % | ||
(percent, progress_size / (1024 * 1024), speed, duration)) | ||
sys.stdout.flush() | ||
|
||
def download(url, path, overwrite=False, size_compare=False): | ||
"""Downloads the file from the internet. | ||
Set the input options correctly to overwrite or do the size comparison | ||
Parameters | ||
---------- | ||
url : str | ||
Download url. | ||
path : str | ||
Local file path to save downloaded file | ||
overwrite : bool, optional | ||
Whether to overwrite existing file | ||
size_compare : bool, optional | ||
Whether to do size compare to check downloaded file. | ||
""" | ||
if os.path.isfile(path) and not overwrite: | ||
if size_compare: | ||
file_size = os.path.getsize(path) | ||
res_head = requests.head(url) | ||
res_get = requests.get(url, stream=True) | ||
if 'Content-Length' not in res_head.headers: | ||
res_get = urllib2.urlopen(url) | ||
url_file_size = int(res_get.headers['Content-Length']) | ||
if url_file_size != file_size: | ||
print("exist file got corrupted, downloading %s file freshly..." % path) | ||
download(url, path, True, False) | ||
return | ||
print('File {} exists, skip.'.format(path)) | ||
return | ||
print('Downloading from url {} to {}'.format(url, path)) | ||
try: | ||
urllib.request.urlretrieve(url, path, reporthook=_download_progress) | ||
print('') | ||
except: | ||
urllib.urlretrieve(url, path, reporthook=_download_progress) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
/*! | ||
* Copyright (c) 2017 by Contributors | ||
* \file nms.cc | ||
* \brief Property def of SSD non-maximum suppression operator. | ||
*/ | ||
|
||
#include <tvm/expr.h> | ||
#include <tvm/packed_func_ext.h> | ||
#include <nnvm/op.h> | ||
#include <nnvm/top/nn.h> | ||
#include <nnvm/op_attr_types.h> | ||
#include <nnvm/compiler/op_attr_types.h> | ||
#include "../op_common.h" | ||
#include "../elemwise_op_common.h" | ||
|
||
namespace nnvm { | ||
namespace top { | ||
using compiler::FTVMCompute; | ||
using tvm::Tensor; | ||
using tvm::Array; | ||
|
||
DMLC_REGISTER_PARAMETER(NMSParam); | ||
|
||
bool NMSShape(const NodeAttrs& attrs, | ||
std::vector<TShape> *in_attrs, | ||
std::vector<TShape> *out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]"; | ||
TShape dshape = in_attrs->at(0); | ||
TShape vshape = in_attrs->at(1); | ||
CHECK_EQ(dshape.ndim(), 3U) << "Input data should be 3-D."; | ||
CHECK_EQ(vshape.ndim(), 1U) << "Input valid count should be 1-D."; | ||
CHECK_EQ(dshape[2], 6U) << "Data input should have shape " | ||
"(batch_size, num_anchors, 6)."; | ||
CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch."; | ||
out_attrs->clear(); | ||
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape); | ||
return true; | ||
} | ||
|
||
inline bool NMSInferType(const NodeAttrs &attrs, | ||
std::vector<int> *in_attrs, | ||
std::vector<int> *out_attrs) { | ||
DTYPE_ASSIGN(out_attrs->at(0), in_attrs->at(0)); | ||
return true; | ||
} | ||
|
||
inline bool NMSInferLayout(const NodeAttrs& attrs, | ||
std::vector<Layout> *ilayouts, | ||
const std::vector<Layout> *last_ilayouts, | ||
std::vector<Layout> *olayouts) { | ||
static const Layout kNCHW("NCHW"); | ||
CHECK_EQ(ilayouts->size(), 2U); | ||
CHECK_EQ(olayouts->size(), 1U); | ||
NNVM_ASSIGN_LAYOUT(*ilayouts, 0, kNCHW); | ||
NNVM_ASSIGN_LAYOUT(*ilayouts, 1, kNCHW); | ||
return true; | ||
} | ||
|
||
NNVM_REGISTER_OP(nms) | ||
.describe(R"doc("Non-maximum suppression." | ||
)doc" NNVM_ADD_FILELINE) | ||
.set_num_inputs(2) | ||
.set_num_outputs(1) | ||
.set_attr_parser(ParamParser<NMSParam>) | ||
.set_attr<FGetAttrDict>("FGetAttrDict", | ||
ParamGetAttrDict<NMSParam>) | ||
.add_arguments(NMSParam::__FIELDS__()) | ||
.add_argument("data", "Tensor", "Input data.") | ||
.add_argument("valid_count", "Tensor", "Number of valid anchor boxes.") | ||
.set_attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) { | ||
return std::vector<std::string>{"data", "valid_count"}; | ||
}) | ||
.set_attr<FInferShape>("FInferShape", NMSShape) | ||
.set_attr<FInferType>("FInferType", NMSInferType) | ||
.set_attr<FCorrectLayout>("FCorrectLayout", NMSInferLayout) | ||
.set_support_level(4); | ||
|
||
} // namespace top | ||
} // namespace nnvm | ||
|
Oops, something went wrong.