Skip to content

Commit

Permalink
[pir]Supporting constant_folding_pass for train
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyuqin1998 committed Dec 28, 2023
1 parent 23808ae commit 16261c2
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 19 deletions.
111 changes: 94 additions & 17 deletions paddle/fluid/pir/transforms/constant_folding_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,7 @@ class ConstantFoldingPattern : public pir::RewritePattern {
pir::PatternRewriter& rewriter) const override { // NOLINT
VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name()
<< "] op";
pir::Program new_program(rewriter.ir_context());
auto output_var_names =
BuildProgramFromOperation(op, &new_program, rewriter);

// execute program
for (auto output_var_name : output_var_names) {
exe_config_->skip_gc_vars.insert(output_var_name);
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(&new_program, place_);
paddle::framework::InterpreterCore core(
place_, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});
auto output_var_names = RunOp(op, rewriter, place_);

// ParameterOp and ConstantTensorOp should be created in the top-level block
rewriter.SetInsertionPointToStart(
Expand Down Expand Up @@ -236,6 +223,28 @@ class ConstantFoldingPattern : public pir::RewritePattern {
return true;
}

protected:
std::vector<std::string> RunOp(
pir::Operation* op,
pir::PatternRewriter& rewriter,
phi::Place place) const { // NOLINT
pir::Program new_program(rewriter.ir_context());
auto output_var_names =
BuildProgramFromOperation(op, &new_program, rewriter);

// execute program
for (auto output_var_name : output_var_names) {
exe_config_->skip_gc_vars.insert(output_var_name);
}
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(&new_program, place);
paddle::framework::InterpreterCore core(
place, {}, kernel_program->block(), scope_, *exe_config_);

core.Run({});
return output_var_names;
}

std::vector<std::string> BuildProgramFromOperation(
pir::Operation* op,
pir::Program* new_program,
Expand Down Expand Up @@ -299,14 +308,76 @@ class ConstantFoldingPattern : public pir::RewritePattern {
return output_var_names;
}

private:
protected:
size_t* counter_;
phi::Place place_;
paddle::framework::Scope* scope_;
paddle::framework::interpreter::ExecutionConfig* exe_config_;
std::vector<std::string>* deleted_vars_;
};

class ConstantFoldingTrainingPattern : public ConstantFoldingPattern {
public:
ConstantFoldingTrainingPattern(
pir::IrContext* context,
size_t* counter,
const phi::Place& place,
paddle::framework::Scope* scope,
paddle::framework::interpreter::ExecutionConfig* exe_config,
std::vector<std::string>* deleted_vars)
: ConstantFoldingPattern(
context, counter, place, scope, exe_config, deleted_vars) {}

bool Match(pir::Operation* op) const override {
VLOG(4) << "constant_folding_training_pass applys match on [" << op->name()
<< "] op";
if (!ConstantFoldingPattern::Match(op)) {
return false;
}
for (uint32_t i = 0; i < op->num_operands(); i++) {
// inputs must come from or constant op
auto* prev_op = pir::GetDefiningOpForInput(op, i);
if (!prev_op || !prev_op->isa<pir::ConstantTensorOp>()) {
return false;
}
}
return true;
}

void Rewrite(pir::Operation* op,
pir::PatternRewriter& rewriter) const override { // NOLINT
VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name()
<< "] op";

auto output_var_names = RunOp(op, rewriter, phi::CPUPlace{});

// ConstantTensorOp should be created in the top-level block
rewriter.SetInsertionPointToStart(
rewriter.block()->parent_program()->block());

for (uint32_t i = 0; i < op->num_results(); i++) {
if (!op->result(i) || !op->result(i).type()) {
continue;
}
std::string output_var_name = output_var_names[i];
PADDLE_ENFORCE_NOT_NULL(
scope_->FindVar(output_var_name),
phi::errors::InvalidArgument("Parameter var [%s] not in scope.",
output_var_name));

auto constant_op = rewriter.Build<pir::ConstantTensorOp>(
rewriter.tensor_name_attr(output_var_name), op->result(i).type());
constant_op->set_attribute(
kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)}));

rewriter.ReplaceAllUsesWith(op->result(i), constant_op->result(0));
}
rewriter.EraseOp(op);
VLOG(4) << "constant_folding_pass applied rewrite on [" << op->name()
<< "] op";
}
};

class ConstantFoldingPass : public pir::Pass {
public:
ConstantFoldingPass()
Expand All @@ -332,8 +403,14 @@ class ConstantFoldingPass : public pir::Pass {
scope_, phi::errors::InvalidArgument("scope can not be nullptr"));

pir::RewritePatternSet ps(context);
ps.Add<ConstantFoldingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);

if (Has("train_mode") && Get<bool>("train_mode")) {
ps.Add<ConstantFoldingTrainingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);
} else {
ps.Add<ConstantFoldingPattern>(
context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);
}
patterns_ = pir::FrozenRewritePatternSet(std::move(ps));
return true;
}
Expand Down
31 changes: 29 additions & 2 deletions test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,10 @@ void BuildConstantFoldingProgram(pir::Program *program,
paddle::platform::DeviceContextPool::Instance().Get(
paddle::platform::CPUPlace());

auto op1 = builder.Build<pir::ParameterOp>("a", dense_tensor_dtype);
auto op2 = builder.Build<pir::ParameterOp>("b", dense_tensor_dtype);
auto op1 = builder.Build<pir::ConstantTensorOp>(builder.tensor_name_attr("a"),
dense_tensor_dtype);
auto op2 = builder.Build<pir::ConstantTensorOp>(builder.tensor_name_attr("b"),
dense_tensor_dtype);

auto op3 =
builder.Build<paddle::dialect::AddOp>(op1->result(0), op2->result(0));
Expand Down Expand Up @@ -493,6 +495,31 @@ TEST(constant_folding, ConstantFolding) {
EXPECT_EQ(program.block()->size(), 2u);
}

TEST(constant_folding, ConstantFolding_Train) {
pir::IrContext *ctx = pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<pir::BuiltinDialect>();

pir::Program program(ctx);
paddle::framework::Scope scope;
BuildConstantFoldingProgram(&program, ctx, &scope);

pir::PassManager pm(ctx);
std::unique_ptr<pir::Pass> constant_folding_pass =
pir::CreateConstantFoldingPass();
phi::Place place = phi::CPUPlace();
constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place);
constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, &scope);
constant_folding_pass->Set("train_mode", new bool(true));

pm.AddPass(std::move(constant_folding_pass));
pm.AddPass(pir::CreateDeadCodeEliminationPass());
pm.EnableIRPrinting();

CHECK_EQ(pm.Run(&program), true);
EXPECT_EQ(program.block()->size(), 4u);
}

void BuildConcatProgram(pir::Program *program, pir::IrContext *ctx) {
pir::Builder builder = pir::Builder(ctx, program->block());
auto x = builder
Expand Down

0 comments on commit 16261c2

Please sign in to comment.