Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring InferShape #3946

Merged
merged 61 commits into from
Sep 26, 2017
Merged
Show file tree
Hide file tree
Changes from 51 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
125a528
init Infershape
jacquesqiao Sep 6, 2017
8ab2d86
add static InferShape interface
jacquesqiao Sep 6, 2017
1b7e6e3
refactor add-op infershape
jacquesqiao Sep 6, 2017
9e1ba61
add AttrReader
jacquesqiao Sep 6, 2017
fcaea1f
add all maker's infershape
jacquesqiao Sep 7, 2017
676c7fd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 7, 2017
703d6ce
add all InferShape
jacquesqiao Sep 7, 2017
570ebb2
add python infer api
jacquesqiao Sep 7, 2017
80a1c63
Merge branch 'fix-cpu-build' of https://github.com/jacquesqiao/Paddle…
jacquesqiao Sep 7, 2017
41996d2
add VarDesc interface
jacquesqiao Sep 8, 2017
fc3b55c
add python VarDesc and OpDesc interface
jacquesqiao Sep 8, 2017
4e7058e
update python code
jacquesqiao Sep 11, 2017
267f0e3
use infershape function to do shape inference
jacquesqiao Sep 12, 2017
129599d
clean code
jacquesqiao Sep 12, 2017
92964d6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 12, 2017
53eb75a
do not use pointer
jacquesqiao Sep 12, 2017
309765c
refine code of op_proto_maker
jacquesqiao Sep 13, 2017
93903fb
add get_dims to VarDesc
jacquesqiao Sep 13, 2017
872a570
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 13, 2017
fc3c095
refine the code
jacquesqiao Sep 14, 2017
8304c74
remove the dependency from operator to op registry
jacquesqiao Sep 15, 2017
59fa374
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 15, 2017
afdfeb9
remove OpProtoAndCheckerMaker from operator
jacquesqiao Sep 15, 2017
228ddf8
restore complete_add_op
jacquesqiao Sep 15, 2017
175abe6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 16, 2017
060677f
add shape_infer_impl.h
jacquesqiao Sep 19, 2017
91134b2
code optimization
jacquesqiao Sep 19, 2017
827cec7
remove const return value
jacquesqiao Sep 20, 2017
28b5d0a
Merge branch 'add_op_proto_maker' of https://github.com/jacquesqiao/P…
jacquesqiao Sep 20, 2017
feb9b1d
Merge branch 'add_op_proto_maker' of https://github.com/jacquesqiao/P…
jacquesqiao Sep 20, 2017
df61245
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 20, 2017
fb39fb3
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 20, 2017
b421e91
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 21, 2017
5fa1188
add fake BlockDesc class
jacquesqiao Sep 21, 2017
509c40a
optimize code
jacquesqiao Sep 22, 2017
4364314
remove infer function in op_info
jacquesqiao Sep 22, 2017
a3436ba
move InferShapeContextImpl to operator.h
jacquesqiao Sep 22, 2017
35ea282
optimize the interface of InferShapeContextBase
jacquesqiao Sep 22, 2017
7053581
add temperary interface of new infershape
jacquesqiao Sep 23, 2017
9acdb74
change add_op, clip_op, conv2d_op and activation_op
jacquesqiao Sep 23, 2017
80c785f
change all operators InferShape
jacquesqiao Sep 24, 2017
d9c46c2
fix SetDim
jacquesqiao Sep 24, 2017
c8e2aa8
update cos_sim_op
jacquesqiao Sep 25, 2017
3fbfe5d
update crop_op
jacquesqiao Sep 25, 2017
e6ec26f
update lookup_table_op
jacquesqiao Sep 25, 2017
c778cec
allocate tensor when call GetDim in InferShapeContext
jacquesqiao Sep 25, 2017
5605a30
update modified_huber_loss_op
jacquesqiao Sep 25, 2017
40f2b53
update rowwise_add_op
jacquesqiao Sep 25, 2017
794d5df
update mean_op
jacquesqiao Sep 25, 2017
dc0d153
update sequence_avg_pool_op
jacquesqiao Sep 25, 2017
4e55aae
typo
jacquesqiao Sep 25, 2017
32c1d29
remove old InferShape interface
jacquesqiao Sep 25, 2017
a522653
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 25, 2017
cdf9bfc
can compile
jacquesqiao Sep 25, 2017
e22d7c4
fix or unit test
jacquesqiao Sep 25, 2017
d2d500e
clean code
jacquesqiao Sep 26, 2017
997cbc5
clean code
jacquesqiao Sep 26, 2017
50cec66
remove const before InferShapeContext
jacquesqiao Sep 26, 2017
01bfcdd
change InferenceContextBase to pointer
jacquesqiao Sep 26, 2017
017b27c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
jacquesqiao Sep 26, 2017
f3bd1ad
rename RunTime to Runtime, code clean
jacquesqiao Sep 26, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry op_info)

cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry op_info)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)

py_proto_compile(framework_py_proto SRCS framework.proto)
Expand Down
15 changes: 15 additions & 0 deletions paddle/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,21 @@ AttrType AttrTypeID();

Attribute GetAttrValue(const OpDesc::Attr& attr_desc);

class AttrReader {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the purpose of this class is to add a method Get, who can find and read an entry in AttributeMap.

How about we change the definition of AttributeMap from the current one

typedef std::unordered_map<std::string, Attribute> AttributeMap;

into

class AttributeMap : public std::unordered_map<std::string, Attribute> {
 public:
  template <typename T> inline const T& Get(...) const { ... }
}

so could we define Get as a method of type AttributeMap; instead of adding a new type AttrReader?

Copy link
Member Author

@jacquesqiao jacquesqiao Sep 26, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem is if we use a Class that inherit from map, then we can't directly use the list_initialization to init this AttributeMap,
like scale_op:

AttributeMap  attrs = {{"scale", Attr<AttrType>("scale")}};

try to find a better way to do this in next PR

public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}

template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}

private:
const AttributeMap& attrs_;
};

// check whether a value(attribute) fit a certain limit
template <typename T>
class GreaterThanChecker {
Expand Down
15 changes: 15 additions & 0 deletions paddle/framework/ddim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,21 @@ std::vector<int64_t> vectorize(const DDim& ddim) {
return result;
}

std::string debug_str(const DDim& ddim) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is debug_str used?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use it a lot when debug the code

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed since we can use cout << ddim to print the debug string.

auto ret = vectorize(ddim);
std::stringstream ss;
ss << "[";
auto size = ddim.size();
for (size_t i = 0; i < size; ++i) {
ss << ret[i];
if (i < size - 1) {
ss << ", ";
}
}
ss << "]";
return ss.str();
}

struct ProductVisitor : public boost::static_visitor<int64_t> {
template <int D>
int64_t operator()(const Dim<D>& dim) {
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/ddim.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ void set(DDim& dim, int idx, int val);

std::vector<int64_t> vectorize(const DDim& ddim);

std::string debug_str(const DDim& ddim);

int64_t product(const DDim& ddim);

/**
Expand Down
1 change: 1 addition & 0 deletions paddle/framework/op_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <unordered_map>

#include "paddle/framework/attribute.h"
#include "paddle/framework/shape_inference.h"

namespace paddle {
namespace framework {
Expand Down
1 change: 1 addition & 0 deletions paddle/framework/op_proto_maker.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/op_proto_maker.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Expand All @@ -15,6 +16,7 @@ limitations under the License. */

#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/shape_inference.h"

namespace paddle {
namespace framework {
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ TEST(OpRegistry, CustomChecker) {
op->Run(scope, dev_ctx);
int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
}
11 changes: 9 additions & 2 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change these lines. If it is not necessary, please leave them unchanged.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle/framework/op_registry.h is not used in this file, keep #include <algorithm>


namespace paddle {
namespace framework {
Expand All @@ -33,6 +31,15 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
}
#endif

const Tensor* GetTensorFromVar(const Variable* var) {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}

std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
Expand Down
171 changes: 160 additions & 11 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
Expand Down Expand Up @@ -56,6 +57,8 @@ class OperatorBase;
class InferShapeContext;
class ExecutionContext;

extern const Tensor* GetTensorFromVar(const Variable* var);

/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
Expand Down Expand Up @@ -262,15 +265,6 @@ class InferShapeContext {
return res;
}

const Tensor* GetTensorFromVar(const Variable* var) const {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}

private:
const OperatorBase& op_;
const Scope& scope_;
Expand Down Expand Up @@ -347,6 +341,154 @@ template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const;

class BlockDesc {
public:
explicit BlockDesc(const std::map<std::string, VarDesc*>& var_descs)
: var_descs_(var_descs) {}
~BlockDesc() {}

VarDesc* GetVar(const std::string& name) const {
PADDLE_ENFORCE(var_descs_.count(name) == 1, "%s must be in Block", name);
return var_descs_.at(name);
}

bool HasVar(const std::string& name) const {
return var_descs_.count(name) > 0;
}

private:
const std::map<std::string, VarDesc*>& var_descs_;
};

class CompileTimeInferShapeContext : public InferShapeContextBase {
public:
CompileTimeInferShapeContext(const OperatorBase& op,
const BlockDesc& block_desc)
: op_(op), block_desc_(block_desc) {}

bool HasInput(const std::string& name) const {
return block_desc_.HasVar(name);
}

bool HasOutput(const std::string& name) const {
return block_desc_.HasVar(name);
}

DDim GetInputDim(const std::string& name) const {
return GetDim(op_.Input(name));
}

void SetInputDim(const std::string& name, const DDim& dim) const {
SetDim(op_.Input(name), dim);
}

DDim GetOutputDim(const std::string& name) const {
return GetDim(op_.Output(name));
}

void SetOutputDim(const std::string& name, const DDim& dim) const {
Copy link
Collaborator

@reyoung reyoung Sep 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set methods should change private data members, so they cannot be const

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

SetDim(op_.Output(name), dim);
}

AttrReader Attrs() const { return AttrReader(op_.Attrs()); }

const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}

const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}

private:
DDim GetDim(const std::string& name) const {
VarDesc* desc = block_desc_.GetVar(name);
std::vector<int64_t> dim;
int length = desc->lod_tensor().dims().size();
dim.reserve(length);
std::copy(desc->lod_tensor().dims().begin(),
desc->lod_tensor().dims().end(), std::back_inserter(dim));
return make_ddim(dim);
}

void SetDim(const std::string& name, const DDim& dim) const {
VarDesc* desc = block_desc_.GetVar(name);
auto tensor = desc->mutable_lod_tensor();
tensor->clear_dims();
for (int i = 0; i < dim.size(); ++i) {
tensor->add_dims(static_cast<int>(dim[i]));
}
}

const OperatorBase& op_;
const BlockDesc& block_desc_;
};

class RunTimeInferShapeContext : public InferShapeContextBase {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RunTime-> Runtime

runtime is a word

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

public:
RunTimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}

bool HasInput(const std::string& name) const {
auto ipt = op_.Input(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}

bool HasOutput(const std::string& name) const {
auto ipt = op_.Output(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}

DDim GetInputDim(const std::string& name) const {
return GetDim(op_.Input(name));
}

void SetInputDim(const std::string& name, const DDim& dim) const {
SetDim(op_.Input(name), dim);
}

DDim GetOutputDim(const std::string& name) const {
return GetDim(op_.Output(name));
}

void SetOutputDim(const std::string& name, const DDim& dim) const {
SetDim(op_.Output(name), dim);
}

AttrReader Attrs() const { return AttrReader(op_.Attrs()); }

const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}

const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}

private:
Tensor* GetTensor(const std::string& name) const {
Tensor* t;
auto* var = scope_.FindVar(name);
if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
t = var->GetMutable<LoDTensor>();
} else {
t = const_cast<Tensor*>(GetTensorFromVar(scope_.FindVar(name)));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const_cast should not be used.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

}
return t;
}

DDim GetDim(const std::string& name) const { return GetTensor(name)->dims(); }

void SetDim(const std::string& name, const DDim& dim) const {
GetTensor(name)->Resize(dim);
}

const OperatorBase& op_;
const Scope& scope_;
};

class OpKernel {
public:
/**
Expand Down Expand Up @@ -390,8 +532,14 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

// runtime infershape
void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(*this, scope));
InferShape(RunTimeInferShapeContext(*this, scope));
}

// compile time infershape
void InferShape(const BlockDesc& block_desc) const {
InferShape(CompileTimeInferShapeContext(*this, block_desc));
}

void Run(const Scope& scope,
Expand All @@ -413,7 +561,8 @@ class OperatorWithKernel : public OperatorBase {
}

protected:
virtual void InferShape(const InferShapeContext& ctx) const = 0;
virtual void InferShape(const InferShapeContext& ctx) const {}
virtual void InferShape(const InferShapeContextBase& ctx) const {}
Copy link
Collaborator

@reyoung reyoung Sep 25, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there are two interfaces?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old one is for compiling and test, has been removed after transform all the old operators.

};

} // namespace framework
Expand Down
1 change: 1 addition & 0 deletions paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/framework/operator.h"
#include "gtest/gtest.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
Expand Down
Loading