Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHERRY-PICK] Added caching to oneDNN FC and op+unsqueeze2 and op+reshape2 fuse passes #47690

Merged
merged 9 commits into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ if(WITH_MKLDNN)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
Expand Down
38 changes: 38 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,44 @@ PDNode *patterns::OperatorActivation::operator()(
return activation_out;
}

PDNode *patterns::OperatorUnsqueeze2::operator()(
const std::string &operator_type, const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr())
->assert_is_op(operator_type)
->assert_has_n_outputs(num_of_operator_outs);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_op_output(operator_type, "Out")
->assert_is_op_input("unsqueeze2");
auto *unsqueeze2_op =
pattern->NewNode(unsqueeze2_op_repr())->assert_is_op("unsqueeze2");
auto *unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr())
->AsOutput()
->assert_is_op_output("unsqueeze2");
preceding_op->LinksTo({preceding_op_out});
unsqueeze2_op->LinksFrom({preceding_op_out}).LinksTo({unsqueeze2_out});
return unsqueeze2_out;
}

PDNode *patterns::OperatorReshape2::operator()(const std::string &operator_type,
const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr())
->assert_is_op(operator_type)
->assert_has_n_outputs(num_of_operator_outs);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_op_output(operator_type, "Out")
->assert_is_op_input("reshape2");
auto *reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto *reshape2_out = pattern->NewNode(reshape2_out_repr())
->AsOutput()
->assert_is_op_output("reshape2");
preceding_op->LinksTo({preceding_op_out});
reshape2_op->LinksFrom({preceding_op_out}).LinksTo({reshape2_out});
return reshape2_out;
}

PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
Expand Down
26 changes: 26 additions & 0 deletions paddle/fluid/framework/ir/graph_pattern_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,32 @@ struct OperatorActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out);
};

struct OperatorUnsqueeze2 : public PatternBase {
OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_unsqueeze2") {}

PDNode* operator()(const std::string& operator_type,
const int num_of_outputs);

PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(unsqueeze2_op);
PATTERN_DECL_NODE(unsqueeze2_out);
};

struct OperatorReshape2 : public PatternBase {
OperatorReshape2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_reshape2") {}

PDNode* operator()(const std::string& operator_type,
const int num_of_outputs);

PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};

// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
Expand Down
52 changes: 34 additions & 18 deletions paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -14,9 +14,8 @@

#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"

#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"

namespace paddle {
Expand All @@ -26,20 +25,20 @@ namespace ir {
using string::PrettyLogDetail;

void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {
"gelu", "tanh", "sigmoid", "mish", "hard_swish"};
auto act_types = paddle::platform::GetSupportedActivations();

for (std::string act_type : act_types) FuseFCAct(graph, act_type);
for (auto act_type : act_types) FuseFCAct(graph, act_type);
}

void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_act", graph);
FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph);

GraphPatternDetector gpd;
patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act");
patterns::OperatorActivation fc_act_pattern(
gpd.mutable_pattern(), "fc_" + act_type + "_mkldnn_fuse_pass");
fc_act_pattern("fc", act_type);

int found_fc_act_count = 0;
Expand All @@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
"is used."));
}

auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}

if (act_type == "gelu" && act_op->HasAttr("approximate")) {
bool approximate = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate"));
std::string type = approximate ? "_tanh" : "_erf";
fc_op->SetAttr("activation_type", act_type + type);
std::string gelu_act_type =
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
: "gelu_erf";
fc_op->SetAttr("fuse_activation", gelu_act_type);
} else {
fc_op->SetAttr("activation_type", act_type);
fc_op->SetAttr("fuse_activation", act_type);
}
fc_op->SetAttr("use_mkldnn", true);

fc_op->SetAttr("use_mkldnn", true);
fc_op->SetOutput("Out", {act_out->Name()});

IR_OP_VAR_LINK(fc, act_out);
Expand All @@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,

gpd(graph, handler);
AddStatis(found_fc_act_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs"))
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_fc_act_count > 0)
PrettyLogDetail(
"--- fused %d fc with %s activation", found_fc_act_count, act_type);
}
Expand All @@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("fc", 0)
.LE("gelu", 0)
.LE("sigmoid", 0)
.LE("mish", 1)
.EQ("abs", 0)
.LE("clip", 1)
.EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0)
.LE("tanh", 0));
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
11 changes: 2 additions & 9 deletions paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,14 @@ namespace paddle {
namespace framework {
namespace ir {

/*
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported
* as an activation function.
*/
class FuseFCActOneDNNPass : public FusePassBase {
public:
virtual ~FuseFCActOneDNNPass() {}

protected:
void ApplyImpl(ir::Graph *graph) const override;
void ApplyImpl(Graph *graph) const override;

void FuseFCAct(ir::Graph *graph, const std::string &act_types) const;
void FuseFCAct(Graph *graph, const std::string &act_types) const;
};

} // namespace ir
Expand Down
28 changes: 14 additions & 14 deletions paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_tanh"), 0);
}
}
Expand Down Expand Up @@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_erf"), 0);
}
}
Expand Down Expand Up @@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu"), 0);
}
}
Expand Down Expand Up @@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("tanh"), 0);
}
}
Expand Down Expand Up @@ -213,9 +213,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("sigmoid"), 0);
}
}
Expand Down Expand Up @@ -246,9 +246,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("mish"), 0);
}
}
Expand Down Expand Up @@ -280,9 +280,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("hard_swish"), 0);
}
}
Expand Down
Loading