-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dependence analysis #37231
Dependence analysis #37231
Changes from 6 commits
4328fac
580a3fc
407b0b4
eae28e9
c291ccd
553295a
1602b10
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -77,6 +77,113 @@ paddle::framework::FetchList InterpreterCore::Run( | |||||
return *(fetch_var->GetMutable<framework::FetchList>()); | ||||||
} | ||||||
|
||||||
void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences, | ||||||
std::map<int, std::list<int>>& var2min_rw_op, | ||||||
int cur_op, int rw_var) { | ||||||
// rw_var is inputs or outputs of cur_op | ||||||
// this function update the var2min_rw_op set . | ||||||
if (var2min_rw_op.find(rw_var) == var2min_rw_op.end()) | ||||||
var2min_rw_op[rw_var] = std::list<int>(); | ||||||
for (auto dep_op : op2dependences.at(cur_op)) { | ||||||
var2min_rw_op[rw_var].remove(dep_op); | ||||||
} | ||||||
var2min_rw_op[rw_var].push_back(cur_op); | ||||||
} | ||||||
|
||||||
std::map<int, std::list<int>> get_downstream_map( | ||||||
const std::map<int, std::set<int>>& op2dependences) { | ||||||
// op2dependences is op -> it's dependences. we want to get op -> [ops] map, | ||||||
// where ops is the next instruction of op. | ||||||
std::map<int, std::list<int>> result; | ||||||
for (auto& item : op2dependences) { | ||||||
int op = item.first; | ||||||
for (auto dep_op : item.second) { | ||||||
if (result.find(dep_op) == result.end()) | ||||||
result[dep_op] = std::list<int>(); | ||||||
result[dep_op].push_back(op); | ||||||
} | ||||||
} | ||||||
return std::move(result); | ||||||
} | ||||||
|
||||||
void InterpreterCore::BuildOperatorDependences() { | ||||||
// set the dependecy_count_ and Call Schedule | ||||||
// refer to http://agroup.baidu.com/share/md/92946214aa4c4785a2cc4c1f361a023c | ||||||
// for pesudo code | ||||||
auto op_nums = vec_instruction_.size(); | ||||||
auto var2min_rw_op = std::map< | ||||||
int, std::list<int>>(); // # map from variable id to read / write op id. | ||||||
auto var2recent_write_op = | ||||||
std::map<int, int>(); // # map from variable to recent write op. | ||||||
auto op2dependences = | ||||||
std::map<int, std::set<int>>(); //# map from op to the dependence list, | ||||||
// op must run after the dependence. | ||||||
std::set<int> remove_duplicate; | ||||||
|
||||||
// reserve | ||||||
for (size_t op = 0; op < vec_instruction_.size(); ++op) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
op2dependences[op] = std::set<int>(); | ||||||
} | ||||||
dependecy_count_.resize(op_nums); | ||||||
|
||||||
for (size_t op = 0; op < vec_instruction_.size(); ++op) { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
remove_duplicate.clear(); | ||||||
// step1: update the op2dependences structure | ||||||
for (auto& item : | ||||||
vec_instruction_[op].Inputs()) { // for all inputs(read only) | ||||||
for (auto var : item.second) { | ||||||
if (var2recent_write_op.count(var)) | ||||||
op2dependences[op].insert(var2recent_write_op[var]); | ||||||
} | ||||||
} | ||||||
|
||||||
for (auto& item : vec_instruction_[op].Outputs()) { // for all write vars | ||||||
for (auto var : item.second) { | ||||||
if (var2min_rw_op.count(var)) { | ||||||
for (auto dep_op : var2min_rw_op[var]) { | ||||||
op2dependences[op].insert(dep_op); | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
// step2: update 2 var2xxxx data structure | ||||||
for (auto& item : | ||||||
vec_instruction_[op].Inputs()) { // for all inputs(read only) | ||||||
for (auto var : item.second) { | ||||||
update_var_min_rw_op(op2dependences, var2min_rw_op, op, var); | ||||||
remove_duplicate.insert(var); | ||||||
} | ||||||
} | ||||||
|
||||||
for (auto& item : vec_instruction_[op].Outputs()) { // for all write vars | ||||||
for (auto var : item.second) { | ||||||
var2recent_write_op[var] = op; | ||||||
if (remove_duplicate.count(var) == | ||||||
0) { // var in input list and in output list, so remove it. | ||||||
update_var_min_rw_op(op2dependences, var2min_rw_op, op, var); | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
auto op2downstream = get_downstream_map(op2dependences); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议把上述逻辑单独放到utils.ccl里 |
||||||
|
||||||
VLOG(5) << "the size of vec_instruction_ : " << vec_instruction_.size(); | ||||||
|
||||||
for (size_t op = 0; op < vec_instruction_.size(); ++op) { | ||||||
VLOG(5) << "the op2downstream : " << op; | ||||||
auto op_list = op2downstream[op]; | ||||||
std::vector<size_t> downsteam_vector(op_list.begin(), op_list.end()); | ||||||
stream_analyzer_.Schedule(downsteam_vector, &vec_instruction_, op); | ||||||
|
||||||
for (auto inst_id : op_list) { | ||||||
VLOG(5) << "\t " << inst_id; | ||||||
dependecy_count_[inst_id]++; | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
void InterpreterCore::Convert( | ||||||
std::vector<paddle::framework::OpFuncNode>* op_func_nodes) { | ||||||
auto& vec_meta_info = global_scope_->MutableVecMetaInfo(); | ||||||
|
@@ -86,7 +193,6 @@ void InterpreterCore::Convert( | |||||
|
||||||
auto op_nums = nodes.size(); | ||||||
vec_instruction_.reserve(op_nums); | ||||||
dependecy_count_.resize(op_nums); | ||||||
|
||||||
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) { | ||||||
auto& op_func_node = nodes[op_idx]; | ||||||
|
@@ -146,30 +252,7 @@ void InterpreterCore::Convert( | |||||
} | ||||||
} | ||||||
|
||||||
for (size_t i = 0; i < vec_instruction_.size(); ++i) { | ||||||
std::vector<size_t> vec_temp; | ||||||
for (auto& item : vec_instruction_[i].Outputs()) { | ||||||
for (auto id : item.second) { | ||||||
vec_temp = interpreter::merge_vector(vec_temp, input_var2op_info_[id]); | ||||||
} | ||||||
} | ||||||
|
||||||
// In Program, op order is a very important information. | ||||||
// Op can only add op after it as next as next ops. | ||||||
std::vector<size_t> filter_next; | ||||||
filter_next.reserve(vec_temp.size()); | ||||||
for (auto item : vec_temp) { | ||||||
if (item > i) { | ||||||
filter_next.push_back(item); | ||||||
} | ||||||
} | ||||||
|
||||||
stream_analyzer_.Schedule(filter_next, &vec_instruction_, i); | ||||||
|
||||||
for (auto inst_id : filter_next) { | ||||||
dependecy_count_[inst_id]++; | ||||||
} | ||||||
} | ||||||
BuildOperatorDependences(); | ||||||
|
||||||
for (size_t i = 0; i < vec_instruction_.size(); ++i) { | ||||||
BuildAndCacheInstructionCtx(&vec_instruction_[i]); | ||||||
|
@@ -289,7 +372,7 @@ void InterpreterCore::BuildSkipShareLoDInfo() { | |||||
void InterpreterCore::RunInstruction(const Instruction& instr_node) { | ||||||
auto* op = instr_node.OpBase(); | ||||||
auto place = instr_node.DeviceContext().GetPlace(); | ||||||
VLOG(4) << place << " " << op->DebugStringEx(global_scope_); | ||||||
VLOG(4) << "Start run" << place << " " << op->DebugStringEx(global_scope_); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op); | ||||||
{ | ||||||
|
@@ -320,7 +403,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { | |||||
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get()); | ||||||
} | ||||||
|
||||||
VLOG(3) << place << " " << op->DebugStringEx(global_scope_); | ||||||
VLOG(4) << "End run" << place << " " << op->DebugStringEx(global_scope_); | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
/*For profiling/benchmark only*/ | ||||||
if (FLAGS_benchmark) { | ||||||
|
@@ -494,6 +577,8 @@ void InterpreterCore::CheckGC(const Instruction& instr) { | |||||
continue; | ||||||
} | ||||||
if (is_ready) { | ||||||
VLOG(6) << "Async delete variable with name : " | ||||||
<< var_scope.GetNameById(var_id); | ||||||
gc_->Add(var_scope.Var(var_id), gc_event_.at(instr_id), | ||||||
&instr.DeviceContext()); | ||||||
} | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import sys | ||
import unittest | ||
import paddle | ||
from paddle.fluid import core | ||
from paddle.fluid.core import StandaloneExecutor | ||
import paddle.fluid as fluid | ||
from paddle.fluid.framework import Program, program_guard | ||
import paddle.fluid.layers as layers | ||
|
||
from test_standalone_controlflow import TestCompatibility | ||
import numpy as np | ||
|
||
paddle.enable_static() | ||
|
||
|
||
class TestMultiplyWrite(TestCompatibility): | ||
def _get_feed(self): | ||
""" return the feeds | ||
""" | ||
return None | ||
|
||
def build_program(self): | ||
main_program = paddle.static.default_main_program() | ||
startup_program = paddle.static.default_startup_program() | ||
with paddle.static.program_guard(main_program, startup_program): | ||
out = paddle.full((1, ), 1) | ||
inp1 = paddle.full((1, ), 2) | ||
inp2 = paddle.full((1, ), 3) | ||
|
||
paddle.fluid.layers.assign(inp1, out) | ||
paddle.fluid.layers.assign(inp2, out) | ||
return main_program, startup_program, out | ||
|
||
def setUp(self): | ||
self.place = paddle.CPUPlace() | ||
self.iter_run = 5 | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不要在代码里添加百度内部的文档链接,可以参看graph.h那样,在这里添加整体的方案思路描述