From 7dd5d036ed675a44ed0ddce5ea5f62a8394564f6 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 27 Mar 2023 09:11:23 +0000 Subject: [PATCH] add IRPrinter --- paddle/infra/CMakeLists.txt | 1 + paddle/infra/Pass/IRPrinting.cc | 183 ++++++++++++++++++++++++++++++++ paddle/infra/Pass/Pass.cc | 2 +- paddle/infra/Pass/Pass.h | 4 +- paddle/infra/Pass/PassManager.h | 81 ++++++++++++-- paddle/infra/test/demo.cc | 1 + 6 files changed, 261 insertions(+), 11 deletions(-) create mode 100644 paddle/infra/Pass/IRPrinting.cc diff --git a/paddle/infra/CMakeLists.txt b/paddle/infra/CMakeLists.txt index 8d907059c30c1..61eff21045169 100644 --- a/paddle/infra/CMakeLists.txt +++ b/paddle/infra/CMakeLists.txt @@ -53,6 +53,7 @@ add_library( pass_infra Pass/Pass.cc Pass/PassRegistry.cc + Pass/IRPrinting.cc IR/PatternMatch.cc Rewrite/FrozenRewritePatternSet.cc Rewrite/PatternApplicator.cc diff --git a/paddle/infra/Pass/IRPrinting.cc b/paddle/infra/Pass/IRPrinting.cc new file mode 100644 index 0000000000000..5497a3109e2e1 --- /dev/null +++ b/paddle/infra/Pass/IRPrinting.cc @@ -0,0 +1,183 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "llvm/Support/raw_ostream.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" + +#include "Pass/PassInstrumentation.h" +#include "Pass/PassManager.h" +// #include "utils/xxhash.h" + +// #include "paddle/phi/core/enforce.h" + +namespace infra { +// A unique fingerprint for a specific operation, and all of it's internal +// operations. +// class IRFingerPrint { +// public: +// IRFingerPrint(mlir::Operation *top_op); + +// IRFingerPrint(const IRFingerPrint &) = default; +// IRFingerPrint &operator=(const IRFingerPrint &) = default; + +// bool operator==(const IRFingerPrint &other) const { +// return hash == other.hash; +// } + +// bool operator!=(const IRFingerPrint &other) const { +// return !(*this == other); +// } + +// private: +// XXH64_hash_t hash; +// }; + +// namespace { +// TODO(liuyuanle): XXH64_update has "Segmentation fault" bug need to be solved! +// template +// void UpdateHash(XXH64_state_t *state, const T &data) { +// XXH64_update(state, &data, sizeof(T)); +// } + +// template +// void UpdateHash(XXH64_state_t *state, T *data) { +// llvm::outs() << "here 21\n"; +// XXH64_update(state, &data, sizeof(T *)); +// } +// } // namespace + +// IRFingerPrint::IRFingerPrint(mlir::Operation *top_op) { +// XXH64_state_t *const state = XXH64_createState(); +// // PADDLE_ENFORCE_NOT_NULL( +// // state, +// // phi::errors::PreconditionNotMet( +// // "xxhash create state failed, maybe a environment error.")); + +// // PADDLE_ENFORCE_NE( +// // XXH64_reset(state, XXH64_hash_t(0)), +// // XXH_ERROR, +// // phi::errors::PreconditionNotMet( +// // "xxhash reset state failed, maybe a environment error.")); + +// // Hash each of the operations based upon their mutable bits: +// top_op->walk([&](mlir::Operation *op) { +// // - Operation pointer +// UpdateHash(state, op); +// // - Attributes +// UpdateHash(state, op->getAttrDictionary()); +// // - Blocks in Regions +// for (auto ®ion : op->getRegions()) { +// for (auto &block : region) { +// UpdateHash(state, &block); +// for (auto arg : block.getArguments()) { +// UpdateHash(state, arg); +// } +// } +// } +// // - Location +// UpdateHash(state, op->getLoc().getAsOpaquePointer()); +// // - Operands +// for (auto operand : op->getOperands()) { +// UpdateHash(state, operand); +// } +// // - Successors +// for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) { +// UpdateHash(state, op->getSuccessor(i)); +// } +// }); +// hash = XXH64_digest(state); +// XXH64_freeState(state); +// } + +namespace { +void PrintIR(mlir::Operation *op, + bool print_module, + llvm::raw_ostream &out, + mlir::OpPrintingFlags flags) { + // Otherwise, check to see if we are not printing at module scope. + if (print_module) { + op->print(out << "\n", flags); + return; + } + + // Otherwise, we are printing at module scope. + out << " ('" << op->getName() << "' operation"; + if (auto symbol_name = op->getAttrOfType( + mlir::SymbolTable::getSymbolAttrName())) + out << ": @" << symbol_name.getValue(); + out << ")\n"; + + // Find the top-level operation. + auto *top_level_op = op; + while (auto *parent_op = top_level_op->getParentOp()) { + top_level_op = parent_op; + } + top_level_op->print(out, flags); +} +} // namespace + +class IRPrinter : public PassInstrumentation { + public: + explicit IRPrinter(std::unique_ptr config) + : config_(std::move(config)){}; + + ~IRPrinter() = default; + + void RunBeforePass(Pass *pass, mlir::Operation *op) override { + if (config_->EnablePrintOnChange()) { + ir_fingerprints_.emplace(pass, op); + } + config_->PrintBeforeIfEnabled(pass, op, [&](llvm::raw_ostream &out) { + out << "// *** IR Dump Before " << pass->GetPassInfo().name << " ***"; + PrintIR( + op, config_->EnablePrintModule(), out, config_->GetOpPrintingFlags()); + out << "\n\n"; + }); + } + + void RunAfterPass(Pass *pass, mlir::Operation *op) override { + if (config_->EnablePrintOnChange()) { + const auto &fingerprint = ir_fingerprints_.at(pass); + if (fingerprint == mlir::OperationFingerPrint(op)) { + ir_fingerprints_.erase(pass); + return; + } + ir_fingerprints_.erase(pass); + } + + config_->PrintBeforeIfEnabled(pass, op, [&](llvm::raw_ostream &out) { + out << "// *** IR Dump After " << pass->GetPassInfo().name << " ***"; + PrintIR( + op, config_->EnablePrintModule(), out, config_->GetOpPrintingFlags()); + out << "\n\n"; + }); + } + + private: + std::unique_ptr config_; + + // TODO(liuyuanle): replace mlir::OperationFingerPrint with IRFingerPrint. + // Pass -> IR fingerprint before pass. + std::unordered_map ir_fingerprints_; +}; + +void PassManager::EnableIRPrinting(std::unique_ptr config) { + AddInstrumentation(std::make_unique(std::move(config))); +} + +} // namespace infra diff --git a/paddle/infra/Pass/Pass.cc b/paddle/infra/Pass/Pass.cc index 60220dabcf875..719480ce1ebcf 100644 --- a/paddle/infra/Pass/Pass.cc +++ b/paddle/infra/Pass/Pass.cc @@ -139,7 +139,7 @@ mlir::LogicalResult PassManager::Run(mlir::Operation* op) { init_key_ = new_init_key; } - // Construct a analysis manager for the pipeline. + // Construct a analysis manager for the pipeline. AnalysisManagerHolder am(op, instrumentor_.get()); bool crash_recovery = false; diff --git a/paddle/infra/Pass/Pass.h b/paddle/infra/Pass/Pass.h index a3ff95f49675e..e52d48ba3cf7a 100644 --- a/paddle/infra/Pass/Pass.h +++ b/paddle/infra/Pass/Pass.h @@ -32,7 +32,7 @@ namespace detail { class AdaptorPass; struct PassExecutionState { - explicit PassExecutionState(mlir::Operation* ir, AnalysisManager am) + explicit PassExecutionState(mlir::Operation* ir, const AnalysisManager& am) : ir(ir), pass_failed(false), am(am) {} mlir::Operation* ir; @@ -71,7 +71,7 @@ class Pass { const std::vector& dependents = {}) : info_(name, opt_level, dependents) {} - PassInfo GetPassInfo() const { return info_; } + const PassInfo& GetPassInfo() const { return info_; } std::unique_ptr Clone() const { return ClonePass(); } diff --git a/paddle/infra/Pass/PassManager.h b/paddle/infra/Pass/PassManager.h index 0b2114ea3acac..1c3b22fde7971 100644 --- a/paddle/infra/Pass/PassManager.h +++ b/paddle/infra/Pass/PassManager.h @@ -37,7 +37,6 @@ namespace infra { class PassInstrumentation; class PassInstrumentor; class AnalysisManager; -class PassManager; namespace detail { class AdaptorPass; @@ -51,7 +50,7 @@ class PassManager { public: ~PassManager(); - explicit PassManager(mlir::MLIRContext* context, int opt_level = 2); + explicit PassManager(mlir::MLIRContext *context, int opt_level = 2); using pass_iterator = llvm::pointee_iterator< llvm::MutableArrayRef>::iterator>; @@ -77,9 +76,9 @@ class PassManager { bool empty() const { return begin() == end(); } - mlir::MLIRContext* GetContext() const { return context_; } + mlir::MLIRContext *GetContext() const { return context_; } - mlir::LogicalResult Run(mlir::Operation* op); + mlir::LogicalResult Run(mlir::Operation *op); void addPass(std::unique_ptr pass) { passes_.emplace_back(std::move(pass)); @@ -87,18 +86,84 @@ class PassManager { void EnableTiming(); + class IRPrinterConfig { + public: + using PrintCallBack = std::function; + + explicit IRPrinterConfig( + const std::function + &enable_print_before = + [](Pass *, mlir::Operation *) { return true; }, + const std::function & + enable_print_after = [](Pass *, mlir::Operation *) { return true; }, + bool print_module = true, + bool print_on_change = true, + llvm::raw_ostream &out = llvm::outs(), + mlir::OpPrintingFlags op_printing_flags = mlir::OpPrintingFlags()) + : enable_print_before_(enable_print_before), + enable_print_after_(enable_print_after), + print_module_(print_module), + print_on_change_(print_on_change), + out_(out), + op_printing_flags_(op_printing_flags) { + assert((enable_print_before_ || enable_print_after_) && + "expected at least one valid filter function"); + } + + ~IRPrinterConfig() = default; + + void PrintBeforeIfEnabled(Pass *pass, + mlir::Operation *op, + const PrintCallBack &print_callback) { + if (enable_print_before_ && enable_print_before_(pass, op)) { + print_callback(out_); + } + } + + void PrintAfterIfEnabled(Pass *pass, + mlir::Operation *op, + const PrintCallBack &print_callback) { + if (enable_print_after_ && enable_print_after_(pass, op)) { + print_callback(out_); + } + } + + bool EnablePrintModule() const { return print_module_; } + + bool EnablePrintOnChange() const { return print_on_change_; } + + mlir::OpPrintingFlags GetOpPrintingFlags() const { + return op_printing_flags_; + } + + private: + std::function enable_print_before_; + std::function enable_print_after_; + + bool print_module_; + bool print_on_change_; + + // TODO(liuyuanle): Replace it with a local implementation. + // The stream to output to. + llvm::raw_ostream &out_; + // Flags to control printing behavior. + mlir::OpPrintingFlags op_printing_flags_; + }; + + void EnableIRPrinting(std::unique_ptr config); + void AddInstrumentation(std::unique_ptr pi); private: - mlir::LogicalResult RunPasses(mlir::Operation* op, AnalysisManager am); + mlir::LogicalResult RunPasses(mlir::Operation *op, AnalysisManager am); - mlir::LogicalResult RunWithCrashRecovery(mlir::Operation* op, + mlir::LogicalResult RunWithCrashRecovery(mlir::Operation *op, AnalysisManager am); - mlir::LogicalResult Initialize(mlir::MLIRContext* context); + mlir::LogicalResult Initialize(mlir::MLIRContext *context); private: - mlir::MLIRContext* context_; + mlir::MLIRContext *context_; std::unique_ptr instrumentor_; diff --git a/paddle/infra/test/demo.cc b/paddle/infra/test/demo.cc index 5ba2f9cb2951c..a027ba642639d 100644 --- a/paddle/infra/test/demo.cc +++ b/paddle/infra/test/demo.cc @@ -194,6 +194,7 @@ int main(int argc, char** argv) { infra::PassManager pm(&context, opt_level); pm.EnableTiming(); + pm.EnableIRPrinting(std::make_unique()); auto pass = std::make_unique(); pm.addPass(std::move(pass));