Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complete register gradient for compile time #4566

Merged
4 changes: 1 addition & 3 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,8 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)

cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info)
cc_library(op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry sum_op)

py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
Expand Down
32 changes: 31 additions & 1 deletion paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
limitations under the License. */

#include "paddle/framework/backward.h"
#include "paddle/operators/net_op.h"

#include <list>
#include <memory>
Expand All @@ -24,6 +25,35 @@
namespace paddle {
namespace framework {

static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op) {
OpDescBind op_desc;
op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type());
op_desc.SetAttrMap(op.Attrs());
auto& info = OpInfoMap::Instance().Get(op.Type());
auto grad_descs = info.GradOpMaker()(op_desc);
std::vector<std::unique_ptr<OperatorBase>> grad_ops;
grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(),
std::back_inserter(grad_ops),
[](const std::unique_ptr<OpDescBind>& grad_desc) {
return OpRegistry::CreateOp(*grad_desc);
});
PADDLE_ENFORCE(!grad_ops.empty());
if (grad_ops.size() == 1) {
return std::move(grad_ops[0]);
} else {
auto net_op = new operators::NetOp();
for (auto& grad_op : grad_ops) {
net_op->AppendOp(std::move(grad_op));
}
net_op->CompleteAddOp();
return std::unique_ptr<OperatorBase>(net_op);
}
}

template <typename Map, typename T>
static void ForEachVarName(const Map& names, T callback) {
for (auto& name : names) {
Expand Down Expand Up @@ -171,7 +201,7 @@ static std::unique_ptr<OperatorBase> BackwardRecursive(
net->InsertOp(pos.first + 1, std::move(pos.second));
}
} else {
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));
std::unique_ptr<OperatorBase> grad_op(CreateGradOp(forwardOp));

ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) {
Expand Down
59 changes: 23 additions & 36 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,34 @@
namespace paddle {
namespace framework {

using OperatorBase = framework::OperatorBase;
using OpProtoAndCheckerMaker = framework::OpProtoAndCheckerMaker;
using OpProto = framework::OpProto;
using OpAttrChecker = framework::OpAttrChecker;
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;

class RowWiseAddOpMaker : public OpProtoAndCheckerMaker {
public:
RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input X of Add").NotInGradient();
AddInput("b", "Bias of Add").NotInGradient();
AddOutput("Out", "Out of Add").NotInGradient();
AddInput("X", "Input X of Add");
AddInput("b", "Bias of Add");
AddOutput("Out", "Out of Add");
AddComment("Add Op");
}
};

class RowWiseAddGradMaker : public SingleGradOpDescMaker {
public:
using SingleGradOpDescMaker::SingleGradOpDescMaker;

protected:
std::unique_ptr<OpDescBind> Apply() const override {
auto grad_op = new OpDescBind();
grad_op->SetInput(GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(GradVarName("X"), InputGrad("X"));
grad_op->SetOutput(GradVarName("b"), InputGrad("b"));
grad_op->SetType("rowwise_add_grad");
return std::unique_ptr<OpDescBind>(grad_op);
}
};

class MulOpMaker : public OpProtoAndCheckerMaker {
public:
MulOpMaker(OpProto *proto, OpAttrChecker *op_checker)
Expand Down Expand Up @@ -137,10 +147,8 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "the input tensors of sum operator.")
.AsDuplicable()
.NotInGradient();
AddOutput("Out", "the output tensor of sum operator.").NotInGradient();
AddInput("X", "the input tensors of sum operator.").AsDuplicable();
AddOutput("Out", "the output tensor of sum operator.");
AddComment("");
}
};
Expand All @@ -151,8 +159,9 @@ class SumOpMaker : public framework::OpProtoAndCheckerMaker {
namespace f = paddle::framework;
namespace ops = paddle::operators;
using EnforceNotMet = paddle::platform::EnforceNotMet;
REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad,
f::NOP);
REGISTER_OPERATOR(rowwise_add, f::NOP, f::RowWiseAddOpMaker,
f::RowWiseAddGradMaker);
REGISTER_OPERATOR(rowwise_add_grad, f::NOP);
REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP);
REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP);
REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker);
Expand All @@ -162,17 +171,6 @@ REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker);
REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad,
f::NOP);

TEST(Backward, simple_op_grad) {
auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
ASSERT_NE(fwd, nullptr);
auto gop = f::OpRegistry::CreateGradOp(*fwd);
ASSERT_EQ(1UL, gop->Inputs().size());
ASSERT_EQ("rowwise_add_grad", gop->Type());
ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X")));
ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b")));
}

TEST(Backward, simple_op_not_need_grad) {
auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
Expand Down Expand Up @@ -289,17 +287,6 @@ TEST(Backward, net_shared_weight) {
ASSERT_EQ("sum", bwd_net->ops_[2]->Type());
}

TEST(Backward, op_register_grad_not_for_network) {
auto fwd =
f::OpRegistry::CreateOp("fc", {{"X", {"x"}}, {"W", {"w"}}, {"b", {"b"}}},
{{"mul_result", {"mul_out"}},
{"add_result", {"add_out"}},
{"Out", {"out1"}}},
{{"temporary_index", std::vector<int>{0, 1}}});

ASSERT_THROW(f::OpRegistry::CreateGradOp(*fwd), EnforceNotMet);
}

TEST(Backward, op_all_input_are_not_need) {
auto fwd = f::OpRegistry::CreateOp(
"rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {});
Expand Down
1 change: 0 additions & 1 deletion paddle/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ message OpProto {

optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ];
optional bool not_in_gradient = 5 [ default = false ];
}

// AttrProto describes the C++ type Attribute.
Expand Down
97 changes: 0 additions & 97 deletions paddle/framework/grad_op_builder.cc

This file was deleted.

28 changes: 0 additions & 28 deletions paddle/framework/grad_op_builder.h

This file was deleted.

Loading