-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring InferShape #3946
Refactoring InferShape #3946
Changes from 31 commits
125a528
8ab2d86
1b7e6e3
9e1ba61
fcaea1f
676c7fd
703d6ce
570ebb2
80a1c63
41996d2
fc3b55c
4e7058e
267f0e3
129599d
92964d6
53eb75a
309765c
93903fb
872a570
fc3c095
8304c74
59fa374
afdfeb9
228ddf8
175abe6
060677f
91134b2
827cec7
28b5d0a
feb9b1d
df61245
fb39fb3
b421e91
5fa1188
509c40a
4364314
a3436ba
35ea282
7053581
9acdb74
80c785f
d9c46c2
c8e2aa8
3fbfe5d
e6ec26f
c778cec
5605a30
40f2b53
794d5df
dc0d153
4e55aae
32c1d29
a522653
cdf9bfc
e22d7c4
d2d500e
997cbc5
50cec66
01bfcdd
017b27c
f3bd1ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
@@ -15,6 +16,7 @@ limitations under the License. */ | |
|
||
#include "paddle/framework/attribute.h" | ||
#include "paddle/framework/framework.pb.h" | ||
#include "paddle/framework/shape_inference.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
@@ -70,11 +72,26 @@ class OpProtoAndCheckerMaker { | |
|
||
void AddComment(const std::string& comment) { proto_->set_comment(comment); } | ||
|
||
void SetShapeInferenceFn(ShapeInferenceFn fn) { shape_infer_fn_ = fn; } | ||
|
||
void SetGradShapeInferenceFn(ShapeInferenceFn fn) { | ||
grad_shape_infer_fn_ = fn; | ||
} | ||
|
||
public: | ||
const ShapeInferenceFn GetShapeInferenceFn() const { return shape_infer_fn_; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed |
||
|
||
const ShapeInferenceFn GetGradShapeInferenceFn() const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed |
||
return grad_shape_infer_fn_; | ||
} | ||
|
||
private: | ||
void CheckNoDuplicatedInOutAttrs(); | ||
|
||
OpProto* proto_; | ||
OpAttrChecker* op_checker_; | ||
ShapeInferenceFn shape_infer_fn_{nullptr}; | ||
ShapeInferenceFn grad_shape_infer_fn_{nullptr}; | ||
bool validated_{false}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we need a flag of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these code is removed |
||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ limitations under the License. */ | |
#include "paddle/framework/op_proto_maker.h" | ||
#include "paddle/framework/operator.h" | ||
#include "paddle/framework/scope.h" | ||
#include "paddle/framework/shape_inference_impl.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
@@ -35,10 +36,12 @@ class OpRegistry { | |
public: | ||
template <typename OpType, typename ProtoMakerType, typename GradOpType> | ||
static void RegisterOp(const std::string& op_type, | ||
const std::string& grad_op_type) { | ||
const std::string& grad_op_type, | ||
const ShapeInferenceFn fn) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed |
||
PADDLE_ENFORCE(!OpInfoMap::Instance().Has(op_type), | ||
"'%s' is registered more than once.", op_type); | ||
OpInfo op_info; | ||
ShapeInferenceFn grad_op_inferer = nullptr; | ||
op_info.creator_ = []( | ||
const std::string& type, const VariableNameMap& inputs, | ||
const VariableNameMap& outputs, const AttributeMap& attrs) { | ||
|
@@ -52,18 +55,21 @@ class OpRegistry { | |
auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); | ||
maker.Validate(); | ||
op_info.proto_->set_type(op_type); | ||
op_info.shapeInferFn_ = maker.GetShapeInferenceFn(); | ||
grad_op_inferer = maker.GetGradShapeInferenceFn(); | ||
PADDLE_ENFORCE( | ||
op_info.proto_->IsInitialized(), | ||
"Fail to initialize %s's OpProto, because %s is not initialized", | ||
op_type, op_info.proto_->InitializationErrorString()); | ||
} else { | ||
op_info.proto_ = nullptr; | ||
op_info.checker_ = nullptr; | ||
op_info.shapeInferFn_ = fn; | ||
} | ||
OpInfoMap::Instance().Insert(op_type, op_info); | ||
// register gradient op | ||
if (!grad_op_type.empty()) { | ||
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, ""); | ||
RegisterOp<GradOpType, NOPMaker, NOP>(grad_op_type, "", grad_op_inferer); | ||
} | ||
} | ||
|
||
|
@@ -75,6 +81,20 @@ class OpRegistry { | |
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); | ||
|
||
static std::unique_ptr<OperatorBase> CreateGradOp(const OperatorBase& op); | ||
|
||
// compile time InferShape | ||
static void InferShape(const OpDesc& op_desc, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. We might need to change this PR according to the recently merged design of |
||
std::map<std::string, VarDesc*>& var_descs) { | ||
auto& info = OpInfoMap::Instance().Get(op_desc.type()); | ||
auto op = OpRegistry::CreateOp(op_desc); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is very strange that when checking Maybe add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed now, Infershape is implemented inside Operator, so do not need to create an operator now |
||
info.shapeInferFn_(CompileTimeInferShapeContext(op, var_descs)); | ||
} | ||
|
||
// runtime InferShape | ||
static void InferShape(const OperatorBase& op, const Scope& scope) { | ||
auto& info = OpInfoMap::Instance().Get(op.Type()); | ||
info.shapeInferFn_(RunTimeInferShapeContext(op, scope)); | ||
} | ||
}; | ||
|
||
class Registrar { | ||
|
@@ -95,8 +115,8 @@ class OpRegistrar : public Registrar { | |
public: | ||
explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } | ||
OpRegistrar(const char* op_type, const char* grad_op_type) { | ||
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>(op_type, | ||
grad_op_type); | ||
OpRegistry::RegisterOp<OpType, ProtoMakerType, GradOpType>( | ||
op_type, grad_op_type, nullptr); | ||
} | ||
}; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and | |
limitations under the License. */ | ||
|
||
#include "paddle/framework/operator.h" | ||
#include <algorithm> | ||
#include "paddle/framework/op_registry.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why change these lines. If it is not necessary, please leave them unchanged. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
namespace paddle { | ||
namespace framework { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/ddim.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
class InferShapeContextBase; | ||
|
||
using ShapeInferenceFn = | ||
std::function<void(const framework::InferShapeContextBase& ctx)>; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's better to define with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
|
||
class InferShapeContextBase { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么需要一个基类? 2、另外,在定义op的ShapeInference时候,只需要暴露CompileTime的InferShapeContext,不需要暴露基类。 SetShapeInferenceFn([](const framework::InferShapeContextBase &ctx) {
auto dim0 = ctx.get_input_dim("X");
auto dim1 = ctx.get_input_dim("Y");
PADDLE_ENFORCE_EQ(dim0.size(), 2,
"input X should be a tensor with 2 dims, a matrix");
PADDLE_ENFORCE_EQ(dim1.size(), 2,
"input Y should be a tensor with 2 dims, a matrix");
PADDLE_ENFORCE_EQ(dim0[1], dim1[0],
"First matrix's width must be equal "
"with second matrix's height.");
ctx.set_output_dim("Out", {dim0[0], dim1[1]});
});
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
并不是这样的,compile time的infershape和runtime的infershape并没有太多本质区别,都要做:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 经过讨论,这个部分已经统一意见 |
||
public: | ||
virtual ~InferShapeContextBase() {} | ||
virtual framework::DDim get_input_dim(const std::string& name) const = 0; | ||
virtual void set_input_dim(const std::string& name, | ||
const framework::DDim& dim) const = 0; | ||
virtual framework::DDim get_output_dim(const std::string& name) const = 0; | ||
virtual void set_output_dim(const std::string& name, | ||
const DDim& dim) const = 0; | ||
virtual AttrReader attrs() const = 0; | ||
|
||
protected: | ||
virtual framework::DDim get_dim(const std::string& name) const = 0; | ||
virtual void set_dim(const std::string& name, | ||
const framework::DDim& dim) const = 0; | ||
}; | ||
|
||
inline void NonFn(const framework::InferShapeContextBase& ctx){}; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/ddim.h" | ||
#include "paddle/framework/operator.h" | ||
#include "paddle/framework/shape_inference.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
class CompileTimeInferShapeContext : public InferShapeContextBase { | ||
public: | ||
CompileTimeInferShapeContext(std::unique_ptr<OperatorBase>& op, | ||
std::map<std::string, VarDesc*>& var_descs) | ||
: op_(std::move(op)), var_descs_(var_descs) {} | ||
|
||
DDim get_input_dim(const std::string& name) const { | ||
return get_dim(op_->Input(name)); | ||
} | ||
|
||
void set_input_dim(const std::string& name, const DDim& dim) const { | ||
set_dim(op_->Input(name), dim); | ||
} | ||
|
||
DDim get_output_dim(const std::string& name) const { | ||
return get_dim(op_->Output(name)); | ||
} | ||
|
||
void set_output_dim(const std::string& name, const DDim& dim) const { | ||
set_dim(op_->Output(name), dim); | ||
} | ||
|
||
AttrReader attrs() const { return AttrReader(op_->Attrs()); } | ||
|
||
private: | ||
DDim get_dim(const std::string& name) const { | ||
VarDesc* desc = var_descs_.at(name); | ||
std::vector<int64_t> dim; | ||
int length = desc->lod_tensor().dims().size(); | ||
dim.reserve(length); | ||
std::copy(desc->lod_tensor().dims().begin(), | ||
desc->lod_tensor().dims().end(), std::back_inserter(dim)); | ||
return make_ddim(dim); | ||
} | ||
|
||
void set_dim(const std::string& name, const DDim& dim) const { | ||
VarDesc* desc = var_descs_.at(name); | ||
auto tensor = desc->mutable_lod_tensor(); | ||
tensor->clear_dims(); | ||
for (int i = 0; i < dim.size(); ++i) { | ||
tensor->add_dims(static_cast<int>(dim[i])); | ||
} | ||
} | ||
|
||
std::unique_ptr<OperatorBase> op_; | ||
std::map<std::string, VarDesc*>& var_descs_; | ||
}; | ||
|
||
class RunTimeInferShapeContext : public InferShapeContextBase { | ||
public: | ||
RunTimeInferShapeContext(const OperatorBase& op, const Scope& scope) | ||
: op_(op), scope_(scope) {} | ||
|
||
DDim get_input_dim(const std::string& name) const { | ||
return get_dim(op_.Input(name)); | ||
} | ||
|
||
void set_input_dim(const std::string& name, const DDim& dim) const { | ||
set_dim(op_.Input(name), dim); | ||
} | ||
|
||
DDim get_output_dim(const std::string& name) const { | ||
return get_dim(op_.Output(name)); | ||
} | ||
|
||
void set_output_dim(const std::string& name, const DDim& dim) const { | ||
set_dim(op_.Output(name), dim); | ||
} | ||
|
||
AttrReader attrs() const { return AttrReader(op_.Attrs()); } | ||
|
||
private: | ||
DDim get_dim(const std::string& name) const { | ||
Tensor* t = scope_.FindVar(op_.Input(name))->GetMutable<Tensor>(); | ||
return t->dims(); | ||
} | ||
|
||
void set_dim(const std::string& name, const DDim& dim) const { | ||
Tensor* t = scope_.FindVar(name)->GetMutable<Tensor>(); | ||
t->Resize(dim); | ||
} | ||
|
||
const OperatorBase& op_; | ||
const Scope& scope_; | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the purpose of this class is to add a method
Get
, who can find and read an entry in AttributeMap.How about we change the definition of AttributeMap from the current one
typedef std::unordered_map<std::string, Attribute> AttributeMap;
into
so could we define
Get
as a method of typeAttributeMap
; instead of adding a new typeAttrReader
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the problem is if we use a Class that inherit from map, then we can't directly use the list_initialization to init this AttributeMap,
like scale_op:
try to find a better way to do this in next PR