From 16261c221d449007c9cb67c2e0fd564bc3065113 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 <2368719370@qq.com> Date: Thu, 28 Dec 2023 04:02:24 +0000 Subject: [PATCH 1/3] [pir]Supporting constant_folding_pass for train --- .../pir/transforms/constant_folding_pass.cc | 111 +++++++++++++++--- .../pattern_rewrite/pattern_rewrite_test.cc | 31 ++++- 2 files changed, 123 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 553cf3967dd68..9be3a0edf9324 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -126,20 +126,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { pir::PatternRewriter& rewriter) const override { // NOLINT VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name() << "] op"; - pir::Program new_program(rewriter.ir_context()); - auto output_var_names = - BuildProgramFromOperation(op, &new_program, rewriter); - - // execute program - for (auto output_var_name : output_var_names) { - exe_config_->skip_gc_vars.insert(output_var_name); - } - auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(&new_program, place_); - paddle::framework::InterpreterCore core( - place_, {}, kernel_program->block(), scope_, *exe_config_); - - core.Run({}); + auto output_var_names = RunOp(op, rewriter, place_); // ParameterOp and ConstantTensorOp should be created in the top-level block rewriter.SetInsertionPointToStart( @@ -236,6 +223,28 @@ class ConstantFoldingPattern : public pir::RewritePattern { return true; } + protected: + std::vector RunOp( + pir::Operation* op, + pir::PatternRewriter& rewriter, + phi::Place place) const { // NOLINT + pir::Program new_program(rewriter.ir_context()); + auto output_var_names = + BuildProgramFromOperation(op, &new_program, rewriter); + + // execute program + for (auto output_var_name : output_var_names) { + exe_config_->skip_gc_vars.insert(output_var_name); + } + auto kernel_program = + paddle::dialect::PdOpLowerToKernelPass(&new_program, place); + paddle::framework::InterpreterCore core( + place, {}, kernel_program->block(), scope_, *exe_config_); + + core.Run({}); + return output_var_names; + } + std::vector BuildProgramFromOperation( pir::Operation* op, pir::Program* new_program, @@ -299,7 +308,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { return output_var_names; } - private: + protected: size_t* counter_; phi::Place place_; paddle::framework::Scope* scope_; @@ -307,6 +316,68 @@ class ConstantFoldingPattern : public pir::RewritePattern { std::vector* deleted_vars_; }; +class ConstantFoldingTrainingPattern : public ConstantFoldingPattern { + public: + ConstantFoldingTrainingPattern( + pir::IrContext* context, + size_t* counter, + const phi::Place& place, + paddle::framework::Scope* scope, + paddle::framework::interpreter::ExecutionConfig* exe_config, + std::vector* deleted_vars) + : ConstantFoldingPattern( + context, counter, place, scope, exe_config, deleted_vars) {} + + bool Match(pir::Operation* op) const override { + VLOG(4) << "constant_folding_training_pass applys match on [" << op->name() + << "] op"; + if (!ConstantFoldingPattern::Match(op)) { + return false; + } + for (uint32_t i = 0; i < op->num_operands(); i++) { + // inputs must come from or constant op + auto* prev_op = pir::GetDefiningOpForInput(op, i); + if (!prev_op || !prev_op->isa()) { + return false; + } + } + return true; + } + + void Rewrite(pir::Operation* op, + pir::PatternRewriter& rewriter) const override { // NOLINT + VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name() + << "] op"; + + auto output_var_names = RunOp(op, rewriter, phi::CPUPlace{}); + + // ConstantTensorOp should be created in the top-level block + rewriter.SetInsertionPointToStart( + rewriter.block()->parent_program()->block()); + + for (uint32_t i = 0; i < op->num_results(); i++) { + if (!op->result(i) || !op->result(i).type()) { + continue; + } + std::string output_var_name = output_var_names[i]; + PADDLE_ENFORCE_NOT_NULL( + scope_->FindVar(output_var_name), + phi::errors::InvalidArgument("Parameter var [%s] not in scope.", + output_var_name)); + + auto constant_op = rewriter.Build( + rewriter.tensor_name_attr(output_var_name), op->result(i).type()); + constant_op->set_attribute( + kAttrIsPersisable, rewriter.array_attr({rewriter.bool_attr(true)})); + + rewriter.ReplaceAllUsesWith(op->result(i), constant_op->result(0)); + } + rewriter.EraseOp(op); + VLOG(4) << "constant_folding_pass applied rewrite on [" << op->name() + << "] op"; + } +}; + class ConstantFoldingPass : public pir::Pass { public: ConstantFoldingPass() @@ -332,8 +403,14 @@ class ConstantFoldingPass : public pir::Pass { scope_, phi::errors::InvalidArgument("scope can not be nullptr")); pir::RewritePatternSet ps(context); - ps.Add( - context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); + + if (Has("train_mode") && Get("train_mode")) { + ps.Add( + context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); + } else { + ps.Add( + context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); + } patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); return true; } diff --git a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc index 93156a9d697ce..1a87247dab35b 100644 --- a/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc +++ b/test/cpp/pir/pattern_rewrite/pattern_rewrite_test.cc @@ -445,8 +445,10 @@ void BuildConstantFoldingProgram(pir::Program *program, paddle::platform::DeviceContextPool::Instance().Get( paddle::platform::CPUPlace()); - auto op1 = builder.Build("a", dense_tensor_dtype); - auto op2 = builder.Build("b", dense_tensor_dtype); + auto op1 = builder.Build(builder.tensor_name_attr("a"), + dense_tensor_dtype); + auto op2 = builder.Build(builder.tensor_name_attr("b"), + dense_tensor_dtype); auto op3 = builder.Build(op1->result(0), op2->result(0)); @@ -493,6 +495,31 @@ TEST(constant_folding, ConstantFolding) { EXPECT_EQ(program.block()->size(), 2u); } +TEST(constant_folding, ConstantFolding_Train) { + pir::IrContext *ctx = pir::IrContext::Instance(); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + + pir::Program program(ctx); + paddle::framework::Scope scope; + BuildConstantFoldingProgram(&program, ctx, &scope); + + pir::PassManager pm(ctx); + std::unique_ptr constant_folding_pass = + pir::CreateConstantFoldingPass(); + phi::Place place = phi::CPUPlace(); + constant_folding_pass->SetNotOwned(pir::kPlaceAttr, &place); + constant_folding_pass->SetNotOwned(pir::kParamScopeAttr, &scope); + constant_folding_pass->Set("train_mode", new bool(true)); + + pm.AddPass(std::move(constant_folding_pass)); + pm.AddPass(pir::CreateDeadCodeEliminationPass()); + pm.EnableIRPrinting(); + + CHECK_EQ(pm.Run(&program), true); + EXPECT_EQ(program.block()->size(), 4u); +} + void BuildConcatProgram(pir::Program *program, pir::IrContext *ctx) { pir::Builder builder = pir::Builder(ctx, program->block()); auto x = builder From b0cd83bda3b8a14607bacf5215bb6113dd40d782 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 <2368719370@qq.com> Date: Thu, 28 Dec 2023 13:16:55 +0000 Subject: [PATCH 2/3] fix --- .../pir/transforms/constant_folding_pass.cc | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 9be3a0edf9324..33c74224dda4f 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -126,7 +126,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { pir::PatternRewriter& rewriter) const override { // NOLINT VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name() << "] op"; - auto output_var_names = RunOp(op, rewriter, place_); + auto output_var_names = RunOp(op, rewriter); // ParameterOp and ConstantTensorOp should be created in the top-level block rewriter.SetInsertionPointToStart( @@ -226,8 +226,7 @@ class ConstantFoldingPattern : public pir::RewritePattern { protected: std::vector RunOp( pir::Operation* op, - pir::PatternRewriter& rewriter, - phi::Place place) const { // NOLINT + pir::PatternRewriter& rewriter) const { // NOLINT pir::Program new_program(rewriter.ir_context()); auto output_var_names = BuildProgramFromOperation(op, &new_program, rewriter); @@ -237,9 +236,9 @@ class ConstantFoldingPattern : public pir::RewritePattern { exe_config_->skip_gc_vars.insert(output_var_name); } auto kernel_program = - paddle::dialect::PdOpLowerToKernelPass(&new_program, place); + paddle::dialect::PdOpLowerToKernelPass(&new_program, place_); paddle::framework::InterpreterCore core( - place, {}, kernel_program->block(), scope_, *exe_config_); + place_, {}, kernel_program->block(), scope_, *exe_config_); core.Run({}); return output_var_names; @@ -316,9 +315,9 @@ class ConstantFoldingPattern : public pir::RewritePattern { std::vector* deleted_vars_; }; -class ConstantFoldingTrainingPattern : public ConstantFoldingPattern { +class ConstantFoldingPatternForTrain : public ConstantFoldingPattern { public: - ConstantFoldingTrainingPattern( + ConstantFoldingPatternForTrain( pir::IrContext* context, size_t* counter, const phi::Place& place, @@ -329,7 +328,7 @@ class ConstantFoldingTrainingPattern : public ConstantFoldingPattern { context, counter, place, scope, exe_config, deleted_vars) {} bool Match(pir::Operation* op) const override { - VLOG(4) << "constant_folding_training_pass applys match on [" << op->name() + VLOG(4) << "constant_folding_pass applys match on [" << op->name() << "] op"; if (!ConstantFoldingPattern::Match(op)) { return false; @@ -346,10 +345,10 @@ class ConstantFoldingTrainingPattern : public ConstantFoldingPattern { void Rewrite(pir::Operation* op, pir::PatternRewriter& rewriter) const override { // NOLINT - VLOG(4) << "constant_folding_pass applys rewrite on [" << op->name() - << "] op"; + VLOG(4) << "constant_folding_pass for train applys rewrite on [" + << op->name() << "] op"; - auto output_var_names = RunOp(op, rewriter, phi::CPUPlace{}); + auto output_var_names = RunOp(op, rewriter); // ConstantTensorOp should be created in the top-level block rewriter.SetInsertionPointToStart( @@ -373,8 +372,8 @@ class ConstantFoldingTrainingPattern : public ConstantFoldingPattern { rewriter.ReplaceAllUsesWith(op->result(i), constant_op->result(0)); } rewriter.EraseOp(op); - VLOG(4) << "constant_folding_pass applied rewrite on [" << op->name() - << "] op"; + VLOG(4) << "constant_folding_pass for traun applied rewrite on [" + << op->name() << "] op"; } }; @@ -405,8 +404,8 @@ class ConstantFoldingPass : public pir::Pass { pir::RewritePatternSet ps(context); if (Has("train_mode") && Get("train_mode")) { - ps.Add( - context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); + ps.Add( + context, &counter_, phi::CPUPlace{}, scope_, &exe_config_, &deleted_vars_); } else { ps.Add( context, &counter_, place_, scope_, &exe_config_, &deleted_vars_); From 5e51af32bf6945a39873417ba9516ef6c798e439 Mon Sep 17 00:00:00 2001 From: zhangyuqin1998 <75946871+zhangyuqin1998@users.noreply.github.com> Date: Thu, 28 Dec 2023 22:36:04 +0800 Subject: [PATCH 3/3] Update constant_folding_pass.cc --- paddle/fluid/pir/transforms/constant_folding_pass.cc | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/transforms/constant_folding_pass.cc b/paddle/fluid/pir/transforms/constant_folding_pass.cc index 33c74224dda4f..620a7c1c2fecc 100644 --- a/paddle/fluid/pir/transforms/constant_folding_pass.cc +++ b/paddle/fluid/pir/transforms/constant_folding_pass.cc @@ -404,8 +404,12 @@ class ConstantFoldingPass : public pir::Pass { pir::RewritePatternSet ps(context); if (Has("train_mode") && Get("train_mode")) { - ps.Add( - context, &counter_, phi::CPUPlace{}, scope_, &exe_config_, &deleted_vars_); + ps.Add(context, + &counter_, + phi::CPUPlace{}, + scope_, + &exe_config_, + &deleted_vars_); } else { ps.Add( context, &counter_, place_, scope_, &exe_config_, &deleted_vars_);