diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 7affd59de162d..9e8c81c2985b7 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -293,11 +293,11 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op") - - if ("${TARGET}" STREQUAL "${manual_pybind_op}") - set(pybind_flag 1) - endif() - endforeach() + + if ("${TARGET}" STREQUAL "${manual_pybind_op}") + set(pybind_flag 1) + endif() + endforeach() # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h. # Note that it's enough to just adding one operator to pybind in a *_op.cc file. diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 66dfb81755f1c..948eaab40b4f6 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -139,7 +139,7 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass - fix_op_run_order_pass) + fix_op_run_order_pass fuse_gemm_epilogue_pass) if (WITH_CINN) set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index c99200ec98aa8..fdf74d2f769fc 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -1,4 +1,5 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -175,6 +176,11 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { !defined(_WIN32) && !defined(__APPLE__) AppendPassWithCheck(strategy_.enable_auto_fusion_, "fusion_group_pass"); #endif + +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) + AppendPassWithCheck(strategy_.fuse_gemm_epilogue_, + "fuse_gemm_epilogue_pass"); +#endif AppendPassWithCheck(strategy_.fuse_elewise_add_act_ops_, "fuse_elewise_add_act_pass"); // for single card training, fuse_all_reduce_ops is unnecessary. @@ -507,3 +513,6 @@ USE_PASS(mkldnn_placement_pass); !defined(_WIN32) && !defined(__APPLE__) USE_PASS(fusion_group_pass); #endif +#if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) +USE_PASS(fuse_gemm_epilogue_pass); +#endif diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 70a083dd70bc3..5eb584aaefa98 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -1,4 +1,5 @@ // Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -124,6 +125,8 @@ struct BuildStrategy { paddle::optional fuse_broadcast_ops_{paddle::none}; // replace batch_norm with sync_batch_norm. bool sync_batch_norm_{false}; + // Fuse GEMM+Epilogue via cublasLt epilogue. + bool fuse_gemm_epilogue_{false}; // mkldnn_enabled_op_types specify the operator type list to // use MKLDNN acceleration. It is null in default, means diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index dad5358590cb1..3bf426c13bfda 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -158,6 +158,7 @@ endif() cc_library(fuse_bn_act_pass SRCS fuse_bn_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_bn_add_act_pass SRCS fuse_bn_add_act_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector ) +cc_library(fuse_gemm_epilogue_pass SRCS fuse_gemm_epilogue_pass.cc DEPS pass graph_pattern_detector ) cc_library(fuse_relu_depthwise_conv_pass SRCS fuse_relu_depthwise_conv_pass.cc DEPS pass graph_pattern_detector ) set(GLOB_PASS_LIB ${PASS_LIBRARY} CACHE INTERNAL "Global PASS library") diff --git a/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc b/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc new file mode 100644 index 0000000000000..f48224cbdc24f --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.cc @@ -0,0 +1,471 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.h" +#include +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void FuseGemmEpiloguePass::ApplyImpl(ir::Graph *graph) const { + EpiloguePassActivationCache cache; + + graph = FuseLinearActFwd(graph, {"relu", "gelu"}, false, false, &cache); + graph = FuseLinearActFwd(graph, {"relu"}, true, true, &cache); + graph = FuseLinearActFwd(graph, {"gelu"}, true, false, &cache); + graph = FuseLinearFwd(graph, false); + graph = FuseLinearFwd(graph, true); + graph = FuseLinearActBwd(graph, {"relu_grad"}, true, &cache); + graph = FuseLinearActBwd(graph, {"gelu_grad"}, false, &cache); + graph = FuseLinearBwd(graph, false); + graph = FuseLinearBwd(graph, true); +} + +ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph, + bool is_training) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + const std::string scope_name("gemm_epilogue"); + FusePassBase::Init(scope_name, graph); + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(scope_name, "x")) + ->AsInput() + ->assert_is_op_input("matmul_v2", "X"); + patterns::LinearAct linear_act_pattern(gpd.mutable_pattern(), "linear_act"); + + linear_act_pattern(x, {}, is_training, false); + + int found_linear_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle LinearAct fuse"; + + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_w, matmul_w, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_add_op, ele_add, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_bias, ele_bias, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out, linear_act_pattern); + + std::vector matmul_x_shape = subgraph.at(x)->Var()->GetShape(); + std::vector matmul_w_shape = matmul_w->Var()->GetShape(); + + // Note (Ming Huang): We only support matmul_v2 from paddle.nn.Linear + // currently. The conditions below are used to verify wether matmul_v2 + // is created by paddle.nn.Linear + auto matmul_op_desc = matmul_op->Op(); + if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape, matmul_op_desc)) + return; + + OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block()); + std::string activation = "none"; + fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue"); + fused_gemm_epilogue_op_desc.SetInput("X", {subgraph.at(x)->Name()}); + fused_gemm_epilogue_op_desc.SetInput("Y", {matmul_w->Name()}); + fused_gemm_epilogue_op_desc.SetInput("Bias", {ele_bias->Name()}); + fused_gemm_epilogue_op_desc.SetOutput("Out", {ele_out->Name()}); + fused_gemm_epilogue_op_desc.SetAttr("activation", activation); + fused_gemm_epilogue_op_desc.SetAttr("op_role", + matmul_op_desc->GetAttr("op_role")); + auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc); + + IR_NODE_LINK_TO(subgraph.at(x), gemm_epilogue_node); + IR_NODE_LINK_TO(matmul_w, gemm_epilogue_node); + IR_NODE_LINK_TO(ele_bias, gemm_epilogue_node); + IR_NODE_LINK_TO(gemm_epilogue_node, ele_out); + + GraphSafeRemoveNodes(g, {matmul_op, matmul_out, ele_add_op}); + + VLOG(4) << "\n\t " << subgraph.at(x)->Name() << " and " << matmul_w->Name() + << " -> " << matmul_op->Name() << " -> " << matmul_out->Name() + << "\n\t " << matmul_out->Name() << " and " << ele_bias->Name() + << " -> " << ele_add_op->Name() << " -> " << ele_out->Name() + << "\n\t " << ele_out->Name(); + found_linear_count++; + }; + + gpd(graph, handler); + + AddStatis(found_linear_count); + return graph; +} + +ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd( + ir::Graph *graph, const std::unordered_set &act_types, + bool is_training, bool is_act_grad_x_from_act, + EpiloguePassActivationCache *cache) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + + const std::string scope_name("gemm_epilogue"); + FusePassBase::Init(scope_name, graph); + + GraphPatternDetector gpd; + auto *x = gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(scope_name, "x")) + ->AsInput() + ->assert_is_op_input("matmul_v2", "X"); + patterns::LinearAct linear_act_pattern(gpd.mutable_pattern(), "linear_act"); + + linear_act_pattern(x, act_types, is_training, is_act_grad_x_from_act); + + int found_linear_act_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle LinearAct fuse"; + + GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_w, matmul_w, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_add_op, ele_add, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_bias, ele_bias, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_out, elewise_add_out, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_op, act, linear_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_out, act_out, linear_act_pattern); + + std::vector matmul_x_shape = subgraph.at(x)->Var()->GetShape(); + std::vector matmul_w_shape = matmul_w->Var()->GetShape(); + + // Note (Ming Huang): We only support matmul_v2 from paddle.nn.Linear + // currently. The conditions below are used to verify wether matmul_v2 + // is created by paddle.nn.Linear + auto matmul_op_desc = matmul_op->Op(); + if (!IsGemmFromLinear_(matmul_x_shape, matmul_w_shape, matmul_op_desc)) + return; + + auto activation = act_op->Op()->Type(); + + OpDesc fused_gemm_epilogue_op_desc(matmul_op->Op()->Block()); + fused_gemm_epilogue_op_desc.SetType("fused_gemm_epilogue"); + fused_gemm_epilogue_op_desc.SetInput("X", {subgraph.at(x)->Name()}); + fused_gemm_epilogue_op_desc.SetInput("Y", {matmul_w->Name()}); + fused_gemm_epilogue_op_desc.SetInput("Bias", {ele_bias->Name()}); + fused_gemm_epilogue_op_desc.SetOutput("Out", {act_out->Name()}); + fused_gemm_epilogue_op_desc.SetAttr("activation", activation); + fused_gemm_epilogue_op_desc.SetAttr("op_role", + matmul_op_desc->GetAttr("op_role")); + + auto gemm_epilogue_node = g->CreateOpNode(&fused_gemm_epilogue_op_desc); + + IR_NODE_LINK_TO(subgraph.at(x), gemm_epilogue_node); + IR_NODE_LINK_TO(matmul_w, gemm_epilogue_node); + IR_NODE_LINK_TO(ele_bias, gemm_epilogue_node); + IR_NODE_LINK_TO(gemm_epilogue_node, act_out); + + // Only need to check weight.shape[1] for auxiliary pointer + // and mark it the act op is fused for backward epilogue fusion. + // That because cuBlasLt epilogue's restriction. + if (is_training) { + int divisor_of_n = activation == "relu" ? 128 : 8; + if (matmul_w_shape[1] % divisor_of_n) return; + + VarDesc reserve_space(patterns::PDNodeName(scope_name, "ReserveSpace")); + auto *reserve_space_node = g->CreateVarNode(&reserve_space); + + cache->InsertFusedActivation( + GetReserveSpaceCacheKey(act_out->Var()->Name(), g->GetBlockId()), + reserve_space_node); + + gemm_epilogue_node->Op()->SetOutput("ReserveSpace", + {reserve_space_node->Name()}); + + if (!is_act_grad_x_from_act) { + GET_IR_NODE_FROM_SUBGRAPH(act_grad_op, act_grad, linear_act_pattern); + act_grad_op->Op()->RenameInput(ele_out->Name(), + reserve_space_node->Name()); + IR_NODE_LINK_TO(reserve_space_node, act_grad_op); + } + IR_NODE_LINK_TO(gemm_epilogue_node, reserve_space_node); + } + + GraphSafeRemoveNodes(g, + {matmul_op, matmul_out, ele_add_op, ele_out, act_op}); + + VLOG(4) << "\n\t " << subgraph.at(x)->Name() << " and " << matmul_w->Name() + << " -> " << matmul_op->Name() << " -> " << matmul_out->Name() + << "\n\t " << matmul_out->Name() << " and " << ele_bias->Name() + << " -> " << ele_add_op->Name() << " -> " << ele_out->Name() + << "\n\t " << ele_out->Name() << " -> " << act_op->Name() << " -> " + << act_out->Name(); + found_linear_act_count++; + }; + + gpd(graph, handler); + + AddStatis(found_linear_act_count); + return graph; +} + +ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, + bool without_x_gradient) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + const std::string scope_name("gemm_epilogue"); + FusePassBase::Init(scope_name, graph); + + GraphPatternDetector gpd; + auto *dout = + gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(scope_name, "dout")) + ->AsInput() + ->assert_is_op_input("elementwise_add_grad", GradVarName("Out")); + + patterns::ElewiseAddMatmulAct ele_add_matmul_act_pattern( + gpd.mutable_pattern(), "ele_add_matmul_act"); + ele_add_matmul_act_pattern(dout, {}, without_x_gradient, false); + + int found_ele_add_matmul_act_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle ElewiseAddMatmulAct fuse"; + + GET_IR_NODE_FROM_SUBGRAPH(ele_add_grad_op, ele_add_grad, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_grad_bias, ele_grad_bias, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_grad_dx, ele_grad_dx, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_grad_dbias, ele_grad_dbias, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_grad_op, matmul_grad, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_grad_x, matmul_grad_x, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_grad_w, matmul_grad_w, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_grad_dw, matmul_grad_dw, + ele_add_matmul_act_pattern); + + Node *matmul_grad_dx = nullptr; + if (!without_x_gradient) { + GET_IR_NODE_FROM_SUBGRAPH(matmul_grad_dx_ptr, matmul_grad_dx, + ele_add_matmul_act_pattern); + matmul_grad_dx = matmul_grad_dx_ptr; + } + + std::vector matmul_grad_x_shape = matmul_grad_x->Var()->GetShape(); + std::vector matmul_grad_w_shape = matmul_grad_w->Var()->GetShape(); + + // Note (Ming Huang): We only support matmul_v2_grad from paddle.nn.Linear + // currently. The conditions below are used to verify wether matmul_v2 + // is created by paddle.nn.Linear + auto matmul_grad_op_desc = matmul_grad_op->Op(); + if (!IsGemmFromLinear_(matmul_grad_x_shape, matmul_grad_w_shape, + matmul_grad_op_desc)) + return; + + OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block()); + std::string activation_grad = "none"; + fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad"); + fused_gemm_epilogue_grad_op_desc.SetInput("DOut", + {subgraph.at(dout)->Name()}); + fused_gemm_epilogue_grad_op_desc.SetInput("X", {matmul_grad_x->Name()}); + fused_gemm_epilogue_grad_op_desc.SetInput("Y", {matmul_grad_w->Name()}); + if (matmul_grad_dx) { + fused_gemm_epilogue_grad_op_desc.SetOutput("DX", + {matmul_grad_dx->Name()}); + } + fused_gemm_epilogue_grad_op_desc.SetOutput("DY", {matmul_grad_dw->Name()}); + fused_gemm_epilogue_grad_op_desc.SetOutput("DBias", + {ele_grad_dbias->Name()}); + fused_gemm_epilogue_grad_op_desc.SetAttr("activation_grad", + activation_grad); + fused_gemm_epilogue_grad_op_desc.SetAttr( + "op_role", matmul_grad_op_desc->GetAttr("op_role")); + + auto gemm_epilogue_grad_node = + g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc); + + IR_NODE_LINK_TO(subgraph.at(dout), gemm_epilogue_grad_node); + IR_NODE_LINK_TO(matmul_grad_x, gemm_epilogue_grad_node); + IR_NODE_LINK_TO(matmul_grad_w, gemm_epilogue_grad_node); + IR_NODE_LINK_TO(gemm_epilogue_grad_node, matmul_grad_dw); + IR_NODE_LINK_TO(gemm_epilogue_grad_node, ele_grad_dbias); + if (matmul_grad_dx) { + IR_NODE_LINK_TO(gemm_epilogue_grad_node, matmul_grad_dx); + } + + GraphSafeRemoveNodes(g, {ele_add_grad_op, ele_grad_dx, matmul_grad_op}); + + std::string matmul_grad_dx_name = + matmul_grad_dx != nullptr ? matmul_grad_dx->Name() : " "; + VLOG(4) << "\n\t " << subgraph.at(dout)->Name() << " and " + << ele_grad_bias->Name() << " -> " << ele_add_grad_op->Name() + << " -> " << ele_grad_dx->Name() << " and " + << ele_grad_dbias->Name() << "\n\t " << ele_grad_dx->Name() << ", " + << matmul_grad_x->Name() << " and " << matmul_grad_w->Name() + << " -> " << matmul_grad_op->Name() << " -> " + << matmul_grad_w->Name() << " and " << matmul_grad_dx_name; + found_ele_add_matmul_act_count++; + }; + + gpd(graph, handler); + + AddStatis(found_ele_add_matmul_act_count); + return graph; +} + +ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( + ir::Graph *graph, const std::unordered_set &act_grad_types, + bool is_act_grad_x_from_act, EpiloguePassActivationCache *cache) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); + const std::string scope_name("gemm_epilogue"); + FusePassBase::Init(scope_name, graph); + + GraphPatternDetector gpd; + auto *dout = + gpd.mutable_pattern() + ->NewNode(patterns::PDNodeName(scope_name, "dout")) + ->AsInput() + ->assert_is_op_input("elementwise_add_grad", GradVarName("Out")); + + patterns::ElewiseAddMatmulAct ele_add_matmul_act_pattern( + gpd.mutable_pattern(), "ele_add_matmul_act"); + ele_add_matmul_act_pattern(dout, act_grad_types, false, + is_act_grad_x_from_act); + + int found_ele_add_matmul_act_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, + Graph *g) { + VLOG(4) << "handle ElewiseAddMatmulAct fuse"; + + GET_IR_NODE_FROM_SUBGRAPH(ele_add_grad_op, ele_add_grad, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_grad_bias, ele_grad_bias, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_grad_dx, ele_grad_dx, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(ele_grad_dbias, ele_grad_dbias, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_grad_op, matmul_grad, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_grad_x, matmul_grad_x, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_grad_w, matmul_grad_w, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_grad_dx, matmul_grad_dx, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(matmul_grad_dw, matmul_grad_dw, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_grad_op, act_grad, + ele_add_matmul_act_pattern); + GET_IR_NODE_FROM_SUBGRAPH(act_grad_dx, act_grad_dx, + ele_add_matmul_act_pattern); + + auto key = + GetReserveSpaceCacheKey(matmul_grad_x->Var()->Name(), g->GetBlockId()); + if (!cache->HasFusedActivation(key)) { + return; + } + auto *reserve_space_node = cache->GetFusedActivationSpace(key); + + std::vector matmul_grad_x_shape = matmul_grad_x->Var()->GetShape(); + std::vector matmul_grad_w_shape = matmul_grad_w->Var()->GetShape(); + + // Note (Ming Huang): We only support matmul_v2_grad from paddle.nn.Linear + // currently. The conditions below are used to verify wether matmul_v2 + // is created by paddle.nn.Linear + auto matmul_grad_op_desc = matmul_grad_op->Op(); + if (!IsGemmFromLinear_(matmul_grad_x_shape, matmul_grad_w_shape, + matmul_grad_op_desc)) + return; + + auto activation_grad = act_grad_op->Op()->Type(); + + OpDesc fused_gemm_epilogue_grad_op_desc(ele_add_grad_op->Op()->Block()); + fused_gemm_epilogue_grad_op_desc.SetType("fused_gemm_epilogue_grad"); + fused_gemm_epilogue_grad_op_desc.SetInput("DOut", + {subgraph.at(dout)->Name()}); + fused_gemm_epilogue_grad_op_desc.SetInput("X", {matmul_grad_x->Name()}); + fused_gemm_epilogue_grad_op_desc.SetInput("Y", {matmul_grad_w->Name()}); + fused_gemm_epilogue_grad_op_desc.SetInput("ReserveSpace", + {reserve_space_node->Name()}); + fused_gemm_epilogue_grad_op_desc.SetOutput("DX", {act_grad_dx->Name()}); + fused_gemm_epilogue_grad_op_desc.SetOutput("DY", {matmul_grad_dw->Name()}); + fused_gemm_epilogue_grad_op_desc.SetOutput("DBias", + {ele_grad_dbias->Name()}); + fused_gemm_epilogue_grad_op_desc.SetAttr("activation_grad", + activation_grad); + fused_gemm_epilogue_grad_op_desc.SetAttr( + "op_role", matmul_grad_op_desc->GetAttr("op_role")); + + auto gemm_epilogue_grad_node = + g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc); + + IR_NODE_LINK_TO(subgraph.at(dout), gemm_epilogue_grad_node); + IR_NODE_LINK_TO(matmul_grad_x, gemm_epilogue_grad_node); + IR_NODE_LINK_TO(matmul_grad_w, gemm_epilogue_grad_node); + IR_NODE_LINK_TO(gemm_epilogue_grad_node, act_grad_dx); + IR_NODE_LINK_TO(gemm_epilogue_grad_node, matmul_grad_dw); + IR_NODE_LINK_TO(gemm_epilogue_grad_node, ele_grad_dbias); + IR_NODE_LINK_TO(reserve_space_node, gemm_epilogue_grad_node); + + GraphSafeRemoveNodes(g, {ele_add_grad_op, ele_grad_dx, matmul_grad_op, + matmul_grad_dx, act_grad_op}); + + VLOG(4) << "\n\t " << subgraph.at(dout)->Name() << " and " + << ele_grad_bias->Name() << " -> " << ele_add_grad_op->Name() + << " -> " << ele_grad_dx->Name() << " and " + << ele_grad_dbias->Name() << "\n\t " << ele_grad_dx->Name() << ", " + << matmul_grad_x->Name() << " and " << matmul_grad_w->Name() + << " -> " << matmul_grad_op->Name() << " -> " + << matmul_grad_dx->Name() << " and " << matmul_grad_w->Name() + << "\n\t " << matmul_grad_dx->Name() << " -> " + << act_grad_op->Name() << " -> " << act_grad_dx->Name(); + found_ele_add_matmul_act_count++; + }; + + gpd(graph, handler); + + AddStatis(found_ele_add_matmul_act_count); + return graph; +} + +bool FuseGemmEpiloguePass::IsGemmFromLinear_( + const std::vector &x_shape, const std::vector &w_shape, + OpDesc *matmul_v2_op) const { + if (w_shape.size() != 2 || x_shape.size() < 2) return false; + for (auto attr_name : + {"fused_reshape_Out", "fused_reshape_X", "fused_reshape_Y", + "fused_transpose_Out", "fused_transpose_X", "fused_transpose_Y"}) { + if (matmul_v2_op->HasAttr(attr_name)) { + std::vector tmp_vec = + BOOST_GET_CONST(std::vector, matmul_v2_op->GetAttr(attr_name)); + if (tmp_vec.size() > 0) return false; + } + } + if (BOOST_GET_CONST(bool, matmul_v2_op->GetAttr("trans_x")) || + BOOST_GET_CONST(bool, matmul_v2_op->GetAttr("trans_y"))) + return false; + + return true; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(fuse_gemm_epilogue_pass, + paddle::framework::ir::FuseGemmEpiloguePass); diff --git a/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.h b/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.h new file mode 100644 index 0000000000000..575ffee73d60e --- /dev/null +++ b/paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.h @@ -0,0 +1,100 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * Fuse the ElewiseAdd and activation + */ +class Graph; +class Node; + +class EpiloguePassActivationCache { + public: + EpiloguePassActivationCache() {} + + EpiloguePassActivationCache(const EpiloguePassActivationCache &) = delete; + void operator=(const EpiloguePassActivationCache &) = delete; + + bool HasFusedActivation(const std::string &key) const { + return fused_activation_space_map_.count(key); + } + + ir::Node *GetFusedActivationSpace(const std::string &key) { + if (HasFusedActivation(key)) { + return fused_activation_space_map_.find(key)->second; + } + PADDLE_THROW(platform::errors::InvalidArgument( + "The key (%d) of EpiloguePassActivationCache does not exist.", key)); + } + + void InsertFusedActivation(const std::string &key, ir::Node *const value) { + if (!HasFusedActivation(key)) { + mtx.lock(); + fused_activation_space_map_.insert({key, value}); + mtx.unlock(); + } else { + PADDLE_THROW(platform::errors::AlreadyExists( + "The key (%d) of EpiloguePassActivationCache already exist.", key)); + } + } + + private: + std::unordered_map fused_activation_space_map_; + std::mutex mtx; +}; + +class FuseGemmEpiloguePass : public FusePassBase { + public: + virtual ~FuseGemmEpiloguePass() {} + + protected: + void ApplyImpl(ir::Graph *graph) const override; + + ir::Graph *FuseLinearFwd(ir::Graph *graph, bool is_training) const; + ir::Graph *FuseLinearActFwd(ir::Graph *graph, + const std::unordered_set &act_types, + bool is_training, bool is_act_grad_x_from_act, + EpiloguePassActivationCache *cache) const; + ir::Graph *FuseLinearBwd(ir::Graph *graph, bool without_x_gradient) const; + ir::Graph *FuseLinearActBwd( + ir::Graph *graph, const std::unordered_set &act_grad_types, + bool is_act_grad_x_from_act, EpiloguePassActivationCache *cache) const; + + private: + bool IsGemmFromLinear_(const std::vector &x_shape, + const std::vector &w_shape, + OpDesc *matmul_v2_op) const; + const std::string GetReserveSpaceCacheKey(const std::string var_name, + int block_id) const { + return std::to_string(block_id) + var_name; + } +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index e4c9dc72128f4..d7d866fa98bb5 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1461,31 +1461,6 @@ PDNode *patterns::BatchNormAddActGrad::operator()( return bn_grad; } -PDNode *patterns::ElewiseAddAct::operator()( - paddle::framework::ir::PDNode *ele_x_var, - std::unordered_set act_types) { - auto *ele_y_var = pattern->NewNode(ele_y_repr()) - ->assert_is_op_input("elementwise_add", "Y"); - - auto *ele_add = - pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add"); - - auto *ele_out_var = pattern->NewNode(elewise_add_out_repr()) - ->assert_is_op_output("elementwise_add", "Out"); - - ele_out_var->AsIntermediate()->assert_is_ops_input(act_types); - - auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types); - - auto *act_out_var = - pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out"); - - ele_add->LinksFrom({ele_x_var, ele_y_var}).LinksTo({ele_out_var}); - act->LinksFrom({ele_out_var}).LinksTo({act_out_var}); - - return act_out_var; -} - PDNode *patterns::ElewiseAddActInplaceGrad::operator()( paddle::framework::ir::PDNode *d_act_out_var, std::unordered_set act_types) { @@ -1526,6 +1501,159 @@ PDNode *patterns::ElewiseAddActInplaceGrad::operator()( return ele_add_grad; } +PDNode *patterns::ElewiseAddAct::operator()( + paddle::framework::ir::PDNode *ele_x_var, + std::unordered_set act_types) { + auto *ele_y_var = pattern->NewNode(ele_y_repr()) + ->assert_is_op_input("elementwise_add", "Y"); + + auto *ele_add = + pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add"); + + auto *ele_out_var = pattern->NewNode(elewise_add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + + ele_out_var->AsIntermediate()->assert_is_ops_input(act_types); + + auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types); + + auto *act_out_var = + pattern->NewNode(act_out_repr())->assert_is_ops_output(act_types, "Out"); + + ele_add->LinksFrom({ele_x_var, ele_y_var}).LinksTo({ele_out_var}); + act->LinksFrom({ele_out_var}).LinksTo({act_out_var}); + + return act_out_var; +} + +PDNode *patterns::LinearAct::operator()( + paddle::framework::ir::PDNode *linear_x_var, + const std::unordered_set &act_types, bool with_grad_link, + bool is_act_grad_x_from_act) { + auto *matmul_w_var = + pattern->NewNode(matmul_w_repr())->assert_is_op_input("matmul_v2", "Y"); + + auto *matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul_v2"); + + auto *matmul_out_var = pattern->NewNode(matmul_out_repr()) + ->assert_is_op_output("matmul_v2", "Out"); + + matmul_out_var->AsIntermediate()->assert_is_op_input("elementwise_add", "X"); + + auto *ele_bias_var = pattern->NewNode(ele_bias_repr()) + ->assert_is_op_input("elementwise_add", "Y"); + + auto *ele_add = + pattern->NewNode(ele_add_repr())->assert_is_op("elementwise_add"); + + auto *ele_out_var = pattern->NewNode(elewise_add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out"); + + matmul->LinksFrom({linear_x_var, matmul_w_var}).LinksTo({matmul_out_var}); + ele_add->LinksFrom({matmul_out_var, ele_bias_var}).LinksTo({ele_out_var}); + + if (with_grad_link) { + matmul_out_var->assert_is_op_input("elementwise_add_grad", "X"); + auto *elementwise_add_grad_op = pattern->NewNode("elementwise_add_grad") + ->assert_is_op("elementwise_add_grad"); + elementwise_add_grad_op->LinksFrom({matmul_out_var}); + } + + if (act_types.size() > 0) { + ele_out_var->AsIntermediate()->assert_is_ops_input(act_types); + + auto *act = pattern->NewNode(act_repr())->assert_is_ops(act_types); + auto *act_out_var = pattern->NewNode(act_out_repr()) + ->assert_is_ops_output(act_types, "Out"); + + act->LinksFrom({ele_out_var}).LinksTo({act_out_var}); + + if (with_grad_link && !is_act_grad_x_from_act) { + std::unordered_set act_grad_types; + for (const auto &act : act_types) { + std::string act_grad(act); + act_grad.append("_grad"); + act_grad_types.insert(act_grad); + } + + ele_out_var->assert_is_ops_input(act_grad_types, "X"); + auto *act_grad_op = + pattern->NewNode(act_grad_repr())->assert_is_ops(act_grad_types); + act_grad_op->LinksFrom({ele_out_var}); + } + + return act_out_var; + } + + return ele_out_var; +} + +PDNode *patterns::ElewiseAddMatmulAct::operator()( + paddle::framework::ir::PDNode *dout_var, + const std::unordered_set &act_grad_types, + bool without_x_gradient, bool is_act_grad_x_from_act) { + auto *ele_grad_bias_var = + pattern->NewNode(ele_grad_bias_repr()) + ->assert_is_op_input("elementwise_add_grad", "Y"); + auto *ele_add_grad = pattern->NewNode(ele_add_grad_repr()) + ->assert_is_op("elementwise_add_grad"); + auto *ele_grad_dx_var = + pattern->NewNode(ele_grad_dx_repr()) + ->assert_is_op_output("elementwise_add_grad", GradVarName("X")); + auto *ele_grad_dbias_var = + pattern->NewNode(ele_grad_dbias_repr()) + ->assert_is_op_output("elementwise_add_grad", GradVarName("Y")); + ele_add_grad->LinksFrom({dout_var, ele_grad_bias_var}) + .LinksTo({ele_grad_dx_var, ele_grad_dbias_var}); + + ele_grad_dx_var->AsIntermediate()->assert_is_op_input("matmul_v2_grad", + GradVarName("Out")); + + auto *matmul_grad_x_var = pattern->NewNode(matmul_grad_x_repr()) + ->assert_is_op_input("matmul_v2_grad", "X"); + auto *matmul_grad_w_var = pattern->NewNode(matmul_grad_w_repr()) + ->assert_is_op_input("matmul_v2_grad", "Y"); + auto *matmul_grad = + pattern->NewNode(matmul_grad_repr())->assert_is_op("matmul_v2_grad"); + auto *matmul_grad_dx_var = + pattern->NewNode(matmul_grad_dx_repr()) + ->assert_is_op_output("matmul_v2_grad", GradVarName("X")); + auto *matmul_grad_dw_var = + pattern->NewNode(matmul_grad_dw_repr()) + ->assert_is_op_output("matmul_v2_grad", GradVarName("Y")); + matmul_grad->LinksFrom( + {ele_grad_dx_var, matmul_grad_x_var, matmul_grad_w_var}); + if (without_x_gradient) { + matmul_grad->LinksTo({matmul_grad_dw_var}); + } else { + matmul_grad->LinksTo({matmul_grad_dx_var, matmul_grad_dw_var}); + } + + if (!without_x_gradient && act_grad_types.size() > 0) { + matmul_grad_dx_var->AsIntermediate()->assert_is_ops_input( + act_grad_types, GradVarName("Out")); + + auto *act_grad = + pattern->NewNode(act_grad_repr())->assert_is_ops(act_grad_types); + auto *act_grad_dx_var = + pattern->NewNode(act_grad_dx_repr()) + ->assert_is_ops_output(act_grad_types, GradVarName("X")); + + auto *act_grad_x_var = matmul_grad_x_var; + if (!is_act_grad_x_from_act) { + auto *ele_out_var = pattern->NewNode(ele_out_repr()) + ->assert_is_ops_input(act_grad_types, "X"); + act_grad_x_var = ele_out_var; + } + + act_grad->LinksFrom({matmul_grad_dx_var, act_grad_x_var}) + .LinksTo({act_grad_dx_var}); + return act_grad; + } + + return matmul_grad; +} + // conv_type: conv2d, conv3d, conv2d_transpose PDNode *patterns::ConvBias::operator()( paddle::framework::ir::PDNode *conv_input, std::string conv_type) { diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index d6400ed6945bf..0f21906d08d0e 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -863,6 +863,65 @@ struct ElewiseAddActInplaceGrad : public PatternBase { PATTERN_DECL_NODE(ele_y); }; +// The following patterns are used to fuse linear and act (ReLu or GeLU) +// formula: act(F.linear(x)) +// op: matmul_v2 + elementwise_add + act +// named nodes: matmul, elementwise_add, act +// matmul_w, matmul_out +// ele_bias, elewise_add_out, act_out +struct LinearAct : public PatternBase { + LinearAct(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "linear_act") {} + + PDNode* operator()(PDNode* x, + const std::unordered_set& act_types, + bool with_grad_link, bool is_act_grad_x_from_act); + + // declare operator node's name + PATTERN_DECL_NODE(matmul); + PATTERN_DECL_NODE(ele_add); + PATTERN_DECL_NODE(act); + PATTERN_DECL_NODE(act_grad); + // declare variable node's name + PATTERN_DECL_NODE(matmul_w); + PATTERN_DECL_NODE(matmul_out); + PATTERN_DECL_NODE(elewise_add_out); + PATTERN_DECL_NODE(ele_bias); + PATTERN_DECL_NODE(act_out); +}; + +// The following patterns are used to fuse linear_grad and act_grad (ReLu or +// GeLU) +// formula: the backward of F.linear( act(x) ) +// op: elementwise_add_grad + matmul_v2_grad + act_grad +// named nodes: ele_add_grad, matmul_grad, act_grad +// ele_grad_bias, ele_grad_dx, ele_grad_dbias +// matmul_grad_x, matmul_grad_dx, matmul_grad_dx +// matmul_grad_dw, act_grad_dx +struct ElewiseAddMatmulAct : public PatternBase { + ElewiseAddMatmulAct(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "elewiseadd_matmul_act") {} + + PDNode* operator()(PDNode* x, + const std::unordered_set& act_grad_types, + bool without_x_gradient, bool is_act_grad_x_from_act); + + // declare operator node's name + PATTERN_DECL_NODE(ele_add_grad); + PATTERN_DECL_NODE(matmul_grad); + PATTERN_DECL_NODE(act_grad); + // declare variable node's name + PATTERN_DECL_NODE(ele_out); + PATTERN_DECL_NODE(ele_grad_bias); + PATTERN_DECL_NODE(ele_grad_dx); + PATTERN_DECL_NODE(ele_grad_dbias); + PATTERN_DECL_NODE(matmul_grad_x); + PATTERN_DECL_NODE(matmul_grad_w); + PATTERN_DECL_NODE(matmul_grad_dx); + PATTERN_DECL_NODE(matmul_grad_dw); + PATTERN_DECL_NODE(act_grad_dx); +}; + // Conv with Elementwise_add as bias // op: conv + elementwise_add // named nodes: diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 67287afa6ae50..80e7f5c001d4b 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -19,7 +19,8 @@ register_operators(EXCLUDES fused_attention_op fused_transformer_op fused_feedforward_op - resnet_unit_op) + resnet_unit_op + fused_gemm_epilogue_op) # fusion_gru_op does not have CUDA kernel op_library(fusion_gru_op) @@ -79,4 +80,8 @@ if (WITH_GPU OR WITH_ROCM) cc_test(test_cudnn_norm_conv SRCS cudnn_norm_conv_test.cc DEPS conv_op blas im2col vol2col depthwise_conv eigen_function tensor op_registry device_context generator memory) cc_test(test_cudnn_bn_add_relu SRCS cudnn_bn_add_relu_test.cc DEPS batch_norm_op fused_bn_add_activation_op tensor op_registry device_context generator memory) endif() + + if (CUDA_VERSION GREATER_EQUAL 11.6) + op_library(fused_gemm_epilogue_op) + endif() endif() diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc new file mode 100644 index 0000000000000..4c4e3661e6d6e --- /dev/null +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc @@ -0,0 +1,353 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +class FusedGemmEpilogueOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedGemmEpilogueOp"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusedGemmEpilogueOp"); + OP_INOUT_CHECK(ctx->HasInput("Bias"), "Output", "Bias", + "FusedGemmEpilogueOp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", + "FusedGemmEpilogueOp"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto bias_dims = ctx->GetInputDim("Bias"); + + auto trans_x = ctx->Attrs().Get("trans_x"); + auto trans_y = ctx->Attrs().Get("trans_y"); + + PADDLE_ENFORCE_EQ( + y_dims.size(), 2, + platform::errors::InvalidArgument( + "The Input tensor Y's dimension of FusedGemmEpilogueOp " + " should be 2, but got %d.", + y_dims.size())); + + PADDLE_ENFORCE_GE( + x_dims.size(), 2, + platform::errors::InvalidArgument( + "The Input tensor X's dimension of FusedGemmEpilogueOp " + " should be >= 2, but got %d.", + x_dims.size())); + + PADDLE_ENFORCE_EQ( + bias_dims.size(), 1, + platform::errors::InvalidArgument( + "The Input tensor bias's dimension of FusedGemmEpilogueOp " + " should be == 1, but got %d.", + bias_dims.size())); + + PADDLE_ENFORCE_EQ(bias_dims[0], trans_y ? y_dims[0] : y_dims[1], + platform::errors::InvalidArgument( + "The Input tensor bias's dimension 0" + " should be == Y[-1], but got bias's shape = [%s] " + "and Y's shape = [%s]", + bias_dims, y_dims)); + + auto x_mat_dims = + phi::flatten_to_2d(x_dims, trans_x ? 1 : x_dims.size() - 1); + + int K_from_x = trans_x ? x_mat_dims[0] : x_mat_dims[1]; + int K_from_y = trans_y ? y_dims[1] : y_dims[0]; + + PADDLE_ENFORCE_EQ( + K_from_x, K_from_y, + platform::errors::InvalidArgument( + "The last dimension of X should be equal with Y's first dimension." + "But received X[-1] = [%d], Y[0] = [%d].", + K_from_x, K_from_y)); + + auto activation = ctx->Attrs().Get("activation"); + + if ((activation != "relu") && (activation != "gelu") && + (activation != "none")) { + PADDLE_ENFORCE_EQ( + true, false, + platform::errors::InvalidArgument( + "The activation attribute of fused_gemm_epilogue op should be" + " one of {\"none\", \"relu\", \"gelu\"}. But received %s." + "But received activation=%s.", + activation)); + } + + if (activation == "none" && ctx->HasOutput("ReserveSpace")) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The ReserveSpace would not be used when activation = \"none\"")); + } + + // cublasLt's restriction for auxiliary. + if (ctx->HasOutput("ReserveSpace") && activation != "none") { + int min_size_of_n = activation == "relu" ? 128 : 8; + int N_size = trans_y ? y_dims[0] : y_dims[1]; + PADDLE_ENFORCE_EQ(N_size % min_size_of_n, 0, + platform::errors::InvalidArgument( + "The output dimension N (X(MxK) * Y(KxN) = C(MxN)) " + "should be multiple of %d when auxiliary_key given " + "and activation=%s, but got N = %d.", + min_size_of_n, activation, N_size)); + } + + std::vector out_dims; + out_dims.reserve(static_cast(x_dims.size())); + if (trans_x) { + for (int i = 1; i < x_dims.size(); ++i) out_dims.push_back(x_dims[i]); + } else { + for (int i = 0; i < x_dims.size() - 1; ++i) out_dims.push_back(x_dims[i]); + } + + if (trans_y) { + out_dims.push_back(y_dims[0]); + } else { + out_dims.push_back(y_dims[1]); + } + + ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); + // Note (Ming Huang): Reserve space of relu is a bit-mask, + // which cannot pass nan_and_inf checking if shape is set. + if (activation == "gelu" && ctx->HasOutput("ReserveSpace")) { + ctx->SetOutputDim("ReserveSpace", phi::make_ddim(out_dims)); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + } +}; + +class FusedGemmEpilogueOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input tensor X of Out = Act((X * Y) + Bias)."); + AddInput("Y", "The input tensor Y of Out = Act((X * Y) + Bias)."); + AddInput("Bias", "The input tensor bias of Out = Act((X * Y) + Bias)."); + + AddOutput("Out", "The output tensor Out of Out = Act((X * Y) + Bias)."); + AddOutput("ReserveSpace", + R"DOC(Reserve GPU space to place + auxiliary data pointer. It is used to pass auxiliary data pointer + for fused_gemm_epilogue op. If not given (empty string), the + auxiliary mode would not be enable.)DOC") + .AsDispensable() + .AsExtra(); + + AddAttr( + "trans_x", + R"DOC((bool, default false), Whether to transpose input tensor X + or not. The input tensor X coulbe be more than two dimension. When + set trans_x=true, it would fully reverse X. For instant: X with shpae + [d0, d1, d2, d3] -> [d3, d2, d1, d0].)DOC") + .SetDefault(false); + AddAttr( + "trans_y", + R"DOC((bool, default false), Whether to transpose input tensor Y + or not. The input tensor Y should be two dimension. When + set trans_y=true, it would transpose Y. For instant: Y with shpae + [d0, d1] -> [d1, d0].)DOC") + .SetDefault(false); + + AddAttr( + "activation", + R"DOC((string, default none), The activation function. It could be + one of {none, relu, gelu}. When none is given, Act would be null + operations)DOC") + .SetDefault("none"); + + AddComment(R"DOC( +FusedGemmEpilogue Operator +This operator is used to perform Activeation(Elementwise_add(Matmul(X, Y), bias)). +It is equal to paddle.nn.Linear + Activation (None, ReLU or GeLU). + +Note: +X could be more than two dimension and would be flatten to 2D for computing. +X with shape [d0, d1, d2, d3] -> X_2D with shape [d0*d1*d2, d3] +)DOC"); + } +}; + +class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("DOut"), "Input", "DOut", + "FusedGemmEpilogueGradOp"); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedGemmEpilogueGradOp"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "FusedGemmEpilogueGradOp"); + OP_INOUT_CHECK(ctx->HasOutput("DY"), "Output", "DY", "FusedGemmEpilogueOp"); + + auto dout_dims = ctx->GetInputDim("DOut"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + PADDLE_ENFORCE_GE( + dout_dims.size(), 2, + platform::errors::InvalidArgument( + "The Input tensor DOut's dimension of FusedGemmEpilogueGradOp " + " should be >= 2, but got %d.", + dout_dims.size())); + + PADDLE_ENFORCE_EQ( + y_dims.size(), 2, + platform::errors::InvalidArgument( + "The Input tensor Y's dimension of FusedGemmEpilogueGradOp " + " should be 2, but got %d.", + y_dims.size())); + + PADDLE_ENFORCE_GE( + x_dims.size(), 2, + platform::errors::InvalidArgument( + "The Input tensor X's dimension of FusedGemmEpilogueGradOp " + " should be >= 2, but got %d.", + x_dims.size())); + + PADDLE_ENFORCE_EQ( + dout_dims.size(), x_dims.size(), + platform::errors::InvalidArgument( + "The Input tensor DOut's and X's dimension of " + "FusedGemmEpilogueGradOp " + " should be the same, but got DOut's dim = %d and X's = %d.", + dout_dims.size(), x_dims.size())); + + auto dout_mat_dims = phi::flatten_to_2d(dout_dims, dout_dims.size() - 1); + + auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1); + + PADDLE_ENFORCE_EQ( + dout_mat_dims[1], y_dims[1], + platform::errors::InvalidArgument( + "The last dimension of DOut should be equal with Y's last" + "dimension. But received DOut[-1] = [%d], Y[1] = [%d].", + dout_mat_dims[1], y_dims[1])); + + PADDLE_ENFORCE_EQ( + dout_mat_dims[0], x_mat_dims[0], + platform::errors::InvalidArgument( + "The first dimension of DOut should be equal with X's first" + "dimension. But received DOut[0] = [%d], Y[0] = [%d].", + dout_mat_dims[0], x_mat_dims[0])); + + auto activation_grad = ctx->Attrs().Get("activation_grad"); + if ((activation_grad != "relu_grad") && (activation_grad != "gelu_grad") && + (activation_grad != "none")) { + PADDLE_ENFORCE_EQ( + true, false, + platform::errors::InvalidArgument( + "The activation attribute of fused_gemm_epilogue op should be" + " one of {\"none\", \"relu\", \"gelu\"}. But received %s." + "But received activation=%s.", + activation_grad)); + } + + if (activation_grad != "none" && !ctx->HasInput("ReserveSpace")) { + PADDLE_ENFORCE_EQ(true, false, + platform::errors::InvalidArgument( + "The ReserveSpace should not be empty. " + "when activation_grad == {relu_grad, gelu_grad}.")); + } + + if (ctx->HasOutput("DX")) { + std::vector dx_dims; + dx_dims.reserve(static_cast(x_dims.size())); + for (int i = 0; i < x_dims.size(); ++i) { + dx_dims.push_back(x_dims[i]); + } + ctx->SetOutputDim("DX", phi::make_ddim(dx_dims)); + } + + std::vector dy_dims(y_dims.Get(), y_dims.Get() + y_dims.size()); + ctx->SetOutputDim("DY", phi::make_ddim(dy_dims)); + + if (ctx->HasOutput("DBias")) { + std::vector dbias_dims; + dbias_dims.push_back(y_dims[1]); + ctx->SetOutputDim("DBias", phi::make_ddim(dbias_dims)); + } + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + } +}; + +class FusedGemmEpilogueGradOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("DOut", + "The input grad tensor to Out of Out = (Act(X) * Y) + bias"); + AddInput("X", "The input tensor X of Out = (Act(X) * Y) + bias"); + AddInput("Y", "The input tensor Y of Out = (Act(X) * Y) + bias"); + AddInput("ReserveSpace", + R"DOC(A GPU space to fetch + auxiliary data pointer. It is used to pass auxiliary data pointer + for fused_gemm_epilogue_grad op. If not given (empty string), the + auxiliary mode would not be enable.)DOC") + .AsDispensable(); + + AddOutput("DX", "The output grad tensor to X of Out = (Act(X) * Y) + bias.") + .AsDispensable(); + AddOutput("DY", + "The output grad tensor to Y of Out = (Act(X) * Y) + bias."); + AddOutput("DBias", + "The output grad tensor to bias of Out = (Act(X) * Y) + bias.") + .AsDispensable(); + + AddAttr( + "activation_grad", + R"DOC((string, default none), The backward activation function. It could be + one of {none, relu_grad, gelu_grad}. When none is given, The backward Act would + be null operations)DOC") + .SetDefault("none"); + + AddComment(R"DOC( +FusedGemmEpilogueGrad Operator +This operator is used to perform backward of Elementwise_add(Matmul(Activeation(X), Y), bias). +It is equal to Activation (None, ReLU or GeLU) + paddle.nn.Linear. + +Note: +X could be more than two dimension and would be flatten to 2D for computing. +X with shape [d0, d1, d2, d3] -> X_2D with shape [d0*d1*d2, d3] +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_gemm_epilogue, ops::FusedGemmEpilogueOp, + ops::FusedGemmEpilogueOpMaker) +REGISTER_OPERATOR(fused_gemm_epilogue_grad, ops::FusedGemmEpilogueGradOp, + ops::FusedGemmEpilogueGradOpMaker) diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu new file mode 100644 index 0000000000000..e16c9e8f483cc --- /dev/null +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -0,0 +1,376 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/dynload/cublasLt.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class FusedGemmEpilogueKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + + const Tensor* x = ctx.Input("X"); + const Tensor* y = ctx.Input("Y"); + const Tensor* bias = ctx.Input("Bias"); + + Tensor* out = ctx.Output("Out"); + Tensor* reserve_space = ctx.Output("ReserveSpace"); + + bool trans_x = ctx.Attr("trans_x"); + bool trans_y = ctx.Attr("trans_y"); + + std::string activation = ctx.Attr("activation"); + bool enable_auxiliary = reserve_space == nullptr ? false : true; + + out->mutable_data(ctx.GetPlace()); + auto* out_data = out->data(); + + auto x_mat_dims = + phi::flatten_to_2d(x->dims(), trans_x ? 1 : x->dims().size() - 1); + int64_t M = trans_x ? x_mat_dims[1] : x_mat_dims[0]; + int64_t K = trans_y ? y->dims()[1] : y->dims()[0]; + int64_t N = trans_y ? y->dims()[0] : y->dims()[1]; + + cudaDataType_t mat_type = CUDA_R_32F; + cudaDataType_t scale_type = CUDA_R_32F; + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + if (std::is_same::value) { + mat_type = CUDA_R_16F; + scale_type = CUDA_R_16F; + } + if (std::is_same::value) { + mat_type = CUDA_R_64F; + scale_type = CUDA_R_64F; + compute_type = CUBLAS_COMPUTE_64F; + } + + cublasLtMatmulDesc_t operation_desc = NULL; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + &operation_desc, compute_type, scale_type)); + cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transx, + sizeof(transx))); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transy, + sizeof(transy))); + + cublasLtEpilogue_t epiloque_func = + get_epilogue_type_(activation, enable_auxiliary); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epiloque_func, + sizeof(epiloque_func))); + const T* bias_data = bias->data(); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_data, + sizeof(bias_data))); + + if (enable_auxiliary && activation != "none") { + size_t reserve_space_size = 0; + if (activation == "relu") { + // Count in bits. + reserve_space_size = phi::product(out->dims()) / 8; + } else { + reserve_space_size = phi::product(out->dims()) * sizeof(T); + } + reserve_space->mutable_data(ctx.GetPlace(), out->type(), + reserve_space_size); + void* aux_data = reinterpret_cast(reserve_space->data()); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &aux_data, sizeof(aux_data))); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N, + sizeof(N))); + } + + cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL; + if (trans_x) + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &x_desc, mat_type, M, K, M)); + else + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &x_desc, mat_type, K, M, K)); + if (trans_y) + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &y_desc, mat_type, K, N, K)); + else + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &y_desc, mat_type, N, K, N)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &out_desc, mat_type, N, M, N)); + + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + size_t workspace_size = 4 * 1024 * 1024; + const cublasLtMatmulAlgo_t* algo = nullptr; + cudaStream_t stream = dev_ctx.stream(); + memory::allocation::AllocationPtr workspace = + memory::Alloc(dev_ctx, workspace_size); + + double alpha64 = 1.0, beta64 = 0.0; + float alpha32 = 1.0f, beta32 = 0.0f; + void *alpha = nullptr, *beta = nullptr; + if (std::is_same::value) { + alpha = &alpha64; + beta = &beta64; + } else { + alpha = &alpha32; + beta = &beta32; + } + + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( + lt_handle, operation_desc, alpha, y->data(), y_desc, x->data(), + x_desc, beta, out_data, out_desc, out_data, out_desc, algo, + workspace->ptr(), workspace_size, stream)); + } + + private: + static cublasLtEpilogue_t get_epilogue_type_(const std::string& activation, + bool enable_auxiliary) { + if (activation == "relu") { + return enable_auxiliary ? CUBLASLT_EPILOGUE_RELU_AUX_BIAS + : CUBLASLT_EPILOGUE_RELU_BIAS; + } else if (activation == "gelu") { + return enable_auxiliary ? CUBLASLT_EPILOGUE_GELU_AUX_BIAS + : CUBLASLT_EPILOGUE_GELU_BIAS; + } else if (activation == "none") { + return CUBLASLT_EPILOGUE_BIAS; + } else { + PADDLE_ENFORCE_EQ( + true, false, + platform::errors::InvalidArgument( + "The activation attribute of fused_gemm_epilogue op should be" + " one of {\"none\", \"relu\", \"gelu\"}. But received %s." + "But received activation=%s.", + activation)); + } + } +}; + +template +class FusedGemmEpilogueGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + + const Tensor* dout = ctx.Input("DOut"); + const Tensor* x = ctx.Input("X"); + const Tensor* y = ctx.Input("Y"); + const Tensor* reserve_space = ctx.Input("ReserveSpace"); + + Tensor* dx = ctx.Output("DX"); + Tensor* dy = ctx.Output("DY"); + Tensor* dbias = ctx.Output("DBias"); + + std::string activation_grad = ctx.Attr("activation_grad"); + + auto dout_mat_dims = + phi::flatten_to_2d(dout->dims(), dout->dims().size() - 1); + auto x_mat_dims = phi::flatten_to_2d(x->dims(), x->dims().size() - 1); + + int64_t M = x_mat_dims[0]; + int64_t K = y->dims()[0]; + int64_t N = y->dims()[1]; + + cudaDataType_t mat_type = CUDA_R_32F; + cudaDataType_t scale_type = CUDA_R_32F; + cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; + if (std::is_same::value) { + mat_type = CUDA_R_16F; + scale_type = CUDA_R_16F; + } + if (std::is_same::value) { + mat_type = CUDA_R_64F; + scale_type = CUDA_R_64F; + compute_type = CUBLAS_COMPUTE_64F; + } + + cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); + size_t workspace_size = 4 * 1024 * 1024; + const cublasLtMatmulAlgo_t* algo = nullptr; + cudaStream_t stream = dev_ctx.stream(); + + double alpha64 = 1.0, beta64 = 0.0; + float alpha32 = 1.0f, beta32 = 0.0f; + void *alpha = nullptr, *beta = nullptr; + if (std::is_same::value) { + alpha = &alpha64; + beta = &beta64; + } else { + alpha = &alpha32; + beta = &beta32; + } + + cublasOperation_t trans_dout = CUBLAS_OP_N; + cublasLtMatrixLayout_t dout_desc = NULL; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &dout_desc, mat_type, N, M, N)); + + if (dx) { + cublasLtMatmulDesc_t dx_operation_desc = NULL; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + &dx_operation_desc, compute_type, scale_type)); + cublasOperation_t trans_y = CUBLAS_OP_T; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_dout, + sizeof(trans_dout))); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_y, + sizeof(trans_y))); + cublasLtEpilogue_t epiloque_func_for_dx = + get_epilogue_type_(activation_grad); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epiloque_func_for_dx, sizeof(epiloque_func_for_dx))); + + if (activation_grad != "none") { + auto* aux_data = reserve_space->data(); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &aux_data, sizeof(aux_data))); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N, + sizeof(N))); + } + + cublasLtMatrixLayout_t y_desc = NULL, dx_desc = NULL; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &y_desc, mat_type, N, K, N)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &dx_desc, mat_type, K, M, K)); + + memory::allocation::AllocationPtr dx_workspace = + memory::Alloc(dev_ctx, workspace_size); + + dx->mutable_data(ctx.GetPlace()); + auto* dx_data = dx->data(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( + lt_handle, dx_operation_desc, alpha, y->data(), y_desc, + dout->data(), dout_desc, beta, dx_data, dx_desc, dx_data, dx_desc, + algo, dx_workspace->ptr(), workspace_size, stream)); + } + + if (dy) { + cublasLtMatmulDesc_t dy_operation_desc = NULL; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + &dy_operation_desc, compute_type, scale_type)); + cublasOperation_t trans_x = CUBLAS_OP_T; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_dout, + sizeof(trans_dout))); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_x, + sizeof(trans_x))); + cublasLtEpilogue_t epiloque_func_for_dy = dbias == nullptr + ? CUBLASLT_EPILOGUE_DEFAULT + : CUBLASLT_EPILOGUE_BGRADA; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epiloque_func_for_dy, sizeof(epiloque_func_for_dy))); + + if (dbias) { + dbias->mutable_data(ctx.GetPlace()); + auto* dbias_data = dbias->data(); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &dbias_data, sizeof(dbias_data))); + } + + cublasLtMatrixLayout_t x_desc = NULL, dy_desc = NULL; + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &x_desc, mat_type, K, M, K)); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + &dy_desc, mat_type, N, K, N)); + + memory::allocation::AllocationPtr dy_workspace = + memory::Alloc(dev_ctx, workspace_size); + + dy->mutable_data(ctx.GetPlace()); + auto* dy_data = dy->data(); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( + lt_handle, dy_operation_desc, alpha, dout->data(), dout_desc, + x->data(), x_desc, beta, dy_data, dy_desc, dy_data, dy_desc, algo, + dy_workspace->ptr(), workspace_size, stream)); + } + } + + private: + static cublasLtEpilogue_t get_epilogue_type_( + const std::string& activation_grad) { + if (activation_grad == "relu_grad") { + return CUBLASLT_EPILOGUE_DRELU; + } else if (activation_grad == "gelu_grad") { + return CUBLASLT_EPILOGUE_DGELU; + } else if (activation_grad == "none") { + return CUBLASLT_EPILOGUE_DEFAULT; + } else { + PADDLE_ENFORCE_EQ( + true, false, + platform::errors::InvalidArgument( + "The activation_grad attribute of fused_gemm_epilogue op should " + "be" + " one of {\"none\", \"relu\", \"gelu\"}. But received %s." + "But received activation_grad=%s.", + activation_grad)); + } + } +}; + +} // namespace operators +} // namespace paddle + +#if CUDA_VERSION >= 11060 +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fused_gemm_epilogue, + ops::FusedGemmEpilogueKernel, + ops::FusedGemmEpilogueKernel, + ops::FusedGemmEpilogueKernel); + +REGISTER_OP_CUDA_KERNEL( + fused_gemm_epilogue_grad, + ops::FusedGemmEpilogueGradKernel, + ops::FusedGemmEpilogueGradKernel, + ops::FusedGemmEpilogueGradKernel); +#endif diff --git a/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h b/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h index ab7d474c1ac38..a32db3a9921e3 100644 --- a/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h +++ b/paddle/fluid/platform/device/gpu/cuda/cuda_helper.h @@ -19,6 +19,7 @@ #include "paddle/fluid/platform/device/gpu/gpu_types.h" #include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/macros.h" @@ -110,5 +111,28 @@ class CublasHandleHolder { mutable std::mutex mtx_; }; +class CublasLtHandleHolder { + public: + CublasLtHandleHolder() { + PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasLtCreate(&handle_)); + } + const cublasLtHandle_t& GetCublasLtHandle() const { return handle_; } + + ~CublasLtHandleHolder() PADDLE_MAY_THROW { + PADDLE_RETRY_CUDA_SUCCESS(dynload::cublasLtDestroy(handle_)); + } + + inline void Call(const std::function& callback) const { + std::lock_guard guard(mtx_); + callback(handle_); + } + + private: + DISABLE_COPY_AND_ASSIGN(CublasLtHandleHolder); + + cublasLtHandle_t handle_; + mutable std::mutex mtx_; +}; + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device/gpu/gpu_types.h b/paddle/fluid/platform/device/gpu/gpu_types.h index d7362fe9cbd81..d0b48eca5021b 100644 --- a/paddle/fluid/platform/device/gpu/gpu_types.h +++ b/paddle/fluid/platform/device/gpu/gpu_types.h @@ -1,4 +1,5 @@ // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 NVIDIA Corporation. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -24,6 +25,7 @@ #else #include #include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/dynload/cudnn.h" #endif @@ -70,6 +72,10 @@ DECLARE_TYPE_FOR_GPU(dnnHandle_t, cudnnHandle_t, miopenHandle_t); DECLARE_TYPE_FOR_GPU(blasHandle_t, cublasHandle_t, rocblas_handle); +// TODO(Ming Huang): Since there is no blasLt handler, +// use rocblas_handle for workround. +DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle); + using CUDAGraphID = unsigned long long; // NOLINT #undef DECLARE_TYPE_FOR_GPU diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 6a7956628f804..d5ff8f4ddc683 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -1,4 +1,6 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -465,6 +467,9 @@ CUDAContext::CUDAContext(const CUDAPlace& place, InitCuBlasContext(); InitCuDNNContext(); #ifndef PADDLE_WITH_HIP +#if CUDA_VERSION >= 11060 + InitCuBlasLtContext(); +#endif InitCuSparseContext(); InitCuSolverContext(); #endif @@ -476,6 +481,9 @@ void CUDAContext::SetStream(gpuStream_t stream) { DestoryCuDNNContext(); DestoryCuBlasContext(); #ifndef PADDLE_WITH_HIP +#if CUDA_VERSION >= 11060 + DestoryCuBlasLtContext(); +#endif DestoryCuSolverContext(); #endif @@ -485,6 +493,9 @@ void CUDAContext::SetStream(gpuStream_t stream) { InitCuBlasContext(); InitCuDNNContext(); #ifndef PADDLE_WITH_HIP +#if CUDA_VERSION >= 11060 + InitCuBlasLtContext(); +#endif InitCuSolverContext(); #endif } @@ -495,6 +506,9 @@ CUDAContext::~CUDAContext() { DestoryCuDNNContext(); DestoryCuBlasContext(); #ifndef PADDLE_WITH_HIP +#if CUDA_VERSION >= 11060 + InitCuBlasLtContext(); +#endif DestoryCuSparseContext(); DestoryCuSolverContext(); #endif @@ -551,6 +565,14 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const { } return phi::GPUContext::cublas_handle(); } +#if CUDA_VERSION >= 11060 +cublasLtHandle_t CUDADeviceContext::cublaslt_handle() const { + if (thread_ctx_.count(this)) { + return context()->CublasLtHandle()->GetCublasLtHandle(); + } + return phi::GPUContext::cublaslt_handle(); +} +#endif cusparseHandle_t CUDADeviceContext::cusparse_handle() const { if (thread_ctx_.count(this)) { return context()->CusparseHandle()->GetCusparseHandle(); diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index e9124dfc1f8a7..513a2a51346a7 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -1,4 +1,6 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Corporation. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at @@ -29,6 +31,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/device/gpu/gpu_helper.h" #include "paddle/fluid/platform/dynload/cublas.h" +#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/dynload/cudnn.h" #include "paddle/fluid/platform/dynload/cusolver.h" #include "paddle/fluid/platform/dynload/cusparse.h" @@ -332,6 +335,12 @@ class CUDAContext { } #ifndef PADDLE_WITH_HIP +#if CUDA_VERSION >= 11060 + const std::unique_ptr& CublasLtHandle() const { + return cublaslt_handle_; + } +#endif + const std::unique_ptr& CusparseHandle() const { return cusparse_handle_; } @@ -348,6 +357,14 @@ class CUDAContext { } #ifndef PADDLE_WITH_HIP +#if CUDA_VERSION >= 11060 + /*! \brief Call cublasLt function safely. */ + inline void CublasLtCall( + const std::function& callback) const { + cublaslt_handle_->Call(callback); + } +#endif + /*! \brief Call cusparse function safely. */ inline void CusparseCall( const std::function& callback) const { @@ -394,6 +411,12 @@ class CUDAContext { #endif #ifndef PADDLE_WITH_HIP +#if CUDA_VERSION >= 11060 + void InitCuBlasLtContext() { + cublaslt_handle_.reset(new CublasLtHandleHolder()); + } +#endif + void InitCuSparseContext() { cusparse_handle_.reset(new CusparseHandleHolder(RawStream())); } @@ -472,6 +495,10 @@ class CUDAContext { } #ifndef PADDLE_WITH_HIP +#if CUDA_VERSION >= 11060 + void DestoryCuBlasLtContext() { cublaslt_handle_.reset(); } +#endif + void DestoryCuSparseContext() { cusparse_handle_.reset(); } #endif @@ -497,6 +524,9 @@ class CUDAContext { std::unique_ptr cublas_tensor_core_handle_; std::unique_ptr cublas_tf32_tensor_core_handle_; #ifndef PADDLE_WITH_HIP +#if CUDA_VERSION >= 11060 + std::unique_ptr cublaslt_handle_; +#endif cusolverDnHandle_t cusolver_dn_handle_; std::unique_ptr cusparse_handle_; #endif @@ -559,6 +589,7 @@ class CUDADeviceContext : public phi::GPUContext { rocblas_handle cublas_handle() const; #else cublasHandle_t cublas_handle() const; + cublasLtHandle_t cublaslt_handle() const; cusparseHandle_t cusparse_handle() const; #endif diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 6e553ad2e60e2..6add6a7033dd8 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1,4 +1,5 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -3445,6 +3446,31 @@ All parameter, weight, gradient are variables in Paddle. build_strategy = static.BuildStrategy() build_strategy.fuse_elewise_add_act_ops = True )DOC") + .def_property( + "fuse_gemm_epilogue", + [](const BuildStrategy &self) { return self.fuse_gemm_epilogue_; }, + [](BuildStrategy &self, bool b) { + PADDLE_ENFORCE_NE(self.IsFinalized(), true, + platform::errors::PreconditionNotMet( + "BuildStrategy has been finlaized, cannot be " + "configured again.")); + self.fuse_gemm_epilogue_ = b; + }, + R"DOC((bool, optional): fuse_gemm_epilogue indicate whether + to fuse matmul_op, elemenewist_add_op and activation_op, + it may make the execution faster. Default is False. + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + + paddle.enable_static() + + build_strategy = static.BuildStrategy() + build_strategy.fuse_gemm_epilogue = True + )DOC") .def_property( "fuse_bn_act_ops", [](const BuildStrategy &self) { return self.fuse_bn_act_ops_; }, diff --git a/paddle/phi/backends/gpu/forwards.h b/paddle/phi/backends/gpu/forwards.h index d0787159e1e30..33daa2bba6b7d 100644 --- a/paddle/phi/backends/gpu/forwards.h +++ b/paddle/phi/backends/gpu/forwards.h @@ -1,4 +1,5 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Corporation. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -56,6 +57,9 @@ using cudnnFusedOpsPlan_t = struct cudnnFusedOpsPlanStruct *; // Forward declaration of cuBLAS types. using cublasHandle_t = struct cublasContext *; +// Forward declaration of cuBLASLt types. +using cublasLtHandle_t = struct cublasLtContext *; + // Forward declaration of cuSOLVER types. using cusolverDnHandle_t = struct cusolverDnContext *; diff --git a/paddle/phi/backends/gpu/gpu_context.cc b/paddle/phi/backends/gpu/gpu_context.cc index dbcc1660c6472..09deb575f2414 100644 --- a/paddle/phi/backends/gpu/gpu_context.cc +++ b/paddle/phi/backends/gpu/gpu_context.cc @@ -1,4 +1,5 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Corporation. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -171,6 +172,7 @@ struct GPUContext::Impl { InitStream(); InitEigenDevice(); InitBlasHandle(); + InitBlasLtHandle(); InitDNNHandle(); InitSolverHandle(); InitSparseHandle(); @@ -183,6 +185,7 @@ struct GPUContext::Impl { InitGpuProperties(); InitStream(); InitBlasHandle(); + InitBlasLtHandle(); InitDNNHandle(); InitSolverHandle(); InitSparseHandle(); @@ -212,6 +215,7 @@ struct GPUContext::Impl { } #endif DestroyInternalBlasHandle(); + DestroyInternalBlasLtHandle(); DestoryInternalStream(); } @@ -418,6 +422,25 @@ struct GPUContext::Impl { void SetBlasHandle(blasHandle_t blas) { blas_handle_ = blas; } + void InitBlasLtHandle() { +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 + phi::dynload::cublasLtCreate(&blaslt_handle_); +#endif + } + + void DestroyInternalBlasLtHandle() { +#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 + phi::dynload::cublasLtDestroy(blaslt_handle_); +#endif + } + + void SetBlasLtHandle(blasLtHandle_t blaslt) { blaslt_handle_ = blaslt; } + + blasLtHandle_t GetBlasLtHandle() const { + PD_CHECK(blaslt_handle_ != nullptr, "the gpu blasLt handle is nullptr."); + return blaslt_handle_; + } + void InitDNNHandle() { if (phi::dynload::HasCUDNN()) { #ifdef PADDLE_WITH_HIP @@ -679,6 +702,7 @@ struct GPUContext::Impl { blasHandle_t blas_handle_{nullptr}; blasHandle_t blas_tensor_core_handle_{nullptr}; blasHandle_t blas_tf32_tensor_core_handle_{nullptr}; + blasLtHandle_t blaslt_handle_{nullptr}; dnnHandle_t dnn_handle_{nullptr}; solverHandle_t solver_handle_{nullptr}; sparseHandle_t sparse_handle_{nullptr}; @@ -725,6 +749,10 @@ blasHandle_t GPUContext::cublas_handle() const { return impl_->GetBlasHandle(); } +blasLtHandle_t GPUContext::cublaslt_handle() const { + return impl_->GetBlasLtHandle(); +} + solverHandle_t GPUContext::cusolver_dn_handle() const { return impl_->GetSolverHandle(); } @@ -815,6 +843,10 @@ void GPUContext::SetBlasHandle(blasHandle_t blas) { impl_->SetBlasHandle(blas); } +void GPUContext::SetBlasLtHandle(blasLtHandle_t blaslt) { + impl_->SetBlasLtHandle(blaslt); +} + void GPUContext::SetDnnHandle(dnnHandle_t handle) { impl_->SetDnnHandle(handle); } diff --git a/paddle/phi/backends/gpu/gpu_context.h b/paddle/phi/backends/gpu/gpu_context.h index b9d843982dc5e..3eb4360ad3538 100644 --- a/paddle/phi/backends/gpu/gpu_context.h +++ b/paddle/phi/backends/gpu/gpu_context.h @@ -1,4 +1,5 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Corporation. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -93,6 +94,9 @@ class GPUContext : public DeviceContext { /*! \brief Return cublas handle in the device context. */ blasHandle_t cublas_handle() const; + /*! \brief Return cublasLt handle in the device context. */ + blasLtHandle_t cublaslt_handle() const; + /*! \brief Return cusolver handle in the device context. */ solverHandle_t cusolver_dn_handle() const; @@ -193,6 +197,8 @@ class GPUContext : public DeviceContext { void SetBlasHandle(blasHandle_t); + void SetBlasLtHandle(blasLtHandle_t); + void SetDnnHandle(dnnHandle_t); void SetSolverHandle(solverHandle_t); diff --git a/paddle/phi/backends/gpu/gpu_decls.h b/paddle/phi/backends/gpu/gpu_decls.h index 0be24392e1b40..4a6b9d2fd87f1 100644 --- a/paddle/phi/backends/gpu/gpu_decls.h +++ b/paddle/phi/backends/gpu/gpu_decls.h @@ -1,4 +1,5 @@ // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 NVIDIA Corporation. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -59,6 +60,10 @@ DECLARE_TYPE_FOR_GPU(dnnHandle_t, cudnnHandle_t, miopenHandle_t); DECLARE_TYPE_FOR_GPU(blasHandle_t, cublasHandle_t, rocblas_handle); +// TODO(Ming Huang): Since there is no blasLt handler, +// use rocblas_handle for workround. +DECLARE_TYPE_FOR_GPU(blasLtHandle_t, cublasLtHandle_t, rocblas_handle); + DECLARE_TYPE_FOR_GPU(solverHandle_t, cusolverDnHandle_t, rocsolver_handle); DECLARE_TYPE_FOR_GPU(sparseHandle_t, cusparseHandle_t, rocsparse_handle); diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 2361bd2706238..cfba8f7fbf2b6 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -125,6 +125,17 @@ if(NOT WITH_GPU) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) + LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) + LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op) + LIST(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass) +endif() + +if (WITH_GPU) + if (CUDA_VERSION LESS 11.6) + LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) + LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op) + LIST(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass) + endif() endif() if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) diff --git a/python/paddle/fluid/tests/unittests/test_fuse_gemm_epilogue_pass.py b/python/paddle/fluid/tests/unittests/test_fuse_gemm_epilogue_pass.py new file mode 100644 index 0000000000000..7f3180e21d8c6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fuse_gemm_epilogue_pass.py @@ -0,0 +1,392 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test cases for role makers.""" + +from __future__ import print_function +import paddle +import os +import unittest +import numpy as np +import paddle.fluid.core as core + + +def compare(ref, res, atol, rtol): + + ref = np.array(ref).flatten() + res = np.array(res).flatten() + + tmp_ref = ref.astype(np.float) + tol = atol + rtol * abs(tmp_ref) + + diff = abs(res - ref) + + indices = np.transpose(np.where(diff > tol)) + if len(indices) == 0: + return True + return False + + +def verify_node_count(graph, node_name, target_count): + count = 0 + for node in graph.nodes(): + if node.name() == node_name: + count += 1 + return count == target_count + + +class MultiFCLayer(paddle.nn.Layer): + def __init__(self, hidden, Activation): + super(MultiFCLayer, self).__init__() + self.linear1 = paddle.nn.Linear(hidden, hidden) + self.linear2 = paddle.nn.Linear(hidden, hidden) + self.linear3 = paddle.nn.Linear(hidden, hidden) + + self.relu1 = Activation() + self.relu2 = Activation() + self.relu3 = Activation() + + def forward(self, x, matmul_y, ele_y): + output = self.linear1(x) + output = self.relu1(output) + output = self.linear2(output) + + output1 = paddle.matmul(output, matmul_y) + output = self.linear3(output) + output = self.relu2(output) + + output = paddle.matmul(output, matmul_y) + output = paddle.add(output, ele_y) + output = self.relu3(output) + output = paddle.add(output, output1) + return output + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueFWDBase(unittest.TestCase): + def setUp(self): + self.batch = 64 + self.seqlen = 128 + self.hidden = 768 + + paddle.enable_static() + + self.main_prog = paddle.static.Program() + self.startup_prog = paddle.static.Program() + + with paddle.static.program_guard(self.main_prog, self.startup_prog): + data = paddle.static.data( + name="_data", + shape=[-1, self.seqlen, self.hidden], + dtype='float32') + matmul_y = paddle.static.data( + name="_matmul_y", + shape=[1, self.hidden, self.hidden], + dtype='float32') + ele_y = paddle.static.data( + name="_ele_y", shape=[self.hidden, ], dtype='float32') + + multi_layer = MultiFCLayer(self.hidden, self._get_act_type()[0]) + with paddle.static.amp.fp16_guard(): + out = multi_layer(data, matmul_y, ele_y) + self.loss = paddle.mean(out) + + self.data_arr = np.random.random( + (self.batch, self.seqlen, self.hidden)).astype("float32") - 0.5 + self.matmul_y_arr = np.random.random( + (1, self.hidden, self.hidden)).astype("float32") - 0.5 + self.ele_y_arr = np.random.random( + (self.hidden, )).astype("float32") - 0.5 + + self.place = paddle.CUDAPlace(0) + self.exe = paddle.static.Executor(self.place) + self.exe.run(self.startup_prog) + + self._pre_test_hooks() + + self.feed = { + "_data": self.data_arr, + "_matmul_y": self.matmul_y_arr, + "_ele_y": self.ele_y_arr + } + self.reference = self.exe.run(self.main_prog, + feed=self.feed, + fetch_list=[self.loss.name]) + + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + def _test_output(self): + build_strategy = paddle.static.BuildStrategy() + build_strategy.fuse_gemm_epilogue = True + program = paddle.static.CompiledProgram(self.main_prog) + program = program.with_data_parallel( + loss_name=self.loss.name, + build_strategy=build_strategy, + places=paddle.static.cuda_places()) + + result = self.exe.run(program, + feed=self.feed, + fetch_list=[self.loss.name]) + self.assertTrue( + compare(self.reference, result, self.atol, self.rtol), + "[{}] outputs are miss-matched.".format(type(self).__name__)) + self.assertTrue( + verify_node_count(program._graph, "fused_gemm_epilogue", 3), + "[{}] The number of fused_gemm_epilogue is miss-matched in the computing graph.". + format(type(self).__name__)) + act_fwd_name = self._get_act_type()[1] + self.assertTrue( + verify_node_count(program._graph, act_fwd_name, 1), + "[{}] The number of {} is miss-matched in the computing graph.". + format(type(self).__name__, act_fwd_name)) + + def _pre_test_hooks(self): + self.atol = 1e-4 + self.rtol = 1e-3 + + def _get_act_type(self): + return paddle.nn.ReLU, "relu" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueReluFWDFP32(TestFuseGemmEpilogueFWDBase): + def _pre_test_hooks(self): + self.atol = 1e-3 + self.rtol = 1e-2 + + def _get_act_type(self): + return paddle.nn.ReLU, "relu" + + def test_output(self): + self._test_output() + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueReluFWDFP16(TestFuseGemmEpilogueReluFWDFP32): + def _pre_test_hooks(self): + self.atol = 1e-3 + self.rtol = 1e-2 + + fp16_var_list = paddle.static.amp.cast_model_to_fp16(self.main_prog) + paddle.static.amp.cast_parameters_to_fp16( + self.place, self.main_prog, to_fp16_var_names=fp16_var_list) + + self.data_arr = self.data_arr.astype("float16") + self.matmul_y_arr = self.matmul_y_arr.astype("float16") + self.ele_y_arr = self.ele_y_arr.astype("float16") + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGeluFWDFP32(TestFuseGemmEpilogueFWDBase): + def _pre_test_hooks(self): + self.atol = 1e-4 + self.rtol = 1e-3 + + def _get_act_type(self): + return paddle.nn.GELU, "gelu" + + def test_output(self): + self._test_output() + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGeluFWDFP16(TestFuseGemmEpilogueGeluFWDFP32): + def _pre_test_hooks(self): + self.atol = 1e-3 + self.rtol = 1e-2 + + fp16_var_list = paddle.static.amp.cast_model_to_fp16(self.main_prog) + paddle.static.amp.cast_parameters_to_fp16( + self.place, self.main_prog, to_fp16_var_names=fp16_var_list) + + self.data_arr = self.data_arr.astype("float16") + self.matmul_y_arr = self.matmul_y_arr.astype("float16") + self.ele_y_arr = self.ele_y_arr.astype("float16") + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueBWDBase(unittest.TestCase): + def setUp(self): + self.batch = 64 + self.seqlen = 128 + self.hidden = 768 + + paddle.enable_static() + + self.main_prog = paddle.static.Program() + self.startup_prog = paddle.static.Program() + + with paddle.static.program_guard(self.main_prog, self.startup_prog): + data = paddle.static.data( + name="_data", + shape=[-1, self.seqlen, self.hidden], + dtype='float32') + matmul_y = paddle.static.data( + name="_matmul_y", + shape=[1, self.hidden, self.hidden], + dtype='float32') + ele_y = paddle.static.data( + name="_ele_y", shape=[self.hidden, ], dtype='float32') + + multi_layer = MultiFCLayer(self.hidden, self._get_act_type()[0]) + with paddle.static.amp.fp16_guard(): + out = multi_layer(data, matmul_y, ele_y) + self.loss = paddle.mean(out) + paddle.static.append_backward(loss=self.loss) + + self.data_arr = np.random.random( + (self.batch, self.seqlen, self.hidden)).astype("float32") - 0.5 + self.matmul_y_arr = np.random.random( + (1, self.hidden, self.hidden)).astype("float32") - 0.5 + self.ele_y_arr = np.random.random( + (self.hidden, )).astype("float32") - 0.5 + + self.place = paddle.CUDAPlace(0) + self.exe = paddle.static.Executor(self.place) + self.exe.run(self.startup_prog) + + self._pre_test_hooks() + + self.feed = { + "_data": self.data_arr, + "_matmul_y": self.matmul_y_arr, + "_ele_y": self.ele_y_arr + } + + self.fetch = [ + self.loss.name, + '{}.w_0@GRAD'.format(multi_layer.linear1.full_name()), + '{}.b_0@GRAD'.format(multi_layer.linear1.full_name()), + '{}.w_0@GRAD'.format(multi_layer.linear2.full_name()), + '{}.b_0@GRAD'.format(multi_layer.linear2.full_name()), + '{}.w_0@GRAD'.format(multi_layer.linear3.full_name()), + '{}.b_0@GRAD'.format(multi_layer.linear3.full_name()) + ] + self.outs_ref = self.exe.run(self.main_prog, + feed=self.feed, + fetch_list=self.fetch) + + @unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") + def _test_output(self): + build_strategy = paddle.static.BuildStrategy() + build_strategy.fuse_gemm_epilogue = True + program = paddle.static.CompiledProgram(self.main_prog) + program = program.with_data_parallel( + loss_name=self.loss.name, + build_strategy=build_strategy, + places=paddle.static.cuda_places()) + + outs_res = self.exe.run(program, feed=self.feed, fetch_list=self.fetch) + + for ref, res in zip(self.outs_ref, outs_res): + self.assertTrue( + compare(ref, res, self.atol, self.rtol), + "[{}] output is miss-matched.".format(type(self).__name__)) + + self.assertTrue( + verify_node_count(program._graph, "fused_gemm_epilogue", 3), + "[{}] The number of fused_gemm_epilogue is miss-matched in the computing graph.". + format(type(self).__name__)) + self.assertTrue( + verify_node_count(program._graph, "fused_gemm_epilogue_grad", 3), + "[{}] The number of fused_gemm_epilogue_grad is miss-matched in the computing graph.". + format(type(self).__name__)) + _, act_fwd_name, act_bwd_name = self._get_act_type() + self.assertTrue( + verify_node_count(program._graph, act_fwd_name, 1), + "[{}] The number of {} is miss-matched in the computing graph.". + format(type(self).__name__, act_fwd_name)) + self.assertTrue( + verify_node_count(program._graph, act_bwd_name, 2), + "[{}] The number of {} is miss-matched in the computing graph.". + format(type(self).__name__, act_bwd_name)) + + def _pre_test_hooks(self): + self.atol = 1e-4 + self.rtol = 1e-3 + + def _get_act_type(self): + return paddle.nn.ReLU, "relu", "relu_grad" + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueReLUBWDFP32(TestFuseGemmEpilogueBWDBase): + def _pre_test_hooks(self): + self.atol = 1e-4 + self.rtol = 1e-3 + + def _get_act_type(self): + return paddle.nn.ReLU, "relu", "relu_grad" + + def test_output(self): + self._test_output() + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueReLUBWDFP16(TestFuseGemmEpilogueReLUBWDFP32): + def _pre_test_hooks(self): + self.atol = 1e-3 + self.rtol = 1e-2 + + fp16_var_list = paddle.static.amp.cast_model_to_fp16(self.main_prog) + paddle.static.amp.cast_parameters_to_fp16( + self.place, self.main_prog, to_fp16_var_names=fp16_var_list) + + self.data_arr = self.data_arr.astype("float16") + self.matmul_y_arr = self.matmul_y_arr.astype("float16") + self.ele_y_arr = self.ele_y_arr.astype("float16") + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGeLUBWDFP32(TestFuseGemmEpilogueBWDBase): + def _pre_test_hooks(self): + self.atol = 5e-4 + self.rtol = 1e-3 + + def _get_act_type(self): + return paddle.nn.GELU, "gelu", "gelu_grad" + + def test_output(self): + self._test_output() + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGeLUBWDFP16(TestFuseGemmEpilogueGeLUBWDFP32): + def _pre_test_hooks(self): + self.atol = 1e-3 + self.rtol = 1e-2 + + fp16_var_list = paddle.static.amp.cast_model_to_fp16(self.main_prog) + paddle.static.amp.cast_parameters_to_fp16( + self.place, self.main_prog, to_fp16_var_names=fp16_var_list) + + self.data_arr = self.data_arr.astype("float16") + self.matmul_y_arr = self.matmul_y_arr.astype("float16") + self.ele_y_arr = self.ele_y_arr.astype("float16") + + +if __name__ == "__main__": + np.random.seed(0) + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_grad_op.py b/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_grad_op.py new file mode 100644 index 0000000000000..2ea1bf2e9cb81 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_grad_op.py @@ -0,0 +1,239 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from op_test import OpTest, skip_check_grad_ci + + +def get_outputs(DOut, X, Y): + DX = np.dot(DOut, Y.T) + DY = np.dot(X.T, DOut) + DBias = np.sum(DOut, axis=0) + + return DX, DY, DBias + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDXYBiasFP16(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue_grad" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5, + 'X': np.random.random((8, 4)).astype(self.dtype) - 0.5, + 'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5 + } + + self.attrs = {"activation": 'none'} + + DX, DY, DBias = get_outputs(self.inputs['DOut'], self.inputs['X'], + self.inputs['Y']) + self.outputs = {'DX': DX, 'DY': DY, 'DBias': DBias} + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDXYBiasFP32( + TestFuseGemmEpilogueGradOpDXYBiasFP16): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDXYBiasFP64( + TestFuseGemmEpilogueGradOpDXYBiasFP16): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDYBiasFP16(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue_grad" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5, + 'X': np.random.random((8, 4)).astype(self.dtype) - 0.5, + 'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5 + } + + self.attrs = {"activation": 'none'} + + _, DY, DBias = get_outputs(self.inputs['DOut'], self.inputs['X'], + self.inputs['Y']) + self.outputs = {'DY': DY, 'DBias': DBias} + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDYBiasFP32( + TestFuseGemmEpilogueGradOpDYBiasFP16): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDYBiasFP64( + TestFuseGemmEpilogueGradOpDYBiasFP16): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDYFP16(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue_grad" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5, + 'X': np.random.random((8, 4)).astype(self.dtype) - 0.5, + 'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5 + } + + self.attrs = {"activation": 'none'} + + _, DY, _ = get_outputs(self.inputs['DOut'], self.inputs['X'], + self.inputs['Y']) + self.outputs = {'DY': DY} + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDYFP32(TestFuseGemmEpilogueGradOpDYFP16): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDYFP64(TestFuseGemmEpilogueGradOpDYFP16): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDXYFP16(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue_grad" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'DOut': np.random.random((8, 128)).astype(self.dtype) - 0.5, + 'X': np.random.random((8, 4)).astype(self.dtype) - 0.5, + 'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5 + } + + self.attrs = {"activation": 'none'} + + DX, DY, _ = get_outputs(self.inputs['DOut'], self.inputs['X'], + self.inputs['Y']) + self.outputs = {'DX': DX, 'DY': DY} + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDXYFP32(TestFuseGemmEpilogueGradOpDXYFP16): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueGradOpDXYFP64(TestFuseGemmEpilogueGradOpDXYFP16): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +if __name__ == "__main__": + np.random.seed(0) + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_op.py b/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_op.py new file mode 100644 index 0000000000000..f826898f9e5dd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_gemm_epilogue_op.py @@ -0,0 +1,450 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from op_test import OpTest, skip_check_grad_ci + + +def gelu(x): + y_ref = 0.5 * x * ( + 1.0 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) + return y_ref.astype(x.dtype) + + +def relu(x): + mask = x > 0 + return x * mask + + +def get_output(X, Y, bias, act): + out = np.dot(X, Y) + bias + if act == 'relu': + return relu(out) + elif act == 'gelu': + return gelu(out) + else: + return out + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMMFP16(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'X': np.random.random((8, 4)).astype(self.dtype) - 0.5, + 'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5, + 'Bias': np.random.random((128, )).astype(self.dtype) - 0.5 + } + self.outputs = { + 'Out': get_output(self.inputs['X'], self.inputs['Y'], + self.inputs['Bias'], 'relu') + } + self.attrs = {"activation": 'relu'} + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMMFP32(TestFuseGemmEpilogueOpReluMMFP16): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMMFP64(TestFuseGemmEpilogueOpReluMMFP16): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMTMFP16(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'X': np.random.random((4, 8)).astype(self.dtype) - 0.5, + 'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5, + 'Bias': np.random.random((128, )).astype(self.dtype) - 0.5 + } + self.outputs = { + 'Out': get_output(self.inputs['X'].T, self.inputs['Y'], + self.inputs['Bias'], 'relu') + } + self.attrs = {'trans_x': True, "activation": 'relu'} + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMTMFP32(TestFuseGemmEpilogueOpReluMTMFP16): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMTMFP64(TestFuseGemmEpilogueOpReluMTMFP16): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMMTFP16(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'X': np.random.random((8, 4)).astype(self.dtype) - 0.5, + 'Y': np.random.random((128, 4)).astype(self.dtype) - 0.5, + 'Bias': np.random.random((128, )).astype(self.dtype) - 0.5 + } + self.outputs = { + 'Out': get_output(self.inputs['X'], self.inputs['Y'].T, + self.inputs['Bias'], 'relu') + } + self.attrs = {'trans_y': True, "activation": 'relu'} + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMMTFP32(TestFuseGemmEpilogueOpReluMMTFP16): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMMTFP64(TestFuseGemmEpilogueOpReluMMTFP16): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMTMTFP16(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'X': np.random.random((4, 8)).astype(self.dtype) - 0.5, + 'Y': np.random.random((128, 4)).astype(self.dtype) - 0.5, + 'Bias': np.random.random((128, )).astype(self.dtype) - 0.5 + } + self.outputs = { + 'Out': get_output(self.inputs['X'].T, self.inputs['Y'].T, + self.inputs['Bias'], 'relu') + } + self.attrs = {'trans_x': True, 'trans_y': True, "activation": 'relu'} + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMTMTFP32(TestFuseGemmEpilogueOpReluMTMTFP16): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMTMTFP64(TestFuseGemmEpilogueOpReluMTMTFP16): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMMFP16MultiDimX(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'X': np.random.random((2, 2, 8, 4)).astype(self.dtype) - 0.5, + 'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5, + 'Bias': np.random.random((128, )).astype(self.dtype) - 0.5 + } + self.outputs = { + 'Out': get_output(self.inputs['X'].reshape( + (-1, 4)), self.inputs['Y'], self.inputs['Bias'], + 'relu').reshape((2, 2, 8, 128)) + } + self.attrs = {"activation": 'relu'} + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMMFP32MultiDimX( + TestFuseGemmEpilogueOpReluMMFP16MultiDimX): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMMFP64MultiDimX( + TestFuseGemmEpilogueOpReluMMFP16MultiDimX): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMTMFP16MultiDimX(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'X': np.random.random((4, 2, 2, 8)).astype(self.dtype) - 0.5, + 'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5, + 'Bias': np.random.random((128, )).astype(self.dtype) - 0.5 + } + self.outputs = { + 'Out': get_output(self.inputs['X'].reshape( + (4, -1)).T, self.inputs['Y'], self.inputs['Bias'], + 'relu').reshape((2, 2, 8, 128)) + } + self.attrs = {'trans_x': True, "activation": 'relu'} + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMTMFP32MultiDimX( + TestFuseGemmEpilogueOpReluMTMFP16MultiDimX): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpReluMTMFP64MultiDimX( + TestFuseGemmEpilogueOpReluMTMFP16MultiDimX): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpGeluMMFP16(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'X': np.random.random((8, 4)).astype(self.dtype) - 0.5, + 'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5, + 'Bias': np.random.random((128, )).astype(self.dtype) - 0.5 + } + + self.attrs = {"activation": 'gelu'} + + self.outputs = { + 'Out': get_output(self.inputs['X'], self.inputs['Y'], + self.inputs['Bias'], 'gelu') + } + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpGeluMMFP32(TestFuseGemmEpilogueOpGeluMMFP16): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpGeluMMFP64(TestFuseGemmEpilogueOpGeluMMFP16): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpNoneMMFP16(OpTest): + def setUp(self): + self.op_type = "fused_gemm_epilogue" + self.place = core.CUDAPlace(0) + self.init_dtype_type() + + self.inputs = { + 'X': np.random.random((8, 4)).astype(self.dtype) - 0.5, + 'Y': np.random.random((4, 128)).astype(self.dtype) - 0.5, + 'Bias': np.random.random((128, )).astype(self.dtype) - 0.5 + } + + self.attrs = {"activation": 'none'} + + self.outputs = { + 'Out': get_output(self.inputs['X'], self.inputs['Y'], + self.inputs['Bias'], 'none') + } + + def init_dtype_type(self): + self.dtype = np.float16 + self.atol = 1e-3 + + def test_check_output(self): + if self.dtype == np.float16 and not core.is_float16_supported( + self.place): + return + self.check_output_with_place(self.place, atol=self.atol) + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpNoneMMFP32(TestFuseGemmEpilogueOpNoneMMFP16): + def init_dtype_type(self): + self.dtype = np.single + self.atol = 1e-6 + + +@skip_check_grad_ci(reason="no grap op") +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestFuseGemmEpilogueOpNoneMMFP64(TestFuseGemmEpilogueOpNoneMMFP16): + def init_dtype_type(self): + self.dtype = np.double + self.atol = 1e-6 + + +if __name__ == "__main__": + np.random.seed(0) + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 694283264ca8f..de308d761cfc7 100755 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -1,4 +1,5 @@ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2022 NVIDIA Corporation. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -730,4 +731,6 @@ 'test_lu_op', 'test_margin_cross_entropy_op', 'test_pull_gpups_sparse_op', + 'test_fused_gemm_epilogue_op', + 'test_fused_gemm_epilogue_grad_op', ]