-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactoring InferShape #3946
Refactoring InferShape #3946
Changes from 51 commits
125a528
8ab2d86
1b7e6e3
9e1ba61
fcaea1f
676c7fd
703d6ce
570ebb2
80a1c63
41996d2
fc3b55c
4e7058e
267f0e3
129599d
92964d6
53eb75a
309765c
93903fb
872a570
fc3c095
8304c74
59fa374
afdfeb9
228ddf8
175abe6
060677f
91134b2
827cec7
28b5d0a
feb9b1d
df61245
fb39fb3
b421e91
5fa1188
509c40a
4364314
a3436ba
35ea282
7053581
9acdb74
80c785f
d9c46c2
c8e2aa8
3fbfe5d
e6ec26f
c778cec
5605a30
40f2b53
794d5df
dc0d153
4e55aae
32c1d29
a522653
cdf9bfc
e22d7c4
d2d500e
997cbc5
50cec66
01bfcdd
017b27c
f3bd1ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -195,6 +195,21 @@ std::vector<int64_t> vectorize(const DDim& ddim) { | |
return result; | ||
} | ||
|
||
std::string debug_str(const DDim& ddim) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I use it a lot when debug the code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed since we can use cout << ddim to print the debug string. |
||
auto ret = vectorize(ddim); | ||
std::stringstream ss; | ||
ss << "["; | ||
auto size = ddim.size(); | ||
for (size_t i = 0; i < size; ++i) { | ||
ss << ret[i]; | ||
if (i < size - 1) { | ||
ss << ", "; | ||
} | ||
} | ||
ss << "]"; | ||
return ss.str(); | ||
} | ||
|
||
struct ProductVisitor : public boost::static_visitor<int64_t> { | ||
template <int D> | ||
int64_t operator()(const Dim<D>& dim) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and | |
limitations under the License. */ | ||
|
||
#include "paddle/framework/operator.h" | ||
#include <algorithm> | ||
#include "paddle/framework/op_registry.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why change these lines. If it is not necessary, please leave them unchanged. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
namespace paddle { | ||
namespace framework { | ||
|
@@ -33,6 +31,15 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { | |
} | ||
#endif | ||
|
||
const Tensor* GetTensorFromVar(const Variable* var) { | ||
if (var->IsType<LoDTensor>()) { | ||
return &var->Get<LoDTensor>(); | ||
} | ||
PADDLE_ENFORCE(var->IsType<Tensor>(), | ||
"The Input must be LoDTensor or Tensor."); | ||
return &var->Get<Tensor>(); | ||
} | ||
|
||
std::string OperatorBase::Input(const std::string& name) const { | ||
auto& ins = Inputs(name); | ||
PADDLE_ENFORCE_LE(ins.size(), 1UL, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,7 @@ limitations under the License. */ | |
#include "paddle/framework/framework.pb.h" | ||
#include "paddle/framework/lod_tensor.h" | ||
#include "paddle/framework/scope.h" | ||
#include "paddle/framework/shape_inference.h" | ||
#include "paddle/framework/tensor.h" | ||
#include "paddle/platform/device_context.h" | ||
#include "paddle/platform/place.h" | ||
|
@@ -56,6 +57,8 @@ class OperatorBase; | |
class InferShapeContext; | ||
class ExecutionContext; | ||
|
||
extern const Tensor* GetTensorFromVar(const Variable* var); | ||
|
||
/** | ||
* OperatorBase has the basic element that Net will call to do computation. | ||
* Only CreateOperator from OpRegistry will new Operator directly. User | ||
|
@@ -262,15 +265,6 @@ class InferShapeContext { | |
return res; | ||
} | ||
|
||
const Tensor* GetTensorFromVar(const Variable* var) const { | ||
if (var->IsType<LoDTensor>()) { | ||
return &var->Get<LoDTensor>(); | ||
} | ||
PADDLE_ENFORCE(var->IsType<Tensor>(), | ||
"The Input(%s) must be LoDTensor or Tensor."); | ||
return &var->Get<Tensor>(); | ||
} | ||
|
||
private: | ||
const OperatorBase& op_; | ||
const Scope& scope_; | ||
|
@@ -347,6 +341,154 @@ template <> | |
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>( | ||
const std::string& name) const; | ||
|
||
class BlockDesc { | ||
public: | ||
explicit BlockDesc(const std::map<std::string, VarDesc*>& var_descs) | ||
: var_descs_(var_descs) {} | ||
~BlockDesc() {} | ||
|
||
VarDesc* GetVar(const std::string& name) const { | ||
PADDLE_ENFORCE(var_descs_.count(name) == 1, "%s must be in Block", name); | ||
return var_descs_.at(name); | ||
} | ||
|
||
bool HasVar(const std::string& name) const { | ||
return var_descs_.count(name) > 0; | ||
} | ||
|
||
private: | ||
const std::map<std::string, VarDesc*>& var_descs_; | ||
}; | ||
|
||
class CompileTimeInferShapeContext : public InferShapeContextBase { | ||
public: | ||
CompileTimeInferShapeContext(const OperatorBase& op, | ||
const BlockDesc& block_desc) | ||
: op_(op), block_desc_(block_desc) {} | ||
|
||
bool HasInput(const std::string& name) const { | ||
return block_desc_.HasVar(name); | ||
} | ||
|
||
bool HasOutput(const std::string& name) const { | ||
return block_desc_.HasVar(name); | ||
} | ||
|
||
DDim GetInputDim(const std::string& name) const { | ||
return GetDim(op_.Input(name)); | ||
} | ||
|
||
void SetInputDim(const std::string& name, const DDim& dim) const { | ||
SetDim(op_.Input(name), dim); | ||
} | ||
|
||
DDim GetOutputDim(const std::string& name) const { | ||
return GetDim(op_.Output(name)); | ||
} | ||
|
||
void SetOutputDim(const std::string& name, const DDim& dim) const { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Set methods should change private data members, so they cannot be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
SetDim(op_.Output(name), dim); | ||
} | ||
|
||
AttrReader Attrs() const { return AttrReader(op_.Attrs()); } | ||
|
||
const std::vector<std::string>& Inputs(const std::string& name) const { | ||
return op_.Inputs(name); | ||
} | ||
|
||
const std::vector<std::string>& Outputs(const std::string& name) const { | ||
return op_.Outputs(name); | ||
} | ||
|
||
private: | ||
DDim GetDim(const std::string& name) const { | ||
VarDesc* desc = block_desc_.GetVar(name); | ||
std::vector<int64_t> dim; | ||
int length = desc->lod_tensor().dims().size(); | ||
dim.reserve(length); | ||
std::copy(desc->lod_tensor().dims().begin(), | ||
desc->lod_tensor().dims().end(), std::back_inserter(dim)); | ||
return make_ddim(dim); | ||
} | ||
|
||
void SetDim(const std::string& name, const DDim& dim) const { | ||
VarDesc* desc = block_desc_.GetVar(name); | ||
auto tensor = desc->mutable_lod_tensor(); | ||
tensor->clear_dims(); | ||
for (int i = 0; i < dim.size(); ++i) { | ||
tensor->add_dims(static_cast<int>(dim[i])); | ||
} | ||
} | ||
|
||
const OperatorBase& op_; | ||
const BlockDesc& block_desc_; | ||
}; | ||
|
||
class RunTimeInferShapeContext : public InferShapeContextBase { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RunTime-> Runtime runtime is a word There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
public: | ||
RunTimeInferShapeContext(const OperatorBase& op, const Scope& scope) | ||
: op_(op), scope_(scope) {} | ||
|
||
bool HasInput(const std::string& name) const { | ||
auto ipt = op_.Input(name); | ||
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); | ||
return var != nullptr; | ||
} | ||
|
||
bool HasOutput(const std::string& name) const { | ||
auto ipt = op_.Output(name); | ||
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); | ||
return var != nullptr; | ||
} | ||
|
||
DDim GetInputDim(const std::string& name) const { | ||
return GetDim(op_.Input(name)); | ||
} | ||
|
||
void SetInputDim(const std::string& name, const DDim& dim) const { | ||
SetDim(op_.Input(name), dim); | ||
} | ||
|
||
DDim GetOutputDim(const std::string& name) const { | ||
return GetDim(op_.Output(name)); | ||
} | ||
|
||
void SetOutputDim(const std::string& name, const DDim& dim) const { | ||
SetDim(op_.Output(name), dim); | ||
} | ||
|
||
AttrReader Attrs() const { return AttrReader(op_.Attrs()); } | ||
|
||
const std::vector<std::string>& Inputs(const std::string& name) const { | ||
return op_.Inputs(name); | ||
} | ||
|
||
const std::vector<std::string>& Outputs(const std::string& name) const { | ||
return op_.Outputs(name); | ||
} | ||
|
||
private: | ||
Tensor* GetTensor(const std::string& name) const { | ||
Tensor* t; | ||
auto* var = scope_.FindVar(name); | ||
if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) { | ||
t = var->GetMutable<LoDTensor>(); | ||
} else { | ||
t = const_cast<Tensor*>(GetTensorFromVar(scope_.FindVar(name))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. const_cast should not be used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed |
||
} | ||
return t; | ||
} | ||
|
||
DDim GetDim(const std::string& name) const { return GetTensor(name)->dims(); } | ||
|
||
void SetDim(const std::string& name, const DDim& dim) const { | ||
GetTensor(name)->Resize(dim); | ||
} | ||
|
||
const OperatorBase& op_; | ||
const Scope& scope_; | ||
}; | ||
|
||
class OpKernel { | ||
public: | ||
/** | ||
|
@@ -390,8 +532,14 @@ class OperatorWithKernel : public OperatorBase { | |
const VariableNameMap& outputs, const AttributeMap& attrs) | ||
: OperatorBase(type, inputs, outputs, attrs) {} | ||
|
||
// runtime infershape | ||
void InferShape(const Scope& scope) const override { | ||
InferShape(InferShapeContext(*this, scope)); | ||
InferShape(RunTimeInferShapeContext(*this, scope)); | ||
} | ||
|
||
// compile time infershape | ||
void InferShape(const BlockDesc& block_desc) const { | ||
InferShape(CompileTimeInferShapeContext(*this, block_desc)); | ||
} | ||
|
||
void Run(const Scope& scope, | ||
|
@@ -413,7 +561,8 @@ class OperatorWithKernel : public OperatorBase { | |
} | ||
|
||
protected: | ||
virtual void InferShape(const InferShapeContext& ctx) const = 0; | ||
virtual void InferShape(const InferShapeContext& ctx) const {} | ||
virtual void InferShape(const InferShapeContextBase& ctx) const {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why there are two interfaces? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The old one is for compiling and test, has been removed after transform all the old operators. |
||
}; | ||
|
||
} // namespace framework | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that the purpose of this class is to add a method
Get
, who can find and read an entry in AttributeMap.How about we change the definition of AttributeMap from the current one
typedef std::unordered_map<std::string, Attribute> AttributeMap;
into
so could we define
Get
as a method of typeAttributeMap
; instead of adding a new typeAttrReader
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the problem is if we use a Class that inherit from map, then we can't directly use the list_initialization to init this AttributeMap,
like scale_op:
try to find a better way to do this in next PR