Skip to content

Commit

Permalink
Add a feed op before each input parameter var. (#44499)
Browse files Browse the repository at this point in the history
* Add a feed op before each input parameter var.

* Fix some issues about the unit test build_cinn_pass_test.
  • Loading branch information
wzzju committed Jul 26, 2022
1 parent 33cc0f7 commit 9b662be
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 33 deletions.
30 changes: 5 additions & 25 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ int ExtractOpRole(const GraphNodeSet& cluster) {
}
}

// Deal with subgraph's feed input var node:
// Deal with input var nodes of the target subgraph:
// create a new input var node and it's feed op node
void AddFeedOpAndVar(const GraphNodeSet& feed_vars,
void AddFeedOpAndVar(const GraphNodeSet& input_vars,
const GraphNodeSet& cluster,
const GraphNodeMap& old_op2new_op,
const GraphNodeMap& old_var2new_var,
Graph* graph) {
for (auto* old_var : feed_vars) {
for (auto* old_var : input_vars) {
// create feed op
OpDesc desc;
desc.SetType("feed");
Expand All @@ -157,7 +157,7 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars,

// get new feed var node
auto* var = old_var2new_var.at(old_var);
VLOG(4) << "Add Feed Op before: " << var->Name();
VLOG(4) << "Add Feed Op before the input var: " << var->Name();

// link feed op and feed var
IR_NODE_LINK_TO(op, var);
Expand All @@ -174,26 +174,6 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars,
}
}

// Deal with subgraph's parameter var node:
// create a new input var node, it's data will get by scope,
// so it don't need feed op
void AddParamVar(const GraphNodeSet& param_vars,
const GraphNodeSet& cluster,
const GraphNodeMap& old_op2new_op,
const GraphNodeMap& old_var2new_var,
Graph* graph) {
for (auto* old_var : param_vars) {
auto* var = old_var2new_var.at(old_var);
VLOG(4) << "Add Param Var Node: " << var->Name();

for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
IR_NODE_LINK_TO(var, old_op2new_op.at(old_op));
}
}
}
}

// Deal with subgraph's outputs var node:
// create a new output var node and it's fetch op
void AddOutputVar(const GraphNodeSet& output_vars,
Expand Down Expand Up @@ -389,7 +369,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,

AddFeedOpAndVar(
need_feed_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
AddParamVar(
AddFeedOpAndVar(
param_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
AddOutputVar(
output_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);

const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(12));
ASSERT_EQ(subnodes.size(), static_cast<size_t>(13));
ASSERT_TRUE(CheckGraphIndependence(subnodes));

ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 2);
ASSERT_EQ(CountNode(subnodes, "feed"), 3);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1);

// No-parameter input should has feed op
Expand All @@ -293,9 +293,10 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_EQ(new_v1->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v1->outputs[0]->Name(), "mul");

// Parameter input should not has feed op
// Parameter input should also have the feed op
auto new_v2 = GetNode(subnodes, "var2");
ASSERT_TRUE(new_v2->inputs.empty());
ASSERT_EQ(new_v2->inputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");

Expand Down Expand Up @@ -400,12 +401,12 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);

const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(8));
ASSERT_EQ(subnodes.size(), static_cast<size_t>(9));
ASSERT_TRUE(CheckGraphIndependence(subnodes));

ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 1);
ASSERT_EQ(CountNode(subnodes, "feed"), 2);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
}

Expand Down Expand Up @@ -526,10 +527,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {

if (CheckNodeExisted(subnodes1, "relu")) {
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(6));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(7));
} else {
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(6));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(7));
}
}

Expand Down

0 comments on commit 9b662be

Please sign in to comment.