Skip to content

Commit

Permalink
[new feature] add local scope for interpretercore (#37379)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Nov 22, 2021
1 parent 964e20e commit 1f0512b
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 109 deletions.
63 changes: 50 additions & 13 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace, true,
"Use inplace in new executor");
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_local_scope, true,
"Use local_scope in new executor(especially used "
"in UT), can turn off for better performance");

DECLARE_bool(check_nan_inf);
DECLARE_bool(benchmark);
Expand All @@ -48,6 +51,14 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
exception_notifier_ = main_thread_blocker_.RegisterEvent(
kExceptionCaught, [this]() { return exception_holder_.IsCaught(); });

create_local_scope_ = FLAGS_new_executor_use_local_scope;
if (FLAGS_new_executor_use_local_scope) {
auto local_scope = &global_scope->GetMutableScope()->NewScope();
local_scope->AddListener(global_scope->Listener());
local_scope_ = local_scope;
}
VLOG(4) << "create_local_scope_ is " << create_local_scope_;

// prune

// optmize graph pass
Expand All @@ -62,10 +73,15 @@ InterpreterCore::~InterpreterCore() {
async_work_queue_.reset(nullptr);
}

void InterpreterCore::SetCopyProgram(std::shared_ptr<ProgramDesc> prog) {
copy_program_ = prog;
}

paddle::framework::FetchList InterpreterCore::Run(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
bool is_build = is_build_;
global_scope_->SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, is_build);

if (is_build) {
Expand All @@ -79,13 +95,27 @@ paddle::framework::FetchList InterpreterCore::Run(

paddle::framework::FetchList InterpreterCore::Run() {
if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_);
if (create_local_scope_ &&
global_scope_->GetMutableLocalScope() !=
global_scope_->GetMutableScope() &&
global_scope_->GetMutableLocalScope()) {
VLOG(4) << "Clear previous local scope before run";
VLOG(4) << global_scope_->GetMutableScope() << " "
<< global_scope_->GetMutableLocalScope();
platform::DeviceContextPool::Instance().Get(place_)->Wait();
// TODO(zhiqiu): clear the tensor holder of all vars in previous local
// scope?
}
global_scope_->SetLocalScope(local_scope_);
paddle::framework::interpreter::build_variable_scope(block_, global_scope_,
create_local_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_);
place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert(&op_func_nodes);

} else {
ExecuteInstructionList(vec_instruction_);
}
Expand Down Expand Up @@ -300,7 +330,10 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto* op = instr_node.OpBase();
auto place = instr_node.DeviceContext().GetPlace();
VLOG(4) << "Start run" << place << " " << op->DebugStringEx(global_scope_);
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(global_scope_);
Scope* local_scope = create_local_scope_
? global_scope_->GetMutableLocalScope()
: global_scope_->GetMutableScope();

auto op_with_kernel = dynamic_cast<const framework::OperatorWithKernel*>(op);
{
Expand All @@ -325,13 +358,14 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
}
{
platform::RecordEvent compute_event("Compute");
if (op_with_kernel == nullptr)
instr_node.OpBase()->Run(*global_scope_->GetScope(), place_);
else
if (op_with_kernel == nullptr) {
instr_node.OpBase()->Run(*local_scope, place_);
} else {
instr_node.KernelFunc()(*instr_node.InnerExecutionContext().get());
}
}

VLOG(4) << "End run" << place << " " << op->DebugStringEx(global_scope_);
VLOG(4) << "End run " << place << " " << op->DebugStringEx(global_scope_);

/*For profiling/benchmark only*/
if (FLAGS_benchmark) {
Expand Down Expand Up @@ -372,8 +406,8 @@ void InterpreterCore::ExecuteInstructionList(
}
}

auto event_id = main_thread_blocker_.WaitEvent();
VLOG(3) << "event_id " << event_id;
auto event_name = main_thread_blocker_.WaitEvent();
VLOG(3) << "event_name: " << event_name;

if (UNLIKELY(exception_holder_.IsCaught())) {
VLOG(4) << "Exception caught " << exception_holder_.Type();
Expand Down Expand Up @@ -526,8 +560,9 @@ void InterpreterCore::Prepare(
VLOG(4) << "Feed inputs";
for (size_t i = 0; i < feed_names.size(); ++i) {
auto* feed_var = global_scope_->FindVar(feed_names[i]);
PADDLE_ENFORCE_NOT_NULL(feed_var, platform::errors::NotFound(
"feed_var shall not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
feed_var, platform::errors::NotFound(
"Variable %s should not be nullptr.", feed_names[i]));

auto feed_tensor = feed_var->GetMutable<framework::LoDTensor>();
feed_tensor->ShareDataWith(feed_tensors[i]);
Expand All @@ -536,11 +571,12 @@ void InterpreterCore::Prepare(
};

if (!is_build_) {
paddle::framework::interpreter::build_variable_scope(block_, global_scope_);
paddle::framework::interpreter::build_variable_scope(block_, global_scope_,
create_local_scope_);
FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
paddle::framework::interpreter::build_op_func_list(
place_, block_, &op_func_nodes, global_scope_);
place_, block_, &op_func_nodes, global_scope_, create_local_scope_);
is_build_ = true;
// convert vec func_list to graph
Convert(&op_func_nodes);
Expand All @@ -556,6 +592,7 @@ void InterpreterCore::Prepare(
interpreter::CostInfo InterpreterCore::DryRun(
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors) {
global_scope_->SetLocalScope(local_scope_);
Prepare(feed_names, feed_tensors, true);
interpreter::CostInfo cost_info;
{
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/framework/new_executor/interpretercore.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ class InterpreterCore {
const std::vector<std::string>& feed_names,
const std::vector<framework::LoDTensor>& feed_tensors);

void SetCopyProgram(std::shared_ptr<ProgramDesc> prog);

private:
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);

Expand Down Expand Up @@ -85,7 +87,13 @@ class InterpreterCore {
bool is_build_;

const platform::Place& place_;
const BlockDesc& block_; // not owned
const BlockDesc& block_; // not owned
// NOTE(zhiqiu): when add fetch ops in GetInterpreterCore, we will
// copy a new program and block, the copy_program_ here is used to
// hold the program, otherwise block_ maybe not valid after the
// new program is deleted.
std::shared_ptr<ProgramDesc> copy_program_{nullptr};

VariableScope* global_scope_; // not owned

std::vector<Instruction> vec_instruction_; // deconstruct before OpFuncNode
Expand All @@ -102,6 +110,8 @@ class InterpreterCore {

std::unique_ptr<InterpreterCoreGarbageCollector> gc_;
std::vector<paddle::platform::DeviceEvent> gc_event_;
bool create_local_scope_{true};
Scope* local_scope_{nullptr}; // not owned
};
} // namespace framework
} // namespace paddle
79 changes: 52 additions & 27 deletions paddle/fluid/framework/new_executor/interpretercore_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,35 @@ std::string get_memcpy_type(const platform::Place& src_place,
}

void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope) {
VariableScope* var_scope, bool use_local_scope) {
VLOG(3) << "Creating Variables";
auto inner_scope = var_scope->GetMutableScope();

// NOTE(zhiqiu): if create_local_scope_ is true, the persistable is
// created in var_scope.scope_ , and other scope is created in local scope.
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();

for (auto& var_desc : block.AllVars()) {
auto var_name = var_desc->Name();
if (var_name == framework::kEmptyVarName) {
continue;
}
if (var_desc->Persistable()) {
auto* ptr = inner_scope->Var(var_name);

if (nullptr == var_scope->FindVar(var_name)) {
var_scope->AddVar(var_desc->Name(), var_desc);
VLOG(3) << "Initialize Variable " << var_name;
InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " global, which pointer is "
<< ptr << " type is " << static_cast<int>(var_desc->GetType());
} else {
auto* var_desc_tmp = var_scope->VarDesc(var_name);
if (nullptr == var_desc_tmp) {
VLOG(3) << "update var:" << var_name << " desc from nullptr into "
<< var_desc;
var_scope->SetVarDesc(var_name, var_desc);
}
auto* ptr = local_scope->Var(var_name);
InitializeVariable(ptr, var_desc->GetType());
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
<< ptr << "Variable Type "
<< static_cast<int>(var_desc->GetType());
}
var_scope->SetVarDesc(var_name, var_desc);
}
}

Expand Down Expand Up @@ -237,14 +249,14 @@ void apply_device_guard(const OperatorBase* op_base,
void deal_operator_base(const platform::Place& place,
const VariableScope* var_scope,
std::shared_ptr<OperatorBase> op_base,
OpFuncNode* op_func_node) {
OpFuncNode* op_func_node, Scope* local_scope) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// input, output is prepared. set the other attributes.
op_func_node->operator_base_ = op_base;
op_func_node->type_ = OpFuncType::kQueueSync; // alway Sync
op_func_node->kernel_func_ = nullptr;
op_base->Run(*var_scope->GetScope(), place); // Run without data transformer.
op_base->Run(*local_scope, place); // Run without data transformer.

std::unordered_set<int> no_data_transform_index;
for (auto& it : op_func_node->input_index) {
Expand Down Expand Up @@ -288,12 +300,21 @@ std::tuple<std::string, OpFuncNode> apply_place_transform_for_var(
const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_key, const platform::Place& place,
const std::string& var_name, const std::string& outer_name,
const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope) {
const OpFuncNode& op_func_node, Variable* var, VariableScope* var_scope,
bool use_local_scope = true) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();

auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
std::string new_var_name =
var_name + "_copy_" + std::to_string(var_scope->VarSize() + 1);
var_scope->AddVar(new_var_name, nullptr);

auto* ptr = local_scope->Var(new_var_name);
InitializeVariable(ptr, static_cast<proto::VarType::Type>(var->Type()));
VLOG(3) << "Create Variable " << var_name << " locally, which pointer is "
<< ptr << "Variable Type " << var->Type();
var_scope->SetVarDesc(var_name, nullptr);

VariableNameMap copy_in_map;
copy_in_map["X"] = {var_name};
Expand Down Expand Up @@ -368,7 +389,8 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
const platform::Place& place,
VariableValueMap* ins_map_temp,
VariableScope* var_scope, OpFuncNode* op_func_node,
std::vector<OpFuncNode>* copy_func_nodes) {
std::vector<OpFuncNode>* copy_func_nodes,
bool use_local_scope = true) {
auto op_base = op_func_node->operator_base_.get();
PADDLE_ENFORCE_NOT_NULL(op_base, platform::errors::PreconditionNotMet(
"op_base is null, please pass a valid "
Expand Down Expand Up @@ -402,9 +424,10 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
std::string new_var_name;
OpFuncNode copy_op_func_node;
std::tie(new_var_name, copy_op_func_node) =
apply_place_transform_for_var(
kernel_type_for_var, expected_kernel_key, place, var_name,
var_name_item.first, *op_func_node, var, var_scope);
apply_place_transform_for_var(kernel_type_for_var,
expected_kernel_key, place, var_name,
var_name_item.first, *op_func_node,
var, var_scope, use_local_scope);
op_func_node->input_index[var_name_item.first][i] =
var_scope->VarId(new_var_name);
copy_func_nodes->emplace_back(copy_op_func_node);
Expand Down Expand Up @@ -438,7 +461,9 @@ void apply_data_transform(const OpKernelType& expected_kernel_key,
void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope) {
VariableScope* var_scope, bool use_local_scope) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
auto& all_op_kernels = OperatorWithKernel::AllOpKernels();
std::vector<std::shared_ptr<OperatorBase>>
ops; // its elements will be moved to vec_func_list
Expand Down Expand Up @@ -478,7 +503,7 @@ void build_op_func_list(const platform::Place& place,

if (dynamic_cast<const framework::OperatorWithKernel*>(op) == nullptr) {
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
deal_operator_base(place, var_scope, ops[i], &op_func_node);
deal_operator_base(place, var_scope, ops[i], &op_func_node, local_scope);
} else {
// construct RuntimeContext and analysis KernelType
RuntimeContext runtime_context({}, {});
Expand Down Expand Up @@ -520,7 +545,7 @@ void build_op_func_list(const platform::Place& place,
// apply_data_transform.
op_func_node.operator_base_ = ops[i];
apply_data_transform(expected_kernel_key, place, &ins_map_temp, var_scope,
&op_func_node, &copy_op_to_insert);
&op_func_node, &copy_op_to_insert, use_local_scope);
for (auto& item : copy_op_to_insert) {
vec_func_list->push_back(item);
}
Expand Down Expand Up @@ -631,16 +656,16 @@ std::vector<size_t> merge_vector(const std::vector<size_t>& first,
}

void update_var_min_rw_op(const std::map<int, std::set<int>>& op2dependences,
std::map<int, std::list<int>>& var2min_rw_op,
std::map<int, std::list<int>>* var2min_rw_op,
int cur_op, int rw_var) {
// rw_var is inputs or outputs of cur_op
// this function update the var2min_rw_op set .
if (var2min_rw_op.find(rw_var) == var2min_rw_op.end())
var2min_rw_op[rw_var] = std::list<int>();
if (var2min_rw_op->find(rw_var) == var2min_rw_op->end())
(*var2min_rw_op)[rw_var] = std::list<int>();
for (auto dep_op : op2dependences.at(cur_op)) {
var2min_rw_op[rw_var].remove(dep_op);
(*var2min_rw_op)[rw_var].remove(dep_op);
}
var2min_rw_op[rw_var].push_back(cur_op);
(*var2min_rw_op)[rw_var].push_back(cur_op);
}

std::map<int, std::list<int>> get_downstream_map(
Expand Down Expand Up @@ -702,7 +727,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
for (auto& item :
vec_instruction[op_idx].Inputs()) { // for all inputs(read only)
for (auto var : item.second) {
update_var_min_rw_op(op2dependences, var2min_rw_op, op_idx, var);
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
remove_duplicate.insert(var);
}
}
Expand All @@ -713,7 +738,7 @@ std::map<int, std::list<int>> build_op_downstream_map(
var2recent_write_op[var] = op_idx;
if (remove_duplicate.count(var) ==
0) { // var in input list and in output list, so remove it.
update_var_min_rw_op(op2dependences, var2min_rw_op, op_idx, var);
update_var_min_rw_op(op2dependences, &var2min_rw_op, op_idx, var);
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions paddle/fluid/framework/new_executor/interpretercore_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace framework {
namespace interpreter {

using AtomicVectorSizeT = std::vector<std::unique_ptr<std::atomic<size_t>>>;
static constexpr char kFetchVarName[] = "fetch_vars";
static constexpr char kFetchVarName[] = "fetch";

class AsyncWorkQueue {
public:
Expand Down Expand Up @@ -98,12 +98,13 @@ std::string get_memcpy_type(const platform::Place& src_place,
const platform::Place& dst_place);

void build_variable_scope(const framework::BlockDesc& block,
VariableScope* var_scope);
VariableScope* var_scope,
bool use_local_scope = true);

void build_op_func_list(const platform::Place& place,
const framework::BlockDesc& block,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* var_scope);
VariableScope* var_scope, bool use_local_scope = true);

std::map<int, std::list<int>> build_op_downstream_map(
const std::vector<Instruction>& vec_instruction);
Expand Down
Loading

0 comments on commit 1f0512b

Please sign in to comment.