Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
gglin001 committed Jun 22, 2022
1 parent 8a92115 commit 108ccec
Show file tree
Hide file tree
Showing 13 changed files with 1,476 additions and 741 deletions.
28 changes: 19 additions & 9 deletions paddle/fluid/framework/executor_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ void AppendSkipDeletionVars(const std::vector<std::string> &append_vars,
* 2. it is an input var used in backward_op
*/
void ParseSafeEagerDeletionSkipVars(
const ProgramDesc &program, int64_t forward_op_nums,
const ProgramDesc &program,
int64_t forward_op_nums,
const std::vector<std::string> &output_var_names,
std::vector<std::string> *skip_eager_delete_vars) {
auto all_ops = program.Block(0).AllOps();
Expand Down Expand Up @@ -143,8 +144,11 @@ ExecutorInfoCache &ExecutorInfoCache::Instance() {
}

static PEAndGraphPair CreateExecutorInfo(
const ProgramDesc &program_desc, const platform::Place &place,
int64_t start_op_index, int64_t end_op_index, framework::Scope *scope,
const ProgramDesc &program_desc,
const platform::Place &place,
int64_t start_op_index,
int64_t end_op_index,
framework::Scope *scope,
const details::BuildStrategy &build_strategy) {
auto execution_strategy = details::GetExecutionStrategy(place);
auto graph = std::make_shared<framework::ir::Graph>(
Expand All @@ -162,15 +166,17 @@ PEAndGraphPair CreateFixOrderExecutorInfo(const ProgramDesc &program_desc,
framework::Scope *scope) {
details::BuildStrategy build_strategy;
build_strategy.fix_op_run_order_ = true;
auto pe_and_graph = CreateExecutorInfo(program_desc, place, start_op_index,
end_op_index, scope, build_strategy);
auto pe_and_graph = CreateExecutorInfo(
program_desc, place, start_op_index, end_op_index, scope, build_strategy);
return pe_and_graph;
}

CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc,
const platform::Place &place,
int64_t start_op_index, int64_t end_op_index,
bool is_grad, int64_t program_id,
int64_t start_op_index,
int64_t end_op_index,
bool is_grad,
int64_t program_id,
framework::Scope *scope) {
auto &cached_exe_info = framework::ExecutorInfoCache::Instance();

Expand All @@ -186,8 +192,12 @@ CacheInfo GetExecutorInfoFromCache(const ProgramDesc &program_desc,
auto &build_strategy = cached_exe_info.GetBuildStrategy(program_id);

// 2. Construct Graph and ParallelExecutor.
auto pe_and_graph = CreateExecutorInfo(program_desc, place, start_op_index,
end_op_index, scope, build_strategy);
auto pe_and_graph = CreateExecutorInfo(program_desc,
place,
start_op_index,
end_op_index,
scope,
build_strategy);

// 3. Insert value into cached map.
auto &cached_value = cached_exe_info.GetMutable(program_id, is_grad);
Expand Down
30 changes: 20 additions & 10 deletions paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,25 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
auto type = std::string{"sgd"};
// auto LearningRate = op->Input("LearningRate");
auto use_nesterov = BOOST_GET_CONST(bool, op->GetAttr("use_nesterov"));
PADDLE_ENFORCE_EQ(use_nesterov, false,
PADDLE_ENFORCE_EQ(use_nesterov,
false,
platform::errors::Unimplemented(
"ipu does not support nesterov mode."));
auto regularization_method =
BOOST_GET_CONST(std::string, op->GetAttr("regularization_method"));
PADDLE_ENFORCE_NE(regularization_method, "l1_decay",
PADDLE_ENFORCE_NE(regularization_method,
"l1_decay",
platform::errors::Unimplemented(
"ipu does not support l1_decay mode."));
auto multi_precision =
BOOST_GET_CONST(bool, op->GetAttr("multi_precision"));
PADDLE_ENFORCE_EQ(multi_precision, false,
PADDLE_ENFORCE_EQ(multi_precision,
false,
platform::errors::Unimplemented(
"ipu does not support multi_precision mode."));
auto rescale_grad = BOOST_GET_CONST(float, op->GetAttr("rescale_grad"));
PADDLE_ENFORCE_EQ(rescale_grad, 1.0,
PADDLE_ENFORCE_EQ(rescale_grad,
1.0,
platform::errors::Unimplemented(
"ipu does not support rescale_grad mode."));
auto regularization_coeff =
Expand All @@ -150,10 +154,12 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
auto lazy_mode = BOOST_GET_CONST(bool, op->GetAttr("lazy_mode"));
auto multi_precision =
BOOST_GET_CONST(bool, op->GetAttr("multi_precision"));
PADDLE_ENFORCE_EQ(lazy_mode, false,
PADDLE_ENFORCE_EQ(lazy_mode,
false,
platform::errors::Unimplemented(
"ipu does not support lazy_mode mode."));
PADDLE_ENFORCE_EQ(multi_precision, false,
PADDLE_ENFORCE_EQ(multi_precision,
false,
platform::errors::Unimplemented(
"ipu does not support multi_precision mode."));
new_op.SetAttr("type", type);
Expand Down Expand Up @@ -268,11 +274,13 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "found loss op type: " << op->Type();
auto outputs = op->Outputs();
PADDLE_ENFORCE_EQ(
outputs.size(), 1,
outputs.size(),
1,
platform::errors::InvalidArgument("Can only support one loss key"));
auto losses = outputs.begin()->second;
PADDLE_ENFORCE_EQ(
losses.size(), 1,
losses.size(),
1,
platform::errors::InvalidArgument("Can only support one loss name"));
auto loss_var = losses.front();
new_op.SetAttr("loss_var", loss_var);
Expand All @@ -282,11 +290,13 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
} else if (op_type == "identity_loss") {
auto outputs = op->Outputs();
PADDLE_ENFORCE_EQ(
outputs.size(), 1,
outputs.size(),
1,
platform::errors::InvalidArgument("Can only support one loss key"));
auto losses = outputs.begin()->second;
PADDLE_ENFORCE_EQ(
losses.size(), 1,
losses.size(),
1,
platform::errors::InvalidArgument("Can only support one loss name"));
auto loss_var = losses.front();
new_op.SetAttr("loss_var", loss_var);
Expand Down
Loading

1 comment on commit 108ccec

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 108ccec Jun 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #43770 Commit ID: 108ccec contains failed CI.

🔹 Failed: PR-CI-APPROVAL

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Static-Check

Unknown Failed
Unknown Failed

🔹 Failed: PR-CE-Framework

Unknown Failed
Unknown Failed

Please sign in to comment.