Skip to content

Commit

Permalink
[Paddle Inference] Add conv_elementwise_act. (#43871)
Browse files Browse the repository at this point in the history
* conv_fusion
  • Loading branch information
xiaoxiaohehe001 committed Jul 6, 2022
1 parent 24d07b7 commit 4c269cc
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 4 deletions.
18 changes: 18 additions & 0 deletions paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,22 @@ ConvElementwiseAdd2ActFusePass::ConvElementwiseAdd2ActFusePass() {
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}

void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
Expand Down Expand Up @@ -188,4 +204,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add2_act_fuse_pass)
.LE("conv2d", 1)
.LE("elementwise_add", 1)
.EQ("relu", 0)
.EQ("sigmoid", 0)
.EQ("tanh", 0)
.EQ("identity", 0));
18 changes: 18 additions & 0 deletions paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,22 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() {
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("sigmoid"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();

AddOpCompat(OpCompat("tanh"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}

void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
Expand Down Expand Up @@ -170,4 +186,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass)
.LE("conv2d", 1)
.LE("elementwise_add", 1)
.EQ("relu", 0)
.EQ("sigmoid", 0)
.EQ("tanh", 0)
.EQ("identity", 0));
3 changes: 2 additions & 1 deletion paddle/fluid/framework/ir/graph_pattern_detector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2324,7 +2324,8 @@ PDNode *patterns::PriorBox::operator()() {
return boxes_var;
}

std::unordered_set<std::string> conv_act_set({"identity", "relu"});
std::unordered_set<std::string> conv_act_set(
{"identity", "relu", "sigmoid", "tanh"});

PDNode *patterns::ConvElementwiseaddAct::operator()(PDNode *conv_in) {
conv_in->AsInput();
Expand Down
8 changes: 5 additions & 3 deletions paddle/fluid/operators/fused/conv_fusion_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -544,9 +544,11 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {

namespace ops = paddle::operators;
#if CUDNN_VERSION >= 7100
REGISTER_OP_CUDA_KERNEL(conv2d_fusion,
ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>);
REGISTER_OP_CUDA_KERNEL(
conv2d_fusion,
ops::CUDNNConvFusionOpKernel<float>,
ops::CUDNNConvFusionOpKernel<double>,
ops::CUDNNConvFusionOpKernel<paddle::platform::float16>);
#endif
#ifdef PADDLE_WITH_HIP
REGISTER_OP_CUDA_KERNEL(conv2d_fusion, ops::CUDNNConvFusionOpKernel<float>);
Expand Down

0 comments on commit 4c269cc

Please sign in to comment.