Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FC + elementwise_add (Residual connection) #40834

Closed
wants to merge 50 commits into from
Closed
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
cb0bcbf
Change tensor name to match activation
Silv3S Feb 25, 2022
3e96cf3
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Feb 28, 2022
9e4cbaa
declare fc_eltwise_add pass
Silv3S Feb 28, 2022
04f376c
merge conv_eltwise refactor PR
Silv3S Mar 1, 2022
e12f39e
first compilable draft
Silv3S Mar 1, 2022
6c0b1b1
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Mar 1, 2022
5616fd0
unittest feedback tools
Silv3S Mar 3, 2022
df721dd
Fuse pass tester
Silv3S Mar 4, 2022
a3a7e73
Move IsReachable() to shared file
Silv3S Mar 7, 2022
dbf13d6
100% coverage of fuse_pass_tester.cc
Silv3S Mar 8, 2022
da2486e
register pass
Silv3S Mar 9, 2022
d654065
Add bias node
Silv3S Mar 10, 2022
a825073
Improve unit tests / remove bias node from pattern
Silv3S Mar 10, 2022
2cfdf8f
Merge branch 'develop' into residual
Silv3S Mar 11, 2022
6085296
improve fc_eltwiseadd_unittest
Silv3S Mar 11, 2022
9752b48
cancel eltwise_add fuse if act is already fused
Silv3S Mar 14, 2022
d4334a2
Add elementwise_input scale
Silv3S Mar 14, 2022
62bf136
Residual MVP
Silv3S Mar 16, 2022
3c30373
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Mar 16, 2022
960ce54
Add new FC attrs
Silv3S Mar 16, 2022
7c25aea
Add more test cases
Silv3S Mar 17, 2022
829a50a
Add missing op attrs
Silv3S Mar 21, 2022
dbd80b0
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Mar 21, 2022
0673cfe
Adapt code to new Elementwise pattern
Silv3S Mar 21, 2022
c039ba3
reuse existing fcpattern
Silv3S Mar 21, 2022
5d9a8d5
improve code style
Silv3S Mar 23, 2022
f88c5a6
remove unused arguments
Silv3S Mar 23, 2022
eacfbce
fix typo
Silv3S Mar 23, 2022
c10c603
remove whitespace
Silv3S Mar 23, 2022
33fd226
remove int8 related code
Silv3S Mar 25, 2022
c9c0415
Remove attributes from base ops
Silv3S Mar 28, 2022
da2ecf2
style
Silv3S Mar 28, 2022
f3bc7fd
style check
Silv3S Mar 28, 2022
22a7ae7
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Mar 28, 2022
8b074b3
Remove input from base op
Silv3S Mar 28, 2022
4e9931f
Set attribute during fuse
Silv3S Mar 29, 2022
fb92131
ut timeout
Silv3S Mar 29, 2022
4bd6e57
download and test model
Silv3S Mar 31, 2022
d67f551
DRY
Silv3S Apr 1, 2022
12a7068
Merge branch 'develop' into residual
Silv3S Apr 4, 2022
7fd091f
apply feedback from review
Silv3S Apr 4, 2022
2b16224
Style check
Silv3S Apr 4, 2022
902da8a
fix typo
Silv3S Apr 4, 2022
0313fed
cosmetic changes
Silv3S Apr 5, 2022
cbe267f
explicitly set residual as output
Silv3S Apr 8, 2022
d79515a
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Apr 8, 2022
acc5db2
VIT-OCR accuracy check
Silv3S Apr 13, 2022
5b4625c
Merge branch 'PaddlePaddle:develop' into residual
Silv3S Apr 13, 2022
be67374
trigger CI
Silv3S Apr 13, 2022
bd4f21d
Merge branch 'residual' of https://github.com/Silv3S/Paddle into resi…
Silv3S Apr 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ if(WITH_MKLDNN)
pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_concat_relu_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(fc_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(scale_matmul_fuse_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_placement_pass inference DIR mkldnn)
pass_library(cpu_bfloat16_pass inference DIR mkldnn)
Expand Down Expand Up @@ -203,6 +204,7 @@ if (WITH_MKLDNN)
cc_test(test_conv_activation_mkldnn_fuse_pass SRCS mkldnn/conv_activation_mkldnn_fuse_pass_tester.cc DEPS conv_activation_mkldnn_fuse_pass)
cc_test(test_conv_concat_relu_mkldnn_fuse_pass SRCS mkldnn/conv_concat_relu_mkldnn_fuse_pass_tester.cc DEPS conv_concat_relu_mkldnn_fuse_pass)
cc_test(test_conv_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/conv_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS conv_elementwise_add_mkldnn_fuse_pass pass_test_util)
cc_test(test_fc_elementwise_add_mkldnn_fuse_pass SRCS mkldnn/fc_elementwise_add_mkldnn_fuse_pass_tester.cc DEPS fc_elementwise_add_mkldnn_fuse_pass pass_test_util)
cc_test(test_fc_act_mkldnn_fuse_pass SRCS mkldnn/fc_act_mkldnn_fuse_pass_tester.cc DEPS fc_act_mkldnn_fuse_pass pass_test_util)
cc_test(test_batch_norm_act_fuse_pass SRCS mkldnn/batch_norm_act_fuse_pass_tester.cc DEPS batch_norm_act_fuse_pass pass_test_util)
set(TEST_CONV_BN_PASS_DEPS conv_bn_fuse_pass graph_to_program_pass conv_op conv_transpose_op math_function im2col vol2col batch_norm_op gelu_op activation_op elementwise_add_op concat_and_split naive_executor device_context eigen_function)
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/framework/ir/fc_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ FCFusePass::FCFusePass() {
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End()
.AddAttr("fuse_residual_connection")
.IsOptional()
.End()
.AddAttr("Scale_in_eltwise")
.IsOptional()
.End();
}

Expand Down
49 changes: 49 additions & 0 deletions paddle/fluid/framework/ir/graph_traits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <list>
#include <map>

#include "paddle/fluid/framework/ir/graph_traits.h"

namespace paddle {
Expand All @@ -23,6 +26,52 @@ namespace ir {
//
class Node;

bool IsReachable(ir::Graph *graph, Node *from, Node *to) {
auto find_node = [](ir::Graph *graph, const Node *node) -> Node * {
for (auto n : graph->Nodes()) {
if (n == node) {
return n;
}
}

return nullptr;
};
Silv3S marked this conversation as resolved.
Show resolved Hide resolved

if (from == to) {
return true;
}

std::map<Node *, bool> visited;

for (auto &node : GraphTraits::DFS(*graph)) {
visited[&node] = false;
}

visited[from] = true;

std::list<Node *> queue;
queue.push_back(from);

while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();

if (!cur) return false;

for (auto n : cur->outputs) {
if (n == to) {
return true;
}

if (!visited[n]) {
visited[n] = true;
queue.push_back(n);
}
}
}
return false;
}

NodesDFSIterator::NodesDFSIterator(const std::vector<Node *> &source) {
for (auto *x : source) stack_.push(x);
}
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/graph_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace ir {
class Graph;
class Node;

bool IsReachable(ir::Graph *graph, Node *from, Node *to);

template <typename IteratorT>
class iterator_range {
IteratorT begin_, end_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
#include "paddle/fluid/framework/ir/mkldnn/conv_elementwise_add_mkldnn_fuse_pass.h"

#include <functional>
#include <list>
#include <map>
#include <memory>
#include <tuple>

Expand All @@ -28,52 +26,6 @@ namespace paddle {
namespace framework {
namespace ir {

bool IsReachable(ir::Graph* graph, Node* from, Node* to) {
auto find_node = [](ir::Graph* graph, const Node* node) -> Node* {
for (auto n : graph->Nodes()) {
if (n == node) {
return n;
}
}

return nullptr;
};

if (from == to) {
return true;
}

std::map<Node*, bool> visited;

for (auto& node : GraphTraits::DFS(*graph)) {
visited[&node] = false;
}

visited[from] = true;

std::list<Node*> queue;
queue.push_back(from);

while (!queue.empty()) {
auto cur = find_node(graph, queue.front());
queue.pop_front();

if (!cur) return false;

for (auto n : cur->outputs) {
if (n == to) {
return true;
}

if (!visited[n]) {
visited[n] = true;
queue.push_back(n);
}
}
}
return false;
}

template <typename T>
paddle::optional<T> HasAttribute(const Node& op, const std::string& attr) {
if (op.Op()->HasAttr(attr))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ namespace ir {

using GraphWithStats = std::pair<ir::Graph*, int>;

bool IsReachable(ir::Graph* graph, Node* from, Node* to);

class ResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
GraphWithStats FuseConvAsX(const std::string& name_scope,
Expand Down
Loading