Skip to content

Commit

Permalink
Merge pull request #3538 from reyoung/feature/remove_shared_ptr
Browse files Browse the repository at this point in the history
Feature/remove shared ptr
  • Loading branch information
reyoung committed Aug 17, 2017
2 parents 812a64c + 16d0215 commit 1808f88
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 160 deletions.
42 changes: 20 additions & 22 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "paddle/framework/backward.h"

#include <list>
#include <memory>

#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
Expand Down Expand Up @@ -43,11 +45,11 @@ static bool AllInSet(
return all_in_set;
}

static std::shared_ptr<OperatorBase> NOP() {
auto net_op = std::make_shared<operators::NetOp>();
static std::unique_ptr<OperatorBase> NOP() {
auto net_op = new operators::NetOp();
net_op->SetType("@NOP@");
net_op->CompleteAddOp();
return net_op;
return std::unique_ptr<OperatorBase>(net_op);
}

// Get backward operator from a forward operator, a recursive implementation.
Expand All @@ -62,11 +64,7 @@ static std::shared_ptr<OperatorBase> NOP() {
// operator, in a complex situation, it maybe a NetOp.
//
// See Backward.h for details
static std::shared_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id);

std::shared_ptr<OperatorBase> BackwardRecursive(
static std::unique_ptr<OperatorBase> BackwardRecursive(
const OperatorBase& forwardOp,
std::unordered_set<std::string>& no_grad_names, size_t& uniq_id) {
// If all input gradients of forwarding operator do not need to calculate,
Expand All @@ -91,7 +89,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
}

// Returned gradient network
auto net = std::make_shared<operators::NetOp>();
auto net = std::unique_ptr<operators::NetOp>(new operators::NetOp());

if (forwardOp.IsNetOp()) {
// Because forwardOp is a net op, it can static_cast.
Expand All @@ -105,14 +103,14 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// reversely travel forwardNet and collect all duplicate outputs.
for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend();
++it, ++local_op_id) {
auto fwd = *it;
auto& fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd);
ForEachVarName(bwd->Outputs(),
[&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id);
return false;
});
net->AddOp(std::move(bwd));
}
// Get unique ID for this method.
auto uid = uniq_id++;
Expand All @@ -122,7 +120,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// to handle this case. For each duplicate output, rename it to an alias
// (original name with a offset), append an `add` op for its operator,
// and finally sum all the alias variable to the final output variable y.
using Pos = std::pair<size_t, std::shared_ptr<OperatorBase>>;
using Pos = std::pair<size_t, std::unique_ptr<OperatorBase>>;
std::list<Pos> insert_position;
for (auto& dup_output_op : dup_output_ops) {
const std::string& name = dup_output_op.first;
Expand Down Expand Up @@ -150,13 +148,13 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
[](const Pos& l, const Pos& r) { return l.first > r.first; });

for (auto& pos : insert_position) {
net->InsertOp(pos.first + 1, pos.second);
net->InsertOp(pos.first + 1, std::move(pos.second));
}
} else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
std::unique_ptr<OperatorBase> grad_op(OpRegistry::CreateGradOp(forwardOp));

ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net,
grad_op](const std::string& grad_input) {
ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, &grad_op](
const std::string& grad_input) {
if (no_grad_names.count(grad_input)) {
// +1 for \0
std::string prefix = grad_input.substr(
Expand Down Expand Up @@ -190,23 +188,23 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
const auto& stepnet_op =
*static_cast<const OperatorBase*>(&rnnop.stepnet());
// create stepnet's gradient op
auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id);
rnn_grad_op->set_stepnet(
std::static_pointer_cast<operators::NetOp>(grad_stepnet));
BackwardRecursive(stepnet_op, no_grad_names, uniq_id));
}

if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
}
net->AddOp(grad_op);
net->AddOp(std::move(grad_op));
}
net->SetType("@GENERATED_BACKWARD@");
net->CompleteAddOp();
return net;
} // namespace framework
return std::unique_ptr<OperatorBase>(
static_cast<OperatorBase*>(net.release()));
}

// See header for comments
std::shared_ptr<OperatorBase> Backward(
std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_names;
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/backward.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace framework {

// Create the backward operator from a forward operator.
// TODO(yuyang18): Add more API reference comment.
extern std::shared_ptr<OperatorBase> Backward(
extern std::unique_ptr<OperatorBase> Backward(
const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework
Expand Down
3 changes: 1 addition & 2 deletions paddle/framework/backward_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ TEST(Backward, simple_op_not_need_grad) {
auto no_input_gop = f::Backward(*fwd, {"x", "b"});
ASSERT_NE(no_input_gop, nullptr);
ASSERT_TRUE(no_input_gop->IsNetOp());
ASSERT_EQ(0UL,
std::static_pointer_cast<ops::NetOp>(no_input_gop)->ops_.size());
ASSERT_EQ(0UL, static_cast<ops::NetOp *>(no_input_gop.get())->ops_.size());
}

TEST(Backward, net_fc_backward_normal) {
Expand Down
11 changes: 5 additions & 6 deletions paddle/framework/op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle {
namespace framework {

std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
AttributeMap attrs) {
Expand All @@ -28,10 +28,10 @@ std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const std::string& type,
"Operator '%s' has not been registered.", type);
it->second.checker_->Check(attrs);
auto op = it->second.creator_(type, inputs, outputs, attrs);
return std::shared_ptr<OperatorBase>(op);
return std::unique_ptr<OperatorBase>(op);
}

std::shared_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
VarNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VarNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
Expand All @@ -55,10 +55,9 @@ OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap(
return ret_val;
}

std::shared_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
std::unique_ptr<OperatorBase> OpRegistry::CreateGradOp(const OperatorBase& op) {
PADDLE_ENFORCE(!op.IsNetOp(), "Use framework::Backward to get backward ops");
std::shared_ptr<OperatorBase> grad_op(BuildGradOp(&op));
return grad_op;
return std::unique_ptr<OperatorBase>(BuildGradOp(&op));
}

} // namespace framework
Expand Down
6 changes: 3 additions & 3 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,17 @@ class OpRegistry {
}
}

static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
static std::unique_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameMap& inputs,
const VarNameMap& outputs,
AttributeMap attrs);

static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);

static VarNameMap ConvertOpDescVarsToVarNameMap(
const google::protobuf::RepeatedPtrField<OpDesc::Var>& op_desc_vars);

static std::shared_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);

static std::unordered_map<std::string, const OpInfo>& op_info_map() {
static std::unordered_map<std::string, const OpInfo> op_info_map_;
Expand Down
6 changes: 2 additions & 4 deletions paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale);

std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
Expand Down Expand Up @@ -118,8 +117,7 @@ TEST(OpRegistry, DefaultValue) {

ASSERT_TRUE(op_desc.IsInitialized());

std::shared_ptr<paddle::framework::OperatorBase> op =
paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
Expand Down
141 changes: 56 additions & 85 deletions paddle/framework/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,6 @@ namespace framework {

using Tensor = framework::Tensor;

template <typename ClassType>
void ExposeOperator(ClassType &m) {
m.def("infer_shape", &ClassType::type::InferShape)
.def("run", &ClassType::type::Run)
.def("type",
[](const typename ClassType::type &op) -> std::string {
return op.Type();
})
.def("outputs",
[](const typename ClassType::type &op)
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("inputs",
[](const typename ClassType::type &op) { return op.Inputs(); })
.def("__str__", &ClassType::type::DebugString)
.def("no_intermediate_outputs",
[](const typename ClassType::type &op) {
return op.OutputVars(false);
})
.def("support_gpu", &ClassType::type::SupportGPU);
}

static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator;
return generator.fetch_add(1);
Expand Down Expand Up @@ -207,75 +184,69 @@ All parameter, weight, gradient are variables in Paddle.
.def(py::init<>())
.def("__str__", string::to_string<const platform::CPUPlace &>);

py::class_<OperatorBase, std::shared_ptr<OperatorBase>> operator_base(
m, "Operator");

operator_base.def_static("create", [](py::bytes protobin) {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
});

operator_base.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
return Backward(forwardOp, no_grad_vars);
});

ExposeOperator(operator_base);

py::class_<operators::NetOp, std::shared_ptr<operators::NetOp>> net(m, "Net");

net.def_static("create",
[]() -> std::shared_ptr<operators::NetOp> {
auto retv = std::make_shared<operators::NetOp>();
retv->SetType("plain_net");
return retv;
})
.def("add_op", &operators::NetOp::AddOp)
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::RecurrentOp> &rnn) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(rnn));
py::class_<OperatorBase>(m, "Operator")
.def_static("create",
[](py::bytes protobin) {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
})
.def("backward",
[](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) {
return Backward(forwardOp, no_grad_vars).release();
})
.def("infer_shape", &OperatorBase::InferShape)
.def("run", &OperatorBase::Run)
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
[](const OperatorBase &op)
-> std::map<std::string, std::vector<std::string>> {
return op.Outputs();
})
.def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
.def("__str__", &OperatorBase::DebugString)
.def("no_intermediate_outputs",
[](const OperatorBase &op) { return op.OutputVars(false); })
.def("support_gpu", &OperatorBase::SupportGPU);

py::class_<operators::NetOp, OperatorBase>(m, "Net")
.def_static("create",
[]() -> operators::NetOp * {
auto *retv = new operators::NetOp;
retv->SetType("plain_net");
return retv;
})
.def("add_op", [](operators::NetOp &self,
const OperatorBase &op) { self.AddOp(op); })
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
});

ExposeOperator(net);

// recurrent_op
py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>>
rnn(m, "RecurrentOp");

rnn.def_static(
"create",
[](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op);
})
.def("set_stepnet",
[](operators::RecurrentOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.set_stepnet(net);
});
ExposeOperator(rnn);
py::class_<operators::RecurrentOp, OperatorBase>(m, "RecurrentOp")
.def_static(
"create",
[](py::bytes protobin) -> operators::RecurrentOp * {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::RecurrentOp *>(rnn_op.release());
})
.def("set_stepnet", [](operators::RecurrentOp &self,
const operators::NetOp &net) -> void {
self.set_stepnet(net.Clone());
});

m.def("unique_integer", UniqueIntegerGenerator);

Expand Down
Loading

0 comments on commit 1808f88

Please sign in to comment.