Skip to content

Commit

Permalink
[PASS] Add order mutation (#7)
Browse files Browse the repository at this point in the history
* [PASS] Add order mutation

* A few benchmarks on compose speed
  • Loading branch information
tqchen committed May 29, 2018
1 parent d99d550 commit 9a956a8
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 63 deletions.
8 changes: 4 additions & 4 deletions nnvm/include/nnvm/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ class Symbol {
* \param kwargs keyword arguments for the symbol
* \param name name of returned symbol.
*/
void Compose(const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
void Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name);
/*!
* \brief Apply the symbol as a function, compose with arguments
Expand All @@ -84,8 +84,8 @@ class Symbol {
* \param name name of returned symbol.
* \return a new Symbol which is the composition of current symbol with its arguments
*/
Symbol operator () (const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
Symbol operator () (const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) const;
/*!
* \brief Add control flow depenencies to operators involved in symbols.
Expand Down
3 changes: 3 additions & 0 deletions nnvm/src/c_api/c_api_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <nnvm/c_api.h>
#include <nnvm/symbolic.h>
#include <vector>
#include <string>

Expand All @@ -36,6 +37,8 @@ struct NNAPIThreadLocalEntry {
std::vector<const char *> ret_vec_charp;
/*! \brief result holder for returning handles */
std::vector<void *> ret_handles;
/*! \brief argument holder to hold symbol */
std::unordered_map<std::string, const nnvm::Symbol*> kwarg_symbol;
};

/*! \brief Thread local store that can be used to hold return values. */
Expand Down
26 changes: 15 additions & 11 deletions nnvm/src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,22 +217,26 @@ int NNSymbolCompose(SymbolHandle sym,
const char** keys,
SymbolHandle* args) {
API_BEGIN();
std::string s_name;
if (name != nullptr) s_name = name;

NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
std::string& s_name = ret->ret_str;
std::unordered_map<std::string, const Symbol*>& kwargs
= ret->kwarg_symbol;
if (name != nullptr) {
s_name = name;
} else {
s_name.clear();
}
Symbol* s = static_cast<Symbol*>(sym);
if (keys == nullptr && num_args != 0) {
std::vector<Symbol> pos_args;
for (nn_uint i = 0; i < num_args; ++i) {
pos_args.push_back(*((Symbol*)args[i])); // NOLINT(*)
}
s->Compose(pos_args, {}, s_name);
kwargs.clear();
array_view<const Symbol*> parg(
(Symbol**)args, (Symbol**)args + num_args); // NOLINT(*)
s->Compose(parg, kwargs, s_name);
} else {
std::unordered_map<std::string, Symbol> kwargs;
for (nn_uint i = 0; i < num_args; ++i) {
kwargs[keys[i]] = *((Symbol*)args[i]); // NOLINT(*)
kwargs[keys[i]] = (Symbol*)args[i]; // NOLINT(*)
}
s->Compose({}, kwargs, s_name);
s->Compose(array_view<const Symbol*>(), kwargs, s_name);
}
API_END();
}
53 changes: 34 additions & 19 deletions nnvm/src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ inline void UpdateNodeVersion(Node *n) {
CHECK(e.node->is_variable())
<< "Mutation target can only be Variable";
// increase the version of the variable.
++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
}
}
}
Expand Down Expand Up @@ -98,14 +98,20 @@ Symbol Symbol::Copy() const {
std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
// use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const std::shared_ptr<Node>& node) {
old_new[node.get()] = std::make_shared<Node>(*node);
std::shared_ptr<Node> np = Node::Create();
np->op = node->op;
np->attrs = node->attrs;
old_new[node.get()] = std::move(np);
});
// connect nodes of new graph
for (const auto &kv : old_new) {
for (const NodeEntry& e : kv.first->inputs) {
Node *ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
}
for (const std::shared_ptr<Node>& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(old_new[p.get()]);
}
}
// set the head
Symbol ret;
Expand All @@ -120,7 +126,7 @@ void Symbol::Print(std::ostream &os) const {
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n';
} else {
// use DFSVisit to copy all the nodes
os << "Outputs:\n";
os << "Symbol Outputs:\n";
for (size_t i = 0; i < outputs.size(); ++i) {
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
<< '(' << outputs[i].index << ")\n";
Expand All @@ -129,7 +135,8 @@ void Symbol::Print(std::ostream &os) const {
if (node->is_variable()) {
os << "Variable:" << node->attrs.name << '\n';
} else {
os << "Name: " << node->attrs.name << " Op:" << node->op->name << '\n'
os << "--------------------\n";
os << "Op:" << node->op->name << ", Name=" << node->attrs.name << '\n'
<< "Inputs:\n";
for (size_t i = 0; i < node->inputs.size(); ++i) {
const NodeEntry& e = node->inputs[i];
Expand All @@ -141,9 +148,17 @@ void Symbol::Print(std::ostream &os) const {
os << '\n';
}
}
os << "Attrs:\n";
for (auto &kv : node->attrs.dict) {
os << '\t' << kv.first << '=' << kv.second << '\n';
if (!node->attrs.dict.empty()) {
os << "Attrs:\n";
for (auto &kv : node->attrs.dict) {
os << '\t' << kv.first << '=' << kv.second << '\n';
}
}
if (node->control_deps.size() != 0) {
os << "Control deps:\n";
for (size_t i = 0; i < node->control_deps.size(); ++i) {
os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n';
}
}
}
});
Expand Down Expand Up @@ -203,8 +218,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
}

// compositional logic
void Symbol::Compose(const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
void Symbol::Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");

Expand All @@ -213,11 +228,11 @@ void Symbol::Compose(const std::vector<Symbol>& args,
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i].outputs.size(), 1)
CHECK_EQ(args[i]->outputs.size(), 1)
<< "Argument " << i << " is a tuple, single value is required";
}
for (const auto& kv : kwargs) {
CHECK_EQ(kv.second.outputs.size(), 1)
CHECK_EQ(kv.second->outputs.size(), 1)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
Expand All @@ -234,7 +249,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
<< "Incorrect number of arguments, requires " << n_req
<< ", provided " << args.size();
for (size_t i = 0; i < args.size(); ++i) {
n->inputs[i] = args[i].outputs[0];
n->inputs[i] = args[i]->outputs[0];
}
// switch to keyword argument matching
if (args.size() != n_req) {
Expand All @@ -247,7 +262,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
for (size_t i = args.size(); i < n_req; ++i) {
auto it = kwargs.find(arg_names[i]);
if (it != kwargs.end() && it->first == arg_names[i]) {
n->inputs[i] = it->second.outputs[0];
n->inputs[i] = it->second->outputs[0];
++nmatched;
} else {
n->inputs[i] = NodeEntry{
Expand All @@ -266,8 +281,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
} else {
CHECK_EQ(kwargs.size(), 0) << "Variable length function do not accept kwargs";
n->inputs.reserve(args.size());
for (const Symbol& s : args) {
n->inputs.push_back(s.outputs[0]);
for (const Symbol* s : args) {
n->inputs.push_back(s->outputs[0]);
}
}
UpdateNodeVersion(n);
Expand All @@ -283,13 +298,13 @@ void Symbol::Compose(const std::vector<Symbol>& args,
(const std::shared_ptr<Node> &node) {
if (node->is_variable()) {
if (arg_counter < args.size()) {
replace_map[node.get()] = &(args[arg_counter].outputs[0]);
replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
++arg_counter;
} else {
// match kwargs
auto kit = kwargs.find(node->attrs.name);
if (kit != kwargs.end()) {
replace_map[node.get()] = &(kit->second.outputs[0]);
replace_map[node.get()] = &(kit->second->outputs[0]);
++nmatched;
}
}
Expand Down Expand Up @@ -334,8 +349,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
}
}

Symbol Symbol::operator () (const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
Symbol Symbol::operator () (const array_view<const Symbol*>& args,
const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) const {
Symbol s = this->Copy();
s.Compose(args, kwargs, name);
Expand Down
142 changes: 142 additions & 0 deletions nnvm/src/pass/order_mutation.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*!
* Copyright (c) 2016 by Contributors
* \file saveload_json.cc
* \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>

namespace nnvm {

template<typename T>
inline T get_with_default(const std::unordered_map<Node*, T> &map,
Node* key,
const T& def) {
auto it = map.find(key);
if (it != map.end()) return it->second;
return def;
}

Graph OrderMutation(const Graph& src) {
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
DFSVisit(src.outputs, [&version_hist](const std::shared_ptr<Node>& n) {
for (const NodeEntry& e : n->inputs) {
if (e.node->is_variable()) {
if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
version_hist[e.node.get()] = std::vector<NodeEntry>{};
}
}
}
});
// no mutation happens, everything if fine.
if (version_hist.size() == 0) return src;
// start preparing for remapping the nodes.
std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
auto prepare = [&version_hist, &old_new] (const std::shared_ptr<Node>& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
bool need_repl = false;
for (size_t i = 0; i < n->inputs.size(); ++i) {
const NodeEntry& e = n->inputs[i];
if (e.node->is_variable()) {
if (e.version != 0) need_repl = true;
auto it = version_hist.find(e.node.get());
if (it != version_hist.end()) {
std::vector<NodeEntry>& vec = it->second;
uint32_t is_mutate =
fmutate_inputs.count(n->op) ? fmutate_inputs[n->op](n->attrs, i) : 0;
vec.emplace_back(NodeEntry{n, is_mutate, e.version});
}
} else {
if (old_new.count(e.node.get()) != 0) need_repl = true;
}
}
for (const std::shared_ptr<Node>& p : n->control_deps) {
if (old_new.count(p.get()) != 0) need_repl = true;
}
if (need_repl) {
std::shared_ptr<Node> np = Node::Create();
np->op = n->op;
np->attrs = n->attrs;
old_new[n.get()] = std::move(np);
}
};
DFSVisit(src.outputs, prepare);
// comparator of history entry
auto comparator = [](const NodeEntry& a, const NodeEntry &b) {
if (a.version < b.version) return true;
if (a.version > b.version) return false;
return a.index > b.index;
};

for (auto &kv : version_hist) {
std::sort(kv.second.begin(), kv.second.end(), comparator);
}
// copy the nodes, as well as add control deps
for (auto &kv : old_new) {
// copy the nodes
for (const NodeEntry& e : kv.first->inputs) {
auto it = old_new.find(e.node.get());
if (it != old_new.end()) {
kv.second->inputs.emplace_back(NodeEntry{it->second, e.index, e.version});
} else {
kv.second->inputs.push_back(e);
}
}
for (const std::shared_ptr<Node>& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(
get_with_default(old_new, p.get(), p));
}
// add control deps
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
for (size_t i = 0; i < kv.first->inputs.size(); ++i) {
const NodeEntry& e = kv.first->inputs[i];
if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) {
FMutateInput fmutate = fmutate_inputs.get(kv.first->op, nullptr);
uint32_t is_mutate = (fmutate == nullptr) ? 0 : fmutate(kv.first->attrs, i);
std::vector<NodeEntry>& vec = version_hist.at(e.node.get());

auto it = std::lower_bound(vec.begin(), vec.end(),
NodeEntry{nullptr, 1, e.version},
comparator);
if (is_mutate != 0) {
int read_dep = 0;
while (it != vec.begin()) {
--it;
if (it->index != 0) break;
++read_dep;
// depend on previous read
kv.second->control_deps.push_back(
get_with_default(old_new, it->node.get(), it->node));
}
if (read_dep == 0 && it->index != 0) {
// depend on last write
kv.second->control_deps.push_back(
get_with_default(old_new, it->node.get(), it->node));
}
} else {
// depend on last write
if (it->index != 0) {
kv.second->control_deps.push_back(
get_with_default(old_new, it->node.get(), it->node));
}
}
}
}
}
Graph ret;
for (const NodeEntry &e : src.outputs) {
ret.outputs.emplace_back(NodeEntry{
get_with_default(old_new, e.node.get(), e.node), e.index, e.version});
}
return ret;
}

NNVM_REGISTER_PASS(OrderMutation)
.describe("Return a new graph that adds control dependencies, "\
"to order the mutation and reads if mutation exists.")
.set_body(OrderMutation)
.set_change_graph(true);

} // namespace nnvm
2 changes: 1 addition & 1 deletion nnvm/src/pass/saveload_json.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*!
* Copyright (c) 2016 by Contributors
* \file saveload_json.cc
* \brief Passes that defines save and load graph to/from JSON file.
* \brief Save and load graph to/from JSON file.
*/
#include <nnvm/pass.h>
#include <dmlc/json.h>
Expand Down
Loading

0 comments on commit 9a956a8

Please sign in to comment.