Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a feed op before each input parameter var. #44499

Merged
merged 3 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 5 additions & 25 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ int ExtractOpRole(const GraphNodeSet& cluster) {
}
}

// Deal with subgraph's feed input var node:
// Deal with input var nodes of the target subgraph:
// create a new input var node and it's feed op node
void AddFeedOpAndVar(const GraphNodeSet& feed_vars,
void AddFeedOpAndVar(const GraphNodeSet& input_vars,
const GraphNodeSet& cluster,
const GraphNodeMap& old_op2new_op,
const GraphNodeMap& old_var2new_var,
Graph* graph) {
for (auto* old_var : feed_vars) {
for (auto* old_var : input_vars) {
// create feed op
OpDesc desc;
desc.SetType("feed");
Expand All @@ -157,7 +157,7 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars,

// get new feed var node
auto* var = old_var2new_var.at(old_var);
VLOG(4) << "Add Feed Op before: " << var->Name();
VLOG(4) << "Add Feed Op before the input var: " << var->Name();

// link feed op and feed var
IR_NODE_LINK_TO(op, var);
Expand All @@ -174,26 +174,6 @@ void AddFeedOpAndVar(const GraphNodeSet& feed_vars,
}
}

// Deal with subgraph's parameter var node:
// create a new input var node, it's data will get by scope,
// so it don't need feed op
void AddParamVar(const GraphNodeSet& param_vars,
const GraphNodeSet& cluster,
const GraphNodeMap& old_op2new_op,
const GraphNodeMap& old_var2new_var,
Graph* graph) {
for (auto* old_var : param_vars) {
auto* var = old_var2new_var.at(old_var);
VLOG(4) << "Add Param Var Node: " << var->Name();

for (auto* old_op : old_var->outputs) {
if (cluster.count(old_op)) {
IR_NODE_LINK_TO(var, old_op2new_op.at(old_op));
}
}
}
}

// Deal with subgraph's outputs var node:
// create a new output var node and it's fetch op
void AddOutputVar(const GraphNodeSet& output_vars,
Expand Down Expand Up @@ -389,7 +369,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,

AddFeedOpAndVar(
need_feed_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
AddParamVar(
AddFeedOpAndVar(
param_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
AddOutputVar(
output_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get());
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);

const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(12));
ASSERT_EQ(subnodes.size(), static_cast<size_t>(13));
ASSERT_TRUE(CheckGraphIndependence(subnodes));

ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 2);
ASSERT_EQ(CountNode(subnodes, "feed"), 3);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1);

// No-parameter input should has feed op
Expand All @@ -293,9 +293,10 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_EQ(new_v1->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v1->outputs[0]->Name(), "mul");

// Parameter input should not has feed op
// Parameter input should also have the feed op
auto new_v2 = GetNode(subnodes, "var2");
ASSERT_TRUE(new_v2->inputs.empty());
ASSERT_EQ(new_v2->inputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->inputs[0]->Name(), "feed");
ASSERT_EQ(new_v2->outputs.size(), static_cast<size_t>(1));
ASSERT_EQ(new_v2->outputs[0]->Name(), "mul");

Expand Down Expand Up @@ -400,12 +401,12 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);

const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(8));
ASSERT_EQ(subnodes.size(), static_cast<size_t>(9));
ASSERT_TRUE(CheckGraphIndependence(subnodes));

ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 1);
ASSERT_EQ(CountNode(subnodes, "feed"), 2);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
}

Expand Down Expand Up @@ -526,10 +527,10 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {

if (CheckNodeExisted(subnodes1, "relu")) {
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(6));
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(7));
} else {
ASSERT_EQ(subnodes2.size(), static_cast<size_t>(5));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(6));
ASSERT_EQ(subnodes1.size(), static_cast<size_t>(7));
}
}

Expand Down