Skip to content

Commit

Permalink
transfer block_id to CreateVarNode in multi_devices_graph_pass (Paddl…
Browse files Browse the repository at this point in the history
…ePaddle#44366)

* fix CreateVarNode in multi_devices_graph_pass

* Revert "Fix var duplication bug for graph_to_program_pass (PaddlePaddle#44278)"

This reverts commit a2c4c86.
  • Loading branch information
pangyoki authored and Aurelius84 committed Jul 29, 2022
1 parent 67f4beb commit 20c06dd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 22 deletions.
15 changes: 1 addition & 14 deletions paddle/fluid/framework/ir/graph_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -579,27 +579,14 @@ void GraphToProgram(const Graph &graph,

VLOG(3) << "Graph to program need convert " << graph.SubGraphsSize()
<< " sub graph";

std::unordered_set<std::string> vars_in_root_block;
for (const proto::VarDesc &var : block->vars()) {
vars_in_root_block.insert(var.name());
}

for (size_t idx = 0; idx < graph.SubGraphsSize(); ++idx) {
// avoid kRootBlockIndex not 0
if (idx == kRootBlockIndex) continue;

block = program_pb.add_blocks();
block->set_idx(idx);
block->set_parent_idx(kRootBlockIndex);

Graph *subgraph = graph.GetSubGraph(idx);
subgraph->SetNotOwned<std::unordered_set<std::string>>(
kGraphToProgramVarsToRemove, &vars_in_root_block);

GraphToBlock(*subgraph, block, sort_kind);

subgraph->Erase(kGraphToProgramVarsToRemove);
GraphToBlock(*graph.GetSubGraph(idx), block, sort_kind);
}
} else {
GraphToBlock(graph, block, sort_kind);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,12 @@ details::VarHandle *CreateOrGetLatestVarHandle(ir::Graph *graph,
details::VarHandle *var = nullptr;
if (var_holder.empty()) {
if (node->Var()) {
var = new details::VarHandle(graph->CreateVarNode(node->Var()),
0,
place_offset,
node->Name(),
place);
var = new details::VarHandle(
graph->CreateVarNode(node->Var(), node->GetVarNodeBlockId()),
0,
place_offset,
node->Name(),
place);
} else {
var = new details::VarHandle(
graph->CreateEmptyNode(node->Name(), ir::Node::Type::kVariable),
Expand Down Expand Up @@ -376,7 +377,8 @@ void MultiDevSSAGraphBuilderBase::CreateOpHandleIOs(ir::Graph *result,
for (ir::Node *output : node->outputs) {
ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
new_node =
result->CreateVarNode(output->Var(), output->GetVarNodeBlockId());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
Expand Down Expand Up @@ -696,7 +698,8 @@ void MultiDevSSAGraphBuilderBase::CreateScaleLossGradOp(

CreateOpOutput(result,
op_handle,
result->CreateVarNode(out_var_node->Var()),
result->CreateVarNode(out_var_node->Var(),
out_var_node->GetVarNodeBlockId()),
places_[i],
i);
}
Expand Down Expand Up @@ -1225,7 +1228,8 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
p = places_[outvar_dev_id];
ir::Node *new_node = nullptr;
if (output->Var()) {
new_node = result->CreateVarNode(output->Var());
new_node =
result->CreateVarNode(output->Var(), output->GetVarNodeBlockId());
} else {
new_node =
result->CreateEmptyNode(output->Name(), ir::Node::Type::kVariable);
Expand Down

0 comments on commit 20c06dd

Please sign in to comment.