Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor rnn infershape #4553

Merged
merged 4 commits into from
Oct 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 35 additions & 39 deletions paddle/operators/recurrent_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,39 @@ using LoDTensor = framework::LoDTensor;

void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
InitMemories(step_scopes[0], false /*infer_shape_mode*/);
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
PADDLE_ENFORCE_GT(seq_len, 0);

for (size_t step_id = 0; step_id < seq_len_; step_id++) {
// create output alias variables
CreateScopes(scope, seq_len);
auto& step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
InitMemories(step_scopes[0]);

for (size_t step_id = 0; step_id < seq_len; step_id++) {
if (step_id > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1,
false /*infer_shape_mode*/);
rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len);
}

void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
void RecurrentAlgorithm::CreateScopes(const Scope& scope,
size_t seq_len) const {
// TODO(superjom) Only two scopes are needed for inference, this case will be
// supported later.
auto step_scopes_var = scope.FindVar(arg_->step_scopes);
auto* step_scopes_var = scope.FindVar(arg_->step_scopes);
PADDLE_ENFORCE(step_scopes_var != nullptr, "");
auto step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();
auto* step_scopes = step_scopes_var->GetMutable<std::vector<Scope*>>();

// Now all variables in scope must be created outside of op.
PADDLE_ENFORCE_NOT_NULL(stepnet_);
PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs");

if (seq_len_ > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len_; ++i) {
if (seq_len > step_scopes->size()) {
for (size_t i = step_scopes->size(); i < seq_len; ++i) {
auto& step_scope = scope.NewScope();

// create step net's temp inputs
Expand All @@ -82,21 +85,17 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
}
}

void RecurrentAlgorithm::InitMemories(Scope* step_scope,
bool infer_shape_mode) const {
void RecurrentAlgorithm::InitMemories(Scope* step_scope) const {
for (auto& attr : arg_->memories) {
auto* pre_mem = step_scope->NewVar(attr.pre_var)->GetMutable<LoDTensor>();
PADDLE_ENFORCE(step_scope->FindVar(attr.boot_var) != nullptr,
"memory [%s]'s boot variable [%s] not exists", attr.var,
attr.boot_var);
auto* boot_mem =
step_scope->FindVar(attr.boot_var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
pre_mem->Resize(boot_mem->dims());
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
} else {
pre_mem->ShareDataWith<float>(*boot_mem);
}
pre_mem->Resize(boot_mem->dims());
PADDLE_ENFORCE_EQ(pre_mem->dims().size(), 2);
pre_mem->ShareDataWith<float>(*boot_mem);
}
}

Expand Down Expand Up @@ -146,23 +145,23 @@ class RecurrentAlgorithmProtoAndCheckerMaker

void RecurrentGradientAlgorithm::Run(
const Scope& scope, const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
false /*infer_shape_mode*/);
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
false /*infer_shape_mode*/);
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
size_t seq_len = input0->GetMutable<LoDTensor>()->dims()[0];
auto& step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len);
for (int step_id = seq_len - 1; step_id >= 0; --step_id) {
if (step_id != seq_len - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1);
}
(*stepnet_)->Run(*step_scopes[step_id], dev_ctx);
}
LinkBootMemoryGradients(step_scopes[0], false);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
false /*infer_shape_mode*/);
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len);
LinkBootMemoryGradients(step_scopes[0]);
}

void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
Scope* step_scope, bool infer_shape_mode) const {
Scope* step_scope) const {
for (auto& attr : arg_->memories) {
PADDLE_ENFORCE(step_scope->FindVar(attr.var) != nullptr,
"memory variable [%s] does not exists", attr.var);
Expand All @@ -171,11 +170,8 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
auto* mem_grad = step_scope->NewVar(attr.var)->GetMutable<LoDTensor>();
auto* boot_mem_grad =
step_scope->NewVar(attr.boot_var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
boot_mem_grad->Resize(mem_grad->dims());
} else {
boot_mem_grad->ShareDataWith<float>(*mem_grad);
}
boot_mem_grad->Resize(mem_grad->dims());
boot_mem_grad->ShareDataWith<float>(*mem_grad);
}
}

Expand Down
10 changes: 4 additions & 6 deletions paddle/operators/recurrent_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,19 @@ class RecurrentAlgorithm {
* NOTE the scopes are reused in both the forward and backward, so just
* create once and expand its size if more steps need.
*/
void CreateScopes(const framework::Scope& scope) const;
void CreateScopes(const framework::Scope& scope, size_t seq_len) const;

const std::vector<framework::Scope*>& GetStepScopes(
const framework::Scope& scope) const {
return *scope.FindVar(arg_->step_scopes)
->GetMutable<std::vector<framework::Scope*>>();
}

void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const;
void InitMemories(framework::Scope* step_scopes) const;

private:
std::unique_ptr<framework::OperatorBase>* stepnet_;
rnn::Argument* arg_;
mutable size_t seq_len_;
};

class RecurrentGradientAlgorithm {
Expand All @@ -86,8 +85,7 @@ class RecurrentGradientAlgorithm {
void Run(const framework::Scope& scope,
const platform::DeviceContext& dev_ctx) const;

void LinkBootMemoryGradients(framework::Scope* step_scopes,
bool infer_shape_mode) const;
void LinkBootMemoryGradients(framework::Scope* step_scopes) const;

protected:
inline const std::vector<framework::Scope*>& GetStepScopes(
Expand All @@ -98,7 +96,6 @@ class RecurrentGradientAlgorithm {

private:
rnn::Argument* arg_;
mutable size_t seq_len_;
std::unique_ptr<framework::OperatorBase>* stepnet_;
};

Expand All @@ -123,6 +120,7 @@ class RecurrentOp : public framework::OperatorBase {
void set_stepnet(std::unique_ptr<OperatorBase> net) {
stepnet_ = std::move(net);
}

const OperatorBase& stepnet() const { return *stepnet_; }

static const rnn::ArgumentName kArgName;
Expand Down
71 changes: 31 additions & 40 deletions paddle/operators/rnn/recurrent_op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using LoDTensor = framework::LoDTensor;

void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& inlinks,
const size_t seq_len, bool infer_shape_mode) {
const size_t seq_len) {
PADDLE_ENFORCE(!inlinks.empty(), "no in links are provided.");
for (size_t i = 0; i < inlinks.size(); ++i) {
// global inputs
Expand All @@ -41,51 +41,45 @@ void SegmentInputs(const std::vector<Scope*>& step_scopes,
for (size_t j = 0; j < seq_len; j++) {
Tensor* step_input =
step_scopes[j]->NewVar(inlinks[i])->GetMutable<Tensor>();
if (!infer_shape_mode) {
// The input of operators of each step is Tensor here.
// Maybe need to modify Slice function.
*step_input = input->Slice<float>(j, j + 1);
}
// The input of operators of each step is Tensor here.
// Maybe need to modify Slice function.
*step_input = input->Slice<float>(j, j + 1);
step_input->Resize(step_dims);
}
}
}

void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks,
const size_t seq_len, bool infer_shape_mode) {
const size_t seq_len) {
for (size_t i = 0; i < outlinks.size(); i++) {
auto output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
auto* output_var = step_scopes[0]->parent().FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(output_var, "output link [%s] is not in scope.",
outlinks[i]);
LoDTensor* output = output_var->GetMutable<LoDTensor>();

if (infer_shape_mode) {
auto step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(f::make_ddim(dims_vec));
} else {
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
LoDTensor* step_output =
step_scopes[j]->FindVar(outlinks[i])->GetMutable<LoDTensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
.CopyFrom<float>(*step_output, platform::CPUPlace());
}
auto* step_scope_var = step_scopes[0]->FindVar(outlinks[i]);
PADDLE_ENFORCE_NOT_NULL(step_scope_var, "%s not in scope", outlinks[i]);
f::DDim step_dims =
step_scope_var->template GetMutable<LoDTensor>()->dims();
std::vector<int64_t> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(f::make_ddim(dims_vec));
output->mutable_data<float>(platform::CPUPlace());
for (size_t j = 0; j < seq_len; j++) {
LoDTensor* step_output =
step_scopes[j]->FindVar(outlinks[i])->GetMutable<LoDTensor>();
// TODO(luotao02) data type and platform::DeviceContext() should set
// correctly
(output->Slice<float>(j, j + 1))
.CopyFrom<float>(*step_output, platform::CPUPlace());
}
}
}

void LinkMemories(const std::vector<Scope*>& scopes,
const std::vector<rnn::MemoryAttr>& memories,
const size_t step_id, const int offset,
bool infer_shape_mode) {
const size_t step_id, const int offset) {
PADDLE_ENFORCE_LT(step_id, scopes.size(),
"step [%d] is out of range of step scopes' size [%d]",
step_id, scopes.size());
Expand All @@ -95,16 +89,13 @@ void LinkMemories(const std::vector<Scope*>& scopes,
step_id + offset, scopes.size(),
"offset [%d] is out of range, it must be less than (%d - %d)", offset,
scopes.size(), step_id);
auto scope = scopes[step_id];
auto linked_scope = scopes[step_id + offset];
auto* scope = scopes[step_id];
auto* linked_scope = scopes[step_id + offset];
for (auto& attr : memories) {
auto mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
auto linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
if (infer_shape_mode) {
mem->Resize(linked_mem->dims());
} else {
mem->ShareDataWith<float>(*linked_mem);
}
auto* mem = scope->FindVar(attr.pre_var)->GetMutable<LoDTensor>();
auto* linked_mem = linked_scope->FindVar(attr.var)->GetMutable<LoDTensor>();
mem->Resize(linked_mem->dims());
mem->ShareDataWith<float>(*linked_mem);
}
}

Expand All @@ -115,11 +106,11 @@ void InitArgument(const ArgumentName& name, Argument* arg,
arg->inlinks = op.Inputs(name.inlinks);
arg->outlinks = op.Outputs(name.outlinks);

auto boot_memories =
auto& boot_memories =
is_grad ? op.Outputs(name.boot_memories) : op.Inputs(name.boot_memories);
// attributes
auto memories = op.Attr<std::vector<std::string>>(name.memories);
auto pre_memories = op.Attr<std::vector<std::string>>(name.pre_memories);
auto& memories = op.Attr<std::vector<std::string>>(name.memories);
auto& pre_memories = op.Attr<std::vector<std::string>>(name.pre_memories);

PADDLE_ENFORCE(memories.size() == boot_memories.size(),
"the size of memories, boot_memories don't match:%d,%d",
Expand Down
6 changes: 3 additions & 3 deletions paddle/operators/rnn/recurrent_op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,18 @@ struct ArgumentName {
*/
void SegmentInputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& inlinks,
const size_t seq_len, bool infer_shape_mode);
const size_t seq_len);

/**
* Process outputs of step nets and merge to variables.
*/
void ConcatOutputs(const std::vector<Scope*>& step_scopes,
const std::vector<std::string>& outlinks,
const size_t seq_len, bool infer_shape_mode);
const size_t seq_len);

void LinkMemories(const std::vector<Scope*>& step_scopes,
const std::vector<MemoryAttr>& memories, const size_t step_id,
const int offset, bool infer_shape_mode);
const int offset);

void InitArgument(const ArgumentName& name, Argument* arg,
const framework::OperatorBase& op, bool is_grad = false);
Expand Down
5 changes: 3 additions & 2 deletions paddle/operators/sum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ class SumOp : public framework::OperatorWithKernel {

protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null");
auto x_dims = ctx->GetInputsDim("X");
PADDLE_ENFORCE(!x_dims.empty(), "Input(X) of SumOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SumOp should not be null.");

auto in_dim = x_dims[0];
size_t N = x_dims.size();
PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");

auto in_dim = x_dims[0];
for (size_t i = 1; i < N; i++) {
auto dim = x_dims[i];
PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
Expand Down
Loading