Skip to content

Commit

Permalink
Rnn make stepnet member (#3469)
Browse files Browse the repository at this point in the history
* make stepnet member

* add pybind support

* fix Inputs Outputs

* remove unique_ptr
  • Loading branch information
Superjomn committed Aug 15, 2017
1 parent 80de7e5 commit 0079fa3
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 296 deletions.
29 changes: 29 additions & 0 deletions paddle/framework/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor_py.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/recurrent_op.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/string/to_string.h"
Expand Down Expand Up @@ -241,13 +242,41 @@ All parameter, weight, gradient are variables in Paddle.
const std::shared_ptr<operators::NetOp> &net) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(net));
})
.def("add_op",
[](operators::NetOp &self,
const std::shared_ptr<operators::RecurrentOp> &rnn) -> void {
self.AddOp(std::static_pointer_cast<OperatorBase>(rnn));
})
.def("complete_add_op", &operators::NetOp::CompleteAddOp)
.def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
self->CompleteAddOp();
});

ExposeOperator(net);

// recurrent_op
py::class_<operators::RecurrentOp, std::shared_ptr<operators::RecurrentOp>>
rnn(m, "RecurrentOp");

rnn.def_static(
"create",
[](py::bytes protobin) -> std::shared_ptr<operators::RecurrentOp> {
OpDesc desc;
PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
"Cannot parse user input to OpDesc");
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
return std::dynamic_pointer_cast<operators::RecurrentOp>(rnn_op);
})
.def("set_stepnet",
[](operators::RecurrentOp &self,
const std::shared_ptr<operators::NetOp> &net) -> void {
self.set_stepnet(net);
});
ExposeOperator(rnn);

m.def("unique_integer", UniqueIntegerGenerator);

m.def("is_compile_gpu", IsCompileGPU);
Expand Down
1 change: 0 additions & 1 deletion paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu)

op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor op_registry operator net_op)
cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op)
op_library(uniform_random_op
SRCS uniform_random_op.cc uniform_random_op.cu)
38 changes: 13 additions & 25 deletions paddle/operators/recurrent_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const {
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
InitMemories(step_scopes[0], true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");

for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
true /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->InferShape(*step_scopes[i]);
(*stepnet_)->InferShape(*step_scopes[i]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
Expand All @@ -56,15 +54,14 @@ void RecurrentAlgorithm::Run(const Scope& scope,
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);

for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// create output alias variables
if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
Expand All @@ -78,18 +75,16 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
auto step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();

// Now all variables in scope must be created outside of op.
auto net_var = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope",
arg_->step_net);
auto net_op = net_var->GetMutable<NetOp>();
PADDLE_ENFORCE(!net_op->Outputs().empty(), "net_op has no outputs");
PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs");

if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
auto& step_scope = scope.NewScope();

// create step net's temp inputs
for (auto& input : net_op->Inputs()) {
for (auto& input : (*stepnet_)->Inputs()) {
// the weight are located in parent scope
for (auto& var_name : input.second) {
if (!step_scope.FindVar(var_name)) {
Expand All @@ -98,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
}
}
// create stepnet's outputs
for (const auto& output : net_op->Outputs()) {
for (const auto& output : (*stepnet_)->Outputs()) {
for (auto& var_name : output.second) {
step_scope.NewVar(var_name);
}
Expand Down Expand Up @@ -140,9 +135,8 @@ RecurrentOp::RecurrentOp(const std::string& type,
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_);
}

class RecurrentAlgorithmProtoAndCheckerMaker
Expand All @@ -158,7 +152,6 @@ class RecurrentAlgorithmProtoAndCheckerMaker
.AsDuplicable();
AddInput(name.boot_memories, "variables to initialize memories.")
.AsDuplicable();
AddInput(name.step_net, "network shared by all steps.");

AddOutput(name.outlinks, "the outputs that need to concated for all steps.")
.AsDuplicable();
Expand All @@ -180,14 +173,12 @@ void RecurrentGradientAlgorithm::Run(
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->Run(*step_scopes[step_id], dev_ctx);
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
Expand Down Expand Up @@ -219,14 +210,12 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
Variable* net = scope.FindVar(arg_->step_net);
PADDLE_ENFORCE(net != nullptr, "failed to get step net");
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
true /*infer_shape_mode*/);
}
net->GetMutable<NetOp>()->InferShape(*step_scopes[step_id]);
(*stepnet_)->InferShape(*step_scopes[step_id]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
Expand All @@ -238,9 +227,8 @@ RecurrentGradientOp::RecurrentGradientOp(
const framework::OperatorBase::VarNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {
std::unique_ptr<rnn::Argument> arg(new rnn::Argument());
rnn::InitArgument(kArgName, arg.get(), *this);
alg_.Init(std::move(arg));
rnn::InitArgument(kArgName, &arg_, *this);
alg_.Init(&arg_, &stepnet_);
}

} // namespace operators
Expand Down
29 changes: 25 additions & 4 deletions paddle/operators/recurrent_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/framework/operator.h"
#include "paddle/operators/net_op.h"
#include "paddle/operators/rnn/recurrent_op_utils.h"

namespace paddle {
Expand All @@ -33,7 +34,11 @@ class RecurrentAlgorithm {
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;

void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = arg;
stepnet_ = stepnet;
}

/**
* InferShape must be called before Run.
Expand All @@ -58,7 +63,8 @@ class RecurrentAlgorithm {
void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;

private:
std::unique_ptr<rnn::Argument> arg_;
std::shared_ptr<NetOp>* stepnet_;
rnn::Argument* arg_;
mutable size_t seq_len_;
};

Expand All @@ -74,7 +80,11 @@ class RecurrentGradientAlgorithm {
* operator.
*/
public:
void Init(std::unique_ptr<rnn::Argument> arg) { arg_ = std::move(arg); }
void Init(rnn::Argument* arg, std::shared_ptr<NetOp>* stepnet) {
PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before.");
arg_ = std::move(arg);
stepnet_ = stepnet;
}

void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;
Expand All @@ -95,8 +105,9 @@ class RecurrentGradientAlgorithm {
}

private:
std::unique_ptr<rnn::Argument> arg_;
rnn::Argument* arg_;
mutable size_t seq_len_;
std::shared_ptr<NetOp>* stepnet_;
};

class RecurrentOp final : public framework::OperatorBase {
Expand All @@ -115,10 +126,15 @@ class RecurrentOp final : public framework::OperatorBase {
alg_.Run(scope, dev_ctx);
}

void set_stepnet(std::shared_ptr<NetOp> net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); }

static const rnn::ArgumentName kArgName;

private:
RecurrentAlgorithm alg_;
rnn::Argument arg_;
std::shared_ptr<NetOp> stepnet_;
};

class RecurrentGradientOp final : public framework::OperatorBase {
Expand All @@ -141,8 +157,13 @@ class RecurrentGradientOp final : public framework::OperatorBase {

static const rnn::ArgumentName kArgName;

void set_stepnet(const std::shared_ptr<NetOp>& net) { stepnet_ = net; }
const NetOp* stepnet() const { return stepnet_.get(); }

private:
RecurrentGradientAlgorithm alg_;
std::shared_ptr<NetOp> stepnet_;
rnn::Argument arg_;
};

} // namespace operators
Expand Down
Loading

0 comments on commit 0079fa3

Please sign in to comment.