From e42de4721fac6a3b3403b66c7aa9438e2f913d22 Mon Sep 17 00:00:00 2001 From: Wang Zhen Date: Thu, 21 Jul 2022 11:53:22 +0800 Subject: [PATCH 1/2] Add a feed op before each input parameter var. --- .../framework/paddle2cinn/build_cinn_pass.cc | 30 ++++--------------- 1 file changed, 5 insertions(+), 25 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index b25d3a7f3af92..593646164940b 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -141,14 +141,14 @@ int ExtractOpRole(const GraphNodeSet& cluster) { } } -// Deal with subgraph's feed input var node: +// Deal with input var nodes of the target subgraph: // create a new input var node and it's feed op node -void AddFeedOpAndVar(const GraphNodeSet& feed_vars, +void AddFeedOpAndVar(const GraphNodeSet& input_vars, const GraphNodeSet& cluster, const GraphNodeMap& old_op2new_op, const GraphNodeMap& old_var2new_var, Graph* graph) { - for (auto* old_var : feed_vars) { + for (auto* old_var : input_vars) { // create feed op OpDesc desc; desc.SetType("feed"); @@ -157,7 +157,7 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars, // get new feed var node auto* var = old_var2new_var.at(old_var); - VLOG(4) << "Add Feed Op before: " << var->Name(); + VLOG(4) << "Add Feed Op before the input var: " << var->Name(); // link feed op and feed var IR_NODE_LINK_TO(op, var); @@ -174,26 +174,6 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars, } } -// Deal with subgraph's parameter var node: -// create a new input var node, it's data will get by scope, -// so it don't need feed op -void AddParamVar(const GraphNodeSet& param_vars, - const GraphNodeSet& cluster, - const GraphNodeMap& old_op2new_op, - const GraphNodeMap& old_var2new_var, - Graph* graph) { - for (auto* old_var : param_vars) { - auto* var = old_var2new_var.at(old_var); - VLOG(4) << "Add Param Var Node: " << var->Name(); - - for (auto* old_op : old_var->outputs) { - if (cluster.count(old_op)) { - IR_NODE_LINK_TO(var, old_op2new_op.at(old_op)); - } - } - } -} - // Deal with subgraph's outputs var node: // create a new output var node and it's fetch op void AddOutputVar(const GraphNodeSet& output_vars, @@ -389,7 +369,7 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, AddFeedOpAndVar( need_feed_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); - AddParamVar( + AddFeedOpAndVar( param_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); AddOutputVar( output_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); From b26bc044f06f296c7665da777f5761ac0c6a2627 Mon Sep 17 00:00:00 2001 From: Wang Zhen Date: Mon, 25 Jul 2022 19:43:37 +0800 Subject: [PATCH 2/2] Fix some issues about the unit test build_cinn_pass_test. --- .../paddle2cinn/build_cinn_pass_test.cc | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index f951a09cfd56a..7d4d856e4cbf2 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -277,13 +277,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subnodes = subgraph.Nodes(); - ASSERT_EQ(subnodes.size(), static_cast(12)); + ASSERT_EQ(subnodes.size(), static_cast(13)); ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); - ASSERT_EQ(CountNode(subnodes, "feed"), 2); + ASSERT_EQ(CountNode(subnodes, "feed"), 3); ASSERT_EQ(CountNode(subnodes, "fetch"), 1); // No-parameter input should has feed op @@ -293,9 +293,10 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ASSERT_EQ(new_v1->inputs[0]->Name(), "feed"); ASSERT_EQ(new_v1->outputs[0]->Name(), "mul"); - // Parameter input should not has feed op + // Parameter input should also have the feed op auto new_v2 = GetNode(subnodes, "var2"); - ASSERT_TRUE(new_v2->inputs.empty()); + ASSERT_EQ(new_v2->inputs.size(), static_cast(1)); + ASSERT_EQ(new_v2->inputs[0]->Name(), "feed"); ASSERT_EQ(new_v2->outputs.size(), static_cast(1)); ASSERT_EQ(new_v2->outputs[0]->Name(), "mul"); @@ -400,12 +401,12 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); const auto& subnodes = subgraph.Nodes(); - ASSERT_EQ(subnodes.size(), static_cast(8)); + ASSERT_EQ(subnodes.size(), static_cast(9)); ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); - ASSERT_EQ(CountNode(subnodes, "feed"), 1); + ASSERT_EQ(CountNode(subnodes, "feed"), 2); ASSERT_EQ(CountNode(subnodes, "fetch"), 1); } @@ -526,10 +527,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { if (CheckNodeExisted(subnodes1, "relu")) { ASSERT_EQ(subnodes1.size(), static_cast(5)); - ASSERT_EQ(subnodes2.size(), static_cast(6)); + ASSERT_EQ(subnodes2.size(), static_cast(7)); } else { ASSERT_EQ(subnodes2.size(), static_cast(5)); - ASSERT_EQ(subnodes1.size(), static_cast(6)); + ASSERT_EQ(subnodes1.size(), static_cast(7)); } }