Skip to content

Commit

Permalink
Refactoring InferShape (#3946)
Browse files Browse the repository at this point in the history
* init Infershape

* add static InferShape interface

* refactor add-op infershape

* add AttrReader

* add all maker's infershape

* add all InferShape

* add python infer api

* add VarDesc interface

* add python VarDesc and OpDesc interface

* update python code

* use infershape function to do shape inference

* clean code

* do not use pointer

* refine code of op_proto_maker

* add get_dims to VarDesc

* refine the code

* remove the dependency from operator to op registry

* remove OpProtoAndCheckerMaker from operator

* restore complete_add_op

* add shape_infer_impl.h

* code optimization

* remove const return value

* add fake BlockDesc class

* optimize code

* remove infer function in op_info

* move InferShapeContextImpl to operator.h

* optimize the interface of InferShapeContextBase

* add temperary interface of new infershape

* change add_op, clip_op, conv2d_op and activation_op

* change all operators InferShape

* fix SetDim

* update cos_sim_op

* update crop_op

* update lookup_table_op

* allocate tensor when call GetDim in InferShapeContext

* update modified_huber_loss_op

* update rowwise_add_op

* update mean_op

* update sequence_avg_pool_op

* typo

* remove old InferShape interface

* can compile

* fix or unit test

* clean code

* clean code

* remove const before InferShapeContext

* change InferenceContextBase to pointer

* rename RunTime to Runtime, code clean
  • Loading branch information
jacquesqiao committed Sep 26, 2017
1 parent 8635103 commit 9a9d50a
Show file tree
Hide file tree
Showing 46 changed files with 975 additions and 864 deletions.
15 changes: 15 additions & 0 deletions paddle/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ inline AttrType AttrTypeID() {

Attribute GetAttrValue(const OpDesc::Attr& attr_desc);

class AttrReader {
public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}

template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name);
return boost::get<T>(attrs_.at(name));
}

private:
const AttributeMap& attrs_;
};

// check whether a value(attribute) fit a certain limit
template <typename T>
class GreaterThanChecker {
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/op_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,4 @@ TEST(OpRegistry, CustomChecker) {
op->Run(scope, dev_ctx);
int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4);
}
}
19 changes: 18 additions & 1 deletion paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ limitations under the License. */

#include "paddle/framework/operator.h"
#include <algorithm>
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace framework {
Expand All @@ -33,6 +32,24 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
}
#endif

const Tensor* GetTensorFromVar(const Variable* var) {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}

Tensor* GetTensorFromVar(Variable* var) {
if (var->IsType<LoDTensor>()) {
return var->GetMutable<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input must be LoDTensor or Tensor.");
return var->GetMutable<Tensor>();
}

std::string OperatorBase::Input(const std::string& name) const {
auto& ins = Inputs(name);
PADDLE_ENFORCE_LE(ins.size(), 1UL,
Expand Down
91 changes: 80 additions & 11 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
Expand Down Expand Up @@ -56,6 +57,9 @@ class OperatorBase;
class InferShapeContext;
class ExecutionContext;

extern const Tensor* GetTensorFromVar(const Variable* var);
extern Tensor* GetTensorFromVar(Variable* var);

/**
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
Expand Down Expand Up @@ -262,15 +266,6 @@ class InferShapeContext {
return res;
}

const Tensor* GetTensorFromVar(const Variable* var) const {
if (var->IsType<LoDTensor>()) {
return &var->Get<LoDTensor>();
}
PADDLE_ENFORCE(var->IsType<Tensor>(),
"The Input(%s) must be LoDTensor or Tensor.");
return &var->Get<Tensor>();
}

void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const {
PADDLE_ENFORCE_LT(i, InputSize(in));
Expand Down Expand Up @@ -340,6 +335,78 @@ class ExecutionContext : public InferShapeContext {
const platform::DeviceContext& device_context_;
};

class RuntimeInferShapeContext : public InferShapeContextBase {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}

bool HasInput(const std::string& name) const {
auto ipt = op_.Input(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}

bool HasOutput(const std::string& name) const {
auto ipt = op_.Output(name);
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
}

DDim GetInputDim(const std::string& name) const {
return GetDim(op_.Input(name));
}

void SetInputDim(const std::string& name, const DDim& dim) {
SetDim(op_.Input(name), dim);
}

DDim GetOutputDim(const std::string& name) const {
return GetDim(op_.Output(name));
}

void SetOutputDim(const std::string& name, const DDim& dim) {
SetDim(op_.Output(name), dim);
}

AttrReader Attrs() const { return AttrReader(op_.Attrs()); }

const std::vector<std::string>& Inputs(const std::string& name) const {
return op_.Inputs(name);
}

const std::vector<std::string>& Outputs(const std::string& name) const {
return op_.Outputs(name);
}

private:
template <bool Allocate>
Tensor* GetTensor(const std::string& name) const {
Tensor* t = nullptr;
auto* var = scope_.FindVar(name);
if (!var->IsType<LoDTensor>() && !var->IsType<Tensor>()) {
if (Allocate) {
t = var->GetMutable<LoDTensor>();
} else {
PADDLE_THROW("Variable(%s) should be tensor", name);
}
} else {
t = GetTensorFromVar(scope_.FindVar(name));
}
return t;
}

DDim GetDim(const std::string& name) const {
return GetTensor<false>(name)->dims();
}

void SetDim(const std::string& name, const DDim& dim) {
GetTensor<true>(name)->Resize(dim);
}

const OperatorBase& op_;
const Scope& scope_;
};

class OpKernel {
public:
/**
Expand Down Expand Up @@ -383,8 +450,10 @@ class OperatorWithKernel : public OperatorBase {
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}

// runtime infershape
void InferShape(const Scope& scope) const override {
InferShape(InferShapeContext(*this, scope));
auto c = RuntimeInferShapeContext(*this, scope);
InferShape(&c);
}

void Run(const Scope& scope,
Expand All @@ -406,7 +475,7 @@ class OperatorWithKernel : public OperatorBase {
}

protected:
virtual void InferShape(const InferShapeContext& ctx) const = 0;
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
};

} // namespace framework
Expand Down
3 changes: 2 additions & 1 deletion paddle/framework/operator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/framework/operator.h"
#include "gtest/gtest.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/op_registry.h"

namespace paddle {
Expand Down Expand Up @@ -114,7 +115,7 @@ class OpWithKernelTest : public OperatorWithKernel {
using OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext& ctx) const override {}
void InferShape(framework::InferShapeContextBase* ctx) const override {}
};

template <typename T1, typename T2>
Expand Down
82 changes: 82 additions & 0 deletions paddle/framework/shape_inference.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/ddim.h"

namespace paddle {
namespace framework {

class InferShapeContextBase {
public:
virtual ~InferShapeContextBase() {}
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
virtual framework::DDim GetInputDim(const std::string &name) const = 0;
std::vector<framework::DDim> GetInputsDim(const std::string &name) const {
const std::vector<std::string> &names = Inputs(name);
return GetDims(names);
}
virtual void SetInputDim(const std::string &name,
const framework::DDim &dim) = 0;
void SetInputsDim(const std::string &name,
const std::vector<framework::DDim> &dims) {
auto &names = Inputs(name);
SetDims(names, dims);
}
virtual framework::DDim GetOutputDim(const std::string &name) const = 0;
std::vector<framework::DDim> GetOutputsDim(const std::string &name) const {
const std::vector<std::string> &names = Outputs(name);
return GetDims(names);
}
virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
void SetOutputsDim(const std::string &name,
const std::vector<framework::DDim> &dims) {
auto &names = Outputs(name);
SetDims(names, dims);
}
virtual AttrReader Attrs() const = 0;
virtual const std::vector<std::string> &Inputs(
const std::string &name) const = 0;
virtual const std::vector<std::string> &Outputs(
const std::string &name) const = 0;
// TODO(qiao) implement this function
void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
size_t j = 0) const {}

protected:
virtual framework::DDim GetDim(const std::string &name) const = 0;
virtual void SetDim(const std::string &name, const framework::DDim &dim) = 0;
std::vector<framework::DDim> GetDims(
const std::vector<std::string> &names) const {
std::vector<framework::DDim> ret;
ret.reserve(names.size());
std::transform(
names.begin(), names.end(), std::back_inserter(ret),
[this](const std::string &name) { return this->GetDim(name); });
return ret;
}
void SetDims(const std::vector<std::string> &names,
const std::vector<framework::DDim> &dims) {
size_t length = names.size();
PADDLE_ENFORCE_EQ(length, dims.size());
for (size_t i = 0; i < length; ++i) {
SetDim(names[i], dims[i]);
}
}
};

} // namespace framework
} // namespace paddle
28 changes: 13 additions & 15 deletions paddle/operators/accuracy_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,23 @@ class AccuracyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(
ctx.InputVar("Inference"),
"Input(Inference) of AccuracyOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) of AccuracyOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(
ctx.OutputVar("Accuracy"),
"Output(Accuracy) of AccuracyOp should not be null.");
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Inference"),
"Input(Inference) of AccuracyOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
"Input(Label) of AccuracyOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Accuracy"),
"Output(Accuracy) of AccuracyOp should not be null.");

auto *inference = ctx.Input<framework::Tensor>("Inference");
auto *label = ctx.Input<framework::Tensor>("Label");
auto inference_dim = ctx->GetInputDim("Inference");
auto label_dim = ctx->GetInputDim("Label");

PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label must be a vector");
PADDLE_ENFORCE_EQ(inference->dims()[0], label->dims()[0],
PADDLE_ENFORCE_EQ(label_dim.size(), 1, "label must be a vector");
PADDLE_ENFORCE_EQ(inference_dim[0], label_dim[0],
"inference size must be the same as label size");

ctx.Output<framework::Tensor>("Accuracy")->Resize({1});
ctx.ShareLoD("Inference", /*->*/ "Accuracy");
ctx->SetOutputDim("Accuracy", {1});
ctx->ShareLoD("Inference", /*->*/ "Accuracy");
}
};

Expand Down
12 changes: 5 additions & 7 deletions paddle/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>("Y")->Resize(
ctx.Input<framework::Tensor>("X")->dims());
ctx.ShareLoD("X", /*->*/ "Y");
void InferShape(framework::InferShapeContextBase *ctx) const override {
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Y");
}
};

Expand All @@ -34,9 +33,8 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::Tensor>(framework::GradVarName("X"))
->Resize(ctx.Input<framework::Tensor>("Y")->dims());
void InferShape(framework::InferShapeContextBase *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
}
};

Expand Down
Loading

0 comments on commit 9a9d50a

Please sign in to comment.