Skip to content

Commit

Permalink
Merge pull request #2 from yuanlehome/pass
Browse files Browse the repository at this point in the history
add IRPrinter
  • Loading branch information
jiweibo committed Mar 27, 2023
2 parents 50747e6 + 7dd5d03 commit a45a037
Show file tree
Hide file tree
Showing 6 changed files with 261 additions and 11 deletions.
1 change: 1 addition & 0 deletions paddle/infra/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ add_library(
pass_infra
Pass/Pass.cc
Pass/PassRegistry.cc
Pass/IRPrinting.cc
IR/PatternMatch.cc
Rewrite/FrozenRewritePatternSet.cc
Rewrite/PatternApplicator.cc
Expand Down
183 changes: 183 additions & 0 deletions paddle/infra/Pass/IRPrinting.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <unordered_map>

#include "llvm/Support/raw_ostream.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/SymbolTable.h"

#include "Pass/PassInstrumentation.h"
#include "Pass/PassManager.h"
// #include "utils/xxhash.h"

// #include "paddle/phi/core/enforce.h"

namespace infra {
// A unique fingerprint for a specific operation, and all of it's internal
// operations.
// class IRFingerPrint {
// public:
// IRFingerPrint(mlir::Operation *top_op);

// IRFingerPrint(const IRFingerPrint &) = default;
// IRFingerPrint &operator=(const IRFingerPrint &) = default;

// bool operator==(const IRFingerPrint &other) const {
// return hash == other.hash;
// }

// bool operator!=(const IRFingerPrint &other) const {
// return !(*this == other);
// }

// private:
// XXH64_hash_t hash;
// };

// namespace {
// TODO(liuyuanle): XXH64_update has "Segmentation fault" bug need to be solved!
// template <typename T>
// void UpdateHash(XXH64_state_t *state, const T &data) {
// XXH64_update(state, &data, sizeof(T));
// }

// template <typename T>
// void UpdateHash(XXH64_state_t *state, T *data) {
// llvm::outs() << "here 21\n";
// XXH64_update(state, &data, sizeof(T *));
// }
// } // namespace

// IRFingerPrint::IRFingerPrint(mlir::Operation *top_op) {
// XXH64_state_t *const state = XXH64_createState();
// // PADDLE_ENFORCE_NOT_NULL(
// // state,
// // phi::errors::PreconditionNotMet(
// // "xxhash create state failed, maybe a environment error."));

// // PADDLE_ENFORCE_NE(
// // XXH64_reset(state, XXH64_hash_t(0)),
// // XXH_ERROR,
// // phi::errors::PreconditionNotMet(
// // "xxhash reset state failed, maybe a environment error."));

// // Hash each of the operations based upon their mutable bits:
// top_op->walk([&](mlir::Operation *op) {
// // - Operation pointer
// UpdateHash(state, op);
// // - Attributes
// UpdateHash(state, op->getAttrDictionary());
// // - Blocks in Regions
// for (auto &region : op->getRegions()) {
// for (auto &block : region) {
// UpdateHash(state, &block);
// for (auto arg : block.getArguments()) {
// UpdateHash(state, arg);
// }
// }
// }
// // - Location
// UpdateHash(state, op->getLoc().getAsOpaquePointer());
// // - Operands
// for (auto operand : op->getOperands()) {
// UpdateHash(state, operand);
// }
// // - Successors
// for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) {
// UpdateHash(state, op->getSuccessor(i));
// }
// });
// hash = XXH64_digest(state);
// XXH64_freeState(state);
// }

namespace {
void PrintIR(mlir::Operation *op,
bool print_module,
llvm::raw_ostream &out,
mlir::OpPrintingFlags flags) {
// Otherwise, check to see if we are not printing at module scope.
if (print_module) {
op->print(out << "\n", flags);
return;
}

// Otherwise, we are printing at module scope.
out << " ('" << op->getName() << "' operation";
if (auto symbol_name = op->getAttrOfType<mlir::StringAttr>(
mlir::SymbolTable::getSymbolAttrName()))
out << ": @" << symbol_name.getValue();
out << ")\n";

// Find the top-level operation.
auto *top_level_op = op;
while (auto *parent_op = top_level_op->getParentOp()) {
top_level_op = parent_op;
}
top_level_op->print(out, flags);
}
} // namespace

class IRPrinter : public PassInstrumentation {
public:
explicit IRPrinter(std::unique_ptr<PassManager::IRPrinterConfig> config)
: config_(std::move(config)){};

~IRPrinter() = default;

void RunBeforePass(Pass *pass, mlir::Operation *op) override {
if (config_->EnablePrintOnChange()) {
ir_fingerprints_.emplace(pass, op);
}
config_->PrintBeforeIfEnabled(pass, op, [&](llvm::raw_ostream &out) {
out << "// *** IR Dump Before " << pass->GetPassInfo().name << " ***";
PrintIR(
op, config_->EnablePrintModule(), out, config_->GetOpPrintingFlags());
out << "\n\n";
});
}

void RunAfterPass(Pass *pass, mlir::Operation *op) override {
if (config_->EnablePrintOnChange()) {
const auto &fingerprint = ir_fingerprints_.at(pass);
if (fingerprint == mlir::OperationFingerPrint(op)) {
ir_fingerprints_.erase(pass);
return;
}
ir_fingerprints_.erase(pass);
}

config_->PrintBeforeIfEnabled(pass, op, [&](llvm::raw_ostream &out) {
out << "// *** IR Dump After " << pass->GetPassInfo().name << " ***";
PrintIR(
op, config_->EnablePrintModule(), out, config_->GetOpPrintingFlags());
out << "\n\n";
});
}

private:
std::unique_ptr<PassManager::IRPrinterConfig> config_;

// TODO(liuyuanle): replace mlir::OperationFingerPrint with IRFingerPrint.
// Pass -> IR fingerprint before pass.
std::unordered_map<Pass *, mlir::OperationFingerPrint> ir_fingerprints_;
};

void PassManager::EnableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
AddInstrumentation(std::make_unique<IRPrinter>(std::move(config)));
}

} // namespace infra
2 changes: 1 addition & 1 deletion paddle/infra/Pass/Pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ mlir::LogicalResult PassManager::Run(mlir::Operation* op) {
init_key_ = new_init_key;
}

// Construct a analysis manager for the pipeline.
// Construct a analysis manager for the pipeline.
AnalysisManagerHolder am(op, instrumentor_.get());

bool crash_recovery = false;
Expand Down
4 changes: 2 additions & 2 deletions paddle/infra/Pass/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace detail {
class AdaptorPass;

struct PassExecutionState {
explicit PassExecutionState(mlir::Operation* ir, AnalysisManager am)
explicit PassExecutionState(mlir::Operation* ir, const AnalysisManager& am)
: ir(ir), pass_failed(false), am(am) {}

mlir::Operation* ir;
Expand Down Expand Up @@ -71,7 +71,7 @@ class Pass {
const std::vector<std::string>& dependents = {})
: info_(name, opt_level, dependents) {}

PassInfo GetPassInfo() const { return info_; }
const PassInfo& GetPassInfo() const { return info_; }

std::unique_ptr<Pass> Clone() const { return ClonePass(); }

Expand Down
81 changes: 73 additions & 8 deletions paddle/infra/Pass/PassManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ namespace infra {
class PassInstrumentation;
class PassInstrumentor;
class AnalysisManager;
class PassManager;

namespace detail {
class AdaptorPass;
Expand All @@ -51,7 +50,7 @@ class PassManager {
public:
~PassManager();

explicit PassManager(mlir::MLIRContext* context, int opt_level = 2);
explicit PassManager(mlir::MLIRContext *context, int opt_level = 2);

using pass_iterator = llvm::pointee_iterator<
llvm::MutableArrayRef<std::unique_ptr<Pass>>::iterator>;
Expand All @@ -77,28 +76,94 @@ class PassManager {

bool empty() const { return begin() == end(); }

mlir::MLIRContext* GetContext() const { return context_; }
mlir::MLIRContext *GetContext() const { return context_; }

mlir::LogicalResult Run(mlir::Operation* op);
mlir::LogicalResult Run(mlir::Operation *op);

void addPass(std::unique_ptr<Pass> pass) {
passes_.emplace_back(std::move(pass));
}

void EnableTiming();

class IRPrinterConfig {
public:
using PrintCallBack = std::function<void(llvm::raw_ostream &)>;

explicit IRPrinterConfig(
const std::function<bool(Pass *, mlir::Operation *)>
&enable_print_before =
[](Pass *, mlir::Operation *) { return true; },
const std::function<bool(Pass *, mlir::Operation *)> &
enable_print_after = [](Pass *, mlir::Operation *) { return true; },
bool print_module = true,
bool print_on_change = true,
llvm::raw_ostream &out = llvm::outs(),
mlir::OpPrintingFlags op_printing_flags = mlir::OpPrintingFlags())
: enable_print_before_(enable_print_before),
enable_print_after_(enable_print_after),
print_module_(print_module),
print_on_change_(print_on_change),
out_(out),
op_printing_flags_(op_printing_flags) {
assert((enable_print_before_ || enable_print_after_) &&
"expected at least one valid filter function");
}

~IRPrinterConfig() = default;

void PrintBeforeIfEnabled(Pass *pass,
mlir::Operation *op,
const PrintCallBack &print_callback) {
if (enable_print_before_ && enable_print_before_(pass, op)) {
print_callback(out_);
}
}

void PrintAfterIfEnabled(Pass *pass,
mlir::Operation *op,
const PrintCallBack &print_callback) {
if (enable_print_after_ && enable_print_after_(pass, op)) {
print_callback(out_);
}
}

bool EnablePrintModule() const { return print_module_; }

bool EnablePrintOnChange() const { return print_on_change_; }

mlir::OpPrintingFlags GetOpPrintingFlags() const {
return op_printing_flags_;
}

private:
std::function<bool(Pass *, mlir::Operation *)> enable_print_before_;
std::function<bool(Pass *, mlir::Operation *)> enable_print_after_;

bool print_module_;
bool print_on_change_;

// TODO(liuyuanle): Replace it with a local implementation.
// The stream to output to.
llvm::raw_ostream &out_;
// Flags to control printing behavior.
mlir::OpPrintingFlags op_printing_flags_;
};

void EnableIRPrinting(std::unique_ptr<IRPrinterConfig> config);

void AddInstrumentation(std::unique_ptr<PassInstrumentation> pi);

private:
mlir::LogicalResult RunPasses(mlir::Operation* op, AnalysisManager am);
mlir::LogicalResult RunPasses(mlir::Operation *op, AnalysisManager am);

mlir::LogicalResult RunWithCrashRecovery(mlir::Operation* op,
mlir::LogicalResult RunWithCrashRecovery(mlir::Operation *op,
AnalysisManager am);

mlir::LogicalResult Initialize(mlir::MLIRContext* context);
mlir::LogicalResult Initialize(mlir::MLIRContext *context);

private:
mlir::MLIRContext* context_;
mlir::MLIRContext *context_;

std::unique_ptr<PassInstrumentor> instrumentor_;

Expand Down
1 change: 1 addition & 0 deletions paddle/infra/test/demo.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ int main(int argc, char** argv) {

infra::PassManager pm(&context, opt_level);
pm.EnableTiming();
pm.EnableIRPrinting(std::make_unique<infra::PassManager::IRPrinterConfig>());
auto pass = std::make_unique<TestPass>();
pm.addPass(std::move(pass));

Expand Down

0 comments on commit a45a037

Please sign in to comment.