Skip to content

Commit

Permalink
Merge pull request #4566 from reyoung/feature/grad_reg_mechanism_cont2
Browse files Browse the repository at this point in the history
Complete register gradient for compile time
  • Loading branch information
wangkuiyi committed Oct 6, 2017
2 parents f8b5d54 + 803b7b6 commit 4c96008
Show file tree
Hide file tree
Showing 18 changed files with 271 additions and 491 deletions.
4 changes: 1 addition & 3 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)

cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op)

py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
Expand Down
32 changes: 31 additions & 1 deletion paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
limitations under the License. */

#include "paddle/framework/backward.h"
#include "paddle/operators/net_op.h"

#include <list>
#include <memory>
Expand All @@ -24,6 +25,35 @@
namespace paddle {
namespace framework {

static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op) {
OpDescBind op_desc;
op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.GradOpMaker()(op_desc);
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(),
std::back_inserter(grad_ops),
[](const std::unique_ptr<OpDescBind>& grad_desc) {
return OpRegistry::CreateOp(*grad_desc);
});
PADDLE_ENFORCE(!grad_ops.empty());
if (grad_ops.size() == 1) {
return std::move(grad_ops[0]);
} else {
auto net_op = new operators::NetOp();
for (auto& grad_op : grad_ops) {
net_op->AppendOp(std::move(grad_op));
}
net_op->CompleteAddOp();
return std::unique_ptr<OperatorBase>(net_op);
}
}

template <typename Map, typename T>
static void ForEachVarName(const Map& names, T callback) {
for (auto& name : names) {
Expand Down Expand Up @@ -171,7 +201,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net->InsertOp(pos.first + 1, std::move(pos.second));
}
} else {
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
std::unique_ptr<OperatorBase> grad_op(CreateGradOp(forwardOp));

ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) {
Expand Down
59 changes: 23 additions & 36 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,34 @@
namespace paddle {
namespace framework {

using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;

class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").NotInGradient();
AddInput("b", "Bias of Add").NotInGradient();
AddOutput("Out", "Out of Add").NotInGradient();
AddInput("X", "Input X of Add");
AddInput("b", "Bias of Add");
AddOutput("Out", "Out of Add");
AddComment("Add Op");
}
};

class RowWiseAddGradMaker : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;

protected:
std::unique_ptr<OpDescBind> Apply() const override {
auto grad_op = new OpDescBind();
grad_op->SetInput(GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(GradVarName("X"), InputGrad("X"));
grad_op->SetOutput(GradVarName("b"), InputGrad("b"));
grad_op->SetType("rowwise_add_grad");
return std::unique_ptr<OpDescBind>(grad_op);
}
};

class MulOpMaker : public OpProtoAndCheckerMaker {
public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
Expand Down Expand Up @@ -137,10 +147,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.")
.AsDuplicable()
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddInput("X", "the input tensors of sum operator.").AsDuplicable();
AddOutput("Out", "the output tensor of sum operator.");
AddComment("");
}
};
Expand All @@ -151,8 +159,9 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
namespace f = paddle::framework;
namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad,
f::NOP);
REGISTER_OPERATOR(rowwise_add, f::NOP, f::RowWiseAddOpMaker,
f::RowWiseAddGradMaker);
REGISTER_OPERATOR(rowwise_add_grad, f::NOP);
REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
Expand All @@ -162,17 +171,6 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP);

TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(1UL, gop->Inputs().size());
ASSERT_EQ("rowwise_add_grad", gop->Type());
ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X")));
ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b")));
}

TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
Expand Down Expand Up @@ -289,17 +287,6 @@ TEST(Backward, net_shared_weight) {
ASSERT_EQ("sum", bwd_net->ops_[2]->Type());
}

TEST(Backward, op_register_grad_not_for_network) {
auto fwd =
f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}},
{{"mul_result", {"mul_out"}},
{"add_result", {"add_out"}},
{"Out", {"out1"}}},
{{"temporary_index", std::vector<int>{0, 1}}});

ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}

TEST(Backward, op_all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
Expand Down
1 change: 0 additions & 1 deletion paddle/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ message OpProto {

optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ];
optional bool not_in_gradient = 5 [ default = false ];
}

// AttrProto describes the C++ type Attribute.
Expand Down
97 changes: 0 additions & 97 deletions paddle/framework/grad_op_builder.cc

This file was deleted.

28 changes: 0 additions & 28 deletions paddle/framework/grad_op_builder.h

This file was deleted.

Loading

0 comments on commit 4c96008

Please sign in to comment.