From f5b4c3e57a03502c021911e2b409c8c7262439a2 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Mon, 29 Aug 2016 10:03:26 -0700 Subject: [PATCH] [NODE] Move op inside node attribute (#30) --- nnvm/example/src/operator.cc | 2 +- nnvm/include/nnvm/node.h | 33 +++++++++++++++++------------- nnvm/include/nnvm/pass_functions.h | 1 + nnvm/src/core/graph.cc | 6 +++--- nnvm/src/core/symbolic.cc | 31 ++++++++++++++-------------- nnvm/src/pass/gradient.cc | 6 +++--- nnvm/src/pass/infer_shape_type.cc | 6 +++--- nnvm/src/pass/order_mutation.cc | 9 ++++---- nnvm/src/pass/place_device.cc | 2 +- nnvm/src/pass/plan_memory.cc | 4 ++-- nnvm/src/pass/saveload_json.cc | 12 +++++------ 11 files changed, 58 insertions(+), 54 deletions(-) diff --git a/nnvm/example/src/operator.cc b/nnvm/example/src/operator.cc index 674893a55549..93a1b3db29c1 100644 --- a/nnvm/example/src/operator.cc +++ b/nnvm/example/src/operator.cc @@ -46,7 +46,7 @@ inline NodeEntry MakeNode(const char* op_name, std::string node_name, std::vector inputs) { NodePtr p = Node::Create(); - p->op = nnvm::Op::Get(op_name); + p->attrs.op = nnvm::Op::Get(op_name); p->attrs.name = std::move(node_name); p->inputs = std::move(inputs); return NodeEntry{p, 0, 0}; diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index 470d4d576381..a7daa0a6e887 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -46,6 +46,11 @@ struct NodeEntry { * Usually are additional parameters like axis, */ struct NodeAttrs { + /*! + * \brief The operator this node uses. + * For place holder variable, op == nullptr. + */ + const Op *op{nullptr}; /*! \brief name of the node */ std::string name; /*! \brief Vector representation of positional attributes */ @@ -65,11 +70,8 @@ struct NodeAttrs { */ class Node { public: - /*! - * \brief The operator this node uses. - * For place holder variable, op == nullptr. - */ - const Op *op{nullptr}; + /*! \brief The attributes in the node. */ + NodeAttrs attrs; /*! \brief inputs to this node */ std::vector inputs; /*! @@ -77,10 +79,10 @@ class Node { * Gives operation must be performed before this operation. */ std::vector control_deps; - /*! \brief The attributes in the node. */ - NodeAttrs attrs; /*! \brief destructor of node */ ~Node(); + /*! \return operator in this node */ + inline const Op* op() const; /*! * \brief return whether node is placeholder variable. * This is equivalent to op == nullptr @@ -99,25 +101,28 @@ class Node { }; // implementation of functions. +inline const Op* Node::op() const { + return this->attrs.op; +} inline bool Node::is_variable() const { - return this->op == nullptr; + return this->op() == nullptr; } inline uint32_t Node::num_outputs() const { if (is_variable()) return 1; - if (this->op->get_num_outputs == nullptr) { - return this->op->num_outputs; + if (this->op()->get_num_outputs == nullptr) { + return this->op()->num_outputs; } else { - return this->op->get_num_outputs(this->attrs); + return this->op()->get_num_outputs(this->attrs); } } inline uint32_t Node::num_inputs() const { if (is_variable()) return 1; - if (this->op->get_num_inputs == nullptr) { - return this->op->num_inputs; + if (this->op()->get_num_inputs == nullptr) { + return this->op()->num_inputs; } else { - return this->op->get_num_inputs(this->attrs); + return this->op()->get_num_inputs(this->attrs); } } diff --git a/nnvm/include/nnvm/pass_functions.h b/nnvm/include/nnvm/pass_functions.h index a2cca949ff5b..b7068822d9ee 100644 --- a/nnvm/include/nnvm/pass_functions.h +++ b/nnvm/include/nnvm/pass_functions.h @@ -12,6 +12,7 @@ #include #include +#include #include "./base.h" #include "./pass.h" #include "./graph_attr_types.h" diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index 2e0160072612..c500995eab70 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -66,9 +66,9 @@ IndexedGraph::IndexedGraph(const Graph &g) { for (size_t nid = 0; nid < nodes_.size(); ++nid) { nodes_[nid].inputs = array_view( iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); - if (nodes_[nid].source->op != nullptr && - fmutate_inputs.count(nodes_[nid].source->op)) { - for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) { + if (nodes_[nid].source->op() != nullptr && + fmutate_inputs.count(nodes_[nid].source->op())) { + for (uint32_t i : fmutate_inputs[nodes_[nid].source->op()](nodes_[nid].source->attrs)) { mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id); } } diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index d595880aed1c..97b35648ee7d 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -20,7 +20,7 @@ struct VariableParam { NodePtr CreateVariableNode(const std::string& name) { NodePtr n = Node::Create(); - n->op = nullptr; + n->attrs.op = nullptr; n->attrs.name = name; n->attrs.parsed = VariableParam(); return n; @@ -37,8 +37,8 @@ inline void UpdateNodeVersion(Node *n) { e.version = nnvm::get(e.node->attrs.parsed).version; } } - if (fmutate_inputs.count(n->op) != 0) { - for (uint32_t i : fmutate_inputs[n->op](n->attrs)) { + if (fmutate_inputs.count(n->op()) != 0) { + for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) { NodeEntry& e = n->inputs[i]; CHECK(e.node->is_variable()) << "Mutation target can only be Variable"; @@ -96,7 +96,6 @@ Symbol Symbol::Copy() const { // use DFSVisit to copy all the nodes DFSVisit(this->outputs, [&old_new](const NodePtr& node) { NodePtr np = Node::Create(); - np->op = node->op; np->attrs = node->attrs; old_new[node.get()] = std::move(np); }); @@ -123,7 +122,7 @@ void Symbol::Print(std::ostream &os) const { if (outputs[0].node->is_variable()) { os << "Variable:" << outputs[0].node->attrs.name << '\n'; } else { - os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n'; + os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n'; } } else { // use DFSVisit to copy all the nodes @@ -137,7 +136,7 @@ void Symbol::Print(std::ostream &os) const { os << "Variable:" << node->attrs.name << '\n'; } else { os << "--------------------\n"; - os << "Op:" << node->op->name << ", Name=" << node->attrs.name << '\n' + os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n' << "Inputs:\n"; for (size_t i = 0; i < node->inputs.size(); ++i) { const NodeEntry& e = node->inputs[i]; @@ -196,8 +195,8 @@ std::vector Symbol::ListInputNames(ListInputOption option) const { DFSVisit(this->outputs, [&ret, &mutable_set, &vlist](const NodePtr &node) { if (node->is_variable()) { vlist.push_back(node.get()); - } else if (fmutate_inputs.count(node->op)) { - for (uint32_t i : fmutate_inputs[node->op](node->attrs)){ + } else if (fmutate_inputs.count(node->op())) { + for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){ mutable_set.insert(node->inputs[i].node.get()); } } @@ -221,7 +220,7 @@ std::vector Symbol::ListOutputNames() const { } else { const std::string& hname = head.node->attrs.name; std::string rname; - FListOutputNames fn = flist_ouputs.get(head.node->op, nullptr); + FListOutputNames fn = flist_ouputs.get(head.node->op(), nullptr); if (fn != nullptr) { rname = fn(head.node->attrs)[head.index]; } else { @@ -278,10 +277,10 @@ void Symbol::Compose(const array_view& args, } // switch to keyword argument matching if (args.size() != n_req) { - FListInputNames fn = flist_inputs.get(n->op, nullptr); + FListInputNames fn = flist_inputs.get(n->op(), nullptr); auto arg_names = (fn == nullptr) ? std::vector{"data"} : fn(n->attrs); if (arg_names.size() != n_req) { - LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op->name; + LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name; } size_t nmatched = 0; for (size_t i = args.size(); i < n_req; ++i) { @@ -422,8 +421,8 @@ void Symbol::SetAttrs(const std::vector >& a node->attrs.dict[kv.first] = kv.second; } } - if (node->op != nullptr && node->op->attr_parser != nullptr) { - node->op->attr_parser(&(node->attrs)); + if (node->op() != nullptr && node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); } } @@ -461,10 +460,10 @@ Symbol Symbol::CreateFunctor(const Op* op, std::unordered_map attrs) { Symbol s; NodePtr n = Node::Create(); - n->op = op; + n->attrs.op = op; n->attrs.dict = std::move(attrs); - if (n->op->attr_parser != nullptr) { - n->op->attr_parser(&(n->attrs)); + if (n->op()->attr_parser != nullptr) { + n->op()->attr_parser(&(n->attrs)); } s.outputs.emplace_back(NodeEntry{std::move(n), 0, 0}); return s; diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc index 530dd29f41c6..a64cf0c11303 100644 --- a/nnvm/src/pass/gradient.cc +++ b/nnvm/src/pass/gradient.cc @@ -20,11 +20,11 @@ NodeEntry DefaultAggregateGradient(std::vector&& v) { return std::move(v[0]); } else if (v.size() == 0) { NodePtr zero_node = Node::Create(); - zero_node->op = Op::Get("__zero__"); + zero_node->attrs.op = Op::Get("__zero__"); return NodeEntry{zero_node, 0, 0}; } else { NodePtr sum_node = Node::Create(); - sum_node->op = Op::Get("__ewise_sum__"); + sum_node->attrs.op = Op::Get("__ewise_sum__"); sum_node->inputs = std::move(v); return NodeEntry{sum_node, 0, 0}; } @@ -109,7 +109,7 @@ Graph Gradient(Graph src) { e.sum = agg_fun(std::move(e.grads)); out_agg_grads.push_back(e.sum); } - std::vector input_grads = grad_fun_map[ptr->op] + std::vector input_grads = grad_fun_map[ptr->op()] (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads); auto git = input_grads.begin(); for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index 5978ecdb79f2..a5cc8c13751f 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -65,7 +65,7 @@ Graph InferAttr(Graph &&ret, } continue; } - if (finfer_shape.count(inode.source->op)) { + if (finfer_shape.count(inode.source->op())) { ishape.resize(num_inputs, def_value); for (uint32_t i = 0; i < ishape.size(); ++i) { ishape[i] = rshape[idx.entry_id(inode.inputs[i])]; @@ -75,14 +75,14 @@ Graph InferAttr(Graph &&ret, oshape[i] = rshape[idx.entry_id(nid, i)]; } num_unknown += - !(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape)); + !(finfer_shape[inode.source->op()](inode.source->attrs, &ishape, &oshape)); for (uint32_t i = 0; i < num_inputs; ++i) { rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; } for (uint32_t i = 0; i < num_outputs; ++i) { rshape[idx.entry_id(nid, i)] = oshape[i]; } - } else if (is_backward.get(inode.source->op, false)) { + } else if (is_backward.get(inode.source->op(), false)) { // backward operator inference. CHECK_GE(inode.control_deps.size(), 1) << "BackwardOp need to have control_deps to its forward op"; diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index 3bcfd9922d53..e91d114ea101 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -43,8 +43,8 @@ Graph OrderMutation(const Graph& src) { auto prepare = [&version_hist, &old_new] (const NodePtr& n) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); std::vector mutate_inputs; - if (!n->is_variable() && fmutate_inputs.count(n->op)) { - mutate_inputs = fmutate_inputs[n->op](n->attrs); + if (!n->is_variable() && fmutate_inputs.count(n->op())) { + mutate_inputs = fmutate_inputs[n->op()](n->attrs); } std::sort(mutate_inputs.begin(), mutate_inputs.end()); @@ -67,7 +67,6 @@ Graph OrderMutation(const Graph& src) { } if (need_repl) { NodePtr np = Node::Create(); - np->op = n->op; np->attrs = n->attrs; old_new[n.get()] = std::move(np); } @@ -101,8 +100,8 @@ Graph OrderMutation(const Graph& src) { // add control deps static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); std::vector mutate_inputs; - if (fmutate_inputs.count(kv.first->op)) { - mutate_inputs = fmutate_inputs[kv.first->op](kv.first->attrs); + if (fmutate_inputs.count(kv.first->op())) { + mutate_inputs = fmutate_inputs[kv.first->op()](kv.first->attrs); } std::sort(mutate_inputs.begin(), mutate_inputs.end()); diff --git a/nnvm/src/pass/place_device.cc b/nnvm/src/pass/place_device.cc index 6a6e877a9f87..402f2cff784c 100644 --- a/nnvm/src/pass/place_device.cc +++ b/nnvm/src/pass/place_device.cc @@ -109,9 +109,9 @@ Graph PlaceDevice(Graph src) { NodeEntry{it->second, 0, 0}); } else { NodePtr copy_node = Node::Create(); - copy_node->op = copy_op; std::ostringstream os; os << inode.source->inputs[i].node->attrs.name << "_" << e.index <<"_copy"; + copy_node->attrs.op = copy_op; copy_node->attrs.name = os.str(); copy_node->inputs.push_back(inode.source->inputs[i]); copy_map[copy_key] = copy_node; diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index 14a88d217de8..34b05d5d6c94 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -168,8 +168,8 @@ Graph PlanMemory(Graph ret) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; // check inplace option - if (finplace_option.count(inode.source->op) != 0) { - auto inplace_pairs = finplace_option[inode.source->op](inode.source->attrs); + if (finplace_option.count(inode.source->op()) != 0) { + auto inplace_pairs = finplace_option[inode.source->op()](inode.source->attrs); for (auto& kv : inplace_pairs) { uint32_t eid_out = idx.entry_id(nid, kv.second); uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]); diff --git a/nnvm/src/pass/saveload_json.cc b/nnvm/src/pass/saveload_json.cc index 984a2c9905c4..681daed7a1dd 100644 --- a/nnvm/src/pass/saveload_json.cc +++ b/nnvm/src/pass/saveload_json.cc @@ -68,8 +68,8 @@ struct JSONNode { // function to save JSON node. void Save(dmlc::JSONWriter *writer) const { writer->BeginObject(); - if (node->op != nullptr) { - writer->WriteObjectKeyValue("op", node->op->name); + if (node->op() != nullptr) { + writer->WriteObjectKeyValue("op", node->op()->name); } else { std::string json_null = "null"; writer->WriteObjectKeyValue("op", json_null); @@ -108,10 +108,10 @@ struct JSONNode { if (op_type_str != "null") { try { - node->op = Op::Get(op_type_str); + node->attrs.op = Op::Get(op_type_str); // rebuild attribute parser - if (node->op->attr_parser != nullptr) { - node->op->attr_parser(&(node->attrs)); + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); } } catch (const dmlc::Error &err) { std::ostringstream os; @@ -120,7 +120,7 @@ struct JSONNode { throw dmlc::Error(os.str()); } } else { - node->op = nullptr; + node->attrs.op = nullptr; } } };