Skip to content

Commit

Permalink
explicitly set residual as output
Browse files Browse the repository at this point in the history
  • Loading branch information
Silv3S committed Apr 8, 2022
1 parent 0313fed commit cbe267f
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC(
return;
}

fc_op->Op()->SetInput("ResidualData", {residual_data->Name()});
fc_op->Op()->SetOutput("ResidualData", {residual_data->Name()});
fc_op->Op()->SetOutput("Out", {elementwise_out->Name()});
fc_op->Op()->SetAttr("fuse_residual_connection", true);

Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/tests/api/analyzer_bert_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ void profile(bool use_mkldnn = false) {
config.EnableMKLDNN();
config.pass_builder()->AppendPass("fc_mkldnn_pass");
config.pass_builder()->AppendPass("fc_act_mkldnn_fuse_pass");
config.pass_builder()->AppendPass("fc_elementwise_add_mkldnn_fuse_pass");
}

std::vector<std::vector<PaddleTensor>> outputs;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ class FCPrimitiveFactory {
const ExecutionContext& ctx, Tensor* output) {
if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) {
auto* residual_param = ctx.Input<Tensor>("ResidualData");
auto* residual_param = ctx.Output<Tensor>("ResidualData");

PADDLE_ENFORCE_EQ(
output->dims(), residual_param->dims(),
Expand Down

0 comments on commit cbe267f

Please sign in to comment.