Skip to content

Commit

Permalink
[PASS] Add gradient pass (apache#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and sergei-mironov committed Aug 8, 2018
1 parent e2925e3 commit ddea371
Show file tree
Hide file tree
Showing 13 changed files with 357 additions and 20 deletions.
77 changes: 69 additions & 8 deletions nnvm/example/src/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ using nnvm::FMutateInputs;
using nnvm::FInferShape;
using nnvm::FInferType;
using nnvm::FInplaceOption;
using nnvm::Node;
using nnvm::NodePtr;
using nnvm::NodeEntry;
using nnvm::FGradient;
using nnvm::NodeAttrs;
using nnvm::TShape;
using nnvm::array_view;
Expand All @@ -37,6 +41,17 @@ inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs)
return {{0, 0}};
}

// quick helper to make node
inline NodeEntry MakeNode(const char* op_name,
std::string node_name,
std::vector<NodeEntry> inputs) {
NodePtr p = Node::Create();
p->op = nnvm::Op::Get(op_name);
p->attrs.name = std::move(node_name);
p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0};
}

// simple demonstration of reshape.
NNVM_REGISTER_OP(reshape)
.describe("reshape source to target shape")
Expand Down Expand Up @@ -84,21 +99,67 @@ NNVM_REGISTER_OP(cast)
return true;
});

NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad",
{ograds[0], NodeEntry{n, 0, 0}})
};
});

NNVM_REGISTER_OP(identity)
.describe("identity function")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds) {
return std::vector<NodeEntry>{ograds[0]};
});

NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{ograds[0], ograds[0]};
});

NNVM_REGISTER_OP(__add_symbol__)
.describe("Alias of add")
.set_num_inputs(2);
NNVM_REGISTER_OP(mul)
.describe("multiply two data together")
.set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
.attr<FGradient>(
"FGradient", [](const NodePtr& n,
const std::vector<NodeEntry>& ograds){
return std::vector<NodeEntry>{
MakeNode("mul", n->attrs.name + "_grad_0",
{ograds[0], n->inputs[1]}),
MakeNode("mul", n->attrs.name + "_grad_1",
{ograds[0], n->inputs[0]})
};
});

NNVM_REGISTER_OP(exp)
.describe("take exponential")
.set_num_inputs(1)
.attr<FInferShape>("FInferShape", SameShape);
NNVM_REGISTER_OP(__ewise_sum__)
.describe("elementwise sum")
.set_num_inputs(nnvm::kVarg);

NNVM_REGISTER_OP(__zero__)
.describe("set output to zero")
.set_num_inputs(0);

NNVM_REGISTER_OP(__one__)
.describe("set output to one")
.set_num_inputs(0);

NNVM_REGISTER_OP(cross_device_copy)
.describe("Copy data across device.")
Expand Down
13 changes: 13 additions & 0 deletions nnvm/include/dmlc/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
__cplusplus >= 201103L || defined(_MSC_VER))
#endif

/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
#endif

/// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
Expand All @@ -69,6 +74,7 @@
#endif
#endif


/*!
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
Expand All @@ -82,6 +88,13 @@
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif

/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif

/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
Expand Down
9 changes: 6 additions & 3 deletions nnvm/include/dmlc/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#if DMLC_STRICT_CXX11
#include "./any.h"
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11

namespace dmlc {
Expand Down Expand Up @@ -320,7 +322,8 @@ class JSONObjectReadHelper {
};

#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
__make_AnyJSONType ## _ ## KeyName ## __

/*!
* \def DMLC_JSON_ENABLE_ANY
Expand Down Expand Up @@ -475,7 +478,7 @@ struct Handler {
}
};

#if DMLC_USE_CXX11
#if DMLC_STRICT_CXX11
// Manager to store json serialization strategy.
class AnyJSONManager {
public:
Expand Down Expand Up @@ -561,7 +564,7 @@ struct Handler<any> {
CHECK(!reader->NextArrayItem()) << "invalid any json format";
}
};
#endif // DMLC_USE_CXX11
#endif // DMLC_STRICT_CXX11

} // namespace json

Expand Down
3 changes: 2 additions & 1 deletion nnvm/include/dmlc/parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,8 @@ struct Parameter {
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
} \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \

//! \endcond
Expand Down
5 changes: 3 additions & 2 deletions nnvm/include/dmlc/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class FunctionRegEntryBase {
* \sa FactoryRegistryEntryBase
*/
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \

/*!
Expand Down Expand Up @@ -272,6 +272,7 @@ class FunctionRegEntryBase {
*/
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
__dmlc_registry_file_tag_ ## UniqueTag ## __();
} // namespace dmlc
#endif // DMLC_REGISTRY_H_
17 changes: 17 additions & 0 deletions nnvm/include/nnvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);

/*!
* \brief Get Set a attribute in json format.
* This feature allows pass graph attributes back and forth in reasonable speed.
Expand All @@ -273,6 +274,7 @@ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle,
const char* key,
const char* json_value);

/*!
* \brief Get a serialized attrirbute from graph.
* This feature allows pass graph attributes back and forth in reasonable speed.
Expand All @@ -289,6 +291,21 @@ NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle,
const char* key,
const char** json_out,
int *success);

/*!
* \brief Set a attribute whose type is std::vector<NodeEntry> in c++
* This feature allows pass List of symbolic variables for gradient request.
*
* \note This is beta feature only used for test purpos
*
* \param handle The graph handle.
* \param key The key to the attribute.
* \param list The symbol whose outputs represents the list of NodeEntry to be passed.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key,
SymbolHandle list);
/*!
* \brief Apply pass on the src graph.
* \param src The source graph handle.
Expand Down
6 changes: 2 additions & 4 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -279,10 +279,8 @@ class OpMap {
};

// internal macros to make
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
#define NNVM_REGISTER_VAR_DEF(OpName) \
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName

/*!
* \def NNVM_REGISTER_OP
Expand All @@ -300,7 +298,7 @@ class OpMap {
* \endcode
*/
#define NNVM_REGISTER_OP(OpName) \
NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)

// implementations of template functions after this.
Expand Down
14 changes: 14 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <utility>
#include <functional>
#include "./base.h"
#include "./node.h"
#include "./tuple.h"

namespace nnvm {
Expand Down Expand Up @@ -107,6 +108,19 @@ using TIsBackwardOp = bool;
using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;

/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using FGradient = std::function<std::vector<NodeEntry>(
const NodePtr& nodeptr,
const std::vector<NodeEntry>& out_grads)>;

} // namespace nnvm

#endif // NNVM_OP_ATTR_TYPES_H_
31 changes: 31 additions & 0 deletions nnvm/include/nnvm/pass_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,37 @@ inline Graph PlaceDevice(Graph graph,
return ApplyPass(std::move(graph), {"PlaceDevice"});
}

/*!
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
* \param graph source graph
* \param ys The entries we want to take gradient from.
* \param xs The input to take gradient with respect to.
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
* \param aggregate_fun aggregation function applied to aggregate the inputs
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
* \return A new graph, whose outputs corresponds to inputs of xs.
*/
inline Graph Gradient(
Graph graph,
std::vector<NodeEntry> ys,
std::vector<NodeEntry> xs,
std::vector<NodeEntry> ys_out_grad,
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
std::function<int(const Node& node)> mirror_fun = nullptr) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));

graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
}
if (mirror_fun != nullptr) {
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
}

return ApplyPass(std::move(graph), {"Gradient"});
}

} // namespace pass
} // namespace nnvm
#endif // NNVM_PASS_FUNCTIONS_H_
23 changes: 21 additions & 2 deletions nnvm/python/nnvm/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._base import c_array, c_str, nn_uint, py_str, string_types
from ._base import GraphHandle, SymbolHandle
from ._base import check_call
from .symbol import Symbol
from .symbol import Symbol, Group as _Group


class Graph(object):
Expand Down Expand Up @@ -56,8 +56,27 @@ def json_attr(self, key):
else:
return None

def _set_symbol_list_attr(self, key, value):
"""Set the attribute of the graph.
Parameters
----------
key : string
The key of the attribute
value : value
The any type that can be dumped to json
type_name : string
The typename registered on c++ side.
"""
if isinstance(value, list):
value = _Group(value)
if not isinstance(value, Symbol):
raise ValueError("value need to be grouped symbol")
check_call(_LIB.NNGraphSetNodeEntryListAttr_(
self.handle, c_str(key), value.handle))

def _set_json_attr(self, key, value, type_name=None):
"""Set the attribute of the symbol.
"""Set the attribute of the graph.
Parameters
----------
Expand Down
11 changes: 11 additions & 0 deletions nnvm/src/c_api/c_api_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
API_END_HANDLE_ERROR(delete s);
}

int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
const char* key,
SymbolHandle list) {
API_BEGIN();
Symbol* s = static_cast<Symbol*>(list);
Graph* g = static_cast<Graph*>(handle);
g->attrs[std::string(key)]
= std::make_shared<any>(s->outputs);
API_END();
}

int NNGraphSetJSONAttr(GraphHandle handle,
const char* key,
const char* json_value) {
Expand Down
Loading

0 comments on commit ddea371

Please sign in to comment.