-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* init Infershape * add static InferShape interface * refactor add-op infershape * add AttrReader * add all maker's infershape * add all InferShape * add python infer api * add VarDesc interface * add python VarDesc and OpDesc interface * update python code * use infershape function to do shape inference * clean code * do not use pointer * refine code of op_proto_maker * add get_dims to VarDesc * refine the code * remove the dependency from operator to op registry * remove OpProtoAndCheckerMaker from operator * restore complete_add_op * add shape_infer_impl.h * code optimization * remove const return value * add fake BlockDesc class * optimize code * remove infer function in op_info * move InferShapeContextImpl to operator.h * optimize the interface of InferShapeContextBase * add temperary interface of new infershape * change add_op, clip_op, conv2d_op and activation_op * change all operators InferShape * fix SetDim * update cos_sim_op * update crop_op * update lookup_table_op * allocate tensor when call GetDim in InferShapeContext * update modified_huber_loss_op * update rowwise_add_op * update mean_op * update sequence_avg_pool_op * typo * remove old InferShape interface * can compile * fix or unit test * clean code * clean code * remove const before InferShapeContext * change InferenceContextBase to pointer * rename RunTime to Runtime, code clean
- Loading branch information
1 parent
8635103
commit 9a9d50a
Showing
46 changed files
with
975 additions
and
864 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/ddim.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
class InferShapeContextBase { | ||
public: | ||
virtual ~InferShapeContextBase() {} | ||
virtual bool HasInput(const std::string &name) const = 0; | ||
virtual bool HasOutput(const std::string &name) const = 0; | ||
virtual framework::DDim GetInputDim(const std::string &name) const = 0; | ||
std::vector<framework::DDim> GetInputsDim(const std::string &name) const { | ||
const std::vector<std::string> &names = Inputs(name); | ||
return GetDims(names); | ||
} | ||
virtual void SetInputDim(const std::string &name, | ||
const framework::DDim &dim) = 0; | ||
void SetInputsDim(const std::string &name, | ||
const std::vector<framework::DDim> &dims) { | ||
auto &names = Inputs(name); | ||
SetDims(names, dims); | ||
} | ||
virtual framework::DDim GetOutputDim(const std::string &name) const = 0; | ||
std::vector<framework::DDim> GetOutputsDim(const std::string &name) const { | ||
const std::vector<std::string> &names = Outputs(name); | ||
return GetDims(names); | ||
} | ||
virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0; | ||
void SetOutputsDim(const std::string &name, | ||
const std::vector<framework::DDim> &dims) { | ||
auto &names = Outputs(name); | ||
SetDims(names, dims); | ||
} | ||
virtual AttrReader Attrs() const = 0; | ||
virtual const std::vector<std::string> &Inputs( | ||
const std::string &name) const = 0; | ||
virtual const std::vector<std::string> &Outputs( | ||
const std::string &name) const = 0; | ||
// TODO(qiao) implement this function | ||
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0, | ||
size_t j = 0) const {} | ||
|
||
protected: | ||
virtual framework::DDim GetDim(const std::string &name) const = 0; | ||
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0; | ||
std::vector<framework::DDim> GetDims( | ||
const std::vector<std::string> &names) const { | ||
std::vector<framework::DDim> ret; | ||
ret.reserve(names.size()); | ||
std::transform( | ||
names.begin(), names.end(), std::back_inserter(ret), | ||
[this](const std::string &name) { return this->GetDim(name); }); | ||
return ret; | ||
} | ||
void SetDims(const std::vector<std::string> &names, | ||
const std::vector<framework::DDim> &dims) { | ||
size_t length = names.size(); | ||
PADDLE_ENFORCE_EQ(length, dims.size()); | ||
for (size_t i = 0; i < length; ++i) { | ||
SetDim(names[i], dims[i]); | ||
} | ||
} | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.