Skip to content

Commit

Permalink
support gradient accumulation
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Sep 3, 2021
1 parent b43a114 commit 8405640
Showing 1 changed file with 56 additions and 45 deletions.
101 changes: 56 additions & 45 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ static std::shared_ptr<NameVarMap<VariableWrapper>> CallGradientHooks(
return tmp_ins_ptr;
}

bool IsInputCanInplace(const std::shared_ptr<VariableWrapper>& var) {
static bool IsInputCanInplace(const std::shared_ptr<VariableWrapper>& var) {
auto* inner_var = var->MutableVar();
if (inner_var->IsInitialized() && inner_var->IsType<framework::LoDTensor>()) {
auto tensor = inner_var->GetMutable<framework::LoDTensor>();
Expand All @@ -325,6 +325,57 @@ bool IsInputCanInplace(const std::shared_ptr<VariableWrapper>& var) {
return false;
}

static void PerformBackwardInplace(const std::string& op_type,
const NameVarMap<VariableWrapper>& ins,
NameVarMap<VariableWrapper>* outs) {
auto& infer_inplace =
paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_;

if (infer_inplace) {
auto in_to_outs = infer_inplace(true);
for (auto& pair : in_to_outs) {
framework::LoDTensor *in_tensor = nullptr, *out_tensor = nullptr;
for (auto& p : ins) {
if (p.first == pair.first) {
// has at least one var
if (p.second.size() > 0 && p.second[0]) {
auto& in_var = p.second[0];
VLOG(10) << p.first << " use_count: " << in_var.use_count();
// the refcount of var to be inplaced should be 1
if (in_var.use_count() == 1) {
if (IsInputCanInplace(in_var)) {
in_tensor =
in_var->MutableVar()->GetMutable<framework::LoDTensor>();
}
}
}
}
}
if (!in_tensor) {
continue;
}
for (auto& p : *outs) {
if (p.first == pair.second) {
if (p.second.size() > 0 && p.second[0]) {
auto& out_var = p.second[0];
if (out_var->Type() == framework::proto::VarType::LOD_TENSOR) {
out_tensor =
out_var->MutableVar()->GetMutable<framework::LoDTensor>();
}
}
}
}
if (!out_tensor) {
continue;
}
out_tensor->ShareBufferWith(*in_tensor);
out_tensor->Resize(in_tensor->dims());
VLOG(4) << "Inplace performed in op " << op_type << ": " << pair.second
<< " -> " << pair.first;
}
}
}

void BasicEngine::Execute() {
if (init_nodes_.empty()) {
return;
Expand Down Expand Up @@ -359,50 +410,6 @@ void BasicEngine::Execute() {
auto& bwd_ins = cur_op.GetInsMap();
auto& bwd_outs = cur_op.GetOutsMap();

auto& infer_inplace = paddle::framework::OpInfoMap::Instance()
.Get(cur_op.Type())
.infer_inplace_;

if (infer_inplace) {
auto in_to_outs = infer_inplace(true);
for (auto& pair : in_to_outs) {
framework::LoDTensor *in_tensor = nullptr, *out_tensor = nullptr;
for (auto& p : cur_op.GetInsMap()) {
if (p.first == pair.first) {
// has at least one var
if (p.second.size() > 0 && p.second[0]) {
auto& in_var = p.second[0];
VLOG(10) << p.first << " use_count: " << in_var.use_count();
// the refcount of var to be inplaced should be 1
if (in_var.use_count() == 1) {
if (IsInputCanInplace(in_var)) {
in_tensor = in_var->MutableVar()
->GetMutable<framework::LoDTensor>();
}
}
}
}
}
if (!in_tensor) {
continue;
}
for (auto& p : cur_op.GetOutsMap()) {
if (p.first == pair.second) {
if (p.second.size() > 0 && p.second[0]) {
auto& out_var = p.second[0];
if (out_var->Type() == framework::proto::VarType::LOD_TENSOR) {
out_tensor =
out_var->MutableVar()->GetMutable<framework::LoDTensor>();
}
}
}
}
out_tensor->ShareBufferWith(*in_tensor);
VLOG(4) << "Inplace performed in op " << cur_op.Type() << ": "
<< pair.second << " -> " << pair.first;
}
}

/**
* [ Why need temporary outputs here? ]
*
Expand Down Expand Up @@ -538,6 +545,10 @@ void BasicEngine::Execute() {
*/
auto tmp_ins_ptr = CallGradientHooks(bwd_ins, cur_op.Type());

if (!tmp_ins_ptr) {
PerformBackwardInplace(cur_op.Type(), bwd_ins, &tmp_outs);
}

{
VLOG(3) << "Start to execute grad op " << cur_op.Type();
try {
Expand Down

0 comments on commit 8405640

Please sign in to comment.