Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Wang committed Jun 4, 2018
1 parent 889457e commit b2eb8f9
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 16 deletions.
6 changes: 3 additions & 3 deletions nnvm/python/nnvm/compiler/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,9 +331,9 @@ def build(graph, target=None, shape=None, dtype="float32",
if params is None:
params = {}
params.update(init_var)
if not build_extra:
return graph, libmod, params
return graph, libmod, params, extra_lib
if build_extra:
return graph, libmod, params, extra_lib
return graph, libmod, params


def _run_graph(graph, params):
Expand Down
3 changes: 3 additions & 0 deletions nnvm/python/nnvm/compiler/graph_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def split_last_op(graph):
"""
graph_idx = graph.index
last_op_node = graph_idx.nodes[-1]
if last_op_node["op"] == "null":
raise RuntimeError("split_last_op doesn't support sast operator "
"to be null.")
last_op_func = getattr(sym, last_op_node["op"])
if "attrs" in last_op_node:
last_op_attr = last_op_node["attrs"]
Expand Down
10 changes: 3 additions & 7 deletions nnvm/src/top/vision/nms.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@ bool NMSShape(const NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 2U) << "Inputs: [data, valid_count]";
TShape dshape = in_attrs->at(0);
TShape vshape = in_attrs->at(1);
CHECK_EQ(dshape.ndim(), 3U) << "Provided: " << dshape;
CHECK_EQ(vshape.ndim(), 1U) << "Provided: " << vshape;
CHECK_EQ(dshape.ndim(), 3U) << "Input data should be 3-D.";
CHECK_EQ(vshape.ndim(), 1U) << "Input valid count should be 1-D.";
CHECK_EQ(dshape[2], 6U) << "Data input should have shape "
"(batch_size, num_anchors, 6).";
CHECK_EQ(dshape[0], vshape[0]) << "batch_size mismatch.";
TShape oshape = TShape(3);
oshape[0] = dshape[0];
oshape[1] = dshape[1];
oshape[2] = 6; // [id, prob, xmin, ymin, xmax, ymax]
out_attrs->clear();
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, oshape);
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_attrs, 0, dshape);
return true;
}

Expand Down
12 changes: 6 additions & 6 deletions nnvm/src/top/vision/ssd/mutibox_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ bool MultiBoxDetectionShape(const NodeAttrs& attrs,
TShape cshape = in_attrs->at(0);
TShape lshape = in_attrs->at(1);
TShape ashape = in_attrs->at(2);
CHECK_EQ(cshape.ndim(), 3U) << "Provided: " << cshape;
CHECK_EQ(lshape.ndim(), 2U) << "Provided: " << lshape;
CHECK_EQ(ashape.ndim(), 3U) << "Provided: " << ashape;
CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch";
CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc";
CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0";
CHECK_EQ(cshape.ndim(), 3U) << "Class probability should be 3-D.";
CHECK_EQ(lshape.ndim(), 2U) << "Location prediction should be 2-D.";
CHECK_EQ(ashape.ndim(), 3U) << "Anchor should be 3-D.";
CHECK_EQ(cshape[2], ashape[1]) << "Number of anchors mismatch.";
CHECK_EQ(cshape[2] * 4, lshape[1]) << "# anchors mismatch with # loc.";
CHECK_GT(ashape[1], 0U) << "Number of anchors must > 0.";
CHECK_EQ(ashape[2], 4U);
TShape oshape = TShape(3);
oshape[0] = cshape[0];
Expand Down

0 comments on commit b2eb8f9

Please sign in to comment.