-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add some passes which can be applied to Program #34730
Changes from 9 commits
3fa2507
ae7cf21
8f913ac
9606931
6545a62
a705a4f
81b1e82
0a8586c
5329b0f
466d483
94ba44f
2da6a3a
3851114
79525d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
#include <string> | ||
|
||
#include "glog/logging.h" | ||
#include "paddle/fluid/framework/executor_gc_helper.h" | ||
#include "paddle/fluid/framework/ir/memory_optimize_pass/memory_reuse_pass.h" | ||
#include "paddle/fluid/framework/ir/pass.h" | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
@@ -30,6 +31,9 @@ class BufferSharedInplaceOpPass : public MemoryReusePass { | |
std::string ReuseType() const override { return "inplace"; } | ||
|
||
void Run(Graph *graph) const override; | ||
|
||
void ApplyImpl(ProgramDesc *main_program, | ||
ProgramDesc *startup_program) const override; | ||
}; | ||
|
||
void BufferSharedInplaceOpPass::Run(Graph *graph) const { | ||
|
@@ -149,6 +153,141 @@ void BufferSharedInplaceOpPass::Run(Graph *graph) const { | |
} | ||
} | ||
|
||
static std::string GetFirstVarName(const OpDesc &op, const std::string &slot, | ||
bool is_input) { | ||
const auto &name_map = is_input ? op.Inputs() : op.Outputs(); | ||
auto iter = name_map.find(slot); | ||
if (iter != name_map.end() && !iter->second.empty()) { | ||
return iter->second[0]; | ||
} | ||
return kEmptyVarName; | ||
} | ||
|
||
static std::vector<std::vector<std::pair<std::string, std::string>>> | ||
GetInplaceVars(const BlockDesc &block, bool use_cuda, | ||
const std::vector<std::string> &skip_vars) { | ||
PADDLE_ENFORCE_EQ(block.ID(), 0, platform::errors::Unimplemented( | ||
"Inplace can only perform in block 0")); | ||
// only take block 0 gc_vars | ||
const auto all_gc_vars = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Readability: suggest to rename 'all_gc_vars' to 'op_gc_vars' or other name, then it is easier to know why its size is same to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
GetEagerDeletionCleanVars(*block.Program(), skip_vars)[0]; | ||
const auto all_ops = block.AllOps(); | ||
PADDLE_ENFORCE_EQ(all_gc_vars.size(), all_ops.size(), | ||
platform::errors::PermissionDenied( | ||
"GC analysis error: op number not match")); | ||
size_t n = all_ops.size(); | ||
std::unordered_set<std::string> visited_vars; | ||
std::unordered_set<std::string> reused_in_vars(skip_vars.begin(), | ||
skip_vars.end()); | ||
std::unordered_set<std::string> reused_out_vars(skip_vars.begin(), | ||
skip_vars.end()); | ||
for (const auto *op : all_ops) { | ||
if (op->Type() == "share_buffer" || op->Type() == "share_data") { | ||
const auto &inputs = op->Input("X"); | ||
const auto &outputs = op->Output("Out"); | ||
reused_in_vars.insert(inputs.begin(), inputs.end()); | ||
reused_out_vars.insert(outputs.begin(), outputs.end()); | ||
} | ||
} | ||
|
||
std::vector<std::vector<std::pair<std::string, std::string>>> result(n); | ||
for (size_t i = 0; i < n; ++i) { | ||
const auto &op = *all_ops[i]; | ||
const auto &gc_vars = all_gc_vars[i]; | ||
const auto inputs = op.InputArgumentNames(); | ||
const auto outputs = op.OutputArgumentNames(); | ||
visited_vars.insert(inputs.begin(), inputs.end()); | ||
|
||
auto &infer_inplace = OpInfoMap::Instance().Get(op.Type()).infer_inplace_; | ||
if (gc_vars.empty() || !infer_inplace) { | ||
visited_vars.insert(outputs.begin(), outputs.end()); | ||
continue; | ||
} | ||
|
||
const auto var_pair = infer_inplace(use_cuda); | ||
std::unordered_multiset<std::string> input_set(inputs.begin(), | ||
inputs.end()); | ||
std::unordered_multiset<std::string> output_set(outputs.begin(), | ||
outputs.end()); | ||
std::unordered_set<std::string> valid_vars; | ||
for (const auto &var : gc_vars) { | ||
if (var != kEmptyVarName && input_set.count(var) == 1 && | ||
output_set.count(var) == 0 && | ||
block.FindVar(var)->GetType() == proto::VarType::LOD_TENSOR) { | ||
valid_vars.insert(var); | ||
} | ||
} | ||
|
||
if (valid_vars.empty()) { | ||
visited_vars.insert(outputs.begin(), outputs.end()); | ||
continue; | ||
} | ||
|
||
for (const auto &pair : var_pair) { | ||
const auto &input_slot = pair.first; | ||
const auto &output_slot = pair.second; | ||
auto input_var = GetFirstVarName(op, input_slot, true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not forward to read the meaning of boolean without looking at the code of GetFirstVarName. Suggest for two options:
2 Change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
if (input_var == kEmptyVarName || valid_vars.count(input_var) == 0) { | ||
continue; | ||
} | ||
auto output_var = GetFirstVarName(op, output_slot, false); | ||
if (output_var == kEmptyVarName || visited_vars.count(output_var) > 0) { | ||
continue; | ||
} | ||
auto output_var_desc = block.FindVar(output_var); | ||
if (output_var_desc == nullptr || output_var_desc->Persistable() || | ||
output_var_desc->GetType() != proto::VarType::LOD_TENSOR) { | ||
continue; | ||
} | ||
|
||
if (reused_in_vars.count(input_var) > 0 || | ||
reused_out_vars.count(output_var) > 0) { | ||
continue; | ||
} | ||
|
||
// input_var -> output_var is reusable | ||
VLOG(10) << "inplace occurs at op " << i << " " << op.Type() << ": " | ||
<< input_var << " -> " << output_var; | ||
result[i].emplace_back(input_var, output_var); | ||
reused_in_vars.insert(input_var); | ||
reused_out_vars.insert(output_var); | ||
} | ||
visited_vars.insert(outputs.begin(), outputs.end()); | ||
std::sort(result[i].begin(), result[i].end()); | ||
} | ||
return result; | ||
} | ||
|
||
void BufferSharedInplaceOpPass::ApplyImpl(ProgramDesc *main_program, | ||
ProgramDesc *startup_program) const { | ||
bool use_cuda = Get<bool>(kUseCuda); | ||
auto skip_vars = Get<std::vector<std::string>>("mem_opt_skip_vars"); | ||
|
||
auto *block = main_program->MutableBlock(0); | ||
auto inplace_vars = GetInplaceVars(*block, use_cuda, skip_vars); | ||
PADDLE_ENFORCE_EQ(inplace_vars.size(), block->OpSize(), | ||
platform::errors::PermissionDenied( | ||
"Inplace analysis error: op number not match")); | ||
int64_t n = static_cast<int64_t>(inplace_vars.size()); | ||
for (int64_t i = n - 1; i >= 0; --i) { | ||
if (inplace_vars[i].empty()) continue; | ||
auto *op = block->InsertOp(i); | ||
std::vector<std::string> inputs, outputs; | ||
inputs.reserve(inplace_vars[i].size()); | ||
outputs.reserve(inplace_vars[i].size()); | ||
for (const auto &pair : inplace_vars[i]) { | ||
inputs.push_back(pair.first); | ||
outputs.push_back(pair.second); | ||
} | ||
op->SetType("share_buffer"); | ||
op->SetInput("X", inputs); | ||
op->SetOutput("Out", outputs); | ||
op->SetOutput("XOut", inputs); // add necessary dependency | ||
op->SetAttr("share_dims", std::vector<bool>(inputs.size(), false)); | ||
} | ||
block->Flush(); | ||
} | ||
|
||
} // namespace ir | ||
} // namespace framework | ||
} // namespace paddle | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we recommend the first letter of error message sentence capitalized and ends with
.
, other cases are sameThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.