Skip to content

Commit

Permalink
add backward inplace for dygraph (PaddlePaddle#35412)
Browse files Browse the repository at this point in the history
* add backward inplace for dygraph

* fix bug

* support gradient accumulation
  • Loading branch information
zhiqiu authored and AnnaTrainingG committed Sep 29, 2021
1 parent a7f7961 commit e3498f9
Showing 1 changed file with 66 additions and 0 deletions.
66 changes: 66 additions & 0 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,68 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
return tmp_ins_ptr;
}

static bool IsInputCanInplace(const std::shared_ptr<VariableWrapper>& var) {
auto* inner_var = var->MutableVar();
if (inner_var->IsInitialized() && inner_var->IsType<framework::LoDTensor>()) {
auto tensor = inner_var->GetMutable<framework::LoDTensor>();
if (tensor->IsInitialized()) {
return true;
}
}
return false;
}

static void PerformBackwardInplace(const std::string& op_type,
const NameVarMap<VariableWrapper>& ins,
NameVarMap<VariableWrapper>* outs) {
auto& infer_inplace =
paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_;

if (infer_inplace) {
auto in_to_outs = infer_inplace(true);
for (auto& pair : in_to_outs) {
framework::LoDTensor *in_tensor = nullptr, *out_tensor = nullptr;
for (auto& p : ins) {
if (p.first == pair.first) {
// has at least one var
if (p.second.size() > 0 && p.second[0]) {
auto& in_var = p.second[0];
VLOG(10) << p.first << " use_count: " << in_var.use_count();
// the refcount of var to be inplaced should be 1
if (in_var.use_count() == 1) {
if (IsInputCanInplace(in_var)) {
in_tensor =
in_var->MutableVar()->GetMutable<framework::LoDTensor>();
}
}
}
}
}
if (!in_tensor) {
continue;
}
for (auto& p : *outs) {
if (p.first == pair.second) {
if (p.second.size() > 0 && p.second[0]) {
auto& out_var = p.second[0];
if (out_var->Type() == framework::proto::VarType::LOD_TENSOR) {
out_tensor =
out_var->MutableVar()->GetMutable<framework::LoDTensor>();
}
}
}
}
if (!out_tensor) {
continue;
}
out_tensor->ShareBufferWith(*in_tensor);
out_tensor->Resize(in_tensor->dims());
VLOG(4) << "Inplace performed in op " << op_type << ": " << pair.second
<< " -> " << pair.first;
}
}
}

void BasicEngine::Execute() {
if (init_nodes_.empty()) {
return;
Expand Down Expand Up @@ -483,6 +545,10 @@ void BasicEngine::Execute() {
*/
auto tmp_ins_ptr = CallGradientHooks(bwd_ins, cur_op.Type());

if (!tmp_ins_ptr) {
PerformBackwardInplace(cur_op.Type(), bwd_ins, &tmp_outs);
}

{
VLOG(3) << "Start to execute grad op " << cur_op.Type();
try {
Expand Down

0 comments on commit e3498f9

Please sign in to comment.