From ddea3712e9a5320a089787570a784b43ef45b316 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 27 Aug 2016 09:45:16 -0700 Subject: [PATCH] [PASS] Add gradient pass (#28) --- nnvm/example/src/operator.cc | 77 +++++++++++++-- nnvm/include/dmlc/base.h | 13 +++ nnvm/include/dmlc/json.h | 9 +- nnvm/include/dmlc/parameter.h | 3 +- nnvm/include/dmlc/registry.h | 5 +- nnvm/include/nnvm/c_api.h | 17 ++++ nnvm/include/nnvm/op.h | 6 +- nnvm/include/nnvm/op_attr_types.h | 14 +++ nnvm/include/nnvm/pass_functions.h | 31 +++++++ nnvm/python/nnvm/graph.py | 23 ++++- nnvm/src/c_api/c_api_graph.cc | 11 +++ nnvm/src/pass/gradient.cc | 144 +++++++++++++++++++++++++++++ nnvm/tests/python/test_gradient.py | 24 +++++ 13 files changed, 357 insertions(+), 20 deletions(-) create mode 100644 nnvm/src/pass/gradient.cc create mode 100644 nnvm/tests/python/test_gradient.py diff --git a/nnvm/example/src/operator.cc b/nnvm/example/src/operator.cc index c0c729cdb4dd7..674893a555491 100644 --- a/nnvm/example/src/operator.cc +++ b/nnvm/example/src/operator.cc @@ -15,6 +15,10 @@ using nnvm::FMutateInputs; using nnvm::FInferShape; using nnvm::FInferType; using nnvm::FInplaceOption; +using nnvm::Node; +using nnvm::NodePtr; +using nnvm::NodeEntry; +using nnvm::FGradient; using nnvm::NodeAttrs; using nnvm::TShape; using nnvm::array_view; @@ -37,6 +41,17 @@ inline std::vector > InplaceIn0Out0(const NodeAttrs& attrs) return {{0, 0}}; } +// quick helper to make node +inline NodeEntry MakeNode(const char* op_name, + std::string node_name, + std::vector inputs) { + NodePtr p = Node::Create(); + p->op = nnvm::Op::Get(op_name); + p->attrs.name = std::move(node_name); + p->inputs = std::move(inputs); + return NodeEntry{p, 0, 0}; +} + // simple demonstration of reshape. NNVM_REGISTER_OP(reshape) .describe("reshape source to target shape") @@ -84,21 +99,67 @@ NNVM_REGISTER_OP(cast) return true; }); +NNVM_REGISTER_OP(exp) +.describe("take exponential") +.set_num_inputs(1) +.attr("FInferShape", SameShape) +.attr( + "FGradient", [](const NodePtr& n, + const std::vector& ograds) { + return std::vector{ + MakeNode("mul", n->attrs.name + "_grad", + {ograds[0], NodeEntry{n, 0, 0}}) + }; + }); + +NNVM_REGISTER_OP(identity) +.describe("identity function") +.set_num_inputs(1) +.attr("FInferShape", SameShape) +.attr( + "FGradient", [](const NodePtr& n, + const std::vector& ograds) { + return std::vector{ograds[0]}; + }); NNVM_REGISTER_OP(add) .describe("add two data together") .set_num_inputs(2) .attr("FInferShape", SameShape) -.attr("FInplaceOption", InplaceIn0Out0); +.attr("FInplaceOption", InplaceIn0Out0) +.attr( + "FGradient", [](const NodePtr& n, + const std::vector& ograds){ + return std::vector{ograds[0], ograds[0]}; + }); -NNVM_REGISTER_OP(__add_symbol__) -.describe("Alias of add") -.set_num_inputs(2); +NNVM_REGISTER_OP(mul) +.describe("multiply two data together") +.set_num_inputs(2) +.attr("FInferShape", SameShape) +.attr("FInplaceOption", InplaceIn0Out0) +.attr( + "FGradient", [](const NodePtr& n, + const std::vector& ograds){ + return std::vector{ + MakeNode("mul", n->attrs.name + "_grad_0", + {ograds[0], n->inputs[1]}), + MakeNode("mul", n->attrs.name + "_grad_1", + {ograds[0], n->inputs[0]}) + }; + }); -NNVM_REGISTER_OP(exp) -.describe("take exponential") -.set_num_inputs(1) -.attr("FInferShape", SameShape); +NNVM_REGISTER_OP(__ewise_sum__) +.describe("elementwise sum") +.set_num_inputs(nnvm::kVarg); + +NNVM_REGISTER_OP(__zero__) +.describe("set output to zero") +.set_num_inputs(0); + +NNVM_REGISTER_OP(__one__) +.describe("set output to one") +.set_num_inputs(0); NNVM_REGISTER_OP(cross_device_copy) .describe("Copy data across device.") diff --git a/nnvm/include/dmlc/base.h b/nnvm/include/dmlc/base.h index 5b34fd6b4e345..9eca4135f1191 100644 --- a/nnvm/include/dmlc/base.h +++ b/nnvm/include/dmlc/base.h @@ -58,6 +58,11 @@ __cplusplus >= 201103L || defined(_MSC_VER)) #endif +/*! \brief strict CXX11 support */ +#ifndef DMLC_STRICT_CXX11 +#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER)) +#endif + /// check if g++ is before 4.6 #if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) #if __GNUC__ == 4 && __GNUC_MINOR__ < 6 @@ -69,6 +74,7 @@ #endif #endif + /*! * \brief Enable std::thread related modules, * Used to disable some module in mingw compile. @@ -82,6 +88,13 @@ #define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER)) #endif +/*! \brief helper macro to supress unused warning */ +#if defined(__GNUC__) +#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused)) +#else +#define DMLC_ATTRIBUTE_UNUSED +#endif + /*! \brief helper macro to generate string concat */ #define DMLC_STR_CONCAT_(__x, __y) __x##__y #define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y) diff --git a/nnvm/include/dmlc/json.h b/nnvm/include/dmlc/json.h index 2daa0aaa017f6..1934aee6a2ce4 100644 --- a/nnvm/include/dmlc/json.h +++ b/nnvm/include/dmlc/json.h @@ -25,7 +25,9 @@ #include #include #include +#if DMLC_STRICT_CXX11 #include "./any.h" +#endif // DMLC_STRICT_CXX11 #endif // DMLC_USE_CXX11 namespace dmlc { @@ -320,7 +322,8 @@ class JSONObjectReadHelper { }; #define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \ - static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __ + static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \ + __make_AnyJSONType ## _ ## KeyName ## __ /*! * \def DMLC_JSON_ENABLE_ANY @@ -475,7 +478,7 @@ struct Handler { } }; -#if DMLC_USE_CXX11 +#if DMLC_STRICT_CXX11 // Manager to store json serialization strategy. class AnyJSONManager { public: @@ -561,7 +564,7 @@ struct Handler { CHECK(!reader->NextArrayItem()) << "invalid any json format"; } }; -#endif // DMLC_USE_CXX11 +#endif // DMLC_STRICT_CXX11 } // namespace json diff --git a/nnvm/include/dmlc/parameter.h b/nnvm/include/dmlc/parameter.h index 4ff99f860cc33..2fbab2a44e32f 100644 --- a/nnvm/include/dmlc/parameter.h +++ b/nnvm/include/dmlc/parameter.h @@ -251,7 +251,8 @@ struct Parameter { static ::dmlc::parameter::ParamManagerSingleton inst(#PType); \ return &inst.manager; \ } \ - static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \ + static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \ + __make__ ## PType ## ParamManager__ = \ (*PType::__MANAGER__()) \ //! \endcond diff --git a/nnvm/include/dmlc/registry.h b/nnvm/include/dmlc/registry.h index 67fbc43ded682..380b31cd3d61e 100644 --- a/nnvm/include/dmlc/registry.h +++ b/nnvm/include/dmlc/registry.h @@ -216,7 +216,7 @@ class FunctionRegEntryBase { * \sa FactoryRegistryEntryBase */ #define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \ - static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ + static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ ::dmlc::Registry::Get()->__REGISTER__(#Name) \ /*! @@ -272,6 +272,7 @@ class FunctionRegEntryBase { */ #define DMLC_REGISTRY_LINK_TAG(UniqueTag) \ int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \ - static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __(); + static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \ + __dmlc_registry_file_tag_ ## UniqueTag ## __(); } // namespace dmlc #endif // DMLC_REGISTRY_H_ diff --git a/nnvm/include/nnvm/c_api.h b/nnvm/include/nnvm/c_api.h index 1c6943f9c681c..f924701115940 100644 --- a/nnvm/include/nnvm/c_api.h +++ b/nnvm/include/nnvm/c_api.h @@ -260,6 +260,7 @@ NNVM_DLL int NNGraphFree(GraphHandle handle); * \return 0 when success, -1 when failure happens */ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); + /*! * \brief Get Set a attribute in json format. * This feature allows pass graph attributes back and forth in reasonable speed. @@ -273,6 +274,7 @@ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol); NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle, const char* key, const char* json_value); + /*! * \brief Get a serialized attrirbute from graph. * This feature allows pass graph attributes back and forth in reasonable speed. @@ -289,6 +291,21 @@ NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle, const char* key, const char** json_out, int *success); + +/*! + * \brief Set a attribute whose type is std::vector in c++ + * This feature allows pass List of symbolic variables for gradient request. + * + * \note This is beta feature only used for test purpos + * + * \param handle The graph handle. + * \param key The key to the attribute. + * \param list The symbol whose outputs represents the list of NodeEntry to be passed. + * \return 0 when success, -1 when failure happens + */ +NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle, + const char* key, + SymbolHandle list); /*! * \brief Apply pass on the src graph. * \param src The source graph handle. diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index 721e8e736e09b..e49bc9ae6643a 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -279,10 +279,8 @@ class OpMap { }; // internal macros to make -#define NNVM_STR_CONCAT_(__x, __y) __x##__y -#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y) #define NNVM_REGISTER_VAR_DEF(OpName) \ - static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName + static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName /*! * \def NNVM_REGISTER_OP @@ -300,7 +298,7 @@ class OpMap { * \endcode */ #define NNVM_REGISTER_OP(OpName) \ - NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ + DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) // implementations of template functions after this. diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index 675b93a6c9d2f..d3129f978cc2e 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -11,6 +11,7 @@ #include #include #include "./base.h" +#include "./node.h" #include "./tuple.h" namespace nnvm { @@ -107,6 +108,19 @@ using TIsBackwardOp = bool; using FInplaceOption = std::function< std::vector > (const NodeAttrs& attrs)>; +/*! + * \brief Get the gradient node of the op node + * This function generates the backward graph of the node + * \param nodeptr The node to take gradient + * \param out_grads Gradient of current node's outputs + * \return gradients of the inputs + * + * \note Register under "FGradient" + */ +using FGradient = std::function( + const NodePtr& nodeptr, + const std::vector& out_grads)>; + } // namespace nnvm #endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/nnvm/include/nnvm/pass_functions.h b/nnvm/include/nnvm/pass_functions.h index 8cca33e97cfde..a2cca949ff5b2 100644 --- a/nnvm/include/nnvm/pass_functions.h +++ b/nnvm/include/nnvm/pass_functions.h @@ -109,6 +109,37 @@ inline Graph PlaceDevice(Graph graph, return ApplyPass(std::move(graph), {"PlaceDevice"}); } +/*! + * \brief Get the gradient graph whose outputs are gradients of xs wrt to ys. + * \param graph source graph + * \param ys The entries we want to take gradient from. + * \param xs The input to take gradient with respect to. + * \param ys_out_grad The symbol for additional gradient to be propagate back to y. + * \param aggregate_fun aggregation function applied to aggregate the inputs + * \param mirror_fun Optional mirror function to do mirror optimization and save memory. + * \return A new graph, whose outputs corresponds to inputs of xs. + */ +inline Graph Gradient( + Graph graph, + std::vector ys, + std::vector xs, + std::vector ys_out_grad, + std::function&& inputs)> aggregate_fun = nullptr, + std::function mirror_fun = nullptr) { + graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); + + graph.attrs["grad_xs"] = std::make_shared(std::move(xs)); + graph.attrs["grad_ys_out_grad"] = std::make_shared(std::move(ys_out_grad)); + if (aggregate_fun != nullptr) { + graph.attrs["grad_aggregate_fun"] = std::make_shared(aggregate_fun); + } + if (mirror_fun != nullptr) { + graph.attrs["grad_mirror_fun"] = std::make_shared(mirror_fun); + } + + return ApplyPass(std::move(graph), {"Gradient"}); +} + } // namespace pass } // namespace nnvm #endif // NNVM_PASS_FUNCTIONS_H_ diff --git a/nnvm/python/nnvm/graph.py b/nnvm/python/nnvm/graph.py index 3f184928a9ec8..e3e857eecbf3b 100644 --- a/nnvm/python/nnvm/graph.py +++ b/nnvm/python/nnvm/graph.py @@ -10,7 +10,7 @@ from ._base import c_array, c_str, nn_uint, py_str, string_types from ._base import GraphHandle, SymbolHandle from ._base import check_call -from .symbol import Symbol +from .symbol import Symbol, Group as _Group class Graph(object): @@ -56,8 +56,27 @@ def json_attr(self, key): else: return None + def _set_symbol_list_attr(self, key, value): + """Set the attribute of the graph. + + Parameters + ---------- + key : string + The key of the attribute + value : value + The any type that can be dumped to json + type_name : string + The typename registered on c++ side. + """ + if isinstance(value, list): + value = _Group(value) + if not isinstance(value, Symbol): + raise ValueError("value need to be grouped symbol") + check_call(_LIB.NNGraphSetNodeEntryListAttr_( + self.handle, c_str(key), value.handle)) + def _set_json_attr(self, key, value, type_name=None): - """Set the attribute of the symbol. + """Set the attribute of the graph. Parameters ---------- diff --git a/nnvm/src/c_api/c_api_graph.cc b/nnvm/src/c_api/c_api_graph.cc index c3de70618b1fe..d3dd1d3e49aaa 100644 --- a/nnvm/src/c_api/c_api_graph.cc +++ b/nnvm/src/c_api/c_api_graph.cc @@ -35,6 +35,17 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) { API_END_HANDLE_ERROR(delete s); } +int NNGraphSetNodeEntryListAttr_(GraphHandle handle, + const char* key, + SymbolHandle list) { + API_BEGIN(); + Symbol* s = static_cast(list); + Graph* g = static_cast(handle); + g->attrs[std::string(key)] + = std::make_shared(s->outputs); + API_END(); +} + int NNGraphSetJSONAttr(GraphHandle handle, const char* key, const char* json_value) { diff --git a/nnvm/src/pass/gradient.cc b/nnvm/src/pass/gradient.cc new file mode 100644 index 0000000000000..530dd29f41c66 --- /dev/null +++ b/nnvm/src/pass/gradient.cc @@ -0,0 +1,144 @@ +/*! + * Copyright (c) 2016 by Contributors + * \file gradients.cc + * \brief Passes that takes gradient of the graph + * This code code was modified based on mxnet codebase by Min Lin + */ +#include +#include +#include +#include + +namespace nnvm { +namespace pass { +namespace { + +// default aggregate gradient function +// require operator __zero__ and __ewise_sum__ to be presented. +NodeEntry DefaultAggregateGradient(std::vector&& v) { + if (v.size() == 1) { + return std::move(v[0]); + } else if (v.size() == 0) { + NodePtr zero_node = Node::Create(); + zero_node->op = Op::Get("__zero__"); + return NodeEntry{zero_node, 0, 0}; + } else { + NodePtr sum_node = Node::Create(); + sum_node->op = Op::Get("__ewise_sum__"); + sum_node->inputs = std::move(v); + return NodeEntry{sum_node, 0, 0}; + } +} + +// helper entry +struct GradEntry { + NodeEntry sum{nullptr, 0, 0}; + std::vector grads; +}; + +Graph Gradient(Graph src) { + using nnvm::FGradient; + using MirrorFun = std::function; + + CHECK_NE(src.attrs.count("grad_ys"), 0) + << "Gradient require grad_ys to be presented."; + CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0) + << "Gradient require grad_ys_out_grad to be presented."; + CHECK_NE(src.attrs.count("grad_xs"), 0) + << "Gradient require grad_xs to be presented."; + const std::vector& ys = + src.GetAttr >("grad_ys"); + const std::vector& ys_out_grad = + src.GetAttr >("grad_ys_out_grad"); + const std::vector& xs = + src.GetAttr >("grad_xs"); + using AggFun = std::function&& inputs)>; + AggFun agg_fun = DefaultAggregateGradient; + if (src.attrs.count("grad_aggregate_fun") != 0) { + agg_fun = src.GetAttr("grad_aggregate_fun"); + } + MirrorFun mirror_fun = nullptr; + if (src.attrs.count("grad_mirror_fun") != 0) { + mirror_fun = src.GetAttr("grad_mirror_fun"); + } + + // topo sort + std::vector topo_order; + std::unordered_map > output_grads; + DFSVisit(ys, [&](const NodePtr& node) { + if (output_grads.count(node.get()) == 0) { + output_grads[node.get()].resize(node->num_outputs()); + } + topo_order.push_back(node); + }); + + CHECK_EQ(ys.size(), ys_out_grad.size()); + for (size_t i = 0; i < ys.size(); ++i) { + output_grads[ys[i].node.get()][ys[i].index].grads = { ys_out_grad[i] }; + } + + // construct mirror reduece memory strategy if needed + std::unordered_map mirror_map; + if (mirror_fun != nullptr) { + for (const NodePtr& n : topo_order) { + if (mirror_fun(*n)) { + NodePtr new_node = Node::Create(); + *new_node = *n; + new_node->attrs.name += "_mirror"; + for (auto& e : new_node->inputs) { + e.node = mirror_map.at(e.node.get()); + } + for (auto& n : new_node->control_deps) { + n = mirror_map.at(n.get()); + } + mirror_map[n.get()] = std::move(new_node); + } else { + mirror_map[n.get()] = n; + } + } + } + + // traverse backward + static auto& grad_fun_map = Op::GetAttr("FGradient"); + std::vector out_agg_grads; + for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { + const NodePtr& ptr = *rit; + if (ptr->is_variable()) continue; + out_agg_grads.clear(); + for (GradEntry& e : output_grads.at(ptr.get())) { + e.sum = agg_fun(std::move(e.grads)); + out_agg_grads.push_back(e.sum); + } + std::vector input_grads = grad_fun_map[ptr->op] + (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads); + auto git = input_grads.begin(); + for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { + output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git)); + } + } + // take out the xs' grads + Graph ret; + ret.outputs.reserve(xs.size()); + for (const NodeEntry& e : xs) { + GradEntry& entry = output_grads[e.node.get()][e.index]; + // aggregate sum if there haven't been + if (entry.sum.node.get() == nullptr) { + entry.sum = agg_fun(std::move(entry.grads)); + } + ret.outputs.emplace_back(std::move(entry.sum)); + } + return ret; +} + +// register pass +NNVM_REGISTER_PASS(Gradient) +.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]") +.set_body(Gradient) +.set_change_graph(true) +.depend_graph_attr("grad_ys") +.depend_graph_attr("grad_xs") +.depend_graph_attr("grad_ys_out_grad"); + +} // namespace +} // namespace pass +} // namespace nnvm diff --git a/nnvm/tests/python/test_gradient.py b/nnvm/tests/python/test_gradient.py new file mode 100644 index 0000000000000..5acdb1207f206 --- /dev/null +++ b/nnvm/tests/python/test_gradient.py @@ -0,0 +1,24 @@ +import json +import nnvm.symbol as sym +import nnvm.graph as graph + +def grad(ys, xs, ys_grads): + g = graph.create(ys) + g._set_symbol_list_attr('grad_ys', ys) + g._set_symbol_list_attr('grad_xs', xs) + g._set_symbol_list_attr('grad_ys_out_grad', ys_grads) + return g.apply('Gradient') + +def test_graph_gradient(): + x0 = sym.Variable('x0') + x1 = sym.Variable('x1') + yg = sym.Variable('yg') + y = sym.exp(sym.mul(x0, x1)) + grad_graph = grad(y, [x0], yg) + print("Original graph") + print(y.debug_str()) + print("Gradient graph") + print grad_graph.symbol.debug_str() + +if __name__ == "__main__": + test_graph_gradient()