Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FC + elementwise_add (Residual connection) #40834

Closed
wants to merge 50 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
cb0bcbf
Change tensor name to match activation
Silv3S Feb 25, 2022
3e96cf3
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Feb 28, 2022
9e4cbaa
declare fc_eltwise_add pass
Silv3S Feb 28, 2022
04f376c
merge conv_eltwise refactor PR
Silv3S Mar 1, 2022
e12f39e
first compilable draft
Silv3S Mar 1, 2022
6c0b1b1
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Mar 1, 2022
5616fd0
unittest feedback tools
Silv3S Mar 3, 2022
df721dd
Fuse pass tester
Silv3S Mar 4, 2022
a3a7e73
Move IsReachable() to shared file
Silv3S Mar 7, 2022
dbf13d6
100% coverage of fuse_pass_tester.cc
Silv3S Mar 8, 2022
da2486e
register pass
Silv3S Mar 9, 2022
d654065
Add bias node
Silv3S Mar 10, 2022
a825073
Improve unit tests / remove bias node from pattern
Silv3S Mar 10, 2022
2cfdf8f
Merge branch 'develop' into residual
Silv3S Mar 11, 2022
6085296
improve fc_eltwiseadd_unittest
Silv3S Mar 11, 2022
9752b48
cancel eltwise_add fuse if act is already fused
Silv3S Mar 14, 2022
d4334a2
Add elementwise_input scale
Silv3S Mar 14, 2022
62bf136
Residual MVP
Silv3S Mar 16, 2022
3c30373
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Mar 16, 2022
960ce54
Add new FC attrs
Silv3S Mar 16, 2022
7c25aea
Add more test cases
Silv3S Mar 17, 2022
829a50a
Add missing op attrs
Silv3S Mar 21, 2022
dbd80b0
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Mar 21, 2022
0673cfe
Adapt code to new Elementwise pattern
Silv3S Mar 21, 2022
c039ba3
reuse existing fcpattern
Silv3S Mar 21, 2022
5d9a8d5
improve code style
Silv3S Mar 23, 2022
f88c5a6
remove unused arguments
Silv3S Mar 23, 2022
eacfbce
fix typo
Silv3S Mar 23, 2022
c10c603
remove whitespace
Silv3S Mar 23, 2022
33fd226
remove int8 related code
Silv3S Mar 25, 2022
c9c0415
Remove attributes from base ops
Silv3S Mar 28, 2022
da2ecf2
style
Silv3S Mar 28, 2022
f3bc7fd
style check
Silv3S Mar 28, 2022
22a7ae7
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Mar 28, 2022
8b074b3
Remove input from base op
Silv3S Mar 28, 2022
4e9931f
Set attribute during fuse
Silv3S Mar 29, 2022
fb92131
ut timeout
Silv3S Mar 29, 2022
4bd6e57
download and test model
Silv3S Mar 31, 2022
d67f551
DRY
Silv3S Apr 1, 2022
12a7068
Merge branch 'develop' into residual
Silv3S Apr 4, 2022
7fd091f
apply feedback from review
Silv3S Apr 4, 2022
2b16224
Style check
Silv3S Apr 4, 2022
902da8a
fix typo
Silv3S Apr 4, 2022
0313fed
cosmetic changes
Silv3S Apr 5, 2022
cbe267f
explicitly set residual as output
Silv3S Apr 8, 2022
d79515a
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Apr 8, 2022
acc5db2
VIT-OCR accuracy check
Silv3S Apr 13, 2022
5b4625c
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Apr 13, 2022
be67374
trigger CI
Silv3S Apr 13, 2022
bd4f21d
Merge branch 'residual' of https://github.com/Silv3S/Paddle into resi…
Silv3S Apr 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ if(WITH_MKLDNN)
pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(fc_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(scale_matmul_fuse_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_pass inference DIR mkldnn)
Expand Down Expand Up @@ -207,6 +208,7 @@ if (WITH_MKLDNN)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util)
cc_test(test_fc_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS fc_elementwise_add_mkldnn_fuse_pass pass_test_util)
cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util)
cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util)
set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context eigen_function)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/mkldnn/fc_elementwise_add_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_traits.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
namespace framework {
namespace ir {

FCResidualConnectionMKLDNNFusePass::FCResidualConnectionMKLDNNFusePass() {
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End();

AddOpCompat(OpCompat("elementwise_add"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsIntIn({-1, 0, 1})
.End();
}

GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC(
const std::string& name_scope, const GraphWithStats& graph_with_stats,
bool fc_as_x) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::FCMKLDNN fc_pattern{pattern, name_scope};
bool fc_has_bias = true;
auto fc_output = fc_pattern(
gpd.mutable_pattern()->NewNode("fc")->AsInput()->assert_is_op_input(
"fc", "Input"),
fc_has_bias);

patterns::ResidualElementwise elementwise_pattern{pattern, name_scope,
fc_as_x};
elementwise_pattern(
fc_output, pattern->NewNode(elementwise_pattern.residual_data_repr()),
"elementwise_add", fc_as_x);
fc_output->AsIntermediate();

int found_fc_count = 0;

auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_IR_NODE_FROM_SUBGRAPH(fc_op, fc, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_input, input, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_weights, weights, fc_pattern);
GET_IR_NODE_FROM_SUBGRAPH(fc_output, output, fc_pattern);

GET_IR_NODE_FROM_SUBGRAPH(elementwise_op, elementwise_op,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(residual_data, residual_data,
elementwise_pattern);
GET_IR_NODE_FROM_SUBGRAPH(elementwise_out, elementwise_out,
elementwise_pattern);

if (FindFuseOption(*fc_op, *elementwise_op) != FUSE_MKLDNN) return;
if (!IsReachable(g, residual_data, fc_output)) return;
if (HasFusedActivation(fc_op)) return;

if (!IsCompat(subgraph, g)) {
LOG(WARNING)
<< "op compat for fc_elementwise_add_mkldnn_fuse_pass failed.";
return;
}

fc_op->Op()->SetOutput("ResidualData", {residual_data->Name()});
fc_op->Op()->SetOutput("Out", {elementwise_out->Name()});
fc_op->Op()->SetAttr("fuse_residual_connection", true);

GraphSafeRemoveNodes(g, {fc_output, elementwise_op});

IR_NODE_LINK_TO(residual_data, fc_op);
IR_NODE_LINK_TO(fc_op, elementwise_out);

found_fc_count++;
};

gpd(graph_with_stats.first, handler);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
std::string fusionMode = fc_as_x ? "x" : "y";
msg_ss << "--- Fused " << found_fc_count << " fc (as " << fusionMode
<< ") + elementwise_add patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}

return std::make_pair(graph_with_stats.first,
found_fc_count + graph_with_stats.second);
}

void FCResidualConnectionMKLDNNFusePass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(name_scope_, graph);
auto graph_with_stats = FuseFC(name_scope_, std::make_pair(graph, 0), true);
graph_with_stats = FuseFC(name_scope_, graph_with_stats, false);

AddStatis(graph_with_stats.second);
}
} // namespace ir
} // namespace framework
} // namespace paddle

REGISTER_PASS(fc_elementwise_add_mkldnn_fuse_pass,
paddle::framework::ir::FCResidualConnectionMKLDNNFusePass);
REGISTER_PASS_CAPABILITY(fc_elementwise_add_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("fc", 0)
.LE("elementwise_add", 1));
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"

namespace paddle {
namespace framework {
namespace ir {

using GraphWithStats = std::pair<ir::Graph*, int>;

class FCResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
GraphWithStats FuseFC(const std::string& name_scope,
const GraphWithStats& graph_with_stats,
bool fc_as_x) const;

public:
FCResidualConnectionMKLDNNFusePass();
virtual ~FCResidualConnectionMKLDNNFusePass() {}

protected:
void ApplyImpl(ir::Graph* graph) const;

static bool HasFusedActivation(Node* fc_node) {
return !(
fc_node->Op()->GetAttrIfExists<std::string>("activation_type").empty());
}

const std::string name_scope_{"fc_elementwise_add_mkldnn_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
Loading