From db7d129ec2b0d0eff4b9208749b6dca82697e27c Mon Sep 17 00:00:00 2001 From: Wangzheee <634486483@qq.com> Date: Wed, 9 Feb 2022 14:35:24 +0800 Subject: [PATCH] [Paddle-Inference] rebuild matmul pass: trt and gpu_cpu (#39369) * rebuild matmul pass: trt and gpu_cpu * rebuild matmul pass: trt and gpu_cpu * rebuild matmul pass: trt and gpu_cpu * rebuild matmul pass: trt and gpu_cpu --- paddle/fluid/framework/ir/CMakeLists.txt | 7 +- ...s.cc => gpu_cpu_map_matmul_to_mul_pass.cc} | 114 +-- ...ass.h => gpu_cpu_map_matmul_to_mul_pass.h} | 38 +- .../ir/trt_map_matmul_to_mul_pass.cc | 842 ++++++++++++++++++ .../framework/ir/trt_map_matmul_to_mul_pass.h | 130 +++ .../inference/api/paddle_pass_builder.cc | 36 +- .../fluid/inference/tensorrt/convert/fc_op.cc | 122 ++- .../quantization/quant2_int8_mkldnn_pass.py | 12 +- .../test_flatten2_matmul_fuse_pass.py | 2 +- .../inference/test_map_matmul_to_mul_pass.py | 2 +- .../test_map_matmul_v2_to_matmul_pass.py | 2 +- .../test_map_matmul_v2_to_mul_pass.py | 5 +- ...n_matmul_v2_transpose_reshape_fuse_pass.py | 2 +- .../test_reshape2_matmul_fuse_pass.py | 2 +- .../test_squeeze2_matmul_fuse_pass.py | 2 +- 15 files changed, 1178 insertions(+), 140 deletions(-) rename paddle/fluid/framework/ir/{map_matmul_to_mul_pass.cc => gpu_cpu_map_matmul_to_mul_pass.cc} (87%) rename paddle/fluid/framework/ir/{map_matmul_to_mul_pass.h => gpu_cpu_map_matmul_to_mul_pass.h} (79%) create mode 100644 paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc create mode 100644 paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.h diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 323e743087ffb..829f43effb6d2 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -62,7 +62,6 @@ pass_library(graph_to_program_pass base) pass_library(graph_viz_pass base) pass_library(lock_free_optimize_pass base DEPS string_helper) pass_library(fc_fuse_pass inference) -pass_library(map_matmul_to_mul_pass inference) pass_library(attention_lstm_fuse_pass inference) pass_library(fc_lstm_fuse_pass inference) pass_library(embedding_fc_lstm_fuse_pass inference) @@ -98,8 +97,14 @@ pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(layer_norm_fuse_pass inference) pass_library(add_support_int8_pass inference) pass_library(matmul_scale_fuse_pass inference) +pass_library(gpu_cpu_map_matmul_to_mul_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) + +if(WITH_TENSORRT) + pass_library(trt_map_matmul_to_mul_pass inference) +endif() + if(WITH_GPU OR WITH_ROCM) pass_library(cudnn_placement_pass base DEPS placement_pass_base) pass_library(embedding_eltwise_layernorm_fuse_pass inference) diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc similarity index 87% rename from paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc rename to paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc index 734f8957ad09e..1759d18761da3 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc +++ b/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.cc @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/map_matmul_to_mul_pass.h" +#include "paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.h" #include #include @@ -28,7 +28,7 @@ namespace ir { class Node; -MapMatmul2MulPass::MapMatmul2MulPass() { +GpuCpuMapMatmul2MulPass::GpuCpuMapMatmul2MulPass() { AddOpCompat(OpCompat("matmul")) .AddInput("X") .IsTensor() @@ -68,7 +68,7 @@ MapMatmul2MulPass::MapMatmul2MulPass() { .End(); } -MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() { +GpuCpuMapMatmulV2ToMulPass::GpuCpuMapMatmulV2ToMulPass() { AddOpCompat(OpCompat("matmul_v2")) .AddInput("X") .IsTensor() @@ -104,7 +104,7 @@ MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() { .End(); } -MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() { +GpuCpuMapMatmulV2ToMatmulPass::GpuCpuMapMatmulV2ToMatmulPass() { AddOpCompat(OpCompat("matmul_v2")) .AddInput("X") .IsTensor() @@ -143,7 +143,7 @@ MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() { .End(); } -Flatten2MatmulFusePass::Flatten2MatmulFusePass() { +GpuCpuFlatten2MatmulFusePass::GpuCpuFlatten2MatmulFusePass() { AddOpCompat(OpCompat("matmul")) .AddInput("X") .IsTensor() @@ -197,7 +197,7 @@ Flatten2MatmulFusePass::Flatten2MatmulFusePass() { .End(); } -Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() { +GpuCpuSqueeze2MatmulFusePass::GpuCpuSqueeze2MatmulFusePass() { AddOpCompat(OpCompat("matmul")) .AddInput("X") .IsTensor() @@ -251,10 +251,10 @@ Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() { .End(); } -void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { +void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - std::string name_scope = "map_matmul_to_mul_pass"; + std::string name_scope = "gpu_cpu_map_matmul_to_mul_pass"; FusePassBase::Init(name_scope, graph); GraphPatternDetector gpd; @@ -264,7 +264,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "map matmul to mul"; + VLOG(4) << "gpu_cpu map matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern); @@ -286,7 +286,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { if (flag) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "MapMatmul2MulPass in op compat failed."; + LOG(WARNING) << "GpuCpuMapMatmul2MulPass in op compat failed."; return; } OpDesc desc(matmul_op->Op()->Block()); @@ -311,7 +311,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ++found_count; if (!IsCompat(desc)) { - LOG(WARNING) << "MapMatmul2MulPass in out mul op compat failed."; + LOG(WARNING) << "GpuCpuMapMatmul2MulPass in out mul op compat failed."; return; } } @@ -321,10 +321,10 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_count); } -void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { +void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - std::string name_scope = "map_matmul_v2_to_mul_pass"; + std::string name_scope = "gpu_cpu_map_matmul_v2_to_mul_pass"; FusePassBase::Init(name_scope, graph); GraphPatternDetector gpd; @@ -335,7 +335,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(3) << "map matmul_v2 to mul"; + VLOG(3) << "gpu_cpu map matmul_v2 to mul"; GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x, matmul_v2_weight_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y, @@ -360,7 +360,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { if (flag) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "MapMatmulV2ToMulPass in op compat failed."; + LOG(WARNING) << "GpuCpuMapMatmulV2ToMulPass in op compat failed."; return; } OpDesc desc(matmul_v2_op->Op()->Block()); @@ -386,7 +386,8 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { ++found_count; if (!IsCompat(desc)) { - LOG(WARNING) << "MapMatmulV2ToMulPass in out mul op compat failed."; + LOG(WARNING) + << "GpuCpuMapMatmulV2ToMulPass in out mul op compat failed."; return; } } @@ -396,10 +397,10 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_count); } -void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { +void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - std::string name_scope = "map_matmul_v2_to_matmul_pass"; + std::string name_scope = "gpu_cpu_map_matmul_v2_to_matmul_pass"; FusePassBase::Init(name_scope, graph); GraphPatternDetector gpd; @@ -409,7 +410,7 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "map matmul_v2 to matmul"; + VLOG(4) << "gpu_cpu map matmul_v2 to matmul"; GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x, matmul_v2_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y, @@ -417,7 +418,7 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern); if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "MapMatmulV2ToMatmulPass in op compat failed."; + LOG(WARNING) << "GpuCpuMapMatmulV2ToMatmulPass in op compat failed."; return; } @@ -463,7 +464,8 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { ++found_count; if (!IsCompat(desc)) { - LOG(WARNING) << "MapMatmulV2ToMatmulPass in out matmul op compat failed."; + LOG(WARNING) + << "GpuCpuMapMatmulV2ToMatmulPass in out matmul op compat failed."; return; } }; @@ -472,10 +474,10 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_count); } -void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { +void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - std::string name_scope = "squeeze2_matmul_fuse_pass"; + std::string name_scope = "gpu_cpu_squeeze2_matmul_fuse_pass"; FusePassBase::Init(name_scope, graph); GraphPatternDetector gpd; @@ -485,7 +487,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "fuse squeeze2+matmul to mul"; + VLOG(4) << "gpu_cpu fuse squeeze2+matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); @@ -518,7 +520,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { if (flag) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Squeeze2MatmulFusePass in op compat failed."; + LOG(WARNING) << "GpuCpuSqueeze2MatmulFusePass in op compat failed."; return; } OpDesc desc(matmul_op->Op()->Block()); @@ -542,7 +544,8 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op}); ++found_count; if (!IsCompat(desc)) { - LOG(WARNING) << "Squeeze2MatmulFusePass in out mul op compat failed."; + LOG(WARNING) + << "GpuCpuSqueeze2MatmulFusePass in out mul op compat failed."; return; } } @@ -552,7 +555,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_count); } -Reshape2MatmulFusePass::Reshape2MatmulFusePass() { +GpuCpuReshape2MatmulFusePass::GpuCpuReshape2MatmulFusePass() { AddOpCompat(OpCompat("reshape2")) .AddInput("X") .IsTensor() @@ -614,10 +617,10 @@ Reshape2MatmulFusePass::Reshape2MatmulFusePass() { .End(); } -void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { +void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - std::string name_scope = "reshape2_matmul_fuse_pass"; + std::string name_scope = "gpu_cpu_reshape2_matmul_fuse_pass"; FusePassBase::Init(name_scope, graph); GraphPatternDetector gpd; @@ -627,7 +630,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "fuse reshape2+matmul to mul"; + VLOG(4) << "gpu_cpu fuse reshape2+matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); @@ -662,7 +665,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { if (flag) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Reshape2MatmulFusePass in op compat failed."; + LOG(WARNING) << "GpuCpuReshape2MatmulFusePass in op compat failed."; return; } OpDesc desc(matmul_op->Op()->Block()); @@ -680,7 +683,8 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { matmul_op->Op()->GetAttr("out_threshold")); } if (!IsCompat(desc)) { - LOG(WARNING) << "Reshape2MatmulFusePass in out mul op compat failed."; + LOG(WARNING) + << "GpuCpuReshape2MatmulFusePass in out mul op compat failed."; return; } auto mul_node = g->CreateOpNode(&desc); @@ -696,10 +700,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(found_count); } -void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { +void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); - std::string name_scope = "flatten2_matmul_fuse_pass"; + std::string name_scope = "gpu_cpu_flatten2_matmul_fuse_pass"; FusePassBase::Init(name_scope, graph); GraphPatternDetector gpd; @@ -709,7 +713,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { int found_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* g) { - VLOG(4) << "fuse flatten2+matmul to mul"; + VLOG(4) << "gpu_cpu fuse flatten2+matmul to mul"; GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern); GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); @@ -749,7 +753,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { if (pattern_found) { if (!IsCompat(subgraph, g)) { - LOG(WARNING) << "Flatten2MatmulFusePass in op compat failed."; + LOG(WARNING) << "GpuCpuFlatten2MatmulFusePass in op compat failed."; return; } OpDesc desc(matmul_op->Op()->Block()); @@ -774,7 +778,8 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ++found_count; if (!IsCompat(desc)) { - LOG(WARNING) << "Flatten2MatmulFusePass in out mul op compat failed."; + LOG(WARNING) + << "GpuCpuFlatten2MatmulFusePass in out mul op compat failed."; return; } } @@ -788,50 +793,51 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { } // namespace framework } // namespace paddle -REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass); -REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass) +REGISTER_PASS(gpu_cpu_map_matmul_to_mul_pass, + paddle::framework::ir::GpuCpuMapMatmul2MulPass); +REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_to_mul_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("matmul", 1) .EQ("mul", 0)); -REGISTER_PASS(map_matmul_v2_to_mul_pass, - paddle::framework::ir::MapMatmulV2ToMulPass); -REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass) +REGISTER_PASS(gpu_cpu_map_matmul_v2_to_mul_pass, + paddle::framework::ir::GpuCpuMapMatmulV2ToMulPass); +REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_mul_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .EQ("matmul_v2", 0) .EQ("mul", 0)); -REGISTER_PASS(map_matmul_v2_to_matmul_pass, - paddle::framework::ir::MapMatmulV2ToMatmulPass); -REGISTER_PASS_CAPABILITY(map_matmul_v2_to_matmul_pass) +REGISTER_PASS(gpu_cpu_map_matmul_v2_to_matmul_pass, + paddle::framework::ir::GpuCpuMapMatmulV2ToMatmulPass); +REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_matmul_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .EQ("matmul_v2", 0) .LE("matmul", 1)); -REGISTER_PASS(squeeze2_matmul_fuse_pass, - paddle::framework::ir::Squeeze2MatmulFusePass); -REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass) +REGISTER_PASS(gpu_cpu_squeeze2_matmul_fuse_pass, + paddle::framework::ir::GpuCpuSqueeze2MatmulFusePass); +REGISTER_PASS_CAPABILITY(gpu_cpu_squeeze2_matmul_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("matmul", 1) .EQ("squeeze2", 0) .EQ("mul", 0)); -REGISTER_PASS(reshape2_matmul_fuse_pass, - paddle::framework::ir::Reshape2MatmulFusePass); -REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass) +REGISTER_PASS(gpu_cpu_reshape2_matmul_fuse_pass, + paddle::framework::ir::GpuCpuReshape2MatmulFusePass); +REGISTER_PASS_CAPABILITY(gpu_cpu_reshape2_matmul_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("matmul", 1) .EQ("reshape2", 0) .EQ("mul", 0)); -REGISTER_PASS(flatten2_matmul_fuse_pass, - paddle::framework::ir::Flatten2MatmulFusePass); -REGISTER_PASS_CAPABILITY(flatten2_matmul_fuse_pass) +REGISTER_PASS(gpu_cpu_flatten2_matmul_fuse_pass, + paddle::framework::ir::GpuCpuFlatten2MatmulFusePass); +REGISTER_PASS_CAPABILITY(gpu_cpu_flatten2_matmul_fuse_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination() .LE("matmul", 1) diff --git a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h b/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.h similarity index 79% rename from paddle/fluid/framework/ir/map_matmul_to_mul_pass.h rename to paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.h index a924cd8ddf92c..e4ea1bc9c94be 100644 --- a/paddle/fluid/framework/ir/map_matmul_to_mul_pass.h +++ b/paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.h @@ -37,22 +37,22 @@ namespace ir { */ class Graph; -class MapMatmul2MulPass : public FusePassBase { +class GpuCpuMapMatmul2MulPass : public FusePassBase { public: - MapMatmul2MulPass(); - virtual ~MapMatmul2MulPass() {} + GpuCpuMapMatmul2MulPass(); + virtual ~GpuCpuMapMatmul2MulPass() {} protected: void ApplyImpl(Graph* graph) const override; }; /* - * Map matmul_v2 to mul, the same as MapMatmul2MulPass. + * Map matmul_v2 to mul, the same as GpuCpuMapMatmul2MulPass. */ -class MapMatmulV2ToMulPass : public FusePassBase { +class GpuCpuMapMatmulV2ToMulPass : public FusePassBase { public: - MapMatmulV2ToMulPass(); - virtual ~MapMatmulV2ToMulPass() {} + GpuCpuMapMatmulV2ToMulPass(); + virtual ~GpuCpuMapMatmulV2ToMulPass() {} protected: void ApplyImpl(Graph* graph) const override; @@ -61,10 +61,10 @@ class MapMatmulV2ToMulPass : public FusePassBase { /* * Map matmul_v2 to matmul, not supoort broadcast. */ -class MapMatmulV2ToMatmulPass : public FusePassBase { +class GpuCpuMapMatmulV2ToMatmulPass : public FusePassBase { public: - MapMatmulV2ToMatmulPass(); - virtual ~MapMatmulV2ToMatmulPass() {} + GpuCpuMapMatmulV2ToMatmulPass(); + virtual ~GpuCpuMapMatmulV2ToMatmulPass() {} protected: void ApplyImpl(Graph* graph) const override; @@ -89,10 +89,10 @@ class MapMatmulV2ToMatmulPass : public FusePassBase { * the above passes to reduce the impact on other models. */ -class Squeeze2MatmulFusePass : public FusePassBase { +class GpuCpuSqueeze2MatmulFusePass : public FusePassBase { public: - Squeeze2MatmulFusePass(); - virtual ~Squeeze2MatmulFusePass() {} + GpuCpuSqueeze2MatmulFusePass(); + virtual ~GpuCpuSqueeze2MatmulFusePass() {} protected: void ApplyImpl(Graph* graph) const override; @@ -119,19 +119,19 @@ class Squeeze2MatmulFusePass : public FusePassBase { * the above passes to reduce the impact on other models. */ -class Reshape2MatmulFusePass : public FusePassBase { +class GpuCpuReshape2MatmulFusePass : public FusePassBase { public: - Reshape2MatmulFusePass(); - virtual ~Reshape2MatmulFusePass() {} + GpuCpuReshape2MatmulFusePass(); + virtual ~GpuCpuReshape2MatmulFusePass() {} protected: void ApplyImpl(Graph* graph) const override; }; -class Flatten2MatmulFusePass : public FusePassBase { +class GpuCpuFlatten2MatmulFusePass : public FusePassBase { public: - Flatten2MatmulFusePass(); - virtual ~Flatten2MatmulFusePass() {} + GpuCpuFlatten2MatmulFusePass(); + virtual ~GpuCpuFlatten2MatmulFusePass() {} protected: void ApplyImpl(Graph* graph) const override; diff --git a/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc b/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc new file mode 100644 index 0000000000000..6c73965a80943 --- /dev/null +++ b/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.cc @@ -0,0 +1,842 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.h" + +#include +#include +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/op_proto_maker.h" + +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Node; + +TrtMapMatmul2MulPass::TrtMapMatmul2MulPass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + +TrtMapMatmulV2ToMulPass::TrtMapMatmulV2ToMulPass() { + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsBoolEQ(false) + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + +TrtMapMatmulV2ToMatmulPass::TrtMapMatmulV2ToMatmulPass() { + AddOpCompat(OpCompat("matmul_v2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("trans_x") + .IsType() + .End() + .AddAttr("trans_y") + .IsType() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumEQ(1.0f) + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("transpose_X") + .IsType() + .End() + .AddAttr("transpose_Y") + .IsType() + .End(); +} + +TrtFlatten2MatmulFusePass::TrtFlatten2MatmulFusePass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + + AddOpCompat(OpCompat("flatten2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axis") + .IsNumEQ(1) + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumGE(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + +TrtSqueeze2MatmulFusePass::TrtSqueeze2MatmulFusePass() { + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGE(0.99f) + .IsNumLE(1.01f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + + AddOpCompat(OpCompat("squeeze2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("axes") + .IsType>() + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + +void TrtMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "trt_map_matmul_to_mul_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::Matmul matmul_pattern(gpd.mutable_pattern(), name_scope); + matmul_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "trt map matmul to mul"; + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, matmul_pattern); + bool flag = true; + + bool transpose_X = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X")); + float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha")); + flag = flag && !transpose_X && std::abs(alpha - 1.0) < 1e-5; + + std::vector x_shape = matmul_in_x->Var()->GetShape(); + std::vector y_shape = matmul_in_y->Var()->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + flag = flag && x_rank >= 2 && y_rank == 2; + + if (flag) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "TrtMapMatmul2MulPass in op compat failed."; + return; + } + OpDesc desc(matmul_op->Op()->Block()); + desc.SetType("mul"); + desc.SetInput("X", {matmul_in_x->Name()}); + desc.SetInput("Y", {matmul_in_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("x_num_col_dims", static_cast(x_rank - 1)); + desc.SetAttr("y_num_col_dims", 1); + desc.SetAttr("transpose_Y", matmul_op->Op()->GetAttr("transpose_Y")); + if (matmul_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_op->Op()->GetAttr("out_threshold")); + } + auto mul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(matmul_in_x, mul_node); + IR_NODE_LINK_TO(matmul_in_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_out); + GraphSafeRemoveNodes(graph, {matmul_op}); + ++found_count; + + if (!IsCompat(desc)) { + LOG(WARNING) << "TrtMapMatmul2MulPass in out mul op compat failed."; + return; + } + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +void TrtMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "trt_map_matmul_v2_to_mul_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::MatmulV2Weight matmul_v2_weight_pattern(gpd.mutable_pattern(), + name_scope); + matmul_v2_weight_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(3) << "trt map matmul_v2 to mul"; + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x, + matmul_v2_weight_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y, + matmul_v2_weight_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, + matmul_v2_weight_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, + matmul_v2_weight_pattern); + + bool flag = true; + bool trans_x = + BOOST_GET_CONST(bool, matmul_v2_op->Op()->GetAttr("trans_x")); + flag = flag && !trans_x; + + std::vector x_shape = matmul_v2_in_x->Var()->GetShape(); + std::vector y_shape = matmul_v2_in_y->Var()->GetShape(); + size_t x_rank = x_shape.size(); + size_t y_rank = y_shape.size(); + flag = flag && x_rank >= 2 && y_rank == 2; + + if (flag) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "TrtMapMatmulV2ToMulPass in op compat failed."; + return; + } + OpDesc desc(matmul_v2_op->Op()->Block()); + desc.SetType("mul"); + desc.SetInput("X", {matmul_v2_in_x->Name()}); + desc.SetInput("Y", {matmul_v2_in_y->Name()}); + desc.SetOutput("Out", {matmul_v2_out->Name()}); + desc.SetAttr("x_num_col_dims", static_cast(x_rank - 1)); + desc.SetAttr("y_num_col_dims", 1); + desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y")); + if (matmul_v2_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", + matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_v2_op->Op()->GetAttr("out_threshold")); + } + auto mul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(matmul_v2_in_x, mul_node); + IR_NODE_LINK_TO(matmul_v2_in_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_v2_out); + GraphSafeRemoveNodes(graph, {matmul_v2_op}); + ++found_count; + + if (!IsCompat(desc)) { + LOG(WARNING) << "TrtMapMatmulV2ToMulPass in out mul op compat failed."; + return; + } + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +void TrtMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "trt_map_matmul_v2_to_matmul_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::MatmulV2 matmul_v2_pattern(gpd.mutable_pattern(), name_scope); + matmul_v2_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "trt map matmul_v2 to matmul"; + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x, + matmul_v2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y, + matmul_v2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern); + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "TrtMapMatmulV2ToMatmulPass in op compat failed."; + return; + } + + std::vector x_shape = matmul_v2_in_x->Var()->GetShape(); + std::vector y_shape = matmul_v2_in_y->Var()->GetShape(); + if (x_shape.size() != y_shape.size()) { + LOG(WARNING) + << "matmul op not support broadcast, please check inputs'shape. "; + return; + } + uint64_t dims = 2; + for (size_t i = 0; i < x_shape.size() - dims; ++i) { + if (x_shape[i] != y_shape[i] && (x_shape[i] == 1 || y_shape[i] == 1)) { + LOG(WARNING) << "matmul op not support broadcast, please check " + "inputs'shape[i]. "; + return; + } + } + + OpDesc desc(matmul_v2_op->Op()->Block()); + desc.SetType("matmul"); + desc.SetInput("X", {matmul_v2_in_x->Name()}); + desc.SetInput("Y", {matmul_v2_in_y->Name()}); + desc.SetOutput("Out", {matmul_v2_out->Name()}); + desc.SetAttr("transpose_X", matmul_v2_op->Op()->GetAttr("trans_x")); + desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y")); + desc.SetAttr("alpha", 1.0f); + if (matmul_v2_op->Op()->HasAttr("use_mkldnn")) { + desc.SetAttr("use_mkldnn", matmul_v2_op->Op()->GetAttr("use_mkldnn")); + } + if (matmul_v2_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_v2_op->Op()->GetAttr("out_threshold")); + } + auto matmul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(matmul_v2_in_x, matmul_node); + IR_NODE_LINK_TO(matmul_v2_in_y, matmul_node); + IR_NODE_LINK_TO(matmul_node, matmul_v2_out); + GraphSafeRemoveNodes(graph, {matmul_v2_op}); + ++found_count; + + if (!IsCompat(desc)) { + LOG(WARNING) + << "TrtMapMatmulV2ToMatmulPass in out matmul op compat failed."; + return; + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +void TrtSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "trt_squeeze2_matmul_fuse_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::Squeeze2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope); + fuse_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "trt fuse squeeze2+matmul to mul"; + GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern); + bool flag = true; + + size_t squeeze2_in_x_rank = (squeeze2_in_x->Var()->GetShape()).size(); + std::vector squeeze2_op_axes = + BOOST_GET_CONST(std::vector, squeeze2_op->Op()->GetAttr("axes")); + flag = flag && squeeze2_in_x_rank == 4 && + squeeze2_op_axes == std::vector{2, 3} && + (matmul_in_x->outputs).size() == 1; + + bool transpose_X = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X")); + bool transpose_Y = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y")); + float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha")); + size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size(); + size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size(); + flag = flag && !transpose_X && !transpose_Y && + std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 && + matmul_in_y_rank == 2; + + std::vector& next_ops = matmul_out->outputs; + flag = flag && next_ops.size() == 1 && + next_ops[0]->Name() == "elementwise_add"; + + if (flag) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "TrtSqueeze2MatmulFusePass in op compat failed."; + return; + } + OpDesc desc(matmul_op->Op()->Block()); + desc.SetType("mul"); + desc.SetInput("X", {squeeze2_in_x->Name()}); + desc.SetInput("Y", {matmul_in_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("x_num_col_dims", 1); + desc.SetAttr("y_num_col_dims", 1); + if (matmul_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_op->Op()->GetAttr("out_threshold")); + } + auto mul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(squeeze2_in_x, mul_node); + IR_NODE_LINK_TO(matmul_in_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_out); + GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op}); + ++found_count; + if (!IsCompat(desc)) { + LOG(WARNING) + << "TrtSqueeze2MatmulFusePass in out mul op compat failed."; + return; + } + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +TrtReshape2MatmulFusePass::TrtReshape2MatmulFusePass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsTensor() + .IsOptional() + .End() + .AddInput("ShapeTensor") + .IsTensor() + .IsOptional() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddAttr("shape") // ints + .IsType>() + .End(); + + AddOpCompat(OpCompat("matmul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("alpha") + .IsNumGT(0.99999f) + .IsNumLT(1.00001f) + .End() + .AddAttr("transpose_X") + .IsBoolEQ(false) + .End() + .AddAttr("transpose_Y") + .IsBoolEQ(false) + .End(); + + AddOpCompat(OpCompat("mul")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Y") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("x_num_col_dims") + .IsNumEQ(1) + .End() + .AddAttr("y_num_col_dims") + .IsNumEQ(1) + .End(); +} + +void TrtReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "trt_reshape2_matmul_fuse_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::Reshape2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope); + fuse_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "trt fuse reshape2+matmul to mul"; + GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern); + bool flag = true; + + size_t reshape2_in_nums = reshape2_op->inputs.size(); + auto reshape2_in_x_shape = reshape2_in_x->Var()->GetShape(); + size_t reshape2_in_x_rank = reshape2_in_x_shape.size(); + std::vector reshape2_op_shape = + BOOST_GET_CONST(std::vector, reshape2_op->Op()->GetAttr("shape")); + flag = flag && reshape2_in_nums == 1 && reshape2_in_x_rank == 4 && + reshape2_in_x_shape[2] == 1 && reshape2_in_x_shape[3] == 1 && + reshape2_op_shape.size() == 2 && (matmul_in_x->outputs).size() == 1; + + bool transpose_X = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X")); + bool transpose_Y = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y")); + float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha")); + size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size(); + size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size(); + flag = flag && !transpose_X && !transpose_Y && + std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 && + matmul_in_y_rank == 2; + + std::vector& next_ops = matmul_out->outputs; + flag = flag && next_ops.size() == 1 && + next_ops[0]->Name() == "elementwise_add"; + + if (flag) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "TrtReshape2MatmulFusePass in op compat failed."; + return; + } + OpDesc desc(matmul_op->Op()->Block()); + desc.SetType("mul"); + desc.SetInput("X", {reshape2_in_x->Name()}); + desc.SetInput("Y", {matmul_in_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("x_num_col_dims", 1); + desc.SetAttr("y_num_col_dims", 1); + if (matmul_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_op->Op()->GetAttr("out_threshold")); + } + if (!IsCompat(desc)) { + LOG(WARNING) + << "TrtReshape2MatmulFusePass in out mul op compat failed."; + return; + } + auto mul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(reshape2_in_x, mul_node); + IR_NODE_LINK_TO(matmul_in_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_out); + GraphSafeRemoveNodes(graph, {reshape2_op, matmul_in_x, matmul_op}); + ++found_count; + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +void TrtFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + std::string name_scope = "trt_flatten2_matmul_fuse_pass"; + FusePassBase::Init(name_scope, graph); + + GraphPatternDetector gpd; + patterns::Flatten2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope); + fuse_pattern(); + + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + VLOG(4) << "trt fuse flatten2+matmul to mul"; + GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern); + bool pattern_found = true; + + size_t flatten2_in_nums = flatten2_op->inputs.size(); + auto flatten2_in_x_shape = flatten2_in_x->Var()->GetShape(); + size_t flatten2_in_x_rank = flatten2_in_x_shape.size(); + int flatten2_axis = + BOOST_GET_CONST(int, flatten2_op->Op()->GetAttr("axis")); + // only convert matmul to mul when the flatten2 has a single input + // and the rank of input is 4 and the size of the output of matmul + // is 1. + pattern_found = pattern_found && flatten2_in_nums == 1 && + flatten2_in_x_rank == 4 && + (matmul_in_x->outputs).size() == 1; + + bool transpose_X = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X")); + bool transpose_Y = + BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y")); + float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha")); + size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size(); + size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size(); + pattern_found = pattern_found && !transpose_X && !transpose_Y && + std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 && + matmul_in_y_rank == 2; + + std::vector& next_ops = matmul_out->outputs; + // we further require the matmul op is followed by one elementwise + // add op. + pattern_found = pattern_found && next_ops.size() == 1 && + next_ops[0]->Name() == "elementwise_add"; + + if (pattern_found) { + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "TrtFlatten2MatmulFusePass in op compat failed."; + return; + } + OpDesc desc(matmul_op->Op()->Block()); + desc.SetType("mul"); + desc.SetInput("X", {flatten2_in_x->Name()}); + desc.SetInput("Y", {matmul_in_y->Name()}); + desc.SetOutput("Out", {matmul_out->Name()}); + desc.SetAttr("x_num_col_dims", flatten2_axis); + desc.SetAttr("y_num_col_dims", 1); + if (matmul_op->Op()->HasAttr("enable_int8")) { + desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); + desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); + desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale")); + desc.SetAttr("out_threshold", + matmul_op->Op()->GetAttr("out_threshold")); + } + auto mul_node = g->CreateOpNode(&desc); + IR_NODE_LINK_TO(flatten2_in_x, mul_node); + IR_NODE_LINK_TO(matmul_in_y, mul_node); + IR_NODE_LINK_TO(mul_node, matmul_out); + GraphSafeRemoveNodes(graph, {flatten2_op, matmul_in_x, matmul_op}); + ++found_count; + + if (!IsCompat(desc)) { + LOG(WARNING) + << "TrtFlatten2MatmulFusePass in out mul op compat failed."; + return; + } + } + }; + + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(trt_map_matmul_to_mul_pass, + paddle::framework::ir::TrtMapMatmul2MulPass); +REGISTER_PASS_CAPABILITY(trt_map_matmul_to_mul_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("matmul", 1) + .EQ("mul", 0)); + +REGISTER_PASS(trt_map_matmul_v2_to_mul_pass, + paddle::framework::ir::TrtMapMatmulV2ToMulPass); +REGISTER_PASS_CAPABILITY(trt_map_matmul_v2_to_mul_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul_v2", 0) + .EQ("mul", 0)); + +REGISTER_PASS(trt_map_matmul_v2_to_matmul_pass, + paddle::framework::ir::TrtMapMatmulV2ToMatmulPass); +REGISTER_PASS_CAPABILITY(trt_map_matmul_v2_to_matmul_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("matmul_v2", 0) + .LE("matmul", 1)); + +REGISTER_PASS(trt_squeeze2_matmul_fuse_pass, + paddle::framework::ir::TrtSqueeze2MatmulFusePass); +REGISTER_PASS_CAPABILITY(trt_squeeze2_matmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("matmul", 1) + .EQ("squeeze2", 0) + .EQ("mul", 0)); + +REGISTER_PASS(trt_reshape2_matmul_fuse_pass, + paddle::framework::ir::TrtReshape2MatmulFusePass); +REGISTER_PASS_CAPABILITY(trt_reshape2_matmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("matmul", 1) + .EQ("reshape2", 0) + .EQ("mul", 0)); + +REGISTER_PASS(trt_flatten2_matmul_fuse_pass, + paddle::framework::ir::TrtFlatten2MatmulFusePass); +REGISTER_PASS_CAPABILITY(trt_flatten2_matmul_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .LE("matmul", 1) + .EQ("flatten2", 0) + .EQ("mul", 0)); diff --git a/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.h b/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.h new file mode 100644 index 0000000000000..c382837f10a6c --- /dev/null +++ b/paddle/fluid/framework/ir/trt_map_matmul_to_mul_pass.h @@ -0,0 +1,130 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +class TrtMapMatmul2MulPass : public FusePassBase { + public: + TrtMapMatmul2MulPass(); + virtual ~TrtMapMatmul2MulPass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +/* + * Map matmul_v2 to mul, the same as TrtMapMatmul2MulPass. + */ +class TrtMapMatmulV2ToMulPass : public FusePassBase { + public: + TrtMapMatmulV2ToMulPass(); + virtual ~TrtMapMatmulV2ToMulPass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +/* + * Map matmul_v2 to matmul, not supoort broadcast. + */ +class TrtMapMatmulV2ToMatmulPass : public FusePassBase { + public: + TrtMapMatmulV2ToMatmulPass(); + virtual ~TrtMapMatmulV2ToMatmulPass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +/* + * Fuse squeeze2+matmul to mul, so the optimization can use fc_fuse_pass. + * The squeeze2 op must satisfy the following conditions: + * 1. the rank of input X is 4 + * 2. the axis attr is [2, 3] + * 3. the next op is only matmul + * + * The matmul op must satisfy the following conditions: + * 1. the transpose_X and transpose_Y attrs are false + * 2. the alpha attr is 1.0 + * 3. the rank of input X and Y is 2 + * 4. the next op of matmul is only elementwise_add + * + * Notice: + * the rank of input activation is obtained from var_desc, + * it maybe change in runtime. Therefore, the pass considers + * the above passes to reduce the impact on other models. + */ + +class TrtSqueeze2MatmulFusePass : public FusePassBase { + public: + TrtSqueeze2MatmulFusePass(); + virtual ~TrtSqueeze2MatmulFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +/* + * Fuse reshape2+matmul to mul, so the optimization can use fc_fuse_pass. + * The reshape2 op must satisfy the following conditions: + * 1. reshape2 has one input node, which means it don't + * have Shape or ShapeTensor input + * 2. the rank of input X is 4 and the last two dims of input X is 1 + * 3. the rank of shape attr is 2 + * 4. the next op is only matmul + * + * The matmul op must satisfy the following conditions: + * 1. the transpose_X and transpose_Y attrs are false + * 2. the alpha attr is 1.0 + * 3. the rank of input X and Y is 2 + * 4. the next op of matmul is only elementwise_add + * + * Notice: + * the shape and rank of input activation is obtained from var_desc, + * they maybe change in runtime. Therefore, the pass considers + * the above passes to reduce the impact on other models. + */ + +class TrtReshape2MatmulFusePass : public FusePassBase { + public: + TrtReshape2MatmulFusePass(); + virtual ~TrtReshape2MatmulFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +class TrtFlatten2MatmulFusePass : public FusePassBase { + public: + TrtFlatten2MatmulFusePass(); + virtual ~TrtFlatten2MatmulFusePass() {} + + protected: + void ApplyImpl(Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 57f90c7cc4a88..66b27b2903a70 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -90,12 +90,12 @@ const std::vector kTRTSubgraphPasses({ "skip_layernorm_fuse_pass", // "conv_bn_fuse_pass", // "unsqueeze2_eltwise_fuse_pass", // - "squeeze2_matmul_fuse_pass", // - "reshape2_matmul_fuse_pass", // - "flatten2_matmul_fuse_pass", // - "map_matmul_v2_to_mul_pass", // - "map_matmul_v2_to_matmul_pass", // - "map_matmul_to_mul_pass", // + "trt_squeeze2_matmul_fuse_pass", // + "trt_reshape2_matmul_fuse_pass", // + "trt_flatten2_matmul_fuse_pass", // + "trt_map_matmul_v2_to_mul_pass", // + "trt_map_matmul_v2_to_matmul_pass", // + "trt_map_matmul_to_mul_pass", // "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // "add_support_int8_pass", @@ -140,12 +140,12 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "conv_eltwiseadd_bn_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", // "multihead_matmul_fuse_pass_v2", // - "squeeze2_matmul_fuse_pass", // - "reshape2_matmul_fuse_pass", // - "flatten2_matmul_fuse_pass", // - "map_matmul_v2_to_mul_pass", // - "map_matmul_v2_to_matmul_pass", // - "map_matmul_to_mul_pass", // + "gpu_cpu_squeeze2_matmul_fuse_pass", // + "gpu_cpu_reshape2_matmul_fuse_pass", // + "gpu_cpu_flatten2_matmul_fuse_pass", // + "gpu_cpu_map_matmul_v2_to_mul_pass", // + "gpu_cpu_map_matmul_v2_to_matmul_pass", // + "gpu_cpu_map_matmul_to_mul_pass", // "fc_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be @@ -202,14 +202,14 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { "fc_gru_fuse_pass", // "mul_gru_fuse_pass", // "seq_concat_fc_fuse_pass", // - "squeeze2_matmul_fuse_pass", // - "reshape2_matmul_fuse_pass", // - "flatten2_matmul_fuse_pass", // + "gpu_cpu_squeeze2_matmul_fuse_pass", // + "gpu_cpu_reshape2_matmul_fuse_pass", // + "gpu_cpu_flatten2_matmul_fuse_pass", // "matmul_v2_scale_fuse_pass", // - "map_matmul_v2_to_mul_pass", // - "map_matmul_v2_to_matmul_pass", // + "gpu_cpu_map_matmul_v2_to_mul_pass", // + "gpu_cpu_map_matmul_v2_to_matmul_pass", // "matmul_scale_fuse_pass", // - "map_matmul_to_mul_pass", // + "gpu_cpu_map_matmul_to_mul_pass", // "fc_fuse_pass", // "repeated_fc_relu_fuse_pass", // "squared_mat_sub_fuse_pass", // diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index b4455259b7669..bdea14c9e9f89 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -67,13 +67,7 @@ class FcOpConverter : public OpConverter { nvinfer1::Dims x_dim, int x_num_col_dims) { // add shuffle after fc nvinfer1::Dims reshape_after_fc_dim; - if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && - x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 1) { - // If use tensorrt'oss, the x_dim and x_num_col_dims need change - reshape_after_fc_dim.nbDims = 4; - } else { - reshape_after_fc_dim.nbDims = x_num_col_dims + 1; - } + reshape_after_fc_dim.nbDims = x_num_col_dims + 1; for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) { reshape_after_fc_dim.d[i] = 0; } @@ -141,7 +135,6 @@ class FcOpConverter : public OpConverter { "The fc's weight should be a matrix with 2 dims, but " "it's %d-dimensional.", Y_t->dims().size())); // a matrix - size_t n_output = Y_t->dims()[1]; int m = Y_t->dims()[0]; int n = Y_t->dims()[1]; auto tranpose_weight = [](const float* src, float* dst, int m, int n) { @@ -175,9 +168,10 @@ class FcOpConverter : public OpConverter { fc_layer_int8->getOutput(0), x_dim, x_num_col_dims); if (activation_type == "relu") { fc_after_reshape_int8->setName( - ("fc_op_int8_reshape_after_fc: Shuffle (Output: " + output_name + - ")") + ("int8_reshape_after_fc: Shuffle (Output: " + output_name + ")") .c_str()); + engine_->SetTensorDynamicRange(fc_after_reshape_int8->getOutput(0), + out_scale); nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( engine_, Activation, *(fc_after_reshape_int8->getOutput(0)), nvinfer1::ActivationType::kRELU); @@ -200,8 +194,7 @@ class FcOpConverter : public OpConverter { fc_layer_float->getOutput(0), x_dim, x_num_col_dims); if (activation_type == "relu") { fc_after_reshape_float->setName( - ("fc_op_float_reshape_after_fc: Shuffle (Output: " + output_name + - ")") + ("float_reshape_after_fc: Shuffle (Output: " + output_name + ")") .c_str()); nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER( engine_, Activation, *(fc_after_reshape_float->getOutput(0)), @@ -215,14 +208,28 @@ class FcOpConverter : public OpConverter { } }; - std::vector weight_data_tmp; - weight_data_tmp.reserve(Y_t->numel()); - memcpy(weight_data_tmp.data(), weight_data, Y_t->numel() * sizeof(float)); - tranpose_weight(weight_data_tmp.data(), weight_data, m, n); + bool transpose_y = false; + if (op_desc.HasAttr("transpose_Y")) { + transpose_y = BOOST_GET_CONST(bool, op_desc.GetAttr("transpose_Y")); + } + int weight_w, weight_h; + if (!transpose_y) { + std::vector weight_data_tmp; + weight_data_tmp.reserve(Y_t->numel()); + memcpy(weight_data_tmp.data(), weight_data, Y_t->numel() * sizeof(float)); + tranpose_weight(weight_data_tmp.data(), weight_data, m, n); + weight_w = n; + weight_h = m; + } else { + weight_w = m; + weight_h = n; + } + size_t n_output = weight_w; TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, static_cast(weight_data), static_cast(Y_t->numel())}; - weight.dims.assign({n, m}); + weight.dims.assign({weight_w, weight_h}); + float* bias_data = nullptr; int bias_num = 0; if (with_bias) { @@ -240,25 +247,72 @@ class FcOpConverter : public OpConverter { if (!engine_->with_dynamic_shape()) { x_num_col_dims--; } - // If use tensorrt'oss, the x_dim and x_num_col_dims need change + // If use tensorrt'oss, the x_dim and x_num_col_dims need change, and can + // not add Shuffle layer in ernie's multihead. if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && - x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 2) { - x_num_col_dims = 1; - } - PADDLE_ENFORCE_GT( - x_dim.nbDims, x_num_col_dims, - platform::errors::InvalidArgument( - "Params and input dims mismatch. Paddle-TRT FC " - "converter expects x_dim.nbDims > x_num_col_dims, but " - "x_dim.nbDims : %d, x_num_col_dims : %d.", - x_dim.nbDims, x_num_col_dims)); - auto* reshape_before_fc_layer = - reshape_before_fc(X, x_dim, x_num_col_dims, output_name); - auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); - if (enable_int8) { - engine_->SetTensorDynamicRange(reshape_itensor, in_scale); + x_dim.d[3] == 1 && x_num_col_dims == 2) { + if (enable_int8) { + // add conv1x1 layer + nvinfer1::DimsHW nv_ksize(1, 1); + auto* fc_layer_int8 = + TRT_ENGINE_ADD_LAYER(engine_, Convolution, *X, n_output, nv_ksize, + weight.get(), bias.get()); + if (activation_type == "relu") { + fc_layer_int8->setName( + ("ernie_fc_op_int8: Convolution (Output: " + output_name + ")") + .c_str()); + PADDLE_ENFORCE_EQ( + op_desc.HasAttr("out_threshold"), true, + platform::errors::InvalidArgument( + "must have out threshold in fc layers in int8 mode")); + float out_scale = + BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); + engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), + out_scale); + nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *(fc_layer_int8->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer_int8, "relu_after_ernie_fc_int8", + {output_name}, test_mode); + } else { + RreplenishLayerAndOutput(fc_layer_int8, + "ernie_fc_op_int8: Convolution", + {output_name}, test_mode); + } + } else { + // add fc layer + auto* fc_layer_float = TRT_ENGINE_ADD_LAYER( + engine_, FullyConnected, *X, n_output, weight.get(), bias.get()); + if (activation_type == "relu") { + fc_layer_float->setName( + ("ernie_fc_op_float: (Output: " + output_name + ")").c_str()); + nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *(fc_layer_float->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer_float, + "relu_after_ernie_fc_float", {output_name}, + test_mode); + } else { + RreplenishLayerAndOutput(fc_layer_float, "ernie_fc_op_float", + {output_name}, test_mode); + } + } + } else { // need reshape input before and after fc + PADDLE_ENFORCE_GT( + x_dim.nbDims, x_num_col_dims, + platform::errors::InvalidArgument( + "Params and input dims mismatch. Paddle-TRT FC " + "converter expects x_dim.nbDims > x_num_col_dims, but " + "x_dim.nbDims : %d, x_num_col_dims : %d.", + x_dim.nbDims, x_num_col_dims)); + auto* reshape_before_fc_layer = + reshape_before_fc(X, x_dim, x_num_col_dims, output_name); + auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); + if (enable_int8) { + engine_->SetTensorDynamicRange(reshape_itensor, in_scale); + } + regist_fc(reshape_itensor, n_output, weight, bias); } - regist_fc(reshape_itensor, n_output, weight, bias); } }; diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 92a335a73dc85..f637b4cbd6c3e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -410,16 +410,16 @@ def _optimize_fp32_graph(self, graph): graph = self._apply_pass(graph, 'multi_gru_fuse_pass') graph = self._apply_pass(graph, 'multi_gru_seq_fuse_pass') graph = self._apply_pass(graph, 'seq_concat_fc_fuse_pass') - graph = self._apply_pass(graph, 'squeeze2_matmul_fuse_pass') - graph = self._apply_pass(graph, 'reshape2_matmul_fuse_pass') - graph = self._apply_pass(graph, 'flatten2_matmul_fuse_pass') + graph = self._apply_pass(graph, 'gpu_cpu_squeeze2_matmul_fuse_pass') + graph = self._apply_pass(graph, 'gpu_cpu_reshape2_matmul_fuse_pass') + graph = self._apply_pass(graph, 'gpu_cpu_flatten2_matmul_fuse_pass') graph = self._apply_pass(graph, 'matmul_v2_scale_fuse_pass') graph = self._apply_pass(graph, 'squared_mat_sub_fuse_pass') graph = self._apply_pass(graph, 'is_test_pass') - graph = self._apply_pass(graph, 'map_matmul_v2_to_mul_pass') - graph = self._apply_pass(graph, 'map_matmul_v2_to_matmul_pass') + graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_mul_pass') + graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_v2_to_matmul_pass') graph = self._apply_pass(graph, 'matmul_scale_fuse_pass') - graph = self._apply_pass(graph, 'map_matmul_to_mul_pass') + graph = self._apply_pass(graph, 'gpu_cpu_map_matmul_to_mul_pass') graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') graph = self._apply_pass(graph, 'mkldnn_placement_pass', ['mkldnn_enabled_op_types'], [set()]) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_flatten2_matmul_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_flatten2_matmul_fuse_pass.py index 6cd9ae970bb58..ec3bc0287323d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_flatten2_matmul_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_flatten2_matmul_fuse_pass.py @@ -174,7 +174,7 @@ def test(self): quant=False, max_examples=50, max_duration=1000, - passes=["flatten2_matmul_fuse_pass"]) + passes=["gpu_cpu_flatten2_matmul_fuse_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_to_mul_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_to_mul_pass.py index 810603a4e4732..ce695ec2f01bf 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_to_mul_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_to_mul_pass.py @@ -116,7 +116,7 @@ def test(self): self.run_and_statis( quant=False, max_examples=100, - passes=["map_matmul_to_mul_pass"], + passes=["gpu_cpu_map_matmul_to_mul_pass"], max_duration=180) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_v2_to_matmul_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_v2_to_matmul_pass.py index 915644f46e486..fac8b710c8ca4 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_v2_to_matmul_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_v2_to_matmul_pass.py @@ -127,7 +127,7 @@ def test(self): self.run_and_statis( quant=False, max_examples=100, - passes=["map_matmul_v2_to_matmul_pass"]) + passes=["gpu_cpu_map_matmul_v2_to_matmul_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_v2_to_mul_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_v2_to_mul_pass.py index cc2c1ab81bb2a..e8a37ebc7ea09 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_v2_to_mul_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_map_matmul_v2_to_mul_pass.py @@ -110,8 +110,9 @@ def sample_program_config(self, draw): def test(self): self.run_and_statis( - quant=False, max_examples=100, - passes=["map_matmul_v2_to_mul_pass"]) + quant=False, + max_examples=100, + passes=["gpu_cpu_map_matmul_v2_to_mul_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py index ffdc84b8bd9ff..3c6560b3b2911 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_v2_transpose_reshape_fuse_pass.py @@ -132,7 +132,7 @@ def generate_input(type): return program_config def sample_predictor_configs(self, program_config): - # map_matmul_v2_to_matmul_pass will affect the type of final fused op + # gpu_cpu_map_matmul_v2_to_matmul_pass will affect the type of final fused op fused_op = "matmul_v2" input1_dim1 = program_config.inputs["input_data1"].shape[0] input2_dim1 = program_config.inputs["input_data2"].shape[0] diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_reshape2_matmul_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_reshape2_matmul_fuse_pass.py index 951ec8e4e8ef4..6f311ab11fefd 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_reshape2_matmul_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_reshape2_matmul_fuse_pass.py @@ -172,7 +172,7 @@ def test(self): quant=False, max_examples=50, max_duration=1000, - passes=["reshape2_matmul_fuse_pass"]) + passes=["gpu_cpu_reshape2_matmul_fuse_pass"]) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_squeeze2_matmul_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_squeeze2_matmul_fuse_pass.py index 605dc4edbe8c6..9600ef7e0d109 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_squeeze2_matmul_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_squeeze2_matmul_fuse_pass.py @@ -180,7 +180,7 @@ def test(self): quant=False, max_examples=50, max_duration=1000, - passes=["squeeze2_matmul_fuse_pass"]) + passes=["gpu_cpu_squeeze2_matmul_fuse_pass"]) if __name__ == "__main__":