Skip to content

Commit

Permalink
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
Browse files Browse the repository at this point in the history
… dev/bf16_op_2
  • Loading branch information
zhangbo9674 committed Feb 9, 2022
2 parents a33373e + 87f4a68 commit 9e9bd1a
Show file tree
Hide file tree
Showing 22 changed files with 1,482 additions and 226 deletions.
7 changes: 6 additions & 1 deletion paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ pass_library(graph_to_program_pass base)
pass_library(graph_viz_pass base)
pass_library(lock_free_optimize_pass base DEPS string_helper)
pass_library(fc_fuse_pass inference)
pass_library(map_matmul_to_mul_pass inference)
pass_library(attention_lstm_fuse_pass inference)
pass_library(fc_lstm_fuse_pass inference)
pass_library(embedding_fc_lstm_fuse_pass inference)
Expand Down Expand Up @@ -98,8 +97,14 @@ pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)

if(WITH_TENSORRT)
pass_library(trt_map_matmul_to_mul_pass inference)
endif()

if(WITH_GPU OR WITH_ROCM)
pass_library(cudnn_placement_pass base DEPS placement_pass_base)
pass_library(embedding_eltwise_layernorm_fuse_pass inference)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/ir/map_matmul_to_mul_pass.h"
#include "paddle/fluid/framework/ir/gpu_cpu_map_matmul_to_mul_pass.h"

#include <cmath>
#include <string>
Expand All @@ -28,7 +28,7 @@ namespace ir {

class Node;

MapMatmul2MulPass::MapMatmul2MulPass() {
GpuCpuMapMatmul2MulPass::GpuCpuMapMatmul2MulPass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -68,7 +68,7 @@ MapMatmul2MulPass::MapMatmul2MulPass() {
.End();
}

MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() {
GpuCpuMapMatmulV2ToMulPass::GpuCpuMapMatmulV2ToMulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -104,7 +104,7 @@ MapMatmulV2ToMulPass::MapMatmulV2ToMulPass() {
.End();
}

MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() {
GpuCpuMapMatmulV2ToMatmulPass::GpuCpuMapMatmulV2ToMatmulPass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -143,7 +143,7 @@ MapMatmulV2ToMatmulPass::MapMatmulV2ToMatmulPass() {
.End();
}

Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
GpuCpuFlatten2MatmulFusePass::GpuCpuFlatten2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -197,7 +197,7 @@ Flatten2MatmulFusePass::Flatten2MatmulFusePass() {
.End();
}

Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
GpuCpuSqueeze2MatmulFusePass::GpuCpuSqueeze2MatmulFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -251,10 +251,10 @@ Squeeze2MatmulFusePass::Squeeze2MatmulFusePass() {
.End();
}

void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_to_mul_pass";
std::string name_scope = "gpu_cpu_map_matmul_to_mul_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -264,7 +264,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul to mul";
VLOG(4) << "gpu_cpu map matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_pattern);
Expand All @@ -286,7 +286,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {

if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmul2MulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmul2MulPass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
Expand All @@ -311,7 +311,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;

if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmul2MulPass in out mul op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmul2MulPass in out mul op compat failed.";
return;
}
}
Expand All @@ -321,10 +321,10 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_mul_pass";
std::string name_scope = "gpu_cpu_map_matmul_v2_to_mul_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -335,7 +335,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(3) << "map matmul_v2 to mul";
VLOG(3) << "gpu_cpu map matmul_v2 to mul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_weight_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
Expand All @@ -360,7 +360,7 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {

if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmulV2ToMulPass in op compat failed.";
return;
}
OpDesc desc(matmul_v2_op->Op()->Block());
Expand All @@ -386,7 +386,8 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;

if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMulPass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuMapMatmulV2ToMulPass in out mul op compat failed.";
return;
}
}
Expand All @@ -396,10 +397,10 @@ void MapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "map_matmul_v2_to_matmul_pass";
std::string name_scope = "gpu_cpu_map_matmul_v2_to_matmul_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -409,15 +410,15 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "map matmul_v2 to matmul";
VLOG(4) << "gpu_cpu map matmul_v2 to matmul";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op, matmul_v2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_out, matmul_v2_out, matmul_v2_pattern);
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in op compat failed.";
LOG(WARNING) << "GpuCpuMapMatmulV2ToMatmulPass in op compat failed.";
return;
}

Expand Down Expand Up @@ -463,7 +464,8 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
++found_count;

if (!IsCompat(desc)) {
LOG(WARNING) << "MapMatmulV2ToMatmulPass in out matmul op compat failed.";
LOG(WARNING)
<< "GpuCpuMapMatmulV2ToMatmulPass in out matmul op compat failed.";
return;
}
};
Expand All @@ -472,10 +474,10 @@ void MapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "squeeze2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_squeeze2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -485,7 +487,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse squeeze2+matmul to mul";
VLOG(4) << "gpu_cpu fuse squeeze2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_in_x, squeeze2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(squeeze2_op, squeeze2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
Expand Down Expand Up @@ -518,7 +520,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {

if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuSqueeze2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
Expand All @@ -542,7 +544,8 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
GraphSafeRemoveNodes(graph, {squeeze2_op, matmul_in_x, matmul_op});
++found_count;
if (!IsCompat(desc)) {
LOG(WARNING) << "Squeeze2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuSqueeze2MatmulFusePass in out mul op compat failed.";
return;
}
}
Expand All @@ -552,7 +555,7 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
GpuCpuReshape2MatmulFusePass::GpuCpuReshape2MatmulFusePass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
Expand Down Expand Up @@ -614,10 +617,10 @@ Reshape2MatmulFusePass::Reshape2MatmulFusePass() {
.End();
}

void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "reshape2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_reshape2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -627,7 +630,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse reshape2+matmul to mul";
VLOG(4) << "gpu_cpu fuse reshape2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(reshape2_in_x, reshape2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
Expand Down Expand Up @@ -662,7 +665,7 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {

if (flag) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Reshape2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuReshape2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
Expand All @@ -680,7 +683,8 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
matmul_op->Op()->GetAttr("out_threshold"));
}
if (!IsCompat(desc)) {
LOG(WARNING) << "Reshape2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuReshape2MatmulFusePass in out mul op compat failed.";
return;
}
auto mul_node = g->CreateOpNode(&desc);
Expand All @@ -696,10 +700,10 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count);
}

void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "flatten2_matmul_fuse_pass";
std::string name_scope = "gpu_cpu_flatten2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);

GraphPatternDetector gpd;
Expand All @@ -709,7 +713,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse flatten2+matmul to mul";
VLOG(4) << "gpu_cpu fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
Expand Down Expand Up @@ -749,7 +753,7 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {

if (pattern_found) {
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "Flatten2MatmulFusePass in op compat failed.";
LOG(WARNING) << "GpuCpuFlatten2MatmulFusePass in op compat failed.";
return;
}
OpDesc desc(matmul_op->Op()->Block());
Expand All @@ -774,7 +778,8 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
++found_count;

if (!IsCompat(desc)) {
LOG(WARNING) << "Flatten2MatmulFusePass in out mul op compat failed.";
LOG(WARNING)
<< "GpuCpuFlatten2MatmulFusePass in out mul op compat failed.";
return;
}
}
Expand All @@ -788,50 +793,51 @@ void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
} // namespace framework
} // namespace paddle

REGISTER_PASS(map_matmul_to_mul_pass, paddle::framework::ir::MapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(map_matmul_to_mul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_to_mul_pass,
paddle::framework::ir::GpuCpuMapMatmul2MulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("mul", 0));

REGISTER_PASS(map_matmul_v2_to_mul_pass,
paddle::framework::ir::MapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_mul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_v2_to_mul_pass,
paddle::framework::ir::GpuCpuMapMatmulV2ToMulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_mul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("mul", 0));

REGISTER_PASS(map_matmul_v2_to_matmul_pass,
paddle::framework::ir::MapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(map_matmul_v2_to_matmul_pass)
REGISTER_PASS(gpu_cpu_map_matmul_v2_to_matmul_pass,
paddle::framework::ir::GpuCpuMapMatmulV2ToMatmulPass);
REGISTER_PASS_CAPABILITY(gpu_cpu_map_matmul_v2_to_matmul_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.LE("matmul", 1));

REGISTER_PASS(squeeze2_matmul_fuse_pass,
paddle::framework::ir::Squeeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(squeeze2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_squeeze2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuSqueeze2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_squeeze2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("squeeze2", 0)
.EQ("mul", 0));

REGISTER_PASS(reshape2_matmul_fuse_pass,
paddle::framework::ir::Reshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_reshape2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuReshape2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_reshape2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("reshape2", 0)
.EQ("mul", 0));

REGISTER_PASS(flatten2_matmul_fuse_pass,
paddle::framework::ir::Flatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(flatten2_matmul_fuse_pass)
REGISTER_PASS(gpu_cpu_flatten2_matmul_fuse_pass,
paddle::framework::ir::GpuCpuFlatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(gpu_cpu_flatten2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
Expand Down
Loading

1 comment on commit 9e9bd1a

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.