Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring InferShape #3946

Merged
merged 61 commits into from
Sep 26, 2017
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
125a528
init Infershape
jacquesqiao Sep 6, 2017
8ab2d86
add static InferShape interface
jacquesqiao Sep 6, 2017
1b7e6e3
refactor add-op infershape
jacquesqiao Sep 6, 2017
9e1ba61
add AttrReader
jacquesqiao Sep 6, 2017
fcaea1f
add all maker's infershape
jacquesqiao Sep 7, 2017
676c7fd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 7, 2017
703d6ce
add all InferShape
jacquesqiao Sep 7, 2017
570ebb2
add python infer api
jacquesqiao Sep 7, 2017
80a1c63
Merge branch 'fix-cpu-build' of https://github.com/jacquesqiao/Paddle…
jacquesqiao Sep 7, 2017
41996d2
add VarDesc interface
jacquesqiao Sep 8, 2017
fc3b55c
add python VarDesc and OpDesc interface
jacquesqiao Sep 8, 2017
4e7058e
update python code
jacquesqiao Sep 11, 2017
267f0e3
use infershape function to do shape inference
jacquesqiao Sep 12, 2017
129599d
clean code
jacquesqiao Sep 12, 2017
92964d6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 12, 2017
53eb75a
do not use pointer
jacquesqiao Sep 12, 2017
309765c
refine code of op_proto_maker
jacquesqiao Sep 13, 2017
93903fb
add get_dims to VarDesc
jacquesqiao Sep 13, 2017
872a570
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 13, 2017
fc3c095
refine the code
jacquesqiao Sep 14, 2017
8304c74
remove the dependency from operator to op registry
jacquesqiao Sep 15, 2017
59fa374
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 15, 2017
afdfeb9
remove OpProtoAndCheckerMaker from operator
jacquesqiao Sep 15, 2017
228ddf8
restore complete_add_op
jacquesqiao Sep 15, 2017
175abe6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 16, 2017
060677f
add shape_infer_impl.h
jacquesqiao Sep 19, 2017
91134b2
code optimization
jacquesqiao Sep 19, 2017
827cec7
remove const return value
jacquesqiao Sep 20, 2017
28b5d0a
Merge branch 'add_op_proto_maker' of https://github.com/jacquesqiao/P…
jacquesqiao Sep 20, 2017
feb9b1d
Merge branch 'add_op_proto_maker' of https://github.com/jacquesqiao/P…
jacquesqiao Sep 20, 2017
df61245
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 20, 2017
fb39fb3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 20, 2017
b421e91
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 21, 2017
5fa1188
add fake BlockDesc class
jacquesqiao Sep 21, 2017
509c40a
optimize code
jacquesqiao Sep 22, 2017
4364314
remove infer function in op_info
jacquesqiao Sep 22, 2017
a3436ba
move InferShapeContextImpl to operator.h
jacquesqiao Sep 22, 2017
35ea282
optimize the interface of InferShapeContextBase
jacquesqiao Sep 22, 2017
7053581
add temperary interface of new infershape
jacquesqiao Sep 23, 2017
9acdb74
change add_op, clip_op, conv2d_op and activation_op
jacquesqiao Sep 23, 2017
80c785f
change all operators InferShape
jacquesqiao Sep 24, 2017
d9c46c2
fix SetDim
jacquesqiao Sep 24, 2017
c8e2aa8
update cos_sim_op
jacquesqiao Sep 25, 2017
3fbfe5d
update crop_op
jacquesqiao Sep 25, 2017
e6ec26f
update lookup_table_op
jacquesqiao Sep 25, 2017
c778cec
allocate tensor when call GetDim in InferShapeContext
jacquesqiao Sep 25, 2017
5605a30
update modified_huber_loss_op
jacquesqiao Sep 25, 2017
40f2b53
update rowwise_add_op
jacquesqiao Sep 25, 2017
794d5df
update mean_op
jacquesqiao Sep 25, 2017
dc0d153
update sequence_avg_pool_op
jacquesqiao Sep 25, 2017
4e55aae
typo
jacquesqiao Sep 25, 2017
32c1d29
remove old InferShape interface
jacquesqiao Sep 25, 2017
a522653
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 25, 2017
cdf9bfc
can compile
jacquesqiao Sep 25, 2017
e22d7c4
fix or unit test
jacquesqiao Sep 25, 2017
d2d500e
clean code
jacquesqiao Sep 26, 2017
997cbc5
clean code
jacquesqiao Sep 26, 2017
50cec66
remove const before InferShapeContext
jacquesqiao Sep 26, 2017
01bfcdd
change InferenceContextBase to pointer
jacquesqiao Sep 26, 2017
017b27c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 26, 2017
f3bd1ad
rename RunTime to Runtime, code clean
jacquesqiao Sep 26, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(framework_proto SRCS framework.proto)

cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry op_info)

cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry op_info)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)

py_proto_compile(framework_py_proto SRCS framework.proto)
Expand Down
15 changes: 15 additions & 0 deletions paddle/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ AttrType AttrTypeID();

Attribute GetAttrValue(const OpDesc::Attr& attr_desc);

class AttrReader {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the purpose of this class is to add a method Get, who can find and read an entry in AttributeMap.

How about we change the definition of AttributeMap from the current one

typedef std::unordered_map<std::string, Attribute> AttributeMap;

into

class AttributeMap : public std::unordered_map<std::string, Attribute> {
 public:
  template <typename T> inline const T& Get(...) const { ... }
}

so could we define Get as a method of type AttributeMap; instead of adding a new type AttrReader?

Copy link
Member Author

@jacquesqiao jacquesqiao Sep 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is if we use a Class that inherit from map, then we can't directly use the list_initialization to init this AttributeMap,
like scale_op:

AttributeMap  attrs = {{"scale", Attr<AttrType>("scale")}};

try to find a better way to do this in next PR

public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}

template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}

private:
const AttributeMap& attrs_;
};

// check whether a value(attribute) fit a certain limit
template <typename T>
class GreaterThanChecker {
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/op_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <unordered_map>

#include "paddle/framework/attribute.h"
#include "paddle/framework/shape_inference.h"

namespace paddle {
namespace framework {
Expand All @@ -34,6 +35,7 @@ struct OpInfo {
std::string grad_op_type_;
OpProto* proto_;
OpAttrChecker* checker_;
ShapeInferenceFn shapeInferFn_;

bool HasOpProtoAndChecker() const {
return proto_ != nullptr && checker_ != nullptr;
Expand Down
59 changes: 59 additions & 0 deletions paddle/framework/op_proto_maker.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/op_proto_maker.h"

namespace paddle {
namespace framework {

void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}

OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
const std::string& name, const std::string& comment) {
auto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{input};
}

OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
const std::string& name, const std::string& comment) {
auto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{output};
}

void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) {
checker(attr.name());
}
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
}
}

} // namespace framework
} // namespace paddle
105 changes: 105 additions & 0 deletions paddle/framework/op_proto_maker.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/shape_inference.h"

namespace paddle {
namespace framework {

// this class not only make proto but also init attribute checkers.
class OpProtoAndCheckerMaker {
public:
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
: proto_(proto), op_checker_(op_checker) {}

virtual ~OpProtoAndCheckerMaker() {
PADDLE_ENFORCE(validated_, "should call Validate after build");
}

void Validate();

protected:
struct VariableBuilder {
OpProto::Var* var_;

VariableBuilder& AsDuplicable() {
var_->set_duplicable(true);
return *this;
}

VariableBuilder& AsIntermediate() {
var_->set_intermediate(true);
return *this;
}

VariableBuilder& NotInGradient() {
var_->set_not_in_gradient(true);
return *this;
}
};

VariableBuilder AddInput(const std::string& name, const std::string& comment);

VariableBuilder AddOutput(const std::string& name,
const std::string& comment);

template <typename T>
TypedAttrChecker<T>& AddAttr(const std::string& name,
const std::string& comment,
bool generated = false) {
auto* attr = proto_->add_attrs();
attr->set_name(name);
attr->set_comment(comment);
attr->set_generated(generated);
attr->set_type(AttrTypeID<T>());
return op_checker_->AddAttrChecker<T>(name);
}

void AddComment(const std::string& comment) { proto_->set_comment(comment); }

void SetShapeInferenceFn(ShapeInferenceFn fn) { shape_infer_fn_ = fn; }

void SetGradShapeInferenceFn(ShapeInferenceFn fn) {
grad_shape_infer_fn_ = fn;
}

public:
const ShapeInferenceFn GetShapeInferenceFn() const { return shape_infer_fn_; }
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return const ShapeInferenceFN& or ShapeInferenceFN

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


const ShapeInferenceFn GetGradShapeInferenceFn() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

return grad_shape_infer_fn_;
}

private:
void CheckNoDuplicatedInOutAttrs();

OpProto* proto_;
OpAttrChecker* op_checker_;
ShapeInferenceFn shape_infer_fn_{nullptr};
ShapeInferenceFn grad_shape_infer_fn_{nullptr};
bool validated_{false};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need a flag of validate? Cause the interface is always invoked by OpMaker, so the validate process is done before any real run call happens, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these code is removed

};

class NOPMaker : public OpProtoAndCheckerMaker {
public:
NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {}
};

} // namespace framework
} // namespace paddle
29 changes: 25 additions & 4 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_proto_maker.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/shape_inference_impl.h"

namespace paddle {
namespace framework {
Expand All @@ -34,10 +36,12 @@ class OpRegistry {
public:
template <typename OpType, typename ProtoMakerType, typename GradOpType>
static void RegisterOp(const std::string& op_type,
const std::string& grad_op_type) {
const std::string& grad_op_type,
const ShapeInferenceFn fn) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const ShapeInferenceFN& or ShapeInferenceFN

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type),
"'%s' is registered more than once.", op_type);
OpInfo op_info;
ShapeInferenceFn grad_op_inferer = nullptr;
op_info.creator_ = [](
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) {
Expand All @@ -51,18 +55,21 @@ class OpRegistry {
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_);
maker.Validate();
op_info.proto_->set_type(op_type);
op_info.shapeInferFn_ = maker.GetShapeInferenceFn();
grad_op_inferer = maker.GetGradShapeInferenceFn();
PADDLE_ENFORCE(
op_info.proto_->IsInitialized(),
"Fail to initialize %s's OpProto, because %s is not initialized",
op_type, op_info.proto_->InitializationErrorString());
} else {
op_info.proto_ = nullptr;
op_info.checker_ = nullptr;
op_info.shapeInferFn_ = fn;
}
OpInfoMap::Instance().Insert(op_type, op_info);
// register gradient op
if (!grad_op_type.empty()) {
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "");
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "", grad_op_inferer);
}
}

Expand All @@ -74,6 +81,20 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);

static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op);

// compile time InferShape
static void InferShape(const OpDesc& op_desc,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. We might need to change this PR according to the recently merged design of ProgramDesc.

std::map<std::string, VarDesc*>& var_descs) {
auto& info = OpInfoMap::Instance().Get(op_desc.type());
auto op = OpRegistry::CreateOp(op_desc);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is very strange that when checking compile-time infer shape, it actually creates a runtime operator.

Maybe add a FIXME comment here, I actually did not figure out what's the better way, but there should be a better way.

Copy link
Member Author

@jacquesqiao jacquesqiao Sep 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed now, Infershape is implemented inside Operator, so do not need to create an operator now

info.shapeInferFn_(CompileTimeInferShapeContext(op, var_descs));
}

// runtime InferShape
static void InferShape(const OperatorBase& op, const Scope& scope) {
auto& info = OpInfoMap::Instance().Get(op.Type());
info.shapeInferFn_(RunTimeInferShapeContext(op, scope));
}
};

class Registrar {
Expand All @@ -94,8 +115,8 @@ class OpRegistrar : public Registrar {
public:
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
OpRegistrar(const char* op_type, const char* grad_op_type) {
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type,
grad_op_type);
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(
op_type, grad_op_type, nullptr);
}
};

Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ TEST(OpRegistry, CustomChecker) {
op->Run(scope, dev_ctx);
int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
}
40 changes: 0 additions & 40 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change these lines. If it is not necessary, please leave them unchanged.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle/framework/op_registry.h is not used in this file, keep #include <algorithm>


namespace paddle {
namespace framework {
Expand Down Expand Up @@ -228,43 +226,5 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
return res;
}

void OpProtoAndCheckerMaker::Validate() {
validated_ = true;
CheckNoDuplicatedInOutAttrs();
}

OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddInput(
const std::string& name, const std::string& comment) {
auto* input = proto_->add_inputs();
input->set_name(name);
input->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{input};
}

OpProtoAndCheckerMaker::VariableBuilder OpProtoAndCheckerMaker::AddOutput(
const std::string& name, const std::string& comment) {
auto* output = proto_->add_outputs();
output->set_name(name);
output->set_comment(comment);
return OpProtoAndCheckerMaker::VariableBuilder{output};
}

void OpProtoAndCheckerMaker::CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) {
PADDLE_ENFORCE(!names.count(name), "[%s] is duplicated", name);
names.insert(name);
};
for (auto& attr : proto_->attrs()) {
checker(attr.name());
}
for (auto& input : proto_->inputs()) {
checker(input.name());
}
for (auto& output : proto_->outputs()) {
checker(output.name());
}
}

} // namespace framework
} // namespace paddle
Loading