Skip to content

Commit

Permalink
cosmetic changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Silv3S committed Apr 5, 2022
1 parent 902da8a commit 0313fed
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,21 @@ FCResidualConnectionMKLDNNFusePass::FCResidualConnectionMKLDNNFusePass() {

GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC(
const std::string& name_scope, const GraphWithStats& graph_with_stats,
bool as_x) const {
bool fc_as_x) const {
GraphPatternDetector gpd;
auto pattern = gpd.mutable_pattern();
patterns::FCMKLDNN fc_pattern{pattern, name_scope};
bool fc_has_bias = true;
auto fc_output = fc_pattern(
gpd.mutable_pattern()->NewNode("fc")->AsInput()->assert_is_op_input(
"fc", "Input"),
true);
fc_has_bias);

patterns::ResidualElementwise elementwise_pattern{pattern, name_scope, as_x};
patterns::ResidualElementwise elementwise_pattern{pattern, name_scope,
fc_as_x};
elementwise_pattern(
fc_output, pattern->NewNode(elementwise_pattern.residual_data_repr()),
"elementwise_add", as_x);
"elementwise_add", fc_as_x);
fc_output->AsIntermediate();

int found_fc_count = 0;
Expand All @@ -88,9 +90,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC(
elementwise_pattern);

if (FindFuseOption(*fc_op, *elementwise_op) != FUSE_MKLDNN) return;

if (!IsReachable(g, residual_data, fc_output)) return;

if (HasFusedActivation(fc_op)) return;

if (!IsCompat(subgraph, g)) {
Expand All @@ -114,7 +114,7 @@ GraphWithStats FCResidualConnectionMKLDNNFusePass::FuseFC(
gpd(graph_with_stats.first, handler);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
std::string fusionMode = as_x ? "x" : "y";
std::string fusionMode = fc_as_x ? "x" : "y";
msg_ss << "--- Fused " << found_fc_count << " fc (as " << fusionMode
<< ") + elementwise_add patterns";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FCResidualConnectionMKLDNNFusePass : public FusePassBase {
private:
GraphWithStats FuseFC(const std::string& name_scope,
const GraphWithStats& graph_with_stats,
bool as_x) const;
bool fc_as_x) const;

public:
FCResidualConnectionMKLDNNFusePass();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class TestFCElementwiseAddMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
axis = draw(st.sampled_from([-1, 0, 1]))
FCAsX = draw(st.sampled_from([True, False]))
fc_as_x = draw(st.sampled_from([True, False]))
fc_in = draw(st.sampled_from([32, 64]))
fc_wei = draw(st.sampled_from([32, 64]))

Expand Down Expand Up @@ -62,7 +62,7 @@ def generate_fc_bias():
"in_num_col_dims": 1,
})

if FCAsX:
if fc_as_x:
inputs = {"X": ["fc_output"], "Y": ["input_data"]}
else:
inputs = {"X": ["input_data"], "Y": ["fc_output"]}
Expand Down

0 comments on commit 0313fed

Please sign in to comment.