From 53437b13d95adeee2a4e40330c426d8401354a0c Mon Sep 17 00:00:00 2001 From: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com> Date: Tue, 4 Apr 2023 02:33:19 +0000 Subject: [PATCH] adds the dense analysis --- .../infra/Analysis/DataFlow/DenseAnalysis.cc | 141 ++++++++++++++++++ .../infra/Analysis/DataFlow/DenseAnalysis.h | 99 ++++++++++++ paddle/infra/Analysis/DataFlow/Framework.cc | 7 +- paddle/infra/Analysis/DataFlow/Framework.h | 129 +++++++++++++--- paddle/infra/CMakeLists.txt | 3 +- 5 files changed, 355 insertions(+), 24 deletions(-) create mode 100644 paddle/infra/Analysis/DataFlow/DenseAnalysis.cc create mode 100644 paddle/infra/Analysis/DataFlow/DenseAnalysis.h diff --git a/paddle/infra/Analysis/DataFlow/DenseAnalysis.cc b/paddle/infra/Analysis/DataFlow/DenseAnalysis.cc new file mode 100644 index 0000000000000..b6c9ecabc7ba2 --- /dev/null +++ b/paddle/infra/Analysis/DataFlow/DenseAnalysis.cc @@ -0,0 +1,141 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "Analysis/DataFlow/DenseAnalysis.h" + +namespace infra { +namespace dataflow { + +bool AbstractDenseAnalysis::Initialize(Operation* top) { + VisitOperation(top); + bool ret = true; + for (auto& region : top->getRegions()) { + for (auto& block : region) { + VisitBlock(&block); + for (auto& op : block) { + ret = ret && Initialize(&op); + } + } + } + return ret; +} + +bool AbstractDenseAnalysis::Visit(ProgramPoint point) { + if (auto* op = point.dyn_cast()) { + VisitOperation(op); + } else if (auto* block = point.dyn_cast()) { + VisitBlock(block); + } else { + return false; + } + return true; +} + +void AbstractDenseAnalysis::VisitOperation(Operation* op) { + if (auto branch = ::mlir::dyn_cast<::mlir::RegionBranchOpInterface>(op)) { + VisitRegionBranchOperation(op, branch); + } else if (auto call = ::mlir::dyn_cast<::mlir::CallOpInterface>(op)) { + VisitCallOperation(op, call); + } else { + const AbstractDenseLattice* before; + if (auto* prev = op->getPrevNode()) { + before = GetLatticeFor(op, prev); + } else if (auto* prev = op->getBlock()) { + before = GetLatticeFor(op, prev); + } + VisitOperationImpl(op, *before, GetLattice(op)); + } +} + +void AbstractDenseAnalysis::VisitBlock(Block* block) { + if (block->isEntryBlock()) { + if (auto callable = + ::mlir::dyn_cast(block->getParentOp())) { + VisitCallableOperation(block, callable); + } else if (auto branch = ::mlir::dyn_cast( + block->getParentOp())) { + VisitRegionBranchOperation(block, branch); + } else { + SetToEntryState(GetLattice(block)); + } + } else { + for (auto it = block->pred_begin(); it != block->pred_end(); ++it) { + Block* pred = *it; + Operation* terminator = pred->getTerminator(); + Join(GetLattice(block), *GetLatticeFor(block, terminator)); + } + } +} + +void AbstractDenseAnalysis::VisitRegionBranchOperation( + ProgramPoint point, RegionBranchOpInterface branch) { + auto* after = GetLattice(point); + const auto* predecessors = GetOrCreateFor(point, point); + assert(predecessors->allPredecessorsKnown()); + for (Operation* op : predecessors->getKnownPredecessors()) { + const AbstractDenseLattice* before; + if (op == branch) { + if (auto* prev = op->getPrevNode()) { + before = GetLatticeFor(op, prev); + } else if (auto* prev = op->getBlock()) { + before = GetLatticeFor(op, prev); + } + } else { + before = GetLatticeFor(point, op); + } + Join(after, *before); + } +} + +void AbstractDenseAnalysis::VisitCallOperation(ProgramPoint op, + CallOpInterface call) { + auto* after = GetLattice(op); + const auto* predecessors = GetOrCreateFor(op, call); + if (!predecessors->allPredecessorsKnown()) { + SetToEntryState(after); + return; + } + for (auto* predecessor : predecessors->getKnownPredecessors()) { + Join(after, *GetLatticeFor(op, predecessor)); + } +} + +void AbstractDenseAnalysis::VisitCallableOperation( + ProgramPoint block, CallableOpInterface callable) { + auto* after = GetLattice(block); + assert(callable.getCallableRegion() == block.get()->getParent()); + const auto* callsites = GetOrCreateFor(block, callable); + if (!callsites->allPredecessorsKnown()) { + return SetToEntryState(after); + } + for (Operation* op : callsites->getKnownPredecessors()) { + const AbstractDenseLattice* before; + if (auto* prev = op->getPrevNode()) { + before = GetLatticeFor(op, prev); + } else if (auto* prev = op->getBlock()) { + before = GetLatticeFor(op, prev); + } + Join(after, *before); + } +} + +const AbstractDenseLattice* AbstractDenseAnalysis::GetLatticeFor( + ProgramPoint dependent, ProgramPoint point) { + AbstractDenseLattice* state = GetLattice(point); + AddDependency(state, dependent); + return state; +} + +} // namespace dataflow +} // namespace infra diff --git a/paddle/infra/Analysis/DataFlow/DenseAnalysis.h b/paddle/infra/Analysis/DataFlow/DenseAnalysis.h new file mode 100644 index 0000000000000..3787e5e2eb6b6 --- /dev/null +++ b/paddle/infra/Analysis/DataFlow/DenseAnalysis.h @@ -0,0 +1,99 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "Analysis/DataFlow/Framework.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" + +namespace infra { + +class RegionBranchOpInterface; + +namespace dataflow { + +// A dense lattice is attached to operations to represent the program +// state after execution, or to blocks to represent the program state +// at the beginning of the block. It is propagated through the analysis. +class AbstractDenseLattice : public AnalysisState { + public: + using AnalysisState::AnalysisState; + + virtual ChangeStatus Join(const AbstractDenseLattice& rhs) = 0; +}; + +// Implements a transfer function from the lattice between operations. +class AbstractDenseAnalysis : public DataFlowAnalysis { + public: + using DataFlowAnalysis::DataFlowAnalysis; + using Operation = ::mlir::Operation; + using Block = ::mlir::Block; + using RegionBranchOpInterface = ::mlir::RegionBranchOpInterface; + using CallOpInterface = ::mlir::CallOpInterface; + using CallableOpInterface = ::mlir::CallableOpInterface; + + // Traversals every operation and block and initialize them. + bool Initialize(Operation* top) override; + + // Visit a program point and modifiy the state of the program. + bool Visit(ProgramPoint point) override; + + protected: + virtual void VisitOperationImpl(Operation* op, + const AbstractDenseLattice& before, + AbstractDenseLattice* after) = 0; + + virtual AbstractDenseLattice* GetLattice(ProgramPoint point) = 0; + + virtual void SetToEntryState(AbstractDenseLattice* lattice) = 0; + + const AbstractDenseLattice* GetLatticeFor(ProgramPoint dependent, + ProgramPoint point); + + void Join(AbstractDenseLattice* lhs, const AbstractDenseLattice& rhs) { + PropagateIfChanged(lhs, lhs->Join(rhs)); + } + + protected: + // If the operation is a call or region, the state is set by control-flow. + // Otherwise it calls the transfer function. + virtual void VisitOperation(Operation* op); + + void VisitRegionBranchOperation(ProgramPoint point, + RegionBranchOpInterface branch); + + void VisitCallOperation(ProgramPoint point, CallOpInterface call); + + void VisitCallableOperation(ProgramPoint point, CallableOpInterface callable); + + void VisitBlock(Block* block); +}; + +template +class DenseAnalysis : public AbstractDenseAnalysis { + static_assert( + std::is_base_of::value, + "The class `LatticeT` must derive from `AbstractDenseLattice`."); + + public: + using AbstractDenseAnalysis::AbstractDenseAnalysis; + + virtual void VisitOperation(Operation* op, + const LatticeT& before, + LatticeT* after) = 0; +}; + +} // namespace dataflow +} // namespace infra diff --git a/paddle/infra/Analysis/DataFlow/Framework.cc b/paddle/infra/Analysis/DataFlow/Framework.cc index 6bf8e373028c9..778013ce2db45 100644 --- a/paddle/infra/Analysis/DataFlow/Framework.cc +++ b/paddle/infra/Analysis/DataFlow/Framework.cc @@ -32,13 +32,12 @@ void DataFlowSolver::InitializeAndRun(Operation* top) { DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver& solver) : solver_{solver} {} -void DataFlowAnalysis::AddDependency(AnalysisState* state, - DataFlowAnalysis* analysis, - ProgramPoint point) { +void DataFlowAnalysis::AddDependency(AnalysisState* state, ProgramPoint point) { solver_.AddDependency(state, this, point); } -void DataFlowAnalysis::PropagateIfChanged(AnalysisState* state, bool changed) { +void DataFlowAnalysis::PropagateIfChanged(AnalysisState* state, + ChangeStatus changed) { solver_.PropagateIfChanged(state, changed); } diff --git a/paddle/infra/Analysis/DataFlow/Framework.h b/paddle/infra/Analysis/DataFlow/Framework.h index feea9841f08cc..ad31dd10666c0 100644 --- a/paddle/infra/Analysis/DataFlow/Framework.h +++ b/paddle/infra/Analysis/DataFlow/Framework.h @@ -14,35 +14,41 @@ #pragma once +#include #include #include "llvm/ADT/PointerUnion.h" +#include "llvm/ADT/SetVector.h" #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" -#include "mlir/Support/LogicalResult.h" +#include "mlir/Support/StorageUniquer.h" +#include "mlir/Support/TypeID.h" namespace infra { class DataFlowAnalysis; +class AnalysisState; -// The base class of analysis state, which contains the information -// in the analysis process. -class AnalysisState { - public: - using ProgramPoint = ::mlir::ProgramPoint; - virtual ~AnalysisState() = default; +enum class ChangeStatus : int8_t { + NoChange, + Change, +}; - ProgramPoint GetPoint() const { return point_; } +ChangeStatus operator|(ChangeStatus lhs, ChangeStatus rhs) { + return lhs == ChangeStatus::Change ? lhs : rhs; +} - private: - ProgramPoint point_; -}; +ChangeStatus operator&(ChangeStatus lhs, ChangeStatus rhs) { + return lhs == ChangeStatus::NoChange ? lhs : rhs; +} // Launch the data flow analyses, running the algotithm. class DataFlowSolver { public: using Operation = ::mlir::Operation; using ProgramPoint = ::mlir::ProgramPoint; + using WorkItem = std::pair; + using TypeID = ::mlir::TypeID; template AnalysisT* Load(Args&&... args); @@ -56,11 +62,24 @@ class DataFlowSolver { DataFlowAnalysis* analysis, ProgramPoint point); - void PropagateIfChanged(AnalysisState* state, bool changed); + void PropagateIfChanged(AnalysisState* state, ChangeStatus changed); + + template + StateT* GetOrCreateState(PointT point); + + template + PointT* GetOrCreatePoint(Args&&... args) { + return PointT::get(uniquer_, std::forward(args)...); + } private: - std::queue> worklist_; + std::queue worklist_; std::vector> analyses_; + ::llvm::DenseMap, + std::unique_ptr> + analysis_states_; + + ::mlir::StorageUniquer uniquer_; friend class DataFlowAnalysis; }; @@ -71,6 +90,58 @@ AnalysisT* DataFlowSolver::Load(Args&&... args) { return static_cast(analyses_.back().get()); } +template +const StateT* DataFlowSolver::LookupState(PointT point) const { + auto it = std::find(analyses_, {ProgramPoint(point), TypeID::get()}); + return it == analyses_.end() ? nullptr + : static_cast(it->second.get()); +} + +template +StateT* DataFlowSolver::GetOrCreateState(PointT point) { + auto& state = analysis_states_[{ProgramPoint(point), TypeID::get()}]; + if (!state) { + state = std::make_unique(point); + } + return static_cast(state.get()); +} + +// The base class of analysis state, which contains the information +// in the analysis process. +class AnalysisState { + public: + using ProgramPoint = ::mlir::ProgramPoint; + virtual ~AnalysisState() = default; + + explicit AnalysisState(ProgramPoint point) : point_(point) {} + + ProgramPoint GetPoint() const { return point_; } + + protected: + virtual void PropagateUpdate(DataFlowSolver* solver) const {} + + ::llvm::SetVector deps_; + + ProgramPoint point_; + + friend class DataFlowSolver; +}; + +class PredecessorState : public AnalysisState { + public: + using AnalysisState::AnalysisState; + + bool allPredecessorsKnown() const { return all_known_; } + + ::llvm::ArrayRef<::mlir::Operation*> getKnownPredecessors() const { + return known_predecessors_; + } + + private: + bool all_known_{true}; + ::llvm::ArrayRef<::mlir::Operation*> known_predecessors_; +}; + // Base class of all data flow analyses. class DataFlowAnalysis { public: @@ -78,15 +149,35 @@ class DataFlowAnalysis { using ProgramPoint = ::mlir::ProgramPoint; explicit DataFlowAnalysis(DataFlowSolver& solver); // NOLINT - virtual void Initialize(Operation* top) = 0; - virtual void Visit(ProgramPoint point) = 0; + virtual bool Initialize(Operation* top) = 0; + virtual bool Visit(ProgramPoint point) = 0; protected: - void AddDependency(AnalysisState* state, - DataFlowAnalysis* analysis, - ProgramPoint point); + void AddDependency(AnalysisState* state, ProgramPoint point); - void PropagateIfChanged(AnalysisState* state, bool changed); + void PropagateIfChanged(AnalysisState* state, ChangeStatus changed); + + template + void RegisterPointKind() { + solver_.uniquer_.registerParametricStorageType(); + } + + template + PointT* GetOrCreatePoint(Args&&... args) { + return solver_.GetOrCreatePoint(std::forward(args)...); + } + + template + StateT* GetOrCreate(PointT point) { + return solver_.GetOrCreateState(point); + } + + template + const StateT* GetOrCreateFor(ProgramPoint dependent, PointT point) { + auto* state = GetOrCreate(point); + AddDependency(state, dependent); + return state; + } private: DataFlowSolver& solver_; diff --git a/paddle/infra/CMakeLists.txt b/paddle/infra/CMakeLists.txt index 61eff21045169..4b1b6c339edeb 100644 --- a/paddle/infra/CMakeLists.txt +++ b/paddle/infra/CMakeLists.txt @@ -59,7 +59,8 @@ add_library( Rewrite/PatternApplicator.cc Transforms/GreedyPatternRewriteDriver.cc Pass/PassTiming.cc - Analysis/DataFlow/Framework.cc) + Analysis/DataFlow/Framework.cc + Analysis/DataFlow/DenseAnalysis.cc) # python and pybind11 add_subdirectory(${CMAKE_SOURCE_DIR}/third_party/pybind11)