Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pir]Supporting constant_folding_pass for train #60355

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 97 additions & 17 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
pir::PatternRewriter& rewriter) const override { // NOLINT
VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name()
<< "] op";
pir::Program new_program(rewriter.ir_context());
auto output_var_names =
BuildProgramFromOperation(op, &new_program, rewriter);

// execute program
for (auto output_var_name : output_var_names) {
exe_config_->skip_gc_vars.insert(output_var_name);
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
paddle::framework::InterpreterCore core(
place_, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});
auto output_var_names = RunOp(op, rewriter);

// ParameterOp and ConstantTensorOp should be created in the top-level block
rewriter.SetInsertionPointToStart(
Expand Down Expand Up @@ -236,6 +223,27 @@ class ConstantFoldingPattern : public pir::RewritePattern {
return true;
}

protected:
std::vector<std::string> RunOp(
pir::Operation* op,
pir::PatternRewriter& rewriter) const { // NOLINT
pir::Program new_program(rewriter.ir_context());
auto output_var_names =
BuildProgramFromOperation(op, &new_program, rewriter);

// execute program
for (auto output_var_name : output_var_names) {
exe_config_->skip_gc_vars.insert(output_var_name);
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
paddle::framework::InterpreterCore core(
place_, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});
return output_var_names;
}

std::vector<std::string> BuildProgramFromOperation(
pir::Operation* op,
pir::Program* new_program,
Expand Down Expand Up @@ -299,14 +307,76 @@ class ConstantFoldingPattern : public pir::RewritePattern {
return output_var_names;
}

private:
protected:
size_t* counter_;
phi::Place place_;
paddle::framework::Scope* scope_;
paddle::framework::interpreter::ExecutionConfig* exe_config_;
std::vector<std::string>* deleted_vars_;
};

class ConstantFoldingPatternForTrain : public ConstantFoldingPattern {
public:
ConstantFoldingPatternForTrain(
pir::IrContext* context,
size_t* counter,
const phi::Place& place,
paddle::framework::Scope* scope,
paddle::framework::interpreter::ExecutionConfig* exe_config,
std::vector<std::string>* deleted_vars)
: ConstantFoldingPattern(
context, counter, place, scope, exe_config, deleted_vars) {}

bool Match(pir::Operation* op) const override {
VLOG(4) << "constant_folding_pass applys match on [" << op->name()
<< "] op";
if (!ConstantFoldingPattern::Match(op)) {
return false;
}
for (uint32_t i = 0; i < op->num_operands(); i++) {
// inputs must come from or constant op
auto* prev_op = pir::GetDefiningOpForInput(op, i);
if (!prev_op || !prev_op->isa<pir::ConstantTensorOp>()) {
return false;
}
}
return true;
}

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
VLOG(4) << "constant_folding_pass for train applys rewrite on ["
<< op->name() << "] op";

auto output_var_names = RunOp(op, rewriter);

// ConstantTensorOp should be created in the top-level block
rewriter.SetInsertionPointToStart(
rewriter.block()->parent_program()->block());

for (uint32_t i = 0; i < op->num_results(); i++) {
if (!op->result(i) || !op->result(i).type()) {
continue;
}
std::string output_var_name = output_var_names[i];
PADDLE_ENFORCE_NOT_NULL(
scope_->FindVar(output_var_name),
phi::errors::InvalidArgument("Parameter var [%s] not in scope.",
output_var_name));

auto constant_op = rewriter.Build<pir::ConstantTensorOp>(
rewriter.tensor_name_attr(output_var_name), op->result(i).type());
constant_op->set_attribute(
kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)}));

rewriter.ReplaceAllUsesWith(op->result(i), constant_op->result(0));
}
rewriter.EraseOp(op);
VLOG(4) << "constant_folding_pass for traun applied rewrite on ["
<< op->name() << "] op";
}
};

class ConstantFoldingPass : public pir::Pass {
public:
ConstantFoldingPass()
Expand All @@ -332,8 +402,18 @@ class ConstantFoldingPass : public pir::Pass {
scope_, phi::errors::InvalidArgument("scope can not be nullptr"));

pir::RewritePatternSet ps(context);
ps.Add<ConstantFoldingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);

if (Has("train_mode") && Get<bool>("train_mode")) {
ps.Add<ConstantFoldingPatternForTrain>(context,
&counter_,
phi::CPUPlace{},
scope_,
&exe_config_,
&deleted_vars_);
} else {
ps.Add<ConstantFoldingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);
}
patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}
Expand Down
31 changes: 29 additions & 2 deletions test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,10 @@ void BuildConstantFoldingProgram(pir::Program *program,
paddle::platform::DeviceContextPool::Instance().Get(
paddle::platform::CPUPlace());

auto op1 = builder.Build<pir::ParameterOp>("a", dense_tensor_dtype);
auto op2 = builder.Build<pir::ParameterOp>("b", dense_tensor_dtype);
auto op1 = builder.Build<pir::ConstantTensorOp>(builder.tensor_name_attr("a"),
dense_tensor_dtype);
auto op2 = builder.Build<pir::ConstantTensorOp>(builder.tensor_name_attr("b"),
dense_tensor_dtype);

auto op3 =
builder.Build<paddle::dialect::AddOp>(op1->result(0), op2->result(0));
Expand Down Expand Up @@ -493,6 +495,31 @@ TEST(constant_folding, ConstantFolding) {
EXPECT_EQ(program.block()->size(), 2u);
}

TEST(constant_folding, ConstantFolding_Train) {
pir::IrContext *ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<pir::BuiltinDialect>();

pir::Program program(ctx);
paddle::framework::Scope scope;
BuildConstantFoldingProgram(&program, ctx, &scope);

pir::PassManager pm(ctx);
std::unique_ptr<pir::Pass> constant_folding_pass =
pir::CreateConstantFoldingPass();
phi::Place place = phi::CPUPlace();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place);
constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, &scope);
constant_folding_pass->Set("train_mode", new bool(true));

pm.AddPass(std::move(constant_folding_pass));
pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.EnableIRPrinting();

CHECK_EQ(pm.Run(&program), true);
EXPECT_EQ(program.block()->size(), 4u);
}

void BuildConcatProgram(pir::Program *program, pir::IrContext *ctx) {
pir::Builder builder = pir::Builder(ctx, program->block());
auto x = builder
Expand Down