Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Inference] add constant folding pass #45494

Merged
merged 53 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
d43724e
first commit
zhoutianzi666 Jul 24, 2022
be08d0f
commit
zhoutianzi666 Jul 25, 2022
f20477b
commit
zhoutianzi666 Jul 25, 2022
a6e7297
commit
zhoutianzi666 Jul 27, 2022
ba3b70b
commit
zhoutianzi666 Jul 28, 2022
695c947
commit
zhoutianzi666 Jul 29, 2022
cd594be
commit
zhoutianzi666 Jul 29, 2022
4b5c900
commit
zhoutianzi666 Aug 1, 2022
ab0b526
commit
zhoutianzi666 Aug 1, 2022
11481bf
commit
zhoutianzi666 Aug 2, 2022
f897c8e
commit
zhoutianzi666 Aug 2, 2022
473d2ac
commit
zhoutianzi666 Aug 3, 2022
6b8dc57
support fill_constant
zhoutianzi666 Aug 8, 2022
c2b1dd2
support fill_constant
zhoutianzi666 Aug 8, 2022
36ab05a
support fill_constant
zhoutianzi666 Aug 8, 2022
9c01f94
support fill_constant
zhoutianzi666 Aug 8, 2022
53e9a87
add variable in node.h
zhoutianzi666 Aug 9, 2022
dbb6595
more common folding
zhoutianzi666 Aug 9, 2022
041321e
remove debug code
zhoutianzi666 Aug 9, 2022
ed8d4a9
Merge branch 'develop' into constant_folding
zhoutianzi666 Aug 9, 2022
5b63ea4
remove debug code
zhoutianzi666 Aug 9, 2022
6819c24
only unsqueezze2 folding
zhoutianzi666 Aug 17, 2022
e7b9859
Merge branch 'develop' into constant_folding
zhoutianzi666 Aug 17, 2022
b423ef7
merge develop
zhoutianzi666 Aug 22, 2022
858f0d3
merge develop
zhoutianzi666 Aug 22, 2022
3f52b8c
merge develop
zhoutianzi666 Aug 22, 2022
9a8c34a
Merge branch 'develop' into constant_folding
zhoutianzi666 Aug 23, 2022
ad8a90d
clean code
zhoutianzi666 Aug 24, 2022
5b07d1d
Merge branch 'constant_folding' of https://github.com/zhoutianzi666/P…
zhoutianzi666 Aug 24, 2022
26398a4
clean code
zhoutianzi666 Aug 24, 2022
a7d8ffe
clean code
zhoutianzi666 Aug 24, 2022
b9c3935
clean code
zhoutianzi666 Aug 24, 2022
5272e3f
clean code
zhoutianzi666 Aug 24, 2022
abe7009
clean code
zhoutianzi666 Aug 25, 2022
9baaf5c
clean code
zhoutianzi666 Aug 25, 2022
42a5485
clean code
zhoutianzi666 Aug 25, 2022
adcd21b
clean code
zhoutianzi666 Aug 25, 2022
7540b9d
clean code
zhoutianzi666 Aug 25, 2022
8efa96e
clean code
zhoutianzi666 Aug 25, 2022
a3dc3d1
clean code
zhoutianzi666 Aug 25, 2022
fc03d5c
clean code
zhoutianzi666 Aug 25, 2022
3383a75
clean code
zhoutianzi666 Aug 25, 2022
e2585a5
clean code
zhoutianzi666 Aug 25, 2022
72cbfc7
clean code
zhoutianzi666 Aug 25, 2022
7204b4c
clean code
zhoutianzi666 Aug 26, 2022
2da81dc
clean code
zhoutianzi666 Aug 26, 2022
d73c0cd
clean code
zhoutianzi666 Aug 26, 2022
317f31b
clean code
zhoutianzi666 Aug 28, 2022
9f6b47b
clean code
zhoutianzi666 Aug 28, 2022
43de7e8
clean code
zhoutianzi666 Aug 28, 2022
1758ed7
clean code
zhoutianzi666 Aug 28, 2022
20b3695
clean code
zhoutianzi666 Aug 28, 2022
8228590
clean code
zhoutianzi666 Aug 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ pass_library(delete_dropout_op_pass inference)
pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference)
pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
Expand Down
159 changes: 159 additions & 0 deletions paddle/fluid/framework/ir/constant_folding_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/ir/constant_folding_pass.h"
#include <string>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"

#include "paddle/fluid/framework/convert_utils.h"

namespace paddle {
namespace framework {
namespace ir {
class Node;
} // namespace ir
} // namespace framework
} // namespace paddle

/*
* When a op's inputs and outputs is determined before feeding data to the
* model, we can remove this op from the model. This ConstantFolding pass can
* remove all these like ops.
*
*/

namespace paddle {
namespace framework {
namespace ir {
namespace patterns {

struct ConstantFolding : public PatternBase {
ConstantFolding(PDPattern *pattern, const std::string &name_scope)
: PatternBase(pattern, name_scope, "constant_folding_pass") {}
};
} // namespace patterns

ConstantFoldingPass::ConstantFoldingPass() {}

void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
FusePassBase::Init("constant_folding", graph);
auto *scope = param_scope();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加nulltpr判断


PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::Fatal(
"scope must not be null when applying constant floding."));

// Now, I don't want to fold fill_constant op in Paddle-TRT
std::vector<std::string> blacklist{"fill_constant", "feed"};

auto op_node_sorted = framework::ir::TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(0));
for (auto *op_node : op_node_sorted) {
if (!op_node->IsOp()) continue;
if (std::find(blacklist.begin(), blacklist.end(), op_node->Name()) !=
blacklist.end())
continue;

bool input_persis = true;
// map is used to record how many time a name string occures in the whole
// graph's nodes
std::map<std::string, int> map;
for (auto in_node : op_node->inputs) {
map[in_node->Name()] = 0;
if (!in_node->Var()->Persistable()) {
input_persis = false;
}
}
for (auto out_node : op_node->outputs) {
map[out_node->Name()] = 0;
}
// Forbid other node in graph having the same name with nodes in map
for (auto iter : map) {
for (auto node : graph->Nodes()) {
if (node->IsVar() && node->Name() == iter.first) {
map[node->Name()]++;
if (map[node->Name()] > 1) {
input_persis = false;
}
}
}
}

framework::Scope *local_scope = new framework::Scope();
std::unordered_set<const paddle::framework::ir::Node *> remove_nodes;
std::unique_ptr<OperatorBase> op;

if (input_persis) {
for (auto in_node : op_node->inputs) {
local_scope->Var(in_node->Var()->Name());
local_scope->FindVar(in_node->Var()->Name())->GetMutable<LoDTensor>();
// This persistable input node is exclusive, and can be removed
if (in_node->outputs.size() == 1L) remove_nodes.emplace(in_node);

auto in_shape = in_node->Var()->GetShape();
auto *global_persis_x_tensor =
scope->FindVar(in_node->Name())->GetMutable<LoDTensor>();
auto *local_x_tensor =
local_scope->FindVar(in_node->Name())->GetMutable<LoDTensor>();
local_x_tensor->Resize(global_persis_x_tensor->dims());
*local_x_tensor = *global_persis_x_tensor;
}

op = paddle::framework::OpRegistry::CreateOp(*op_node->Op());
remove_nodes.emplace(op_node);
for (auto out_node : op_node->outputs) {
local_scope->Var(out_node->Var()->Name());
local_scope->FindVar(out_node->Var()->Name())->GetMutable<LoDTensor>();
// useless out_node can be removed, not need set it persistable !
if (out_node->outputs.size() == 0L) remove_nodes.emplace(out_node);
}
op->Run(*local_scope, platform::CPUPlace());
for (auto out_node : op_node->outputs) {
// this out_node is useless, do not set it persistable
if (out_node->outputs.size() == 0L) continue;
auto out_desc = out_node->Var();
auto out_name = out_desc->Name();
auto *local_out_tensor =
local_scope->FindVar(out_name)->GetMutable<LoDTensor>();
std::vector<int64_t> out_shape;
for (int64_t i = 0; i < local_out_tensor->dims().size(); i++) {
out_shape.push_back(local_out_tensor->dims()[i]);
}
out_desc->SetShape(out_shape);
out_desc->SetPersistable(true);
auto *global_out_tensor = scope->Var(out_name)->GetMutable<LoDTensor>();
*global_out_tensor = *local_out_tensor;
}
GraphSafeRemoveNodes(graph, remove_nodes);
}
delete local_scope;
}
}

} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(constant_folding_pass,
paddle::framework::ir::ConstantFoldingPass);
37 changes: 37 additions & 0 deletions paddle/fluid/framework/ir/constant_folding_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/fluid/framework/ir/fuse_pass_base.h"

namespace paddle {

namespace framework {
namespace ir {

class Graph;

class ConstantFoldingPass : public FusePassBase {
public:
ConstantFoldingPass();
virtual ~ConstantFoldingPass() {}

protected:
void ApplyImpl(ir::Graph* graph) const override;
};

} // namespace ir
} // namespace framework
} // namespace paddle
7 changes: 5 additions & 2 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ const std::vector<std::string> kTRTSubgraphPasses({
// "yolo_box_fuse_pass", //
"dense_fc_to_sparse_pass", //
"dense_multihead_matmul_to_sparse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
"constant_folding_pass",
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
Expand Down Expand Up @@ -213,6 +214,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass", //
"constant_folding_pass",
// following pass should be located in the last, since it will
// work on all fused ops.
"runtime_context_cache_pass"
Expand Down Expand Up @@ -276,6 +278,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", //
"is_test_pass", //
"constant_folding_pass",
// following pass should be located in the last, since
// it will work on all fused ops.
"runtime_context_cache_pass"});
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/inference/tests/api/analyzer_dam_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,16 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots,
input_slots->push_back(std::move(response_mask_tensor));
}

/*
* this model is unreasonable, it set a output tensor persistable, so
* ridiculous! so I disable constant_folding_pass
*/

void SetConfig(AnalysisConfig *cfg) {
cfg->SetModel(FLAGS_infer_model + "/__model__", FLAGS_infer_model + "/param");
cfg->SwitchSpecifyInputNames();
auto pass_builder = cfg->pass_builder();
pass_builder->DeletePass("constant_folding_pass");
cfg->SwitchIrOptim(true);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
namespace paddle {
namespace inference {

/*
* this model is unreasonable, it set a middle-tensor persistable, so
* ridiculous! so I disable constant_folding_pass
*/

using paddle::PaddleTensor;

#ifdef PADDLE_WITH_MKLDNN
Expand All @@ -25,6 +30,8 @@ void SetInt8Config(AnalysisConfig *cfg,
cfg->SetModel(FLAGS_infer_model);
cfg->EnableMKLDNN();
cfg->EnableMkldnnQuantizer();
auto pass_builder = cfg->pass_builder();
pass_builder->DeletePass("constant_folding_pass");
auto warmup_data = std::make_shared<std::vector<PaddleTensor>>(data);
cfg->mkldnn_quantizer_config()->SetWarmupData(warmup_data);
cfg->mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_batch_size);
Expand Down
20 changes: 16 additions & 4 deletions paddle/fluid/inference/tests/api/analyzer_ernie_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
namespace paddle {
namespace inference {

/*
* this model is unreasonable, it set a middle-tensor persistable, so
* ridiculous! so I disable constant_folding_pass
*/

using paddle::PaddleTensor;

void profile(bool use_mkldnn = false, bool use_gpu = false) {
AnalysisConfig config;

SetConfig(&config, use_mkldnn, use_gpu);

auto pass_builder = config.pass_builder();
pass_builder->DeletePass("constant_folding_pass");
std::vector<std::vector<PaddleTensor>> outputs;
std::vector<std::vector<PaddleTensor>> inputs;
LoadInputData(&inputs);
Expand All @@ -48,6 +54,9 @@ TEST(Analyzer_Ernie, fuse_statis) {
AnalysisConfig cfg;
SetConfig(&cfg);

auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass");

int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis(
Expand All @@ -70,7 +79,8 @@ void compare(bool use_mkldnn = false) {

AnalysisConfig cfg;
SetConfig(&cfg, use_mkldnn, false);

auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass");
CompareNativeAndAnalysis(
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), inputs);
}
Expand All @@ -84,7 +94,8 @@ TEST(Analyzer_ernie, compare_mkldnn) { compare(true /* use_mkldnn */); }
TEST(Analyzer_Ernie, compare_determine) {
AnalysisConfig cfg;
SetConfig(&cfg);

auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass");
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);
CompareDeterministic(reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
Expand All @@ -95,7 +106,8 @@ TEST(Analyzer_Ernie, compare_determine) {
TEST(Analyzer_Ernie, compare_results) {
AnalysisConfig cfg;
SetConfig(&cfg);

auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass");
std::vector<std::vector<PaddleTensor>> input_slots_all;
LoadInputData(&input_slots_all);

Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/inference/tests/api/analyzer_save_model_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@ int GetNumOps(const AnalysisConfig &cfg) {
return num_ops;
}

/*
* this model is unreasonable, it set a output tensor persistable, so
* ridiculous! so I disable constant_folding_pass
*/

TEST(Analyzer, save_model) {
AnalysisConfig cfg;
SetConfig(&cfg);
cfg.SetModel(FLAGS_infer_model + "/__model__", FLAGS_infer_model + "/param");

auto pass_builder = cfg.pass_builder();
pass_builder->DeletePass("constant_folding_pass");

// ensure the path being unique
std::string optimModelPath = FLAGS_infer_model + "/only_for_save_model_test";
MKDIR(optimModelPath.c_str());
Expand All @@ -49,6 +58,8 @@ TEST(Analyzer, save_model) {

AnalysisConfig cfg3;
SetConfig(&cfg3);
auto pass_builder3 = cfg3.pass_builder();
pass_builder3->DeletePass("constant_folding_pass");
cfg3.SetModel(optimModelPath + "/model", optimModelPath + "/params");
int fused_num_ops = GetNumOps(cfg3);
CHECK_LE(fused_num_ops, origin_num_ops);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) {
EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 185);
EXPECT_EQ(num_ops, 183);
}

} // namespace seq_pool1_tester
Expand Down